Fix Matrix slice toMatrix, add is Scalar and getScalar

This commit is contained in:
Krad
2023-04-27 14:36:15 +08:00
parent 1056f970fa
commit 6d709f74e0
2 changed files with 37 additions and 13 deletions

View File

@@ -3,12 +3,14 @@
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <complex>
//必须在mkl.h和Eigen的头之前<complex>之后 #include "AuroraDefs.h"
#define MKL_Complex16 std::complex<double>
#include "mkl.h" #include <Eigen/Core>
#include <Eigen/Dense>
#include "Function.h" #include "Function.h"
namespace Aurora{ namespace Aurora{
typedef void(*CalcFuncD)(const MKL_INT n, const double a[], const MKL_INT inca, const double b[], typedef void(*CalcFuncD)(const MKL_INT n, const double a[], const MKL_INT inca, const double b[],
const MKL_INT incb, double r[], const MKL_INT incr); const MKL_INT incb, double r[], const MKL_INT incr);
@@ -250,6 +252,17 @@ namespace Aurora {
return !mData || mInfo.empty(); return !mData || mInfo.empty();
} }
bool Matrix::isScalar() const {
return (getDimSize(0) == 1 &&
getDimSize(1) == 1 &&
getDimSize(1));
}
double Matrix::getScalar() const {
if (isNull()) return 0.0;
if (isNull()) return 0.0;
return getData()[0];
}
int Matrix::getDims() const { int Matrix::getDims() const {
if(mInfo[2] > 1) if(mInfo[2] > 1)
{ {
@@ -269,6 +282,7 @@ namespace Aurora {
return 0; return 0;
} }
size_t Matrix::getDataSize() const { size_t Matrix::getDataSize() const {
if (!mData.get())return 0; if (!mData.get())return 0;
size_t ret = 1; size_t ret = 1;
@@ -336,7 +350,7 @@ namespace Aurora {
int colsize = cols>0?cols:1; int colsize = cols>0?cols:1;
int slicesize = slices>0?slices:1; int slicesize = slices>0?slices:1;
int size = rows*colsize*slicesize; int size = rows*colsize*slicesize;
double *newBuffer = malloc(size, type); double *newBuffer = Aurora::malloc(size, type);
cblas_dcopy(size*type,data,1,newBuffer,1); cblas_dcopy(size*type,data,1,newBuffer,1);
return New(newBuffer,rows,cols,slices,type); return New(newBuffer,rows,cols,slices,type);
} }
@@ -350,7 +364,6 @@ namespace Aurora {
getDimSize(2), getDimSize(2),
getValueType()); getValueType());
} }
//operation + //operation +
Matrix Matrix::operator+(double aScalar) const { return operatorMxA(&vdAddI, aScalar, *this);} Matrix Matrix::operator+(double aScalar) const { return operatorMxA(&vdAddI, aScalar, *this);}
Matrix operator+(double aScalar, const Matrix &matrix) {return matrix + aScalar;} Matrix operator+(double aScalar, const Matrix &matrix) {return matrix + aScalar;}
@@ -364,10 +377,10 @@ namespace Aurora {
Matrix Matrix::operator+(Matrix &&aMatrix) const { Matrix Matrix::operator+(Matrix &&aMatrix) const {
return operatorMxM_RR(&vdAddI,&vzAddI,*this,std::forward<Matrix&&>(aMatrix)); return operatorMxM_RR(&vdAddI,&vzAddI,*this,std::forward<Matrix&&>(aMatrix));
} }
Matrix operator+(Matrix &&aMatrix, Matrix &aOther) { Matrix operator+(Matrix &&aMatrix, Matrix &aOther) {
return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward<Matrix&&>(aMatrix),true); return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward<Matrix&&>(aMatrix),true);
} }
//operation - //operation -
Matrix Matrix::operator-(double aScalar) const { return operatorMxA(&vdSubI, aScalar, *this);} Matrix Matrix::operator-(double aScalar) const { return operatorMxA(&vdSubI, aScalar, *this);}
Matrix operator-(double aScalar, const Matrix &matrix) {return matrix - aScalar;} Matrix operator-(double aScalar, const Matrix &matrix) {return matrix - aScalar;}
@@ -381,10 +394,10 @@ namespace Aurora {
Matrix Matrix::operator-(Matrix &&aMatrix) const { Matrix Matrix::operator-(Matrix &&aMatrix) const {
return operatorMxM_RR(&vdSubI,&vzSubI,*this,std::forward<Matrix&&>(aMatrix)); return operatorMxM_RR(&vdSubI,&vzSubI,*this,std::forward<Matrix&&>(aMatrix));
} }
Matrix operator-(Matrix &&aMatrix, Matrix &aOther) { Matrix operator-(Matrix &&aMatrix, Matrix &aOther) {
return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward<Matrix&&>(aMatrix),true); return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward<Matrix&&>(aMatrix),true);
} }
//operation * //operation *
Matrix Matrix::operator*(double aScalar) const { return operatorMxA(&vdMulI, aScalar, *this);} Matrix Matrix::operator*(double aScalar) const { return operatorMxA(&vdMulI, aScalar, *this);}
Matrix operator*(double aScalar, const Matrix &matrix) {return matrix * aScalar;} Matrix operator*(double aScalar, const Matrix &matrix) {return matrix * aScalar;}
@@ -398,10 +411,10 @@ namespace Aurora {
Matrix Matrix::operator*(Matrix &&aMatrix) const { Matrix Matrix::operator*(Matrix &&aMatrix) const {
return operatorMxM_RR(&vdMulI,&vzMulI,*this,std::forward<Matrix&&>(aMatrix)); return operatorMxM_RR(&vdMulI,&vzMulI,*this,std::forward<Matrix&&>(aMatrix));
} }
Matrix operator*(Matrix &&aMatrix, Matrix &aOther) { Matrix operator*(Matrix &&aMatrix, Matrix &aOther) {
return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward<Matrix&&>(aMatrix),true); return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward<Matrix&&>(aMatrix),true);
} }
//operation / //operation /
Matrix Matrix::operator/(double aScalar) const { return operatorMxA(&vdDivI, aScalar, *this);} Matrix Matrix::operator/(double aScalar) const { return operatorMxA(&vdDivI, aScalar, *this);}
Matrix operator/(double aScalar, const Matrix &matrix) {return matrix / aScalar;} Matrix operator/(double aScalar, const Matrix &matrix) {return matrix / aScalar;}
@@ -415,16 +428,16 @@ namespace Aurora {
Matrix Matrix::operator/(Matrix &&aMatrix) const { Matrix Matrix::operator/(Matrix &&aMatrix) const {
return operatorMxM_RR(&vdDivI,&vzDivI,*this,std::forward<Matrix&&>(aMatrix)); return operatorMxM_RR(&vdDivI,&vzDivI,*this,std::forward<Matrix&&>(aMatrix));
} }
Matrix operator/(Matrix &&aMatrix, Matrix &aOther) { Matrix operator/(Matrix &&aMatrix, Matrix &aOther) {
return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward<Matrix&&>(aMatrix),true); return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward<Matrix&&>(aMatrix),true);
} }
//operator ^ (pow) //operator ^ (pow)
Matrix Matrix::operator^(int times) const { return operatorMxA(&vdPowI, times, *this);} Matrix Matrix::operator^(int times) const { return operatorMxA(&vdPowI, times, *this);}
Matrix operator^( Matrix &&matrix,int times) { Matrix operator^( Matrix &&matrix,int times) {
return operatorMxA(&vdPowI, times, std::forward<Matrix&&>(matrix)); return operatorMxA(&vdPowI, times, std::forward<Matrix&&>(matrix));
} }
void Matrix::printf() { void Matrix::printf() {
if(isNull()) if(isNull())
{ {
@@ -463,6 +476,11 @@ namespace Aurora {
} }
} }
void Matrix::printfShape() {
std::cerr << "Matrix shape:(" << getDimSize(0) << "," << getDimSize(1) << ","
<< getDimSize(2) << ")" << std::endl;
}
Matrix::MatrixSlice Matrix::operator()(int aRowIdx, int aColIdx, int aSliceIdx) const { Matrix::MatrixSlice Matrix::operator()(int aRowIdx, int aColIdx, int aSliceIdx) const {
std::vector<int> dims = {aRowIdx, aColIdx, aSliceIdx}; std::vector<int> dims = {aRowIdx, aColIdx, aSliceIdx};
std::vector<int> allDimIndex; std::vector<int> allDimIndex;
@@ -711,7 +729,7 @@ namespace Aurora {
} }
} }
return Matrix::New(data,mSize,mSize2,0,mType); return Matrix::New(data,mSize,mSize2,1,mType);
} }
} }

View File

@@ -149,6 +149,12 @@ namespace Aurora {
*/ */
void printf(); void printf();
void printfShape();
bool isScalar() const;
double getScalar() const;
/** /**
* Get is the Matrix's data is empty or size is zero. * Get is the Matrix's data is empty or size is zero.
* @return * @return