Complex mul and div for CudaMatrix
This commit is contained in:
@@ -653,7 +653,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
||||
unsigned long long size = getDataSize() * getValueType();
|
||||
cudaMalloc((void**)&data, sizeof(float) * size);
|
||||
auto out = CudaMatrix::fromRawData(data, getDimSize(0), getDimSize(1), getDimSize(2), getValueType());
|
||||
unaryMul(this->getData(),aMatrix.getData(),out.getData(),this->getDataSize());
|
||||
if (isComplex()){
|
||||
unaryMulc(aMatrix.getData(),getData(),data,aMatrix.getDataSize());
|
||||
}
|
||||
else{
|
||||
unaryMul(aMatrix.getData(),getData(),data,aMatrix.getDataSize());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
CudaMatrix CudaMatrix::operator*(CudaMatrix &&aMatrix) const{
|
||||
@@ -667,7 +672,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
||||
<<" and the matrix1 type is "<<(aMatrix.isComplex()?"Comples":"Real")<<std::endl;
|
||||
return CudaMatrix();
|
||||
}
|
||||
if (isComplex()){
|
||||
unaryMulc(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize());
|
||||
}
|
||||
else{
|
||||
unaryMul(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize());
|
||||
}
|
||||
return aMatrix;
|
||||
}
|
||||
CudaMatrix operator*(CudaMatrix &&aMatrix,CudaMatrix &aOther){
|
||||
@@ -681,7 +691,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
||||
<<" and the matrix1 type is "<<(aOther.isComplex()?"Comples":"Real")<<std::endl;
|
||||
return CudaMatrix();
|
||||
}
|
||||
if (aMatrix.isComplex()){
|
||||
unaryMulc(aOther.getData(),aMatrix.getData(),aMatrix.getData(),aOther.getDataSize());
|
||||
}
|
||||
else{
|
||||
unaryMul(aOther.getData(),aMatrix.getData(),aMatrix.getData(),aOther.getDataSize());
|
||||
}
|
||||
return aMatrix;
|
||||
}
|
||||
|
||||
@@ -843,7 +858,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
||||
unsigned long long size = aMatrix.getDataSize() * aMatrix.getValueType();
|
||||
cudaMalloc((void**)&data, sizeof(float) * size);
|
||||
auto out = CudaMatrix::fromRawData(data, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), aMatrix.getValueType());
|
||||
unaryDiv(this->getData(),aMatrix.getData(),out.getData(),aMatrix.getDataSize());
|
||||
if (isComplex()){
|
||||
unaryDivc(getData(),aMatrix.getData(),data,aMatrix.getDataSize());
|
||||
}
|
||||
else{
|
||||
unaryDiv(getData(),aMatrix.getData(),data,aMatrix.getDataSize());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
CudaMatrix CudaMatrix::operator/(CudaMatrix &&aMatrix) const{
|
||||
@@ -858,7 +878,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
||||
<<" and the matrix1 size is "<<aMatrix.getDataSize()<<std::endl;
|
||||
return CudaMatrix();
|
||||
}
|
||||
if (aMatrix.isComplex()){
|
||||
unaryDivc(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
||||
}
|
||||
else{
|
||||
unaryDiv(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
||||
}
|
||||
return aMatrix;
|
||||
}
|
||||
CudaMatrix operator/(CudaMatrix &&aMatrix, CudaMatrix &aOther){
|
||||
@@ -873,7 +898,13 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
||||
<<" and the matrix1 size is "<<aOther.getDataSize()<<std::endl;
|
||||
return CudaMatrix();
|
||||
}
|
||||
if (aMatrix.isComplex()){
|
||||
unaryDivc(aMatrix.getData(),aOther.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
||||
}
|
||||
else{
|
||||
unaryDiv(aMatrix.getData(),aOther.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
||||
}
|
||||
|
||||
return aMatrix;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include <CudaMatrixPrivate.cuh>
|
||||
#include <math.h>
|
||||
#include <thrust/transform.h>
|
||||
#include <thrust/complex.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
using namespace thrust::placeholders;
|
||||
@@ -143,6 +144,15 @@ void unaryMul(float* in1, float* in2, float* out, unsigned long length)
|
||||
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
|
||||
}
|
||||
|
||||
void unaryMulc(float* in1, float* in2, float* out, unsigned long length)
|
||||
{
|
||||
thrust::complex<float>* _in1 = (thrust::complex<float>*)in1;
|
||||
thrust::complex<float>* _in2 = (thrust::complex<float>*)in2;
|
||||
thrust::complex<float>* _out = (thrust::complex<float>*)out;;
|
||||
thrust::multiplies<thrust::complex<float>> op;
|
||||
thrust::transform(thrust::device,_in1,_in1+length,_in2,_out,op);
|
||||
}
|
||||
|
||||
void unaryMul(float* in1, const float& in2, float* out, unsigned long length)
|
||||
{
|
||||
thrust::transform(thrust::device,in1, in1+length, out, in2 * _1);
|
||||
@@ -163,6 +173,15 @@ void unaryDiv(float* in1, float* in2, float* out, unsigned long length){
|
||||
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
|
||||
}
|
||||
|
||||
void unaryDivc(float* in1, float* in2, float* out, unsigned long length)
|
||||
{
|
||||
thrust::complex<float>* _in1 = (thrust::complex<float>*)in1;
|
||||
thrust::complex<float>* _in2 = (thrust::complex<float>*)in2;
|
||||
thrust::complex<float>* _out = (thrust::complex<float>*)out;;
|
||||
thrust::divides<thrust::complex<float>> op;
|
||||
thrust::transform(thrust::device,_in1,_in1+length,_in2,_out,op);
|
||||
}
|
||||
|
||||
void unarySub(const float& in1, float* in2, float* out, unsigned long length){
|
||||
thrust::transform(thrust::device,in2,in2+length,out,in1-_1);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ namespace{
|
||||
void unaryAdd(float* in1, float* in2, float* out, unsigned long length);
|
||||
void unaryAdd(float* in1, const float& in2, float* out, unsigned long length);
|
||||
void unaryMul(float* in1, float* in2, float* out, unsigned long length);
|
||||
void unaryMulc(float* in1, float* in2, float* out, unsigned long length);
|
||||
|
||||
void unaryMul(float* in1, const float& in2, float* out, unsigned long length);
|
||||
|
||||
void unaryNeg(float* in1, float* out, unsigned long length);
|
||||
@@ -19,6 +21,8 @@ void unaryPow(float* in1, float N,float* out, unsigned long length);
|
||||
|
||||
void unarySub(float* in1, float* in2, float* out, unsigned long length);
|
||||
void unaryDiv(float* in1, float* in2, float* out, unsigned long length);
|
||||
void unaryDivc(float* in1, float* in2, float* out, unsigned long length);
|
||||
|
||||
void unarySub(const float& in1, float* in2, float* out, unsigned long length);
|
||||
void unaryDiv(const float& in1, float* in2, float* out, unsigned long length);
|
||||
void unarySub(float* in1, const float& in2, float* out, unsigned long length);
|
||||
|
||||
Reference in New Issue
Block a user