Fix MatrixSlice scalar mode bug.

This commit is contained in:
Krad
2023-04-20 11:21:11 +08:00
parent 044de7be5f
commit 286008c4ab
3 changed files with 66 additions and 23 deletions

View File

@@ -335,8 +335,8 @@ namespace Aurora {
double *startPointer = getData() + (rowStride * rowOffset
+ colStride * colOffset
+ sliceStride * sliceOffset) * getValueType();
int size1 = getDimSize(allDimIndex[0]);
int stride1 = strides[allDimIndex[0]];
int size1 = allDimIndex.empty()?1:getDimSize(allDimIndex[0]);
int stride1 = allDimIndex.empty()?1:strides[allDimIndex[0]];
switch (mode) {
//matrix mode
case 2:{
@@ -479,6 +479,49 @@ namespace Aurora {
return *this;
}
Matrix::MatrixSlice &Matrix::MatrixSlice::operator=(double value) {
if(!mData){
std::cerr <<"Assign value fail!Des data pointer is null!";
return *this;
}
if (mSliceMode!=0) {
std::cerr <<"Assign value fail!Des slicemode is"<<mSliceMode<<", not scalar mode!";
return *this;
}
if (mSize!=1) {
std::cerr <<"Assign value fail!Des size:"<<mSize<<", not scalar mode!";
return *this;
}
if (mType!=Normal) {
std::cerr <<"Assign value fail!Des type is complex!";
return *this;
}
mData[0]=value;
return *this;
}
Matrix::MatrixSlice &Matrix::MatrixSlice::operator=(std::complex<double> value) {
if(!mData){
std::cerr <<"Assign value fail!Des data pointer is null!";
return *this;
}
if (mSliceMode!=0) {
std::cerr <<"Assign value fail!Des slicemode is"<<mSliceMode<<", not scalar mode!";
return *this;
}
if (mSize!=1) {
std::cerr <<"Assign value fail!Des size:"<<mSize<<", not scalar mode!";
return *this;
}
if (mType!=Complex) {
std::cerr <<"Assign value fail!Des type is not complex!";
return *this;
}
mData[0]=value.real();
mData[1]=value.imag();
return *this;
}
Matrix Matrix::MatrixSlice::toMatrix() const {
double * data = (double *) mkl_malloc(mSize*(mSize2>0?mSize2:1) * sizeof(double)*mType, 64);

View File

@@ -2,6 +2,7 @@
#define MATRIX_H
#include <memory>
#include <complex>
#include <vector>
@@ -22,6 +23,8 @@ namespace Aurora {
MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0);
MatrixSlice& operator=(const MatrixSlice& slice);
MatrixSlice& operator=(const Matrix& matrix);
MatrixSlice& operator=(double value);
MatrixSlice& operator=(std::complex<double> value);
Matrix toMatrix() const;
private:
int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix

View File

@@ -59,13 +59,10 @@ int main() {
Aurora::Matrix C = Aurora::Matrix::New(dataC, 2, 2);
printf("C:\r\n");
C.printf();
A(Aurora::$,Aurora::$,0) = C;
A(1, 1, 1) = 1024.0;
printf("New A:\r\n");
A.printf();
return 0;
}
{
}{
double * dataA =new double[9];
double * dataB =new double[9];
for (int i = 0; i < 9; ++i) {