Complex mul and div for CudaMatrix

This commit is contained in:
kradchen
2023-12-08 16:18:05 +08:00
parent a65ee38196
commit 7dc3bc221a
3 changed files with 62 additions and 8 deletions

View File

@@ -653,7 +653,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
unsigned long long size = getDataSize() * getValueType();
cudaMalloc((void**)&data, sizeof(float) * size);
auto out = CudaMatrix::fromRawData(data, getDimSize(0), getDimSize(1), getDimSize(2), getValueType());
unaryMul(this->getData(),aMatrix.getData(),out.getData(),this->getDataSize());
if (isComplex()){
unaryMulc(aMatrix.getData(),getData(),data,aMatrix.getDataSize());
}
else{
unaryMul(aMatrix.getData(),getData(),data,aMatrix.getDataSize());
}
return out;
}
CudaMatrix CudaMatrix::operator*(CudaMatrix &&aMatrix) const{
@@ -667,7 +672,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
<<" and the matrix1 type is "<<(aMatrix.isComplex()?"Comples":"Real")<<std::endl;
return CudaMatrix();
}
unaryMul(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize());
if (isComplex()){
unaryMulc(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize());
}
else{
unaryMul(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize());
}
return aMatrix;
}
CudaMatrix operator*(CudaMatrix &&aMatrix,CudaMatrix &aOther){
@@ -681,7 +691,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
<<" and the matrix1 type is "<<(aOther.isComplex()?"Comples":"Real")<<std::endl;
return CudaMatrix();
}
unaryMul(aOther.getData(),aMatrix.getData(),aMatrix.getData(),aOther.getDataSize());
if (aMatrix.isComplex()){
unaryMulc(aOther.getData(),aMatrix.getData(),aMatrix.getData(),aOther.getDataSize());
}
else{
unaryMul(aOther.getData(),aMatrix.getData(),aMatrix.getData(),aOther.getDataSize());
}
return aMatrix;
}
@@ -843,7 +858,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
unsigned long long size = aMatrix.getDataSize() * aMatrix.getValueType();
cudaMalloc((void**)&data, sizeof(float) * size);
auto out = CudaMatrix::fromRawData(data, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), aMatrix.getValueType());
unaryDiv(this->getData(),aMatrix.getData(),out.getData(),aMatrix.getDataSize());
if (isComplex()){
unaryDivc(getData(),aMatrix.getData(),data,aMatrix.getDataSize());
}
else{
unaryDiv(getData(),aMatrix.getData(),data,aMatrix.getDataSize());
}
return out;
}
CudaMatrix CudaMatrix::operator/(CudaMatrix &&aMatrix) const{
@@ -858,7 +878,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
<<" and the matrix1 size is "<<aMatrix.getDataSize()<<std::endl;
return CudaMatrix();
}
unaryDiv(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize());
if (aMatrix.isComplex()){
unaryDivc(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize());
}
else{
unaryDiv(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize());
}
return aMatrix;
}
CudaMatrix operator/(CudaMatrix &&aMatrix, CudaMatrix &aOther){
@@ -873,7 +898,13 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
<<" and the matrix1 size is "<<aOther.getDataSize()<<std::endl;
return CudaMatrix();
}
unaryDiv(aMatrix.getData(),aOther.getData(),aMatrix.getData(),aMatrix.getDataSize());
if (aMatrix.isComplex()){
unaryDivc(aMatrix.getData(),aOther.getData(),aMatrix.getData(),aMatrix.getDataSize());
}
else{
unaryDiv(aMatrix.getData(),aOther.getData(),aMatrix.getData(),aMatrix.getDataSize());
}
return aMatrix;
}

View File

@@ -1,6 +1,7 @@
#include <CudaMatrixPrivate.cuh>
#include <math.h>
#include <thrust/transform.h>
#include <thrust/complex.h>
#include <thrust/functional.h>
#include <thrust/execution_policy.h>
using namespace thrust::placeholders;
@@ -143,6 +144,15 @@ void unaryMul(float* in1, float* in2, float* out, unsigned long length)
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
}
void unaryMulc(float* in1, float* in2, float* out, unsigned long length)
{
thrust::complex<float>* _in1 = (thrust::complex<float>*)in1;
thrust::complex<float>* _in2 = (thrust::complex<float>*)in2;
thrust::complex<float>* _out = (thrust::complex<float>*)out;;
thrust::multiplies<thrust::complex<float>> op;
thrust::transform(thrust::device,_in1,_in1+length,_in2,_out,op);
}
void unaryMul(float* in1, const float& in2, float* out, unsigned long length)
{
thrust::transform(thrust::device,in1, in1+length, out, in2 * _1);
@@ -163,6 +173,15 @@ void unaryDiv(float* in1, float* in2, float* out, unsigned long length){
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
}
void unaryDivc(float* in1, float* in2, float* out, unsigned long length)
{
thrust::complex<float>* _in1 = (thrust::complex<float>*)in1;
thrust::complex<float>* _in2 = (thrust::complex<float>*)in2;
thrust::complex<float>* _out = (thrust::complex<float>*)out;;
thrust::divides<thrust::complex<float>> op;
thrust::transform(thrust::device,_in1,_in1+length,_in2,_out,op);
}
void unarySub(const float& in1, float* in2, float* out, unsigned long length){
thrust::transform(thrust::device,in2,in2+length,out,in1-_1);
}
@@ -183,8 +202,8 @@ void unaryDiv(float* in1, const float& in2, float* out, unsigned long length){
void unaryPow(float* in1, float N,float* out, unsigned long length){
if (N == 0.0f)
{
thrust::fill(thrust::device,out,out+length,1);
return;
thrust::fill(thrust::device,out,out+length,1);
return;
}
if (N == 1.0f)
{

View File

@@ -12,6 +12,8 @@ namespace{
void unaryAdd(float* in1, float* in2, float* out, unsigned long length);
void unaryAdd(float* in1, const float& in2, float* out, unsigned long length);
void unaryMul(float* in1, float* in2, float* out, unsigned long length);
void unaryMulc(float* in1, float* in2, float* out, unsigned long length);
void unaryMul(float* in1, const float& in2, float* out, unsigned long length);
void unaryNeg(float* in1, float* out, unsigned long length);
@@ -19,6 +21,8 @@ void unaryPow(float* in1, float N,float* out, unsigned long length);
void unarySub(float* in1, float* in2, float* out, unsigned long length);
void unaryDiv(float* in1, float* in2, float* out, unsigned long length);
void unaryDivc(float* in1, float* in2, float* out, unsigned long length);
void unarySub(const float& in1, float* in2, float* out, unsigned long length);
void unaryDiv(const float& in1, float* in2, float* out, unsigned long length);
void unarySub(float* in1, const float& in2, float* out, unsigned long length);