diff --git a/src/Function1D.cu b/src/Function1D.cu index 3d5e121..7dd5825 100644 --- a/src/Function1D.cu +++ b/src/Function1D.cu @@ -1253,3 +1253,60 @@ CudaMatrix Aurora::auroraUnion(const CudaMatrix& aMatrix1, const CudaMatrix& aMa return CudaMatrix::fromRawData(data, endPointer - data); } + +CudaMatrix Aurora::intersect(const CudaMatrix& aMatrix1, const CudaMatrix& aMatrix2) +{ + if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.isComplex() || aMatrix2.isComplex()) + { + std::cerr<<"intersect not support complex cudamatrix"<>>(aMatrix1.getData(), result.getData(), aMatrix1.getDataSize(), iaResult, size); + cudaDeviceSynchronize(); + + aIa = CudaMatrix::fromRawData(iaResult,size); + return result; +} diff --git a/src/Function1D.cuh b/src/Function1D.cuh index adaefcc..5db2eb7 100644 --- a/src/Function1D.cuh +++ b/src/Function1D.cuh @@ -75,6 +75,10 @@ namespace Aurora CudaMatrix auroraUnion(const CudaMatrix& aMatrix1, const CudaMatrix& aMatrix2); + CudaMatrix intersect(const CudaMatrix& aMatrix1, const CudaMatrix& aMatrix2); + + CudaMatrix intersect(const CudaMatrix& aMatrix1, const CudaMatrix& aMatrix2, CudaMatrix& aIa); + // ------compareSet---------------------------------------------------- diff --git a/test/Function1D_Cuda_Test.cpp b/test/Function1D_Cuda_Test.cpp index f0f45a9..20fb373 100644 --- a/test/Function1D_Cuda_Test.cpp +++ b/test/Function1D_Cuda_Test.cpp @@ -965,3 +965,25 @@ TEST_F(Function1D_Cuda_Test, auroraUnion) { EXPECT_FLOAT_AE(result1[i], result2[i]); } } + +TEST_F(Function1D_Cuda_Test, intersect) { + float* data1 = new float[9]{3,3,2,2,2,1,4,4,7}; + auto matrix1 = Aurora::Matrix::fromRawData(data1, 9,1,1).toDeviceMatrix(); + float* data2 = new float[8]{6,6,7,7,8,1,2}; + auto matrix2 = Aurora::Matrix::fromRawData(data2, 7,1,1).toDeviceMatrix(); + + auto result = Aurora::intersect(matrix1, matrix2).toHostMatrix(); + EXPECT_FLOAT_AE(result.getData()[0],1); + EXPECT_FLOAT_AE(result.getData()[1],2); + EXPECT_FLOAT_AE(result.getData()[2],7); + + Aurora::CudaMatrix ia; + result = Aurora::intersect(matrix1, matrix2, ia).toHostMatrix(); + auto iaHost = ia.toHostMatrix(); + EXPECT_FLOAT_AE(result.getData()[0],1); + EXPECT_FLOAT_AE(result.getData()[1],2); + EXPECT_FLOAT_AE(result.getData()[2],7); + EXPECT_FLOAT_AE(iaHost.getData()[0],6); + EXPECT_FLOAT_AE(iaHost.getData()[1],3); + EXPECT_FLOAT_AE(iaHost.getData()[2],9); +}