diff --git a/src/Function1D.cpp b/src/Function1D.cpp index c5aad1d..d1b5ea9 100644 --- a/src/Function1D.cpp +++ b/src/Function1D.cpp @@ -928,6 +928,80 @@ void Aurora::compareSet(Matrix& aMatrix,double compareValue, double newValue,Com } } +void Aurora::compareSet(Matrix& aMatrix,Matrix& aCompareMatrix,double compareValue, double newValue,CompareOp op){ + Eigen::Map v(aMatrix.getData(),aMatrix.getDataSize()); + Eigen::Map c(aCompareMatrix.getData(),aCompareMatrix.getDataSize()); + switch (op) { + case EQ: + v = (c.array() == compareValue).select(newValue, v); + break; + case GT: + v = (c.array() > compareValue).select(newValue, v); + break; + case LT: + v = (c.array() < compareValue).select(newValue, v); + break; + case NG: + v = (c.array() <= compareValue).select(newValue, v); + break; + case NL: + v = (c.array() >= compareValue).select(newValue, v); + break; + case NE: + v = (c.array() != compareValue).select(newValue, v); + break; + } +} + +void Aurora::compareSet(Matrix& aValueAndCompareMatrix,Matrix& aOtherCompareMatrix, double newValue,CompareOp op){ + Eigen::Map v(aValueAndCompareMatrix.getData(),aValueAndCompareMatrix.getDataSize()); + Eigen::Map c(aOtherCompareMatrix.getData(),aOtherCompareMatrix.getDataSize()); + switch (op) { + case EQ: + v = (v.array() == c.array()).select(newValue, v); + break; + case GT: + v = (v.array() > c.array()).select(newValue, v); + break; + case LT: + v = (v.array() < c.array()).select(newValue, v); + break; + case NG: + v = (v.array() <= c.array()).select(newValue, v); + break; + case NL: + v = (v.array() >= c.array()).select(newValue, v); + break; + case NE: + v = (v.array() != c.array()).select(newValue, v); + break; + } +} +void Aurora::compareSet(Matrix& aCompareMatrix,double compareValue, Matrix& aNewValueMatrix,CompareOp op){ + Eigen::Map v(aCompareMatrix.getData(),aCompareMatrix.getDataSize()); + Eigen::Map nv(aNewValueMatrix.getData(),aNewValueMatrix.getDataSize()); + switch (op) { + case EQ: + v = (v.array() == compareValue).select(nv, v); + break; + case GT: + v = (v.array() > compareValue).select(nv, v); + break; + case LT: + v = (v.array() < compareValue).select(nv, v); + break; + case NG: + v = (v.array() <= compareValue).select(nv, v); + break; + case NL: + v = (v.array() >= compareValue).select(nv, v); + break; + case NE: + v = (v.array() != compareValue).select(nv, v); + break; + } +} + Matrix Aurora::convertfp16tofloat(short* aData, int aRows, int aColumns) { diff --git a/src/Function1D.h b/src/Function1D.h index e89f9b3..4d341b0 100644 --- a/src/Function1D.h +++ b/src/Function1D.h @@ -134,7 +134,10 @@ namespace Aurora { enum CompareOp{ EQ,GT,LT,NG,NL,NE }; - void compareSet(Matrix& aMatrix,double compareValue, double newValue,CompareOp op); + void compareSet(Matrix& aValueMatrix,double compareValue, double newValue,CompareOp op); + void compareSet(Matrix& aValueMatrix,Matrix& aCompareMatrix,double compareValue, double newValue,CompareOp op); + void compareSet(Matrix& aDesAndCompareMatrix,Matrix& aOtherCompareMatrix, double newValue,CompareOp op); + void compareSet(Matrix& aCompareMatrix,double compareValue, Matrix& aNewValueMatrix,CompareOp op); Matrix convertfp16tofloat(short* aData, int aRows, int aColumns); }; diff --git a/test/Function1D_Test.cpp b/test/Function1D_Test.cpp index 8b4a114..1bd7ba4 100644 --- a/test/Function1D_Test.cpp +++ b/test/Function1D_Test.cpp @@ -28,8 +28,10 @@ protected: }; TEST_F(Function1D_Test,compareSet){ double * dataA =Aurora::malloc(9); + double * dataB =Aurora::malloc(9); for (int i = 0; i < 9; ++i) { dataA[i]=(double)(i-3); + dataB[i]=(double)(i+2); } Aurora::Matrix A = Aurora::Matrix::New(dataA,3,3); EXPECT_EQ(-3, A[0]); @@ -39,6 +41,22 @@ TEST_F(Function1D_Test,compareSet){ EXPECT_EQ(0, A[0]); EXPECT_EQ(0, A[1]); EXPECT_EQ(0, A[2]); + Aurora::Matrix B = Aurora::Matrix::New(dataB,3,3); + compareSet(A,B,4,-1,Aurora::NE); + EXPECT_EQ(-1, A[0]); + EXPECT_EQ(-1, A[1]); + EXPECT_EQ(0, A[2]); + EXPECT_EQ(-1, A[3]); + compareSet(A,B,9,Aurora::LT); + EXPECT_EQ(9, A[0]); + EXPECT_EQ(9, A[1]); + EXPECT_EQ(9, A[2]); + EXPECT_EQ(9, A[3]); + compareSet(A,9,B,Aurora::EQ); + EXPECT_EQ(2, A[0]); + EXPECT_EQ(3, A[1]); + EXPECT_EQ(4, A[2]); + EXPECT_EQ(5, A[3]); }