Add complex support to Max function

This commit is contained in:
kradchen
2023-05-19 16:24:36 +08:00
parent ae27c18c13
commit 8b7bff2f00

View File

@@ -257,51 +257,53 @@ Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction) {
}
Matrix Aurora::max(const Matrix &aMatrix, FunctionDirection direction, long& rowIdx, long& colIdx) {
if (aMatrix.getDimSize(2)>1 || aMatrix.isComplex()) {
if (aMatrix.getDimSize(2)>1) {
std::cerr
<< (aMatrix.getDimSize(2) > 1 ? "max() not support 3D data!" : "max() not support complex value type!")
<< "max() not support 3D data!"
<< std::endl;
return Matrix();
}
auto calcMatrix = aMatrix.isComplex()?abs(aMatrix):aMatrix;
//针对向量行等于列
if (direction == Column && aMatrix.getDimSize(0)==1){
if (direction == Column && calcMatrix.getDimSize(0)==1){
direction = All;
}
switch (direction)
{
case All:
{
Eigen::Map<Eigen::MatrixXd> retV(aMatrix.getData(), aMatrix.getDimSize(0), aMatrix.getDimSize(1));
Eigen::Map<Eigen::MatrixXd> retV(calcMatrix.getData(), calcMatrix.getDimSize(0), calcMatrix.getDimSize(1));
double *ret = malloc(1);
ret[0] = retV.array().maxCoeff(&rowIdx, &colIdx);
return Matrix::New(ret,1);
}
case Row:
{
Eigen::Map<Eigen::MatrixXd> srcMatrix(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(0));
if (aMatrix.getDimSize(0) == 1){
Eigen::Map<Eigen::MatrixXd> srcMatrix(calcMatrix.getData(),calcMatrix.getDimSize(0),calcMatrix.getDimSize(1));
double * ret = malloc(calcMatrix.getDimSize(0));
if (calcMatrix.getDimSize(0) == 1){
ret[0] = srcMatrix.topRows(0).maxCoeff(&rowIdx, &colIdx);
}
else{
Eigen::Map<Eigen::MatrixXd> retMatrix(ret,aMatrix.getDimSize(0),1);
Eigen::Map<Eigen::MatrixXd> retMatrix(ret,calcMatrix.getDimSize(0),1);
retMatrix = srcMatrix.rowwise().maxCoeff();
}
return Matrix::New(ret,aMatrix.getDimSize(0),1);
return Matrix::New(ret,calcMatrix.getDimSize(0),1);
}
case Column:
default:
{
Eigen::Map<Eigen::MatrixXd> srcMatrix(aMatrix.getData(),aMatrix.getDimSize(0),aMatrix.getDimSize(1));
double * ret = malloc(aMatrix.getDimSize(0));
if (aMatrix.getDimSize(1) == 1){
Eigen::Map<Eigen::MatrixXd> srcMatrix(calcMatrix.getData(),calcMatrix.getDimSize(0),calcMatrix.getDimSize(1));
double * ret = malloc(calcMatrix.getDimSize(0));
if (calcMatrix.getDimSize(1) == 1){
ret[0] = srcMatrix.col(0).maxCoeff(&rowIdx, &colIdx);
}
else {
Eigen::Map<Eigen::MatrixXd> retMatrix(ret,1,aMatrix.getDimSize(1));
Eigen::Map<Eigen::MatrixXd> retMatrix(ret,1,calcMatrix.getDimSize(1));
retMatrix = srcMatrix.colwise().maxCoeff();
}
return Matrix::New(ret,1,aMatrix.getDimSize(1));
return Matrix::New(ret,1,calcMatrix.getDimSize(1));
}
}
}