Add intersect and intersect's unittest.

This commit is contained in:
sunwen
2023-04-26 16:43:09 +08:00
parent 9a3ec2805e
commit abe69eb26d
3 changed files with 80 additions and 1 deletions

View File

@@ -553,7 +553,7 @@ Matrix Aurora::linspace(double aStart, double aEnd, int aNum)
Matrix Aurora::auroraUnion(const Matrix& aMatrix1, const Matrix& aMatrix2)
{
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getValueType() == Complex || aMatrix2.getValueType() == Complex)
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.isComplex() || aMatrix2.isComplex())
{
return Matrix();
}
@@ -572,3 +572,53 @@ Matrix Aurora::auroraUnion(const Matrix& aMatrix1, const Matrix& aMatrix2)
return Matrix::copyFromRawData(vector.data(),vector.size());
}
Matrix Aurora::intersect(const Matrix& aMatrix1, const Matrix& aMatrix2)
{
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.isComplex() || aMatrix2.isComplex())
{
return Matrix();
}
size_t size1= aMatrix1.getDataSize();
size_t size2= aMatrix2.getDataSize();
std::vector<double> vector1(aMatrix1.getData(), aMatrix1.getData() + size1);
std::vector<double> vector2(aMatrix2.getData(), aMatrix2.getData() + size2);
std::sort(vector1.begin(), vector1.end());
std::sort(vector2.begin(), vector2.end());
std::vector<double> intersection;
std::set_intersection(vector1.begin(), vector1.end(),
vector2.begin(), vector2.end(),
std::back_inserter(intersection));
return Matrix::copyFromRawData(intersection.data(), intersection.size());
}
Matrix Aurora::intersect(const Matrix& aMatrix1, const Matrix& aMatrix2, Matrix& aIa)
{
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.isComplex() || aMatrix2.isComplex())
{
return Matrix();
}
Matrix result = intersect(aMatrix1,aMatrix2);
size_t size = result.getDataSize();
double* iaResult = Aurora::malloc(size);
for(size_t i=0; i<size; ++i)
{
for(size_t j=0; j<aMatrix1.getDataSize(); ++j)
{
if(aMatrix1.getData()[j] == result.getData()[i])
{
iaResult[i] = j + 1;
break;
}
}
}
aIa = Matrix::New(iaResult,size);
return result;
}

View File

@@ -76,6 +76,14 @@ namespace Aurora {
Matrix auroraUnion(const Matrix& aMatrix1, const Matrix& aMatrix2);
Matrix intersect(const Matrix& aMatrix1, const Matrix& aMatrix2);
/**
* 并集
* @param aIa, [C,ia,~] = intersect(A,B)用法中ia的返回值
* @return 并集结果
*/
Matrix intersect(const Matrix& aMatrix1, const Matrix& aMatrix2, Matrix& aIa);
/**
* 多项式计算
* @brief 例如p[1 0 1],x[3 2 5],代表对多项式 y = x^2 + 1 求(x=3, x=2, x=5)时所有的y

View File

@@ -455,3 +455,24 @@ TEST_F(Function1D_Test, auroraUnion) {
EXPECT_DOUBLE_AE(result.getData()[5],7);
EXPECT_DOUBLE_AE(result.getData()[6],8);
}
TEST_F(Function1D_Test, intersect) {
double* data1 = new double[9]{3,3,2,2,2,1,4,4,7};
auto matrix1 = Aurora::Matrix::fromRawData(data1, 9,1,1);
double* data2 = new double[8]{6,6,7,7,8,1,2};
auto matrix2 = Aurora::Matrix::fromRawData(data2, 7,1,1);
auto result = Aurora::intersect(matrix1, matrix2);
EXPECT_DOUBLE_AE(result.getData()[0],1);
EXPECT_DOUBLE_AE(result.getData()[1],2);
EXPECT_DOUBLE_AE(result.getData()[2],7);
Aurora::Matrix ia;
result = Aurora::intersect(matrix1, matrix2, ia);
EXPECT_DOUBLE_AE(result.getData()[0],1);
EXPECT_DOUBLE_AE(result.getData()[1],2);
EXPECT_DOUBLE_AE(result.getData()[2],7);
EXPECT_DOUBLE_AE(ia.getData()[0],6);
EXPECT_DOUBLE_AE(ia.getData()[1],3);
EXPECT_DOUBLE_AE(ia.getData()[2],9);
}