Add cuda auroraUnion and unittest.

This commit is contained in:
sunwen
2023-12-04 11:27:17 +08:00
parent 65c78cd878
commit 5e77b4dafe
3 changed files with 38 additions and 0 deletions

View File

@@ -5,6 +5,7 @@
#include "Matrix.h"
#include <cmath>
#include <cstddef>
#include <cstdlib>
#include <thrust/device_vector.h>
#include <thrust/transform.h>
@@ -1232,3 +1233,23 @@ CudaMatrix Aurora::linspaceCuda(float aStart, float aEnd, int aNum)
cudaDeviceSynchronize();
return Aurora::CudaMatrix::fromRawData(data,aNum);
}
CudaMatrix Aurora::auroraUnion(const CudaMatrix& aMatrix1, const CudaMatrix& aMatrix2)
{
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.isComplex() || aMatrix2.isComplex())
{
std::cerr<<"auroraUnion not support complex cudamatrix"<<std::endl;
return CudaMatrix();
}
size_t size1= aMatrix1.getDataSize();
size_t size2= aMatrix2.getDataSize();
float* data = nullptr;
cudaMalloc((void**)&data, sizeof(float) * (size1 + size2));
cudaMemcpy(data, aMatrix1.getData(), sizeof(float) * size1, cudaMemcpyDeviceToDevice);
cudaMemcpy(data + size1, aMatrix2.getData(), sizeof(float) * size2, cudaMemcpyDeviceToDevice);
thrust::sort(thrust::device, data, data+size1+size2);
float* endPointer = thrust::unique(thrust::device, data, data+size1+size2);
return CudaMatrix::fromRawData(data, endPointer - data);
}

View File

@@ -73,6 +73,8 @@ namespace Aurora
CudaMatrix linspaceCuda(float aStart, float aEnd, int aNum);
CudaMatrix auroraUnion(const CudaMatrix& aMatrix1, const CudaMatrix& aMatrix2);
// ------compareSet----------------------------------------------------