Fix matrix operator bug on - and /.
This commit is contained in:
@@ -158,27 +158,28 @@ namespace Aurora{
|
||||
}
|
||||
|
||||
inline Matrix operatorMxM_RR(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Aurora::Matrix &aMatrix,
|
||||
Aurora::Matrix &&aOther) {
|
||||
Aurora::Matrix &&aOther,bool oppside = false) {
|
||||
std::cout << "use right ref operation m" << std::endl;
|
||||
if (aMatrix.compareShape(aOther)) {
|
||||
int DimsStride = 1;
|
||||
double* X = oppside?aOther.getData():aMatrix.getData();
|
||||
double* Y = oppside?aMatrix.getData():aOther.getData();
|
||||
if (aMatrix.getValueType() != aOther.getValueType()) {
|
||||
//aOther is not a complex matrix
|
||||
if (aMatrix.getValueType() == Complex) {
|
||||
double *output = _MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(),
|
||||
DimsStride);
|
||||
double *output = _MxM_NC_Calc(aFuncD,aMatrix.getDataSize(), X, Y,DimsStride);
|
||||
return Matrix::New(output, aOther);
|
||||
}
|
||||
//aOther is a complex matrix, use aOther as output
|
||||
V_MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(),
|
||||
V_MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), X, Y, aOther.getData(),
|
||||
DimsStride);
|
||||
return aOther;
|
||||
} else if (aMatrix.getValueType() == Normal) {
|
||||
V_MxM_NN_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(),
|
||||
V_MxM_NN_Calc(aFuncD, aMatrix.getDataSize(), X, Y, aOther.getData(),
|
||||
DimsStride);
|
||||
return aOther;
|
||||
} else {
|
||||
V_MxM_CC_Calc(aFuncZ, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(),
|
||||
V_MxM_CC_Calc(aFuncZ, aMatrix.getDataSize(), X, Y, aOther.getData(),
|
||||
DimsStride);
|
||||
return aOther;
|
||||
}
|
||||
@@ -289,6 +290,7 @@ namespace Aurora {
|
||||
if (slices > 0)vector.push_back(slices);
|
||||
Matrix ret({data, free}, vector);
|
||||
if (type != Normal)ret.setValueType(type);
|
||||
ret.printf();
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -349,7 +351,7 @@ namespace Aurora {
|
||||
return operatorMxM_RR(&vdAddI,&vzAddI,*this,std::forward<Matrix&&>(aMatrix));
|
||||
}
|
||||
Matrix operator+(Matrix &&aMatrix, const Matrix &aOther) {
|
||||
return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward<Matrix&&>(aMatrix));
|
||||
return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward<Matrix&&>(aMatrix),true);
|
||||
}
|
||||
|
||||
//operation -
|
||||
@@ -366,7 +368,7 @@ namespace Aurora {
|
||||
return operatorMxM_RR(&vdSubI,&vzSubI,*this,std::forward<Matrix&&>(aMatrix));
|
||||
}
|
||||
Matrix operator-(Matrix &&aMatrix, const Matrix &aOther) {
|
||||
return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward<Matrix&&>(aMatrix));
|
||||
return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward<Matrix&&>(aMatrix),true);
|
||||
}
|
||||
|
||||
//operation *
|
||||
@@ -383,7 +385,7 @@ namespace Aurora {
|
||||
return operatorMxM_RR(&vdMulI,&vzMulI,*this,std::forward<Matrix&&>(aMatrix));
|
||||
}
|
||||
Matrix operator*(Matrix &&aMatrix, const Matrix &aOther) {
|
||||
return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward<Matrix&&>(aMatrix));
|
||||
return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward<Matrix&&>(aMatrix),true);
|
||||
}
|
||||
|
||||
//operation /
|
||||
@@ -400,7 +402,7 @@ namespace Aurora {
|
||||
return operatorMxM_RR(&vdDivI,&vzDivI,*this,std::forward<Matrix&&>(aMatrix));
|
||||
}
|
||||
Matrix operator/(Matrix &&aMatrix, const Matrix &aOther) {
|
||||
return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward<Matrix&&>(aMatrix));
|
||||
return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward<Matrix&&>(aMatrix),true);
|
||||
}
|
||||
|
||||
//operator ^ (pow)
|
||||
@@ -449,7 +451,7 @@ namespace Aurora {
|
||||
std::vector<int> allDimIndex;
|
||||
int mode = 0;
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
if (dims[j]==$ && this->getDims()>j){
|
||||
if (dims[j]==$ && getDims()>j){
|
||||
++mode;
|
||||
allDimIndex.push_back(j);
|
||||
}
|
||||
@@ -481,7 +483,7 @@ namespace Aurora {
|
||||
//scalar mode or ALL $
|
||||
case 0:
|
||||
default: {
|
||||
return Matrix::MatrixSlice(1 , 1, startPointer,getValueType(), mode);
|
||||
return Matrix::MatrixSlice(1 , 1, startPointer,getValueType(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user