Fix operator bug(+-*/) in matrix and

= in matrixslice.
This commit is contained in:
kradchen
2023-05-19 16:04:23 +08:00
parent e32367b7dc
commit ae27c18c13

View File

@@ -401,7 +401,15 @@ namespace Aurora {
//operation + //operation +
Matrix Matrix::operator+(double aScalar) const { return operatorMxA(&vdAddI, aScalar, *this);} Matrix Matrix::operator+(double aScalar) const { return operatorMxA(&vdAddI, aScalar, *this);}
Matrix operator+(double aScalar, const Matrix &matrix) {return matrix + aScalar;} Matrix operator+(double aScalar, const Matrix &matrix) {return matrix + aScalar;}
Matrix Matrix::operator+(const Matrix &matrix) const {return operatorMxM(vdAddI, vzAddI, *this, matrix);} Matrix Matrix::operator+(const Matrix &matrix) const {
if (isScalar()){
return getScalar()+matrix;
}
if (matrix.isScalar()){
return (*this)+matrix.getScalar();
}
return operatorMxM(vdAddI, vzAddI, *this, matrix);
}
Matrix &operator+(double aScalar, Matrix &&matrix) { Matrix &operator+(double aScalar, Matrix &&matrix) {
return operatorMxA_RR(&vdAddI,aScalar, std::forward<Matrix&&>(matrix)); return operatorMxA_RR(&vdAddI,aScalar, std::forward<Matrix&&>(matrix));
} }
@@ -409,16 +417,36 @@ namespace Aurora {
return operatorMxA_RR(&vdAddI,aScalar, std::forward<Matrix&&>(matrix)); return operatorMxA_RR(&vdAddI,aScalar, std::forward<Matrix&&>(matrix));
} }
Matrix Matrix::operator+(Matrix &&aMatrix) const { Matrix Matrix::operator+(Matrix &&aMatrix) const {
if (isScalar()){
return getScalar()+aMatrix;
}
if (aMatrix.isScalar()){
return (*this)+aMatrix.getScalar();
}
return operatorMxM_RR(&vdAddI,&vzAddI,*this,std::forward<Matrix&&>(aMatrix)); return operatorMxM_RR(&vdAddI,&vzAddI,*this,std::forward<Matrix&&>(aMatrix));
} }
Matrix operator+(Matrix &&aMatrix, Matrix &aOther) { Matrix operator+(Matrix &&aMatrix, Matrix &aOther) {
if (aOther.isScalar()){
return aOther.getScalar()+aMatrix;
}
if (aMatrix.isScalar()){
return aOther+aMatrix.getScalar();
}
return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward<Matrix&&>(aMatrix),true); return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward<Matrix&&>(aMatrix),true);
} }
//operation - //operation -
Matrix Matrix::operator-(double aScalar) const { return operatorMxA(&vdSubI, aScalar, *this);} Matrix Matrix::operator-(double aScalar) const { return operatorMxA(&vdSubI, aScalar, *this);}
Matrix operator-(double aScalar, const Matrix &matrix) {return matrix - aScalar;} Matrix operator-(double aScalar, const Matrix &matrix) {return matrix - aScalar;}
Matrix Matrix::operator-(const Matrix &matrix) const {return operatorMxM(vdSubI, vzSubI, *this, matrix);} Matrix Matrix::operator-(const Matrix &aMatrix) const {
if (isScalar()){
return getScalar()-aMatrix;
}
if (aMatrix.isScalar()){
return (*this)-aMatrix.getScalar();
}
return operatorMxM(vdSubI, vzSubI, *this, aMatrix);
}
Matrix &operator-(double aScalar, Matrix &&matrix) { Matrix &operator-(double aScalar, Matrix &&matrix) {
return operatorMxA_RR(&vdSubI,aScalar, std::forward<Matrix&&>(matrix)); return operatorMxA_RR(&vdSubI,aScalar, std::forward<Matrix&&>(matrix));
} }
@@ -426,33 +454,73 @@ namespace Aurora {
return operatorMxA_RR(&vdSubI,aScalar, std::forward<Matrix&&>(matrix)); return operatorMxA_RR(&vdSubI,aScalar, std::forward<Matrix&&>(matrix));
} }
Matrix Matrix::operator-(Matrix &&aMatrix) const { Matrix Matrix::operator-(Matrix &&aMatrix) const {
if (isScalar()){
return getScalar()-aMatrix;
}
if (aMatrix.isScalar()){
return (*this)-aMatrix.getScalar();
}
return operatorMxM_RR(&vdSubI,&vzSubI,*this,std::forward<Matrix&&>(aMatrix)); return operatorMxM_RR(&vdSubI,&vzSubI,*this,std::forward<Matrix&&>(aMatrix));
} }
Matrix operator-(Matrix &&aMatrix, Matrix &aOther) { Matrix operator-(Matrix &&aMatrix, Matrix &aOther) {
if (aOther.isScalar()){
return aMatrix-aOther.getScalar();
}
if (aMatrix.isScalar()){
return aMatrix.getScalar()-aOther;
}
return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward<Matrix&&>(aMatrix),true); return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward<Matrix&&>(aMatrix),true);
} }
//operation * //operation *
Matrix Matrix::operator*(double aScalar) const { return operatorMxA(&vdMulI, aScalar, *this);} Matrix Matrix::operator*(double aScalar) const { return operatorMxA(&vdMulI, aScalar, *this);}
Matrix operator*(double aScalar, const Matrix &matrix) {return matrix * aScalar;} Matrix operator*(double aScalar, const Matrix &matrix) {return matrix * aScalar;}
Matrix Matrix::operator*(const Matrix &matrix) const {return operatorMxM(vdMulI, vzMulI, *this, matrix);} Matrix Matrix::operator*(const Matrix &aMatrix) const {
if (isScalar()){
return getScalar()*aMatrix;
}
if (aMatrix.isScalar()){
return (*this)*aMatrix.getScalar();
}
return operatorMxM(vdMulI, vzMulI, *this, aMatrix);
}
Matrix &operator*(double aScalar, Matrix &&matrix) { Matrix &operator*(double aScalar, Matrix &&matrix) {
return operatorMxA_RR(&vdMulI,aScalar, std::forward<Matrix&&>(matrix)); return operatorMxA_RR(&vdMulI,aScalar, std::forward<Matrix&&>(matrix));
} }
Matrix &operator*(Matrix &&matrix,double aScalar) { Matrix &operator*(Matrix &&aMatrix,double aScalar) {
return operatorMxA_RR(&vdMulI,aScalar, std::forward<Matrix&&>(matrix)); return operatorMxA_RR(&vdMulI,aScalar, std::forward<Matrix&&>(aMatrix));
} }
Matrix Matrix::operator*(Matrix &&aMatrix) const { Matrix Matrix::operator*(Matrix &&aMatrix) const {
if (isScalar()){
return getScalar()*aMatrix;
}
if (aMatrix.isScalar()){
return (*this)*aMatrix.getScalar();
}
return operatorMxM_RR(&vdMulI,&vzMulI,*this,std::forward<Matrix&&>(aMatrix)); return operatorMxM_RR(&vdMulI,&vzMulI,*this,std::forward<Matrix&&>(aMatrix));
} }
Matrix operator*(Matrix &&aMatrix, Matrix &aOther) { Matrix operator*(Matrix &&aMatrix, Matrix &aOther) {
if (aOther.isScalar()){
return aMatrix*aOther.getScalar();
}
if (aMatrix.isScalar()){
return aMatrix.getScalar()*aOther;
}
return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward<Matrix&&>(aMatrix),true); return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward<Matrix&&>(aMatrix),true);
} }
//operation / //operation /
Matrix Matrix::operator/(double aScalar) const { return operatorMxA(&vdDivI, aScalar, *this);} Matrix Matrix::operator/(double aScalar) const { return operatorMxA(&vdDivI, aScalar, *this);}
Matrix operator/(double aScalar, const Matrix &matrix) {return matrix / aScalar;} Matrix operator/(double aScalar, const Matrix &matrix) {return matrix / aScalar;}
Matrix Matrix::operator/(const Matrix &matrix) const {return operatorMxM(vdDivI, vzDivI, *this, matrix);} Matrix Matrix::operator/(const Matrix &aMatrix) const {
if (isScalar()){
return getScalar()/aMatrix;
}
if (aMatrix.isScalar()){
return (*this)/aMatrix.getScalar();
}
return operatorMxM(vdDivI, vzDivI, *this, aMatrix);
}
Matrix &operator/(double aScalar, Matrix &&matrix) { Matrix &operator/(double aScalar, Matrix &&matrix) {
return operatorMxA_RR(&vdDivI,aScalar, std::forward<Matrix&&>(matrix)); return operatorMxA_RR(&vdDivI,aScalar, std::forward<Matrix&&>(matrix));
} }
@@ -460,10 +528,22 @@ namespace Aurora {
return operatorMxA_RR(&vdDivI,aScalar, std::forward<Matrix&&>(matrix)); return operatorMxA_RR(&vdDivI,aScalar, std::forward<Matrix&&>(matrix));
} }
Matrix Matrix::operator/(Matrix &&aMatrix) const { Matrix Matrix::operator/(Matrix &&aMatrix) const {
if (isScalar()){
return getScalar()/aMatrix;
}
if (aMatrix.isScalar()){
return (*this)/aMatrix.getScalar();
}
return operatorMxM_RR(&vdDivI,&vzDivI,*this,std::forward<Matrix&&>(aMatrix)); return operatorMxM_RR(&vdDivI,&vzDivI,*this,std::forward<Matrix&&>(aMatrix));
} }
Matrix operator/(Matrix &&aMatrix, Matrix &aOther) { Matrix operator/(Matrix &&aMatrix, Matrix &aOther) {
if (aOther.isScalar()){
return aMatrix/aOther.getScalar();
}
if (aMatrix.isScalar()){
return aMatrix.getScalar()/aOther;
}
return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward<Matrix&&>(aMatrix),true); return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward<Matrix&&>(aMatrix),true);
} }
//operator ^ (pow) //operator ^ (pow)
@@ -1043,7 +1123,7 @@ namespace Aurora {
cblas_dcopy(mSize,slice.mData,slice.mStride,mData,mStride); cblas_dcopy(mSize,slice.mData,slice.mStride,mData,mStride);
} }
else { else {
cblas_dcopy(mSize*2, slice.mData, slice.mStride, mData, mStride); cblas_zcopy(mSize, (std::complex<double> *)slice.mData, slice.mStride, (std::complex<double> *)mData, mStride);
} }
break; break;
} }