Add some compareSet functions

This commit is contained in:
kradchen
2023-06-09 14:32:02 +08:00
parent 4077843d13
commit 611da0872b
3 changed files with 96 additions and 1 deletions

View File

@@ -928,6 +928,80 @@ void Aurora::compareSet(Matrix& aMatrix,double compareValue, double newValue,Com
}
}
void Aurora::compareSet(Matrix& aMatrix,Matrix& aCompareMatrix,double compareValue, double newValue,CompareOp op){
Eigen::Map<Eigen::VectorXd> v(aMatrix.getData(),aMatrix.getDataSize());
Eigen::Map<Eigen::VectorXd> c(aCompareMatrix.getData(),aCompareMatrix.getDataSize());
switch (op) {
case EQ:
v = (c.array() == compareValue).select(newValue, v);
break;
case GT:
v = (c.array() > compareValue).select(newValue, v);
break;
case LT:
v = (c.array() < compareValue).select(newValue, v);
break;
case NG:
v = (c.array() <= compareValue).select(newValue, v);
break;
case NL:
v = (c.array() >= compareValue).select(newValue, v);
break;
case NE:
v = (c.array() != compareValue).select(newValue, v);
break;
}
}
void Aurora::compareSet(Matrix& aValueAndCompareMatrix,Matrix& aOtherCompareMatrix, double newValue,CompareOp op){
Eigen::Map<Eigen::VectorXd> v(aValueAndCompareMatrix.getData(),aValueAndCompareMatrix.getDataSize());
Eigen::Map<Eigen::VectorXd> c(aOtherCompareMatrix.getData(),aOtherCompareMatrix.getDataSize());
switch (op) {
case EQ:
v = (v.array() == c.array()).select(newValue, v);
break;
case GT:
v = (v.array() > c.array()).select(newValue, v);
break;
case LT:
v = (v.array() < c.array()).select(newValue, v);
break;
case NG:
v = (v.array() <= c.array()).select(newValue, v);
break;
case NL:
v = (v.array() >= c.array()).select(newValue, v);
break;
case NE:
v = (v.array() != c.array()).select(newValue, v);
break;
}
}
void Aurora::compareSet(Matrix& aCompareMatrix,double compareValue, Matrix& aNewValueMatrix,CompareOp op){
Eigen::Map<Eigen::VectorXd> v(aCompareMatrix.getData(),aCompareMatrix.getDataSize());
Eigen::Map<Eigen::VectorXd> nv(aNewValueMatrix.getData(),aNewValueMatrix.getDataSize());
switch (op) {
case EQ:
v = (v.array() == compareValue).select(nv, v);
break;
case GT:
v = (v.array() > compareValue).select(nv, v);
break;
case LT:
v = (v.array() < compareValue).select(nv, v);
break;
case NG:
v = (v.array() <= compareValue).select(nv, v);
break;
case NL:
v = (v.array() >= compareValue).select(nv, v);
break;
case NE:
v = (v.array() != compareValue).select(nv, v);
break;
}
}
Matrix Aurora::convertfp16tofloat(short* aData, int aRows, int aColumns)
{

View File

@@ -134,7 +134,10 @@ namespace Aurora {
enum CompareOp{
EQ,GT,LT,NG,NL,NE
};
void compareSet(Matrix& aMatrix,double compareValue, double newValue,CompareOp op);
void compareSet(Matrix& aValueMatrix,double compareValue, double newValue,CompareOp op);
void compareSet(Matrix& aValueMatrix,Matrix& aCompareMatrix,double compareValue, double newValue,CompareOp op);
void compareSet(Matrix& aDesAndCompareMatrix,Matrix& aOtherCompareMatrix, double newValue,CompareOp op);
void compareSet(Matrix& aCompareMatrix,double compareValue, Matrix& aNewValueMatrix,CompareOp op);
Matrix convertfp16tofloat(short* aData, int aRows, int aColumns);
};