Add complex support to Max function

This commit is contained in:
kradchen
2023-05-19 16:24:36 +08:00
parent ae27c18c13
commit 8b7bff2f00

View File

@@ -257,51 +257,53 @@ Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction) {
} }
Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction, long& rowIdx, long& colIdx) { Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction, long& rowIdx, long& colIdx) {
if (aMatrix.getDimSize(2)>1 || aMatrix.isComplex()) { if (aMatrix.getDimSize(2)>1) {
std::cerr std::cerr
<< (aMatrix.getDimSize(2) > 1 ? "max() not support 3D data!" : "max() not support complex value type!") << "max() not support 3D data!"
<< std::endl; << std::endl;
return Matrix(); return Matrix();
} }
auto calcMatrix = aMatrix.isComplex()?abs(aMatrix):aMatrix;
//针对向量行等于列 //针对向量行等于列
if (direction == Column && aMatrix.getDimSize(0)==1){ if (direction == Column && calcMatrix.getDimSize(0)==1){
direction = All; direction = All;
} }
switch (direction) switch (direction)
{ {
case All: case All:
{ {
Eigen::Map<Eigen::MatrixXd> retV(aMatrix.getData(), aMatrix.getDimSize(0), aMatrix.getDimSize(1)); Eigen::Map<Eigen::MatrixXd> retV(calcMatrix.getData(), calcMatrix.getDimSize(0), calcMatrix.getDimSize(1));
double *ret = malloc(1); double *ret = malloc(1);
ret[0] = retV.array().maxCoeff(&rowIdx, &colIdx); ret[0] = retV.array().maxCoeff(&rowIdx, &colIdx);
return Matrix::New(ret,1); return Matrix::New(ret,1);
} }
case Row: case Row:
{ {
Eigen::Map<Eigen::MatrixXd> srcMatrix(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1)); Eigen::Map<Eigen::MatrixXd> srcMatrix(calcMatrix.getData(),calcMatrix.getDimSize(0),calcMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(0)); double * ret = malloc(calcMatrix.getDimSize(0));
if (aMatrix.getDimSize(0) == 1){ if (calcMatrix.getDimSize(0) == 1){
ret[0] = srcMatrix.topRows(0).maxCoeff(&rowIdx, &colIdx); ret[0] = srcMatrix.topRows(0).maxCoeff(&rowIdx, &colIdx);
} }
else{ else{
Eigen::Map<Eigen::MatrixXd> retMatrix(ret,aMatrix.getDimSize(0),1); Eigen::Map<Eigen::MatrixXd> retMatrix(ret,calcMatrix.getDimSize(0),1);
retMatrix = srcMatrix.rowwise().maxCoeff(); retMatrix = srcMatrix.rowwise().maxCoeff();
} }
return Matrix::New(ret,aMatrix.getDimSize(0),1); return Matrix::New(ret,calcMatrix.getDimSize(0),1);
} }
case Column: case Column:
default: default:
{ {
Eigen::Map<Eigen::MatrixXd> srcMatrix(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1)); Eigen::Map<Eigen::MatrixXd> srcMatrix(calcMatrix.getData(),calcMatrix.getDimSize(0),calcMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(0)); double * ret = malloc(calcMatrix.getDimSize(0));
if (aMatrix.getDimSize(1) == 1){ if (calcMatrix.getDimSize(1) == 1){
ret[0] = srcMatrix.col(0).maxCoeff(&rowIdx, &colIdx); ret[0] = srcMatrix.col(0).maxCoeff(&rowIdx, &colIdx);
} }
else { else {
Eigen::Map<Eigen::MatrixXd> retMatrix(ret,1,aMatrix.getDimSize(1)); Eigen::Map<Eigen::MatrixXd> retMatrix(ret,1,calcMatrix.getDimSize(1));
retMatrix = srcMatrix.colwise().maxCoeff(); retMatrix = srcMatrix.colwise().maxCoeff();
} }
return Matrix::New(ret,1,aMatrix.getDimSize(1)); return Matrix::New(ret,1,calcMatrix.getDimSize(1));
} }
} }
} }