Add sum and sum's unit test.

This commit is contained in:
Krad
2023-04-23 17:32:04 +08:00
parent 46be233087
commit c3462330c1
3 changed files with 69 additions and 1 deletions

View File

@@ -222,7 +222,7 @@ Matrix Aurora::min(const Matrix &aMatrix, const Matrix &aOther) {
Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction) {
if (aMatrix.getDimSize(2)>1 || aMatrix.isComplex()) {
std::cerr
<< (aMatrix.getDimSize(2) > 1 ? "min() not support 3D data!" : "min() not support complex value type!")
<< (aMatrix.getDimSize(2) > 1 ? "max() not support 3D data!" : "max() not support complex value type!")
<< std::endl;
return Matrix();
}
@@ -244,6 +244,7 @@ Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction) {
return Matrix::New(ret,aMatrix.getDimSize(0),1);
}
case Column:
default:
{
Eigen::Map<Eigen::MatrixXd> srcMatrix(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(0));
@@ -253,3 +254,40 @@ Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction) {
}
}
}
Matrix Aurora::sum(const Matrix &aMatrix, FunctionDirection direction) {
if (aMatrix.getDimSize(2)>1 || aMatrix.isComplex()) {
std::cerr
<< (aMatrix.getDimSize(2) > 1 ? "sum() not support 3D data!" : "sum() not support complex value type!")
<< std::endl;
return Matrix();
}
switch (direction)
{
case All:
{
double * ret = malloc(1);
ret[0] = cblas_dasum(aMatrix.getDataSize(),aMatrix.getData(),1);
return Matrix::New(ret,1);
}
case Row:
{
double * ret = malloc(aMatrix.getDimSize(0));
for (int i = 0; i < aMatrix.getDimSize(0); ++i) {
ret[i] = cblas_dasum(aMatrix.getDimSize(1), aMatrix.getData() + i,
aMatrix.getDimSize(0));
}
return Matrix::New(ret,aMatrix.getDimSize(0),1);
}
case Column:
default:
{
double * ret = malloc(aMatrix.getDimSize(0));
for (int i = 0; i < aMatrix.getDimSize(1); ++i) {
ret[i] = cblas_dasum(aMatrix.getDimSize(0), aMatrix.getData()+aMatrix.getDimSize(0)*i,
1);
}
return Matrix::New(ret,1,aMatrix.getDimSize(1));
}
}
}

View File

@@ -33,6 +33,7 @@ namespace Aurora {
* @return
*/
Matrix max(const Matrix& aMatrix,FunctionDirection direction = Column);
/**
* 比较两个矩阵,求对应位置的最小值,不支持三维
* @attention 矩阵形状不一样时如A为[MxN],则B应为标量或[1xN]的行向量
@@ -42,6 +43,14 @@ namespace Aurora {
*/
Matrix min(const Matrix& aMatrix,const Matrix& aOther);
/**
* 求矩阵和,可按行、列、单元, 目前不支持三维,不支持复数
* @param aMatrix 矩阵
* @param direction 方向Column, Row, All
* @return
*/
Matrix sum(const Matrix& aMatrix,FunctionDirection direction = Column);
};