diff --git a/src/Matrix.cpp b/src/Matrix.cpp index 9642a75..a2fe4d4 100644 --- a/src/Matrix.cpp +++ b/src/Matrix.cpp @@ -158,27 +158,28 @@ namespace Aurora{ } inline Matrix operatorMxM_RR(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Aurora::Matrix &aMatrix, - Aurora::Matrix &&aOther) { + Aurora::Matrix &&aOther,bool oppside = false) { std::cout << "use right ref operation m" << std::endl; if (aMatrix.compareShape(aOther)) { int DimsStride = 1; + double* X = oppside?aOther.getData():aMatrix.getData(); + double* Y = oppside?aMatrix.getData():aOther.getData(); if (aMatrix.getValueType() != aOther.getValueType()) { //aOther is not a complex matrix if (aMatrix.getValueType() == Complex) { - double *output = _MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), - DimsStride); + double *output = _MxM_NC_Calc(aFuncD,aMatrix.getDataSize(), X, Y,DimsStride); return Matrix::New(output, aOther); } //aOther is a complex matrix, use aOther as output - V_MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(), + V_MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), X, Y, aOther.getData(), DimsStride); return aOther; } else if (aMatrix.getValueType() == Normal) { - V_MxM_NN_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(), + V_MxM_NN_Calc(aFuncD, aMatrix.getDataSize(), X, Y, aOther.getData(), DimsStride); return aOther; } else { - V_MxM_CC_Calc(aFuncZ, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(), + V_MxM_CC_Calc(aFuncZ, aMatrix.getDataSize(), X, Y, aOther.getData(), DimsStride); return aOther; } @@ -289,6 +290,7 @@ namespace Aurora { if (slices > 0)vector.push_back(slices); Matrix ret({data, free}, vector); if (type != Normal)ret.setValueType(type); + ret.printf(); return ret; } @@ -349,7 +351,7 @@ namespace Aurora { return operatorMxM_RR(&vdAddI,&vzAddI,*this,std::forward(aMatrix)); } Matrix operator+(Matrix &&aMatrix, const Matrix &aOther) { - return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward(aMatrix)); + return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward(aMatrix),true); } //operation - @@ -366,7 +368,7 @@ namespace Aurora { return operatorMxM_RR(&vdSubI,&vzSubI,*this,std::forward(aMatrix)); } Matrix operator-(Matrix &&aMatrix, const Matrix &aOther) { - return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward(aMatrix)); + return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward(aMatrix),true); } //operation * @@ -383,7 +385,7 @@ namespace Aurora { return operatorMxM_RR(&vdMulI,&vzMulI,*this,std::forward(aMatrix)); } Matrix operator*(Matrix &&aMatrix, const Matrix &aOther) { - return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward(aMatrix)); + return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward(aMatrix),true); } //operation / @@ -400,7 +402,7 @@ namespace Aurora { return operatorMxM_RR(&vdDivI,&vzDivI,*this,std::forward(aMatrix)); } Matrix operator/(Matrix &&aMatrix, const Matrix &aOther) { - return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward(aMatrix)); + return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward(aMatrix),true); } //operator ^ (pow) @@ -449,7 +451,7 @@ namespace Aurora { std::vector allDimIndex; int mode = 0; for (int j = 0; j < 3; ++j) { - if (dims[j]==$ && this->getDims()>j){ + if (dims[j]==$ && getDims()>j){ ++mode; allDimIndex.push_back(j); } @@ -481,7 +483,7 @@ namespace Aurora { //scalar mode or ALL $ case 0: default: { - return Matrix::MatrixSlice(1 , 1, startPointer,getValueType(), mode); + return Matrix::MatrixSlice(1 , 1, startPointer,getValueType(), 0); } } } diff --git a/test/FunctionTester.cpp b/test/FunctionTester.cpp index 949e15f..faf10f9 100644 --- a/test/FunctionTester.cpp +++ b/test/FunctionTester.cpp @@ -133,9 +133,11 @@ TEST_F(FunctionTester, MatrixCreate) { TEST_F(FunctionTester, matrixSlice) { double * dataA =Aurora::malloc(8); double * dataB =Aurora::malloc(8); + double * dataC =Aurora::malloc(8); for (int i = 0; i < 8; ++i) { - dataA[i]=(double)(i-3); - dataB[i]=(double)(i+2); + dataA[i]=(double)(i); + dataB[i]=(double)(1); + dataC[i]=(double)(9-i); } Aurora::Matrix A = Aurora::Matrix::New(dataA,2,2,2); printf("A:\r\n"); @@ -143,32 +145,53 @@ TEST_F(FunctionTester, matrixSlice) { Aurora::Matrix B = Aurora::Matrix::New(dataB,2,2,2); printf("B:\r\n"); B.printf(); - A(Aurora::$,Aurora::$,1) = B(Aurora::$,Aurora::$,0); - printf("New A:\r\n"); - A.printf(); - printf("New B:\r\n"); - B.printf(); + Aurora::Matrix C = Aurora::Matrix::New(dataC,2,2,2); + printf("C:\r\n"); + C.printf(); + //2D slice + EXPECT_EQ(4.0,dataA[4]); + A(Aurora::$,Aurora::$,1) = B(0,Aurora::$,Aurora::$); + EXPECT_EQ(1,dataA[4]); + EXPECT_EQ(3,dataA[3]); + A(Aurora::$,1,Aurora::$) = B(Aurora::$,Aurora::$,0); + EXPECT_EQ(1.0,dataA[3]); + EXPECT_EQ(0.0,dataA[0]); + A(0,Aurora::$,Aurora::$) = B(Aurora::$,0,Aurora::$); + EXPECT_EQ(1.0,dataA[0]); + //vector slice + A(0,Aurora::$,0) = C(0,0,Aurora::$); + EXPECT_EQ(9.0,dataA[0]); + A(Aurora::$,0,0) = C(0,Aurora::$,1); + EXPECT_EQ(5.0,dataA[0]); + //error slice + EXPECT_EQ(1,A(Aurora::$,Aurora::$,Aurora::$).toMatrix().getDataSize() ); + auto D =C(0,0,0).toMatrix(); + EXPECT_EQ(1,D.getDataSize() ); + EXPECT_EQ(9,D.getData()[0] ); } -TEST_F(FunctionTester, RawDataMatrix) { - double * dataA =new double[8]; - double * dataB =new double[8]; - for (int i = 0; i < 8; ++i) { - dataA[i]=(double)(i-3); - dataB[i]=(double)(i+2); +TEST_F(FunctionTester, matrixOpertaor) { + //3D + { + double * dataA =new double[8]; + double * dataB =new double[8]; + for (int i = 0; i < 8; ++i) { + dataA[i]=(double)(i); + dataB[i]=(double)(2); + } + Aurora::Matrix A = Aurora::Matrix::fromRawData(dataA,2,2,2); + DISPLAY_MATRIX(A) + Aurora::Matrix B = Aurora::Matrix::fromRawData(dataB,2,2,2); + DISPLAY_MATRIX(B) + auto C = (A*B)-B; + DISPLAY_MATRIX(C) + EXPECT_EQ(C.getData()[2],2); + C = A*B/2.0; + EXPECT_EQ(C.getData()[2],2); + C = A*B*B/2.0; + EXPECT_EQ(C.getData()[2],4); } - Aurora::Matrix A = Aurora::Matrix::fromRawData(dataA,2,2,2); - printf("A:\r\n"); - A.printf(); - Aurora::Matrix B = Aurora::Matrix::copyFromRawData(dataB,2,2,2); - delete [] dataB; - printf("B:\r\n"); - B.printf(); - A(Aurora::$,Aurora::$,1) = B(Aurora::$,Aurora::$,0); - printf("New A:\r\n"); - A.printf(); - printf("New B:\r\n"); - B.printf(); + } TEST_F(FunctionTester, sign) {