faet: Add valid cuda function.

This commit is contained in:
sunwen
2024-12-24 10:44:06 +08:00
parent 5407c3ccb6
commit 04e0c4b38d
2 changed files with 49 additions and 1 deletions

View File

@@ -1675,4 +1675,50 @@ CudaMatrix Aurora::hilbert(const CudaMatrix &aMatrix)
x = x * h;
auto result = ifft(x);
return result;
}
}
__global__ void validKernel(const float* aData, const float* aValid, float* aOutput, int aOutputRowCount, int aOutputColumnCount)
{
int threadIndex = blockIdx.x * blockDim.x + threadIdx.x;
int dataIndex = (int)aValid[threadIndex];
if(threadIndex < aOutputColumnCount)
{
for(int i=0; i < aOutputRowCount; ++i)
{
aOutput[threadIndex * aOutputRowCount + i] = aData[dataIndex * aOutputRowCount + i];
}
}
}
Aurora::CudaMatrix Aurora::valid(const Aurora::CudaMatrix aData, const Aurora::CudaMatrix aValid)
{
int validSize = aValid.getDataSize();
int rowCount = aData.getDimSize(0);
float* hostValid = new float[validSize];
float* validProcessed = new float[validSize];
float* validProcessedDevice = nullptr;
cudaMemcpy(hostValid, aValid.getData(), sizeof(float) * validSize, cudaMemcpyDeviceToHost);
int validColumnCount = 0;
for(int i=0;i<validSize;++i)
{
if(hostValid[i] == 1)
{
validProcessed[validColumnCount] = i;
++validColumnCount;
}
}
cudaMalloc((void**)&validProcessedDevice, sizeof(float) * validColumnCount );
cudaMemcpy(validProcessedDevice, validProcessed, sizeof(float) * validColumnCount, cudaMemcpyHostToDevice);
int threadPerBlock = 1024;
int blockPerGrid = validColumnCount / threadPerBlock + 1;
float* result = nullptr;
cudaMalloc((void**)&result, sizeof(float) * validColumnCount * rowCount);
validKernel<<<blockPerGrid, threadPerBlock>>>(aData.getData(), validProcessedDevice, result, rowCount, validColumnCount);
cudaDeviceSynchronize();
cudaFree(validProcessedDevice);
delete[] hostValid;
delete[] validProcessed;
return Aurora::CudaMatrix::fromRawData(result, rowCount, validColumnCount);
}

View File

@@ -86,6 +86,8 @@ namespace Aurora
*/
CudaMatrix ifft_symmetric(const CudaMatrix &aMatrix,long aLength);
CudaMatrix valid(const CudaMatrix aData, const CudaMatrix aValid);
}
#endif // __FUNCTION2D_CUDA_H__