From ed7312992f6297521f5dc9612d50c089bd0e13d3 Mon Sep 17 00:00:00 2001 From: Krad Date: Thu, 20 Apr 2023 17:35:03 +0800 Subject: [PATCH] Calc fix and 2d functions. --- src/Function2D.cpp | 68 +++++++++++++ src/Function2D.h | 4 +- src/Matrix.cpp | 239 ++++++++++++++++++++++++++++++++------------- src/main.cxx | 13 ++- 4 files changed, 252 insertions(+), 72 deletions(-) diff --git a/src/Function2D.cpp b/src/Function2D.cpp index 7241bea..e0f7eea 100644 --- a/src/Function2D.cpp +++ b/src/Function2D.cpp @@ -1,4 +1,72 @@ +#include +#include "Function.h" #include "Function2D.h" +#include "mkl.h" + +double Aurora::immse(const Aurora::Matrix &aImageA, const Aurora::Matrix &aImageB) { + if (aImageA.getDims()!=2|| aImageB.getDims()!=2){ + std::cerr<<"Fail!immse args must all 2d matrix!"; + return 0.0; + } + if (!aImageB.compareShape(aImageA)){ + std::cerr<<"Fail!immse args must be same shape!"; + return 0.0; + } + if (aImageA.getValueType()!=Normal || aImageB.getValueType() != Normal) { + std::cerr << "Fail!immse args must be normal value type!"; + return 0.0; + } + int size = aImageA.getDataSize(); + auto temp = malloc(size); + vdSub(size, aImageA.getData(), aImageB.getData(), temp); + vdSqr(size, temp, temp); + double result = cblas_dasum(size, temp, 1) / (double) size; + free(temp); + return result; +} + +Aurora::Matrix Aurora::inv(const Aurora::Matrix &aMatrix) { + if (aMatrix.getDims() != 2) { + std::cerr << "Fail!inv args must be 2d matrix!"; + return aMatrix; + } + if (aMatrix.getDimSize(0) != aMatrix.getDimSize(1)) { + std::cerr << "Fail!inv args must be square matrix!"; + return aMatrix; + } + if (aMatrix.getValueType() != Normal) { + std::cerr << "Fail!inv args must be normal value type!"; + return aMatrix; + } + int size = aMatrix.getDataSize(); + int *ipiv = new int[aMatrix.getDimSize(0)]; + auto result = malloc(size); + cblas_dcopy(size,result, 1,aMatrix.getData(), 1); + LAPACKE_dgetrf(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), aMatrix.getDimSize(0), result, aMatrix.getDimSize(0), ipiv); + LAPACKE_dgetri(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), result, aMatrix.getDimSize(0), ipiv); + delete[] ipiv; + return Matrix::New(result,aMatrix); +} + +Aurora::Matrix Aurora::inv(Aurora::Matrix&& aMatrix) { + if (aMatrix.getDims() != 2) { + std::cerr << "Fail!inv args must be 2d matrix!"; + return aMatrix; + } + if (aMatrix.getDimSize(0) != aMatrix.getDimSize(1)) { + std::cerr << "Fail!inv args must be square matrix!"; + return aMatrix; + } + if (aMatrix.getValueType() != Normal) { + std::cerr << "Fail!inv args must be normal value type!"; + return aMatrix; + } + int *ipiv = new int[aMatrix.getDimSize(0)]; + LAPACKE_dgetrf(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), aMatrix.getDimSize(0), aMatrix.getData(), aMatrix.getDimSize(0), ipiv); + LAPACKE_dgetri(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), aMatrix.getData(), aMatrix.getDimSize(0), ipiv); + delete[] ipiv; + return aMatrix; +} #include "Function1D.h" #include "Function.h" diff --git a/src/Function2D.h b/src/Function2D.h index 1d42f81..b8e1c05 100644 --- a/src/Function2D.h +++ b/src/Function2D.h @@ -6,7 +6,9 @@ namespace Aurora { - + double immse(const Matrix& aImageA, const Matrix& aImageB); + Matrix inv(const Matrix& aMatrix); + Matrix inv(Matrix&& aMatrix); Matrix interp2(const Matrix& aX, const Matrix& aY, const Matrix& aV, const Matrix& aX1, const Matrix& aY1, InterpnMethod aMethod); Matrix interpn(const Matrix& aX, const Matrix& aY, const Matrix& aV, const Matrix& aX1, const Matrix& aY1, InterpnMethod aMethod); diff --git a/src/Matrix.cpp b/src/Matrix.cpp index 0228ebf..2e5661b 100644 --- a/src/Matrix.cpp +++ b/src/Matrix.cpp @@ -27,37 +27,6 @@ namespace Aurora{ aMatrix.getValueType()); } - inline Matrix operatorMxM(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Matrix &aMatrix, - const Matrix &aOther) { - if (!aMatrix.compareShape(aOther))return Matrix(); - if (aMatrix.getValueType() != aOther.getValueType()) { - double *output = malloc(aMatrix.getDataSize(), true); - if (aMatrix.getValueType() == Complex) { - aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1, - output, 1); - aFuncD(aMatrix.getDataSize(), aMatrix.getData() + 1, 1, aOther.getData(), 1, - output + 1, - 1); - return Matrix::New(output, aMatrix); - } - aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1, output, - 1); - aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData() + 1, 1, - output + 1, 1); - return Matrix::New(output, aOther); - } else if (aMatrix.getValueType() == Normal) { - double *output = malloc(aMatrix.getDataSize()); - aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1, output, - 1); - return Matrix::New(output, aMatrix); - } else { - double *output = malloc(aMatrix.getDataSize(), true); - aFuncZ(aMatrix.getDataSize(), (std::complex *) aMatrix.getData(), 1, - (std::complex *) aOther.getData(), 1, (std::complex *) output, 1); - return Matrix::New(output, aOther); - } - } - inline Matrix &operatorMxA_RR(CalcFuncD aFunc, double aScalar, Aurora::Matrix &&aMatrix) { std::cout << "use right ref operation" << std::endl; std::cout << "before operation" << std::endl; @@ -70,45 +39,172 @@ namespace Aurora{ aMatrix.getData() + 1, 1); } else { aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0, - aMatrix.getData(), - 1); + aMatrix.getData(), + 1); } std::cout << "after operation" << std::endl; aMatrix.printf(); return aMatrix; } + + inline void V_MxM_CN_Calc( + CalcFuncD aFuncD, + const int size, double* xC,double* yN,double *output, int DimsStride) { + aFuncD(size, xC, DimsStride * 2, yN, 1, output, 2); + aFuncD(size, xC + 1, DimsStride * 2, yN, 1, output + 1, 2); + } + + inline double* _MxM_CN_Calc( + CalcFuncD aFuncD, + const int size, double* xC,double* yN, int dimsStride) + { + double *output = malloc(size, true); + V_MxM_CN_Calc(aFuncD, size, xC, yN, output, dimsStride); + return output; + } + + inline void V_MxM_NC_Calc( + CalcFuncD aFuncD, + const int size, double* xC,double* yN,double *output, int DimsStride) { + aFuncD(size, xC, DimsStride, yN, 2, output, 2); + aFuncD(size, xC , DimsStride, yN+ 1, 2, output + 1, 2); + } + + inline double* _MxM_NC_Calc( + CalcFuncD aFuncD, + const int size, double* xN,double* yC, int dimsStride) + { + double *output = malloc(size, true); + V_MxM_NC_Calc(aFuncD, size, xN, yC, output, dimsStride); + return output; + } + + inline void V_MxM_NN_Calc( + CalcFuncD aFuncD, + const int size, double* x,double* y,double* output, int DimsStride) { + aFuncD(size, x, DimsStride, y, 1, output,1); + } + + inline double* _MxM_NN_Calc( + CalcFuncD aFuncD, + const int size, double* x,double* y, int DimsStride) { + double *output = malloc(size); + V_MxM_NN_Calc(aFuncD, size, x, y, output, DimsStride); + return output; + } + + inline void V_MxM_CC_Calc( + CalcFuncZ aFuncZ, const int size, double* x,double* y,double* output, + int DimsStride) { + aFuncZ(size, (std::complex *) x, DimsStride, + (std::complex *) y, 1, (std::complex *) output, 1); + } + + inline double* _MxM_CC_Calc( + CalcFuncZ aFuncZ, const int size, double* x,double* y, + int DimsStride) { + double *output = malloc(size, true); + V_MxM_CC_Calc(aFuncZ, size, x, y, output, DimsStride); + return output; + } + + inline Matrix operatorMxM(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Matrix &aMatrix, + const Matrix &aOther) { + // 2v2,1v1,3v3 + if (aMatrix.compareShape(aOther)) { + int DimsStride = 1; + double *data = nullptr; + if (aMatrix.getValueType() != aOther.getValueType()) { + if (aMatrix.getValueType() == Normal) { + data = _MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), + DimsStride); + return Matrix::New(data, aOther); + } else { + data = _MxM_CN_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), + DimsStride); + return Matrix::New(data, aMatrix); + } + + } else if (aMatrix.getValueType() == Normal) { + data = _MxM_NN_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), DimsStride); + return Matrix::New(data, aMatrix); + } else { + data = _MxM_CC_Calc(aFuncZ, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), DimsStride); + return Matrix::New(data, aMatrix); + } + } + //0v3, 0v2 + else if (aMatrix.getDataSize()==1){ + if (aMatrix.getValueType() ==Normal)return operatorMxA(aFuncD,aMatrix.getData()[0],aOther); + else{ + std::cerr<<"M * M fail, Complex scalar * not support now!"< *) aMatrix.getData(), 1, - (std::complex *) aOther.getData(), 1, (std::complex *) aOther.getData(), 1); - return aOther; + } + //0v3, 0v2 + else if (aMatrix.getDataSize()==1){ + if (aMatrix.getValueType() ==Normal){ + return operatorMxA(aFuncD,aMatrix.getData()[0],std::forward(aOther)); + } + else{ + std::cerr<<"M * M fail, Complex scalar * not support now!"<