From 7dc3bc221ae0accd12938893d1b9a056107ddb81 Mon Sep 17 00:00:00 2001 From: kradchen Date: Fri, 8 Dec 2023 16:18:05 +0800 Subject: [PATCH] Complex mul and div for CudaMatrix --- src/CudaMatrix.cpp | 43 +++++++++++++++++++++++++++++++++------ src/CudaMatrixPrivate.cu | 23 +++++++++++++++++++-- src/CudaMatrixPrivate.cuh | 4 ++++ 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/src/CudaMatrix.cpp b/src/CudaMatrix.cpp index 0b4f073..d19e3c5 100644 --- a/src/CudaMatrix.cpp +++ b/src/CudaMatrix.cpp @@ -653,7 +653,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat unsigned long long size = getDataSize() * getValueType(); cudaMalloc((void**)&data, sizeof(float) * size); auto out = CudaMatrix::fromRawData(data, getDimSize(0), getDimSize(1), getDimSize(2), getValueType()); - unaryMul(this->getData(),aMatrix.getData(),out.getData(),this->getDataSize()); + if (isComplex()){ + unaryMulc(aMatrix.getData(),getData(),data,aMatrix.getDataSize()); + } + else{ + unaryMul(aMatrix.getData(),getData(),data,aMatrix.getDataSize()); + } return out; } CudaMatrix CudaMatrix::operator*(CudaMatrix &&aMatrix) const{ @@ -667,7 +672,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat <<" and the matrix1 type is "<<(aMatrix.isComplex()?"Comples":"Real")<getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize()); + if (isComplex()){ + unaryMulc(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize()); + } + else{ + unaryMul(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize()); + } return aMatrix; } CudaMatrix operator*(CudaMatrix &&aMatrix,CudaMatrix &aOther){ @@ -681,7 +691,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat <<" and the matrix1 type is "<<(aOther.isComplex()?"Comples":"Real")<getData(),aMatrix.getData(),out.getData(),aMatrix.getDataSize()); + if (isComplex()){ + unaryDivc(getData(),aMatrix.getData(),data,aMatrix.getDataSize()); + } + else{ + unaryDiv(getData(),aMatrix.getData(),data,aMatrix.getDataSize()); + } return out; } CudaMatrix CudaMatrix::operator/(CudaMatrix &&aMatrix) const{ @@ -858,7 +878,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat <<" and the matrix1 size is "<getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize()); + if (aMatrix.isComplex()){ + unaryDivc(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize()); + } + else{ + unaryDiv(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize()); + } return aMatrix; } CudaMatrix operator/(CudaMatrix &&aMatrix, CudaMatrix &aOther){ @@ -873,7 +898,13 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat <<" and the matrix1 size is "< #include #include +#include #include #include using namespace thrust::placeholders; @@ -143,6 +144,15 @@ void unaryMul(float* in1, float* in2, float* out, unsigned long length) thrust::transform(thrust::device,in1,in1+length,in2,out,op); } +void unaryMulc(float* in1, float* in2, float* out, unsigned long length) +{ + thrust::complex* _in1 = (thrust::complex*)in1; + thrust::complex* _in2 = (thrust::complex*)in2; + thrust::complex* _out = (thrust::complex*)out;; + thrust::multiplies> op; + thrust::transform(thrust::device,_in1,_in1+length,_in2,_out,op); +} + void unaryMul(float* in1, const float& in2, float* out, unsigned long length) { thrust::transform(thrust::device,in1, in1+length, out, in2 * _1); @@ -163,6 +173,15 @@ void unaryDiv(float* in1, float* in2, float* out, unsigned long length){ thrust::transform(thrust::device,in1,in1+length,in2,out,op); } +void unaryDivc(float* in1, float* in2, float* out, unsigned long length) +{ + thrust::complex* _in1 = (thrust::complex*)in1; + thrust::complex* _in2 = (thrust::complex*)in2; + thrust::complex* _out = (thrust::complex*)out;; + thrust::divides> op; + thrust::transform(thrust::device,_in1,_in1+length,_in2,_out,op); +} + void unarySub(const float& in1, float* in2, float* out, unsigned long length){ thrust::transform(thrust::device,in2,in2+length,out,in1-_1); } @@ -183,8 +202,8 @@ void unaryDiv(float* in1, const float& in2, float* out, unsigned long length){ void unaryPow(float* in1, float N,float* out, unsigned long length){ if (N == 0.0f) { - thrust::fill(thrust::device,out,out+length,1); - return; + thrust::fill(thrust::device,out,out+length,1); + return; } if (N == 1.0f) { diff --git a/src/CudaMatrixPrivate.cuh b/src/CudaMatrixPrivate.cuh index 8b60081..89b0fdc 100644 --- a/src/CudaMatrixPrivate.cuh +++ b/src/CudaMatrixPrivate.cuh @@ -12,6 +12,8 @@ namespace{ void unaryAdd(float* in1, float* in2, float* out, unsigned long length); void unaryAdd(float* in1, const float& in2, float* out, unsigned long length); void unaryMul(float* in1, float* in2, float* out, unsigned long length); +void unaryMulc(float* in1, float* in2, float* out, unsigned long length); + void unaryMul(float* in1, const float& in2, float* out, unsigned long length); void unaryNeg(float* in1, float* out, unsigned long length); @@ -19,6 +21,8 @@ void unaryPow(float* in1, float N,float* out, unsigned long length); void unarySub(float* in1, float* in2, float* out, unsigned long length); void unaryDiv(float* in1, float* in2, float* out, unsigned long length); +void unaryDivc(float* in1, float* in2, float* out, unsigned long length); + void unarySub(const float& in1, float* in2, float* out, unsigned long length); void unaryDiv(const float& in1, float* in2, float* out, unsigned long length); void unarySub(float* in1, const float& in2, float* out, unsigned long length);