Add prod, dot and their unit test.

This commit is contained in:
Krad
2023-04-25 17:28:33 +08:00
parent de1c1b0dda
commit cf96572074
3 changed files with 179 additions and 27 deletions

View File

@@ -288,37 +288,67 @@ Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction, long& row
}
Matrix Aurora::sum(const Matrix &aMatrix, FunctionDirection direction) {
if (aMatrix.getDimSize(2)>1 || aMatrix.isComplex()) {
std::cerr
<< (aMatrix.getDimSize(2) > 1 ? "sum() not support 3D data!" : "sum() not support complex value type!")
if (aMatrix.getDimSize(2)>1 ) {
std::cerr<< "sum() not support 3D data!"
<< std::endl;
return Matrix();
}
switch (direction)
{
case All:
if (aMatrix.isComplex()){
switch (direction)
{
Eigen::Map<Eigen::VectorXd> srcV(aMatrix.getData(),aMatrix.getDataSize());
double * ret = malloc(1);
ret[0] = srcV.array().sum();
return Matrix::New(ret,1);
case All:
{
Eigen::Map<Eigen::VectorXcd> srcV((std::complex<double>*)aMatrix.getData(),aMatrix.getDataSize());
std::complex<double>* ret = (std::complex<double>*)malloc(1,true);
ret[0] = srcV.array().sum();
return Matrix::New((double*)ret,1,1,1,Complex);
}
case Row:
{
Eigen::Map<Eigen::MatrixXcd> srcM((std::complex<double>*)aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
std::complex<double> * ret = (std::complex<double>*)malloc(aMatrix.getDimSize(0),true);
Eigen::Map<Eigen::VectorXcd> retV(ret,aMatrix.getDimSize(0));
retV = srcM.rowwise().sum();
return Matrix::New((double*)ret,aMatrix.getDimSize(0),1,1,Complex);
}
case Column:
default:
{
Eigen::Map<Eigen::MatrixXcd> srcM((std::complex<double>*)aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
std::complex<double>* ret = (std::complex<double>*)malloc(aMatrix.getDimSize(1),true);
Eigen::Map<Eigen::VectorXcd> retV(ret,aMatrix.getDimSize(1));
retV = srcM.colwise().sum();
return Matrix::New((double*)ret,1,aMatrix.getDimSize(1),1,Complex);
}
}
case Row:
}
else{
switch (direction)
{
Eigen::Map<Eigen::MatrixXd> srcM(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(0));
Eigen::Map<Eigen::VectorXd> retV(ret,aMatrix.getDimSize(0));
retV = srcM.rowwise().sum();
return Matrix::New(ret,aMatrix.getDimSize(0),1);
}
case Column:
default:
{
Eigen::Map<Eigen::MatrixXd> srcM(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(1));
Eigen::Map<Eigen::VectorXd> retV(ret,aMatrix.getDimSize(0));
retV = srcM.colwise().sum();
return Matrix::New(ret,1,aMatrix.getDimSize(1));
case All:
{
Eigen::Map<Eigen::VectorXd> srcV(aMatrix.getData(),aMatrix.getDataSize());
double * ret = malloc(1);
ret[0] = srcV.array().sum();
return Matrix::New(ret,1);
}
case Row:
{
Eigen::Map<Eigen::MatrixXd> srcM(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(0));
Eigen::Map<Eigen::VectorXd> retV(ret,aMatrix.getDimSize(0));
retV = srcM.rowwise().sum();
return Matrix::New(ret,aMatrix.getDimSize(0),1);
}
case Column:
default:
{
Eigen::Map<Eigen::MatrixXd> srcM(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(1));
Eigen::Map<Eigen::VectorXd> retV(ret,aMatrix.getDimSize(0));
retV = srcM.colwise().sum();
return Matrix::New(ret,1,aMatrix.getDimSize(1));
}
}
}
}
@@ -624,7 +654,7 @@ Matrix Aurora::ifft(const Matrix &aMatrix) {
Matrix Aurora::hilbert(const Matrix &aMatrix) {
auto x = fft(aMatrix);
auto h = new double[aMatrix.getDimSize(0)];
auto h = malloc(aMatrix.getDimSize(0));
auto two = 2.0;
auto zero = 0.0;
cblas_dcopy(aMatrix.getDimSize(0), &zero, 0, h, 1);
@@ -641,3 +671,60 @@ Matrix Aurora::hilbert(const Matrix &aMatrix) {
return result;
}
Matrix Aurora::prod(const Matrix &aMatrix) {
if (aMatrix.getDimSize(2) > 1 ) {
std::cerr<< "prod() not support 3D data!"
<< std::endl;
return Matrix();
}
if (aMatrix.isComplex()){
Eigen::Map<Eigen::MatrixXcd> srcM((std::complex<double>*)aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
auto ret = malloc(aMatrix.getDimSize(1),true);
Eigen::Map<Eigen::VectorXcd> retV((std::complex<double>*)ret,aMatrix.getDimSize(1));
retV = srcM.colwise().prod();
return Matrix::New(ret,1,aMatrix.getDimSize(1),1,Complex);
}
else{
Eigen::Map<Eigen::MatrixXd> srcM(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
auto ret = malloc(aMatrix.getDimSize(1));
Eigen::Map<Eigen::VectorXd> retV(ret,aMatrix.getDimSize(1));
retV = srcM.colwise().prod();
return Matrix::New(ret,1,aMatrix.getDimSize(1));
}
}
Matrix Aurora::dot(const Matrix &aMatrix,const Matrix& aOther,FunctionDirection direction ) {
if ( direction == All){
std::cerr<< "dot() not support 3D data!"
<< std::endl;
return Matrix();
}
if (!aMatrix.compareShape(aOther)){
std::cerr<< "dot() matrix must be same shape!"
<< std::endl;
return Matrix();
}
if (aMatrix.isComplex()){
return sum(conj(aMatrix)*aOther,direction);
}
else{
if (direction == Column)
{
auto ret = malloc(aMatrix.getDimSize(1));
for (int i = 0; i < aMatrix.getDimSize(1); ++i) {
ret[i]=cblas_ddot(aMatrix.getDimSize(0),aMatrix.getData()+i*aMatrix.getDimSize(0),1,
aOther.getData()+i*aMatrix.getDimSize(0),1);
}
return Matrix::New(ret,1,aMatrix.getDimSize(1),1);
}
else{
auto ret = malloc(aMatrix.getDimSize(0));
for (int i = 0; i < aMatrix.getDimSize(0); ++i) {
ret[i] = cblas_ddot(aMatrix.getDimSize(1),aMatrix.getData()+i,aMatrix.getDimSize(0),
aOther.getData()+i,aMatrix.getDimSize(0));
}
return Matrix::New(ret,aMatrix.getDimSize(1),1,1);
}
}
}

View File

@@ -48,7 +48,7 @@ namespace Aurora {
Matrix min(const Matrix& aMatrix,const Matrix& aOther);
/**
* 求矩阵和,可按行、列、单元, 目前不支持三维,不支持复数
* 求矩阵和,可按行、列、单元, 目前不支持三维
* @param aMatrix 目标矩阵
* @param direction 方向Column, Row, All
* @return 求和结果矩阵
@@ -120,6 +120,15 @@ namespace Aurora {
* @return
*/
Matrix hilbert(const Matrix& aMatrix);
/**
* prod支持到2维
* @param aMatrix
* @return
*/
Matrix prod(const Matrix& aMatrix);
Matrix dot(const Matrix& aMatrix,const Matrix& aOther,FunctionDirection direction = Column);
};