From 22a49127b7416af958ef83a9eb6be684deeb8a68 Mon Sep 17 00:00:00 2001 From: kradchen Date: Sat, 6 May 2023 16:21:07 +0800 Subject: [PATCH] Add Matrix compare with >=, <= and their test --- src/Matrix.cpp | 104 +++++++++++++++++++++++++++++++++++++++++++ src/Matrix.h | 8 ++++ test/Matrix_Test.cpp | 72 ++++++++++++++++++++++++++++++ 3 files changed, 184 insertions(+) diff --git a/src/Matrix.cpp b/src/Matrix.cpp index d084816..4360a77 100644 --- a/src/Matrix.cpp +++ b/src/Matrix.cpp @@ -604,6 +604,58 @@ namespace Aurora { return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2)); } + Matrix Matrix::operator>=(double aScalar) const { + if (isComplex()) { + std::cerr << "Complex cann't compare!" << std::endl; + return Matrix(); + } + Eigen::Map v(getData(), getDataSize()); + double *ret = malloc(getDataSize()); + Eigen::Map result(ret, getDataSize()); + result.setConstant(0.0); + result = (v.array() >= aScalar).select(1.0, result); + return New(ret, getDimSize(0), getDimSize(1), getDimSize(2)); + } + + Matrix operator>=(double aScalar, const Matrix &matrix) { + if (matrix.isComplex()) { + std::cerr << "Complex cann't compare!" << std::endl; + return Matrix(); + } + Eigen::Map v(matrix.getData(), matrix.getDataSize()); + double *ret = malloc(matrix.getDataSize()); + Eigen::Map result(ret, matrix.getDataSize()); + result.setConstant(0.0); + result = (aScalar >= v.array() ).select(1.0, result); + return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2)); + } + + Matrix Matrix::operator>=(const Matrix &matrix) const { + if (isComplex() || matrix.isComplex()) { + std::cerr << "Complex cann't compare!" << std::endl; + return Matrix(); + } + if (!compareShape(matrix) && !isScalar() && !matrix.isScalar()) { + std::cerr << "Matrix not equal, matrix 1(" << matrix.getDimSize(0) << "," << matrix.getDimSize(1) << "," + << matrix.getDimSize(2) << "), matrix 2(" << getDimSize(0) << "," << getDimSize(1) << "," + << getDimSize(2) << ")" << std::endl; + return Matrix(); + } + if(isScalar()){ + return getData()[0]<=matrix; + } + if(matrix.isScalar()){ + return (*this)>=matrix.getData()[0]; + } + Eigen::Map v(getData(), getDataSize()); + Eigen::Map v2(matrix.getData(), matrix.getDataSize()); + double *ret = malloc(matrix.getDataSize()); + Eigen::Map result(ret, matrix.getDataSize()); + result.setConstant(0.0); + result = (v.array() >= v2.array()).select(1.0, result); + return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2)); + } + Matrix Matrix::operator<(double aScalar) const { if (isComplex()) { std::cerr << "Complex cann't compare!" << std::endl; @@ -656,6 +708,58 @@ namespace Aurora { return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2)); } + Matrix Matrix::operator<=(double aScalar) const { + if (isComplex()) { + std::cerr << "Complex cann't compare!" << std::endl; + return Matrix(); + } + Eigen::Map v(getData(), getDataSize()); + double *ret = malloc(getDataSize()); + Eigen::Map result(ret, getDataSize()); + result.setConstant(0.0); + result = (v.array() <= aScalar).select(1.0, result); + return New(ret, getDimSize(0), getDimSize(1), getDimSize(2)); + } + + Matrix operator<=(double aScalar, const Matrix &matrix) { + if (matrix.isComplex()) { + std::cerr << "Complex cann't compare!" << std::endl; + return Matrix(); + } + Eigen::Map v(matrix.getData(), matrix.getDataSize()); + double *ret = malloc(matrix.getDataSize()); + Eigen::Map result(ret, matrix.getDataSize()); + result.setConstant(0.0); + result = (aScalar <= v.array() ).select(1.0, result); + return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2)); + } + + Matrix Matrix::operator<=(const Matrix &matrix) const { + if (isComplex() || matrix.isComplex()) { + std::cerr << "Complex cann't compare!" << std::endl; + return Matrix(); + } + if (!compareShape(matrix) && !isScalar() && !matrix.isScalar()) { + std::cerr << "Matrix not equal, matrix 1(" << matrix.getDimSize(0) << "," << matrix.getDimSize(1) << "," + << matrix.getDimSize(2) << "), matrix 2(" << getDimSize(0) << "," << getDimSize(1) << "," + << getDimSize(2) << ")" << std::endl; + return Matrix(); + } + if(isScalar()){ + return getData()[0]<=matrix; + } + if(matrix.isScalar()){ + return (*this)<=matrix.getData()[0]; + } + Eigen::Map v(getData(), getDataSize()); + Eigen::Map v2(matrix.getData(), matrix.getDataSize()); + double *ret = malloc(matrix.getDataSize()); + Eigen::Map result(ret, matrix.getDataSize()); + result.setConstant(0.0); + result = (v.array() <= v2.array()).select(1.0, result); + return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2)); + } + Matrix Matrix::operator==(double aScalar) const { if (isComplex()) { std::cerr << "Complex cann't compare!" << std::endl; diff --git a/src/Matrix.h b/src/Matrix.h index 0be0acc..df95176 100644 --- a/src/Matrix.h +++ b/src/Matrix.h @@ -157,6 +157,14 @@ namespace Aurora { friend Matrix operator<(double aScalar, const Matrix &matrix); Matrix operator<(const Matrix &matrix) const; + Matrix operator>=(double aScalar) const; + friend Matrix operator>=(double aScalar, const Matrix &matrix); + Matrix operator>=(const Matrix &matrix) const; + + Matrix operator<=(double aScalar) const; + friend Matrix operator<=(double aScalar, const Matrix &matrix); + Matrix operator<=(const Matrix &matrix) const; + Matrix operator==(double aScalar) const; friend Matrix operator==(double aScalar, const Matrix &matrix); Matrix operator==(const Matrix &matrix) const; diff --git a/test/Matrix_Test.cpp b/test/Matrix_Test.cpp index 1c39381..55deaf0 100644 --- a/test/Matrix_Test.cpp +++ b/test/Matrix_Test.cpp @@ -285,4 +285,76 @@ TEST_F(Matrix_Test, matrixOpertaor) { EXPECT_EQ(C.getData()[0], -2); EXPECT_EQ(C.getData()[1], -2); } +} + +TEST_F(Matrix_Test, matrixCompare){ + double *dataA = new double[8]; + double *dataB = new double[8]; + for (int i = 0; i < 8; ++i) { + dataA[i] = (double) (i); + dataB[i] = (double) (2); + } + Aurora::Matrix A = Aurora::Matrix::fromRawData(dataA, 2, 2, 2); + Aurora::Matrix B = Aurora::Matrix::fromRawData(dataB, 2, 2, 2); + auto C = A < B; + EXPECT_EQ(C.getData()[0], 1); + EXPECT_EQ(C.getData()[1], 1); + EXPECT_EQ(C.getData()[2], 0); + EXPECT_EQ(C.getData()[3], 0); + C = A <= B; + EXPECT_EQ(C.getData()[0], 1); + EXPECT_EQ(C.getData()[1], 1); + EXPECT_EQ(C.getData()[2], 1); + EXPECT_EQ(C.getData()[3], 0); + C = A > B; + EXPECT_EQ(C.getData()[0], 0); + EXPECT_EQ(C.getData()[1], 0); + EXPECT_EQ(C.getData()[2], 0); + EXPECT_EQ(C.getData()[3], 1); + C = A >= B; + EXPECT_EQ(C.getData()[0], 0); + EXPECT_EQ(C.getData()[1], 0); + EXPECT_EQ(C.getData()[2], 1); + EXPECT_EQ(C.getData()[3], 1); + C = A == B; + EXPECT_EQ(C.getData()[0], 0); + EXPECT_EQ(C.getData()[1], 0); + EXPECT_EQ(C.getData()[2], 1); + EXPECT_EQ(C.getData()[3], 0); + C = A == 1.0; + EXPECT_EQ(C.getData()[0], 0); + EXPECT_EQ(C.getData()[1], 1); + EXPECT_EQ(C.getData()[2], 0); + EXPECT_EQ(C.getData()[3], 0); + + C = A > 1.0; + EXPECT_EQ(C.getData()[0], 0); + EXPECT_EQ(C.getData()[1], 0); + EXPECT_EQ(C.getData()[2], 1); + EXPECT_EQ(C.getData()[3], 1); + + C = A >= 1.0; + EXPECT_EQ(C.getData()[0], 0); + EXPECT_EQ(C.getData()[1], 1); + EXPECT_EQ(C.getData()[2], 1); + EXPECT_EQ(C.getData()[3], 1); + + C = A < 3.0; + EXPECT_EQ(C.getData()[0], 1); + EXPECT_EQ(C.getData()[1], 1); + EXPECT_EQ(C.getData()[2], 1); + EXPECT_EQ(C.getData()[3], 0); + + C = A <= 2.0; + EXPECT_EQ(C.getData()[0], 1); + EXPECT_EQ(C.getData()[1], 1); + EXPECT_EQ(C.getData()[2], 1); + EXPECT_EQ(C.getData()[3], 0); + + C = A == 2.0; + EXPECT_EQ(C.getData()[0], 0); + EXPECT_EQ(C.getData()[1], 0); + EXPECT_EQ(C.getData()[2], 1); + EXPECT_EQ(C.getData()[3], 0); + } \ No newline at end of file