Add Matrix compare with >=, <= and their test

This commit is contained in:
kradchen
2023-05-06 16:21:07 +08:00
parent e7317d0ade
commit 22a49127b7
3 changed files with 184 additions and 0 deletions

View File

@@ -604,6 +604,58 @@ namespace Aurora {
return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2));
}
Matrix Matrix::operator>=(double aScalar) const {
if (isComplex()) {
std::cerr << "Complex cann't compare!" << std::endl;
return Matrix();
}
Eigen::Map<Eigen::VectorXd> v(getData(), getDataSize());
double *ret = malloc(getDataSize());
Eigen::Map<Eigen::VectorXd> result(ret, getDataSize());
result.setConstant(0.0);
result = (v.array() >= aScalar).select(1.0, result);
return New(ret, getDimSize(0), getDimSize(1), getDimSize(2));
}
Matrix operator>=(double aScalar, const Matrix &matrix) {
if (matrix.isComplex()) {
std::cerr << "Complex cann't compare!" << std::endl;
return Matrix();
}
Eigen::Map<Eigen::VectorXd> v(matrix.getData(), matrix.getDataSize());
double *ret = malloc(matrix.getDataSize());
Eigen::Map<Eigen::VectorXd> result(ret, matrix.getDataSize());
result.setConstant(0.0);
result = (aScalar >= v.array() ).select(1.0, result);
return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2));
}
Matrix Matrix::operator>=(const Matrix &matrix) const {
if (isComplex() || matrix.isComplex()) {
std::cerr << "Complex cann't compare!" << std::endl;
return Matrix();
}
if (!compareShape(matrix) && !isScalar() && !matrix.isScalar()) {
std::cerr << "Matrix not equal, matrix 1(" << matrix.getDimSize(0) << "," << matrix.getDimSize(1) << ","
<< matrix.getDimSize(2) << "), matrix 2(" << getDimSize(0) << "," << getDimSize(1) << ","
<< getDimSize(2) << ")" << std::endl;
return Matrix();
}
if(isScalar()){
return getData()[0]<=matrix;
}
if(matrix.isScalar()){
return (*this)>=matrix.getData()[0];
}
Eigen::Map<Eigen::VectorXd> v(getData(), getDataSize());
Eigen::Map<Eigen::VectorXd> v2(matrix.getData(), matrix.getDataSize());
double *ret = malloc(matrix.getDataSize());
Eigen::Map<Eigen::VectorXd> result(ret, matrix.getDataSize());
result.setConstant(0.0);
result = (v.array() >= v2.array()).select(1.0, result);
return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2));
}
Matrix Matrix::operator<(double aScalar) const {
if (isComplex()) {
std::cerr << "Complex cann't compare!" << std::endl;
@@ -656,6 +708,58 @@ namespace Aurora {
return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2));
}
Matrix Matrix::operator<=(double aScalar) const {
if (isComplex()) {
std::cerr << "Complex cann't compare!" << std::endl;
return Matrix();
}
Eigen::Map<Eigen::VectorXd> v(getData(), getDataSize());
double *ret = malloc(getDataSize());
Eigen::Map<Eigen::VectorXd> result(ret, getDataSize());
result.setConstant(0.0);
result = (v.array() <= aScalar).select(1.0, result);
return New(ret, getDimSize(0), getDimSize(1), getDimSize(2));
}
Matrix operator<=(double aScalar, const Matrix &matrix) {
if (matrix.isComplex()) {
std::cerr << "Complex cann't compare!" << std::endl;
return Matrix();
}
Eigen::Map<Eigen::VectorXd> v(matrix.getData(), matrix.getDataSize());
double *ret = malloc(matrix.getDataSize());
Eigen::Map<Eigen::VectorXd> result(ret, matrix.getDataSize());
result.setConstant(0.0);
result = (aScalar <= v.array() ).select(1.0, result);
return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2));
}
Matrix Matrix::operator<=(const Matrix &matrix) const {
if (isComplex() || matrix.isComplex()) {
std::cerr << "Complex cann't compare!" << std::endl;
return Matrix();
}
if (!compareShape(matrix) && !isScalar() && !matrix.isScalar()) {
std::cerr << "Matrix not equal, matrix 1(" << matrix.getDimSize(0) << "," << matrix.getDimSize(1) << ","
<< matrix.getDimSize(2) << "), matrix 2(" << getDimSize(0) << "," << getDimSize(1) << ","
<< getDimSize(2) << ")" << std::endl;
return Matrix();
}
if(isScalar()){
return getData()[0]<=matrix;
}
if(matrix.isScalar()){
return (*this)<=matrix.getData()[0];
}
Eigen::Map<Eigen::VectorXd> v(getData(), getDataSize());
Eigen::Map<Eigen::VectorXd> v2(matrix.getData(), matrix.getDataSize());
double *ret = malloc(matrix.getDataSize());
Eigen::Map<Eigen::VectorXd> result(ret, matrix.getDataSize());
result.setConstant(0.0);
result = (v.array() <= v2.array()).select(1.0, result);
return Matrix::New(ret, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2));
}
Matrix Matrix::operator==(double aScalar) const {
if (isComplex()) {
std::cerr << "Complex cann't compare!" << std::endl;

View File

@@ -157,6 +157,14 @@ namespace Aurora {
friend Matrix operator<(double aScalar, const Matrix &matrix);
Matrix operator<(const Matrix &matrix) const;
Matrix operator>=(double aScalar) const;
friend Matrix operator>=(double aScalar, const Matrix &matrix);
Matrix operator>=(const Matrix &matrix) const;
Matrix operator<=(double aScalar) const;
friend Matrix operator<=(double aScalar, const Matrix &matrix);
Matrix operator<=(const Matrix &matrix) const;
Matrix operator==(double aScalar) const;
friend Matrix operator==(double aScalar, const Matrix &matrix);
Matrix operator==(const Matrix &matrix) const;

View File

@@ -285,4 +285,76 @@ TEST_F(Matrix_Test, matrixOpertaor) {
EXPECT_EQ(C.getData()[0], -2);
EXPECT_EQ(C.getData()[1], -2);
}
}
TEST_F(Matrix_Test, matrixCompare){
double *dataA = new double[8];
double *dataB = new double[8];
for (int i = 0; i < 8; ++i) {
dataA[i] = (double) (i);
dataB[i] = (double) (2);
}
Aurora::Matrix A = Aurora::Matrix::fromRawData(dataA, 2, 2, 2);
Aurora::Matrix B = Aurora::Matrix::fromRawData(dataB, 2, 2, 2);
auto C = A < B;
EXPECT_EQ(C.getData()[0], 1);
EXPECT_EQ(C.getData()[1], 1);
EXPECT_EQ(C.getData()[2], 0);
EXPECT_EQ(C.getData()[3], 0);
C = A <= B;
EXPECT_EQ(C.getData()[0], 1);
EXPECT_EQ(C.getData()[1], 1);
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 0);
C = A > B;
EXPECT_EQ(C.getData()[0], 0);
EXPECT_EQ(C.getData()[1], 0);
EXPECT_EQ(C.getData()[2], 0);
EXPECT_EQ(C.getData()[3], 1);
C = A >= B;
EXPECT_EQ(C.getData()[0], 0);
EXPECT_EQ(C.getData()[1], 0);
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 1);
C = A == B;
EXPECT_EQ(C.getData()[0], 0);
EXPECT_EQ(C.getData()[1], 0);
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 0);
C = A == 1.0;
EXPECT_EQ(C.getData()[0], 0);
EXPECT_EQ(C.getData()[1], 1);
EXPECT_EQ(C.getData()[2], 0);
EXPECT_EQ(C.getData()[3], 0);
C = A > 1.0;
EXPECT_EQ(C.getData()[0], 0);
EXPECT_EQ(C.getData()[1], 0);
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 1);
C = A >= 1.0;
EXPECT_EQ(C.getData()[0], 0);
EXPECT_EQ(C.getData()[1], 1);
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 1);
C = A < 3.0;
EXPECT_EQ(C.getData()[0], 1);
EXPECT_EQ(C.getData()[1], 1);
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 0);
C = A <= 2.0;
EXPECT_EQ(C.getData()[0], 1);
EXPECT_EQ(C.getData()[1], 1);
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 0);
C = A == 2.0;
EXPECT_EQ(C.getData()[0], 0);
EXPECT_EQ(C.getData()[1], 0);
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 0);
}