Fix matrix operator bug on - and /. 3
This commit is contained in:
@@ -43,7 +43,7 @@ Aurora::Matrix Aurora::inv(const Aurora::Matrix &aMatrix) {
|
||||
int size = aMatrix.getDataSize();
|
||||
int *ipiv = new int[aMatrix.getDimSize(0)];
|
||||
auto result = malloc(size);
|
||||
cblas_dcopy(size,result, 1,aMatrix.getData(), 1);
|
||||
cblas_dcopy(size,aMatrix.getData(), 1,result, 1);
|
||||
LAPACKE_dgetrf(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), aMatrix.getDimSize(0), result, aMatrix.getDimSize(0), ipiv);
|
||||
LAPACKE_dgetri(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), result, aMatrix.getDimSize(0), ipiv);
|
||||
delete[] ipiv;
|
||||
|
||||
@@ -18,10 +18,15 @@ namespace Aurora{
|
||||
|
||||
inline Matrix operatorMxA(CalcFuncD aFunc, double aScalar, const Matrix &aMatrix) {
|
||||
double *output = malloc(aMatrix.getDataSize(), aMatrix.getValueType() == Complex);
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0, output, 1);
|
||||
if (aMatrix.getValueType() == Complex) {
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData() + 1, 1, &aScalar, 0, output + 1,
|
||||
1);
|
||||
//复数时,+和-需要特别操作,只影响实部
|
||||
if (aMatrix.getValueType() == Complex && (aFunc == &vdAddI || aFunc == &vdSubI)) {
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData() , 2, &aScalar, 0, output ,
|
||||
2);
|
||||
double zero = 0.0;
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData()+1 , 2, &zero, 0, output+1 ,
|
||||
2);
|
||||
} else{
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0, output, 1);
|
||||
}
|
||||
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2),
|
||||
aMatrix.getValueType());
|
||||
@@ -30,13 +35,16 @@ namespace Aurora{
|
||||
inline Matrix &operatorMxA_RR(CalcFuncD aFunc, double aScalar, Aurora::Matrix &&aMatrix) {
|
||||
std::cout << "use right ref operation" << std::endl;
|
||||
std::cout << "before operation" << std::endl;
|
||||
aMatrix.printf();
|
||||
if (aMatrix.getValueType() == Complex) {
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0,
|
||||
//针对实部操作
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 2, &aScalar, 0,
|
||||
aMatrix.getData(),
|
||||
1);
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData() + 1, 1, &aScalar, 0,
|
||||
aMatrix.getData() + 1, 1);
|
||||
2);
|
||||
//乘法除法需特别操作,影响虚部
|
||||
if (aFunc == &vdDivI || aFunc == &vdMulI){
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData() + 1, 2, &aScalar, 0,
|
||||
aMatrix.getData() + 1, 2);
|
||||
}
|
||||
} else {
|
||||
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0,
|
||||
aMatrix.getData(),
|
||||
@@ -52,7 +60,13 @@ namespace Aurora{
|
||||
CalcFuncD aFuncD,
|
||||
const int size, double* xC,double* yN,double *output, int DimsStride) {
|
||||
aFuncD(size, xC, DimsStride * 2, yN, 1, output, 2);
|
||||
aFuncD(size, xC + 1, DimsStride * 2, yN, 1, output + 1, 2);
|
||||
if (aFuncD == &vdDivI || aFuncD == &vdMulI){
|
||||
aFuncD(size, xC + 1, DimsStride * 2, yN, 1, output + 1, 2);
|
||||
}
|
||||
else{
|
||||
double zero = 0.0;
|
||||
aFuncD(size, xC+1 , DimsStride*2, &zero, 0, output + 1, 2);
|
||||
}
|
||||
}
|
||||
|
||||
inline double* _MxM_CN_Calc(
|
||||
@@ -66,9 +80,15 @@ namespace Aurora{
|
||||
|
||||
inline void V_MxM_NC_Calc(
|
||||
CalcFuncD aFuncD,
|
||||
const int size, double* xC,double* yN,double *output, int DimsStride) {
|
||||
aFuncD(size, xC, DimsStride, yN, 2, output, 2);
|
||||
aFuncD(size, xC , DimsStride, yN+ 1, 2, output + 1, 2);
|
||||
const int size, double* xN,double* yC,double *output, int DimsStride) {
|
||||
aFuncD(size, xN, DimsStride, yC, 2, output, 2);
|
||||
if (aFuncD == &vdDivI || aFuncD == &vdMulI){
|
||||
aFuncD(size, xN , DimsStride, yC+ 1, 2, output + 1, 2);
|
||||
}
|
||||
else{
|
||||
double zero = 0.0;
|
||||
aFuncD(size, yC+ 1, 2, &zero, 0, output + 1, 2);
|
||||
}
|
||||
}
|
||||
|
||||
inline double* _MxM_NC_Calc(
|
||||
@@ -290,7 +310,6 @@ namespace Aurora {
|
||||
if (slices > 0)vector.push_back(slices);
|
||||
Matrix ret({data, free}, vector);
|
||||
if (type != Normal)ret.setValueType(type);
|
||||
ret.printf();
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user