From 9dd7d972378a53db933c7ff8536d0054527769e3 Mon Sep 17 00:00:00 2001 From: kradchen Date: Wed, 26 Mar 2025 13:02:43 +0800 Subject: [PATCH] feat: memory Improve for ifft & conj --- src/Function1D.cu | 24 ++++++++++++++++++++++++ src/Function1D.cuh | 2 ++ src/Function2D.cu | 27 ++++++++++++++++++++++++--- src/Function2D.cuh | 1 + 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/Function1D.cu b/src/Function1D.cu index b648e4a..2251a52 100644 --- a/src/Function1D.cu +++ b/src/Function1D.cu @@ -988,6 +988,17 @@ __global__ void conjKernel(float *aInputData, float *aOutput, unsigned int aInpu } } +__global__ void conjInplaceKernel(float *aInputData, unsigned int aInputSize) +{ + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < aInputSize) + { + unsigned int index = idx * 2; + aInputData[index + 1] = -aInputData[index + 1]; + } +} + + CudaMatrix Aurora::conj(const CudaMatrix &aMatrix) { if (!aMatrix.isComplex()) @@ -1003,6 +1014,19 @@ CudaMatrix Aurora::conj(const CudaMatrix &aMatrix) return Aurora::CudaMatrix::fromRawData(data, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), aMatrix.getValueType()); } +CudaMatrix Aurora::conj(CudaMatrix &&aMatrix) +{ + if (!aMatrix.isComplex()) + { + return CudaMatrix::copyFromRawData(aMatrix.getData(), aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2)); + } + size_t size = aMatrix.getDataSize(); + int blocksPerGrid = (size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + conjInplaceKernel<<>>(aMatrix.getData(), size); + cudaDeviceSynchronize(); + return aMatrix; +} + float Aurora::norm(const CudaMatrix &aMatrix, NormMethod aNormMethod) { float resultValue = 0; diff --git a/src/Function1D.cuh b/src/Function1D.cuh index 11b608d..9599dc0 100644 --- a/src/Function1D.cuh +++ b/src/Function1D.cuh @@ -63,6 +63,8 @@ namespace Aurora CudaMatrix conj(const CudaMatrix& aMatrix); + CudaMatrix conj(CudaMatrix&& aMatrix); + float norm(const CudaMatrix& aMatrix, NormMethod aNormMethod); CudaMatrix transpose(const CudaMatrix& aMatrix); diff --git a/src/Function2D.cu b/src/Function2D.cu index 49daf5e..6838292 100644 --- a/src/Function2D.cu +++ b/src/Function2D.cu @@ -39,6 +39,9 @@ using namespace Aurora; namespace { const int THREADS_PER_BLOCK = 256; + const int FFT_FORWARD = 0; + const int FFT_BACKWARD = 1; + } __global__ void maxColKernel(float *aInputData, float *aOutput, unsigned int aColSize) @@ -1454,7 +1457,7 @@ CudaMatrix Aurora::fft(const CudaMatrix &aMatrix, long aFFTSize) } cudaDeviceSynchronize(); auto ret = Aurora::CudaMatrix::fromRawData(data, ColEleCount, aMatrix.getDimSize(1), 1, Complex); - ExecFFT(ret, 0); + ExecFFT(ret, FFT_FORWARD); return ret; } @@ -1475,16 +1478,34 @@ CudaMatrix Aurora::ifft(const CudaMatrix &aMatrix, long aFFTSize) complexCopyKernel<<>>(aMatrix.getData(), data, needCopySize, aMatrix.getDimSize(0), ColEleCount); cudaDeviceSynchronize(); auto ret = Aurora::CudaMatrix::fromRawData(data, ColEleCount, aMatrix.getDimSize(1), 1, Complex); - ExecFFT(ret, 1); + ExecFFT(ret, FFT_BACKWARD); float colEleCountf = 1.f / ColEleCount; auto lambda = [=] __device__(const float &v) { return v * colEleCountf; - }; + } ; thrust::transform(thrust::device, ret.getData(), ret.getData() + ret.getDataSize() * 2, ret.getData(), lambda); return ret; } +CudaMatrix Aurora::ifft(CudaMatrix && aMatrix) +{ + if (!aMatrix.isComplex()) + { + std::cerr << "ifft input must be complex value" << std::endl; + return CudaMatrix(); + } + size_t ColEleCount = aMatrix.getDimSize(0); + ExecFFT(aMatrix, FFT_BACKWARD); + float colEleCountf = 1.f / ColEleCount; + auto lambda = [=] __device__(const float &v) + { + return v * colEleCountf; + } ; + thrust::transform(thrust::device, aMatrix.getData(), aMatrix.getData() + aMatrix.getDataSize() * 2, aMatrix.getData(), lambda); + return aMatrix; +}; + __global__ void fftshiftSwapKernel(float *aData, unsigned int aColEleCount) { unsigned int idx = blockIdx.x * aColEleCount + threadIdx.x; diff --git a/src/Function2D.cuh b/src/Function2D.cuh index afa92db..2fde164 100644 --- a/src/Function2D.cuh +++ b/src/Function2D.cuh @@ -64,6 +64,7 @@ namespace Aurora CudaMatrix fft(const CudaMatrix &aMatrix, long aFFTSize = -1); CudaMatrix ifft(const CudaMatrix &aMatrix, long aFFTSize = -1); + CudaMatrix ifft(CudaMatrix && aMatrix); CudaMatrix hilbert(const CudaMatrix &aMatrix);