Calc fix and 2d functions.

This commit is contained in:
Krad
2023-04-20 17:35:03 +08:00
parent 65dd582f77
commit ed7312992f
4 changed files with 252 additions and 72 deletions

View File

@@ -1,4 +1,72 @@
#include <iostream>
#include "Function.h"
#include "Function2D.h"
#include "mkl.h"
double Aurora::immse(const Aurora::Matrix &aImageA, const Aurora::Matrix &aImageB) {
if (aImageA.getDims()!=2|| aImageB.getDims()!=2){
std::cerr<<"Fail!immse args must all 2d matrix!";
return 0.0;
}
if (!aImageB.compareShape(aImageA)){
std::cerr<<"Fail!immse args must be same shape!";
return 0.0;
}
if (aImageA.getValueType()!=Normal || aImageB.getValueType() != Normal) {
std::cerr << "Fail!immse args must be normal value type!";
return 0.0;
}
int size = aImageA.getDataSize();
auto temp = malloc(size);
vdSub(size, aImageA.getData(), aImageB.getData(), temp);
vdSqr(size, temp, temp);
double result = cblas_dasum(size, temp, 1) / (double) size;
free(temp);
return result;
}
Aurora::Matrix Aurora::inv(const Aurora::Matrix &aMatrix) {
if (aMatrix.getDims() != 2) {
std::cerr << "Fail!inv args must be 2d matrix!";
return aMatrix;
}
if (aMatrix.getDimSize(0) != aMatrix.getDimSize(1)) {
std::cerr << "Fail!inv args must be square matrix!";
return aMatrix;
}
if (aMatrix.getValueType() != Normal) {
std::cerr << "Fail!inv args must be normal value type!";
return aMatrix;
}
int size = aMatrix.getDataSize();
int *ipiv = new int[aMatrix.getDimSize(0)];
auto result = malloc(size);
cblas_dcopy(size,result, 1,aMatrix.getData(), 1);
LAPACKE_dgetrf(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), aMatrix.getDimSize(0), result, aMatrix.getDimSize(0), ipiv);
LAPACKE_dgetri(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), result, aMatrix.getDimSize(0), ipiv);
delete[] ipiv;
return Matrix::New(result,aMatrix);
}
Aurora::Matrix Aurora::inv(Aurora::Matrix&& aMatrix) {
if (aMatrix.getDims() != 2) {
std::cerr << "Fail!inv args must be 2d matrix!";
return aMatrix;
}
if (aMatrix.getDimSize(0) != aMatrix.getDimSize(1)) {
std::cerr << "Fail!inv args must be square matrix!";
return aMatrix;
}
if (aMatrix.getValueType() != Normal) {
std::cerr << "Fail!inv args must be normal value type!";
return aMatrix;
}
int *ipiv = new int[aMatrix.getDimSize(0)];
LAPACKE_dgetrf(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), aMatrix.getDimSize(0), aMatrix.getData(), aMatrix.getDimSize(0), ipiv);
LAPACKE_dgetri(LAPACK_ROW_MAJOR, aMatrix.getDimSize(0), aMatrix.getData(), aMatrix.getDimSize(0), ipiv);
delete[] ipiv;
return aMatrix;
}
#include "Function1D.h"
#include "Function.h"

View File

@@ -6,7 +6,9 @@
namespace Aurora {
double immse(const Matrix& aImageA, const Matrix& aImageB);
Matrix inv(const Matrix& aMatrix);
Matrix inv(Matrix&& aMatrix);
Matrix interp2(const Matrix& aX, const Matrix& aY, const Matrix& aV, const Matrix& aX1, const Matrix& aY1, InterpnMethod aMethod);
Matrix interpn(const Matrix& aX, const Matrix& aY, const Matrix& aV, const Matrix& aX1, const Matrix& aY1, InterpnMethod aMethod);

View File

@@ -27,37 +27,6 @@ namespace Aurora{
aMatrix.getValueType());
}
inline Matrix operatorMxM(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Matrix &aMatrix,
const Matrix &aOther) {
if (!aMatrix.compareShape(aOther))return Matrix();
if (aMatrix.getValueType() != aOther.getValueType()) {
double *output = malloc(aMatrix.getDataSize(), true);
if (aMatrix.getValueType() == Complex) {
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1,
output, 1);
aFuncD(aMatrix.getDataSize(), aMatrix.getData() + 1, 1, aOther.getData(), 1,
output + 1,
1);
return Matrix::New(output, aMatrix);
}
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1, output,
1);
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData() + 1, 1,
output + 1, 1);
return Matrix::New(output, aOther);
} else if (aMatrix.getValueType() == Normal) {
double *output = malloc(aMatrix.getDataSize());
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1, output,
1);
return Matrix::New(output, aMatrix);
} else {
double *output = malloc(aMatrix.getDataSize(), true);
aFuncZ(aMatrix.getDataSize(), (std::complex<double> *) aMatrix.getData(), 1,
(std::complex<double> *) aOther.getData(), 1, (std::complex<double> *) output, 1);
return Matrix::New(output, aOther);
}
}
inline Matrix &operatorMxA_RR(CalcFuncD aFunc, double aScalar, Aurora::Matrix &&aMatrix) {
std::cout << "use right ref operation" << std::endl;
std::cout << "before operation" << std::endl;
@@ -70,45 +39,172 @@ namespace Aurora{
aMatrix.getData() + 1, 1);
} else {
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0,
aMatrix.getData(),
1);
aMatrix.getData(),
1);
}
std::cout << "after operation" << std::endl;
aMatrix.printf();
return aMatrix;
}
inline void V_MxM_CN_Calc(
CalcFuncD aFuncD,
const int size, double* xC,double* yN,double *output, int DimsStride) {
aFuncD(size, xC, DimsStride * 2, yN, 1, output, 2);
aFuncD(size, xC + 1, DimsStride * 2, yN, 1, output + 1, 2);
}
inline double* _MxM_CN_Calc(
CalcFuncD aFuncD,
const int size, double* xC,double* yN, int dimsStride)
{
double *output = malloc(size, true);
V_MxM_CN_Calc(aFuncD, size, xC, yN, output, dimsStride);
return output;
}
inline void V_MxM_NC_Calc(
CalcFuncD aFuncD,
const int size, double* xC,double* yN,double *output, int DimsStride) {
aFuncD(size, xC, DimsStride, yN, 2, output, 2);
aFuncD(size, xC , DimsStride, yN+ 1, 2, output + 1, 2);
}
inline double* _MxM_NC_Calc(
CalcFuncD aFuncD,
const int size, double* xN,double* yC, int dimsStride)
{
double *output = malloc(size, true);
V_MxM_NC_Calc(aFuncD, size, xN, yC, output, dimsStride);
return output;
}
inline void V_MxM_NN_Calc(
CalcFuncD aFuncD,
const int size, double* x,double* y,double* output, int DimsStride) {
aFuncD(size, x, DimsStride, y, 1, output,1);
}
inline double* _MxM_NN_Calc(
CalcFuncD aFuncD,
const int size, double* x,double* y, int DimsStride) {
double *output = malloc(size);
V_MxM_NN_Calc(aFuncD, size, x, y, output, DimsStride);
return output;
}
inline void V_MxM_CC_Calc(
CalcFuncZ aFuncZ, const int size, double* x,double* y,double* output,
int DimsStride) {
aFuncZ(size, (std::complex<double> *) x, DimsStride,
(std::complex<double> *) y, 1, (std::complex<double> *) output, 1);
}
inline double* _MxM_CC_Calc(
CalcFuncZ aFuncZ, const int size, double* x,double* y,
int DimsStride) {
double *output = malloc(size, true);
V_MxM_CC_Calc(aFuncZ, size, x, y, output, DimsStride);
return output;
}
inline Matrix operatorMxM(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Matrix &aMatrix,
const Matrix &aOther) {
// 2v2,1v1,3v3
if (aMatrix.compareShape(aOther)) {
int DimsStride = 1;
double *data = nullptr;
if (aMatrix.getValueType() != aOther.getValueType()) {
if (aMatrix.getValueType() == Normal) {
data = _MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(),
DimsStride);
return Matrix::New(data, aOther);
} else {
data = _MxM_CN_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(),
DimsStride);
return Matrix::New(data, aMatrix);
}
} else if (aMatrix.getValueType() == Normal) {
data = _MxM_NN_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), DimsStride);
return Matrix::New(data, aMatrix);
} else {
data = _MxM_CC_Calc(aFuncZ, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), DimsStride);
return Matrix::New(data, aMatrix);
}
}
//0v3, 0v2
else if (aMatrix.getDataSize()==1){
if (aMatrix.getValueType() ==Normal)return operatorMxA(aFuncD,aMatrix.getData()[0],aOther);
else{
std::cerr<<"M * M fail, Complex scalar * not support now!"<<std::endl;
return Matrix();
}
}
//3v0, 2v0
else if (aOther.getDataSize()==1){
if (aOther.getValueType() ==Normal)return operatorMxA(aFuncD,aOther.getData()[0],aMatrix);
else{
std::cerr<<"M * M fail, Complex scalar * not support now!"<<std::endl;
return Matrix();
}
}
//other
else {
std::cerr<<"M * M Shape error!If you want do vector * Matrix, use repmat to replicate the vector"<<std::endl;
return Matrix();
}
}
inline Matrix operatorMxM_RR(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Aurora::Matrix &aMatrix,
Aurora::Matrix &&aOther) {
if (!aMatrix.compareShape(aOther))return Matrix();
std::cout << "use right ref operation m" << std::endl;
if (aMatrix.getValueType() != aOther.getValueType()) {
//aOther is not a complex matrix
if (aMatrix.getValueType() == Complex) {
double *output = malloc(aMatrix.getDataSize(), true);
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1,
output, 1);
aFuncD(aMatrix.getDataSize(), aMatrix.getData() + 1,1, aOther.getData(), 1,
output + 1,
1);
return Matrix::New(output, aOther);
if (aMatrix.compareShape(aOther)) {
int DimsStride = 1;
if (aMatrix.getValueType() != aOther.getValueType()) {
//aOther is not a complex matrix
if (aMatrix.getValueType() == Complex) {
double *output = _MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(),
DimsStride);
return Matrix::New(output, aOther);
}
//aOther is a complex matrix, use aOther as output
V_MxM_NC_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(),
DimsStride);
return aOther;
} else if (aMatrix.getValueType() == Normal) {
V_MxM_NN_Calc(aFuncD, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(),
DimsStride);
return aOther;
} else {
V_MxM_CC_Calc(aFuncZ, aMatrix.getDataSize(), aMatrix.getData(), aOther.getData(), aOther.getData(),
DimsStride);
return aOther;
}
//aOther is a complex matrix, use aOther as output
aFuncD(aMatrix.getDataSize(), aMatrix.getData(),1, aOther.getData(), 1,
aOther.getData(),
1);
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData() + 1, 1,
aOther.getData() + 1, 1);
return aOther;
} else if (aMatrix.getValueType() == Normal) {
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1,
aOther.getData(),
1);
return aOther;
} else {
aFuncZ(aMatrix.getDataSize(), (std::complex<double> *) aMatrix.getData(), 1,
(std::complex<double> *) aOther.getData(), 1, (std::complex<double> *) aOther.getData(), 1);
return aOther;
}
//0v3, 0v2
else if (aMatrix.getDataSize()==1){
if (aMatrix.getValueType() ==Normal){
return operatorMxA(aFuncD,aMatrix.getData()[0],std::forward<Aurora::Matrix &&>(aOther));
}
else{
std::cerr<<"M * M fail, Complex scalar * not support now!"<<std::endl;
return Matrix();
}
}
//3v0, 2v0
else if (aOther.getDataSize()==1){
if (aOther.getValueType() ==Normal)return operatorMxA(aFuncD,aOther.getData()[0],aMatrix);
else{
std::cerr<<"M * M fail, Complex scalar * not support now!"<<std::endl;
return Matrix();
}
}
//other
else {
std::cerr<<"M * M Shape error!If you want do vector * Matrix, use repmat to replicate the vector"<<std::endl;
return Matrix();
}
}
};
@@ -159,7 +255,17 @@ namespace Aurora {
bool Matrix::compareShape(const Matrix &other) const {
if (mInfo.size() != other.mInfo.size()) return false;
if (mInfo.size() != other.mInfo.size()) {
// all vector compare length
if (mInfo.size() + other.mInfo.size() == 3){
return getDataSize()== other.getDataSize();
}
// 2 and 3
else if (mInfo.size() + other.mInfo.size() == 5){
return (mInfo.size()==3&& mInfo[2]==1)|| (other.mInfo.size()==3&& other.mInfo[2]==1);
}
return false;
}
for (int i = 0; i < mInfo.size(); ++i) {
if (mInfo[i] != other.mInfo[i]) return false;
}
@@ -388,23 +494,24 @@ namespace Aurora {
return *this;
}
if(!slice.mData){
std::cerr <<"Assign value fail!Src data pointer is null!";
std::cerr <<"Assign value fail!Src data pointer is null!"<<std::endl;
return *this;
}
if (slice.mSliceMode!=mSliceMode) {
std::cerr <<"Assign value fail!Src slice(dims count:"<< slice.mSliceMode <<"), not match of des(dims count:"<<mSliceMode<<")!";
std::cerr << "Assign value fail!Src slice(dims count:" << slice.mSliceMode
<< "), not match of des(dims count:" << mSliceMode << ")!" << std::endl;
return *this;
}
if (slice.mSize!=mSize) {
std::cerr <<"Assign value fail!Src slice(dim 1 size:"<< slice.mSize <<"), not match of des(dim 1 size:"<<mSize<<")!";
std::cerr <<"Assign value fail!Src slice(dim 1 size:"<< slice.mSize <<"), not match of des(dim 1 size:"<<mSize<<")!"<<std::endl;
return *this;
}
if (slice.mSize2!=mSize2) {
std::cerr <<"Assign value fail!Src slice(dim 2 size:"<< slice.mSize2 <<"), not match of des(dim 2 size:"<<mSize2<<")!";
std::cerr <<"Assign value fail!Src slice(dim 2 size:"<< slice.mSize2 <<"), not match of des(dim 2 size:"<<mSize2<<")!"<<std::endl;
return *this;
}
if (slice.mType!=mType) {
std::cerr <<"Assign value fail!Src slice(value type:"<< slice.mType <<"), not match of des(value type:"<<mType<<")!";
std::cerr <<"Assign value fail!Src slice(value type:"<< slice.mType <<"), not match of des(value type:"<<mType<<")!"<<std::endl;
return *this;
}
switch (mSliceMode) {

View File

@@ -17,15 +17,17 @@
int main() {
{
double *dataA = Aurora::malloc(8);
double *dataB = Aurora::malloc(8);;
double *dataC = Aurora::malloc(8);;
double *dataA = Aurora::malloc(8,true);
double *dataB = Aurora::malloc(8);
double *dataC = Aurora::malloc(8);
for (int i = 0; i < 16; ++i) {
dataA[i] = (double) (i + 2);
}
for (int i = 0; i < 8; ++i) {
dataA[i] = (double) (i - 3);
dataB[i] = (double) (i + 2);
dataC[i / 2] = (double) (i + 9);
}
Aurora::Matrix A = Aurora::Matrix::New(dataA, 2, 2, 2);
Aurora::Matrix A = Aurora::Matrix::New(dataA, 2, 2, 2,Aurora::ValueType::Complex);
printf("A:\r\n");
A.printf();
Aurora::Matrix B = Aurora::Matrix::New(dataB, 2, 2, 2);
@@ -54,6 +56,7 @@ int main() {
printf("New A col slice 1 toMatrix:\r\n");
auto Ds = A(Aurora::$, 1, Aurora::$);
auto D = Ds.toMatrix();
printf("D:\r\n");
D.printf();