Complex mul and div for CudaMatrix
This commit is contained in:
@@ -653,7 +653,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
|||||||
unsigned long long size = getDataSize() * getValueType();
|
unsigned long long size = getDataSize() * getValueType();
|
||||||
cudaMalloc((void**)&data, sizeof(float) * size);
|
cudaMalloc((void**)&data, sizeof(float) * size);
|
||||||
auto out = CudaMatrix::fromRawData(data, getDimSize(0), getDimSize(1), getDimSize(2), getValueType());
|
auto out = CudaMatrix::fromRawData(data, getDimSize(0), getDimSize(1), getDimSize(2), getValueType());
|
||||||
unaryMul(this->getData(),aMatrix.getData(),out.getData(),this->getDataSize());
|
if (isComplex()){
|
||||||
|
unaryMulc(aMatrix.getData(),getData(),data,aMatrix.getDataSize());
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
unaryMul(aMatrix.getData(),getData(),data,aMatrix.getDataSize());
|
||||||
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
CudaMatrix CudaMatrix::operator*(CudaMatrix &&aMatrix) const{
|
CudaMatrix CudaMatrix::operator*(CudaMatrix &&aMatrix) const{
|
||||||
@@ -667,7 +672,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
|||||||
<<" and the matrix1 type is "<<(aMatrix.isComplex()?"Comples":"Real")<<std::endl;
|
<<" and the matrix1 type is "<<(aMatrix.isComplex()?"Comples":"Real")<<std::endl;
|
||||||
return CudaMatrix();
|
return CudaMatrix();
|
||||||
}
|
}
|
||||||
|
if (isComplex()){
|
||||||
|
unaryMulc(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize());
|
||||||
|
}
|
||||||
|
else{
|
||||||
unaryMul(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize());
|
unaryMul(this->getData(),aMatrix.getData(),aMatrix.getData(),this->getDataSize());
|
||||||
|
}
|
||||||
return aMatrix;
|
return aMatrix;
|
||||||
}
|
}
|
||||||
CudaMatrix operator*(CudaMatrix &&aMatrix,CudaMatrix &aOther){
|
CudaMatrix operator*(CudaMatrix &&aMatrix,CudaMatrix &aOther){
|
||||||
@@ -681,7 +691,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
|||||||
<<" and the matrix1 type is "<<(aOther.isComplex()?"Comples":"Real")<<std::endl;
|
<<" and the matrix1 type is "<<(aOther.isComplex()?"Comples":"Real")<<std::endl;
|
||||||
return CudaMatrix();
|
return CudaMatrix();
|
||||||
}
|
}
|
||||||
|
if (aMatrix.isComplex()){
|
||||||
|
unaryMulc(aOther.getData(),aMatrix.getData(),aMatrix.getData(),aOther.getDataSize());
|
||||||
|
}
|
||||||
|
else{
|
||||||
unaryMul(aOther.getData(),aMatrix.getData(),aMatrix.getData(),aOther.getDataSize());
|
unaryMul(aOther.getData(),aMatrix.getData(),aMatrix.getData(),aOther.getDataSize());
|
||||||
|
}
|
||||||
return aMatrix;
|
return aMatrix;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -843,7 +858,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
|||||||
unsigned long long size = aMatrix.getDataSize() * aMatrix.getValueType();
|
unsigned long long size = aMatrix.getDataSize() * aMatrix.getValueType();
|
||||||
cudaMalloc((void**)&data, sizeof(float) * size);
|
cudaMalloc((void**)&data, sizeof(float) * size);
|
||||||
auto out = CudaMatrix::fromRawData(data, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), aMatrix.getValueType());
|
auto out = CudaMatrix::fromRawData(data, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), aMatrix.getValueType());
|
||||||
unaryDiv(this->getData(),aMatrix.getData(),out.getData(),aMatrix.getDataSize());
|
if (isComplex()){
|
||||||
|
unaryDivc(getData(),aMatrix.getData(),data,aMatrix.getDataSize());
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
unaryDiv(getData(),aMatrix.getData(),data,aMatrix.getDataSize());
|
||||||
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
CudaMatrix CudaMatrix::operator/(CudaMatrix &&aMatrix) const{
|
CudaMatrix CudaMatrix::operator/(CudaMatrix &&aMatrix) const{
|
||||||
@@ -858,7 +878,12 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
|||||||
<<" and the matrix1 size is "<<aMatrix.getDataSize()<<std::endl;
|
<<" and the matrix1 size is "<<aMatrix.getDataSize()<<std::endl;
|
||||||
return CudaMatrix();
|
return CudaMatrix();
|
||||||
}
|
}
|
||||||
|
if (aMatrix.isComplex()){
|
||||||
|
unaryDivc(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
||||||
|
}
|
||||||
|
else{
|
||||||
unaryDiv(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
unaryDiv(this->getData(),aMatrix.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
||||||
|
}
|
||||||
return aMatrix;
|
return aMatrix;
|
||||||
}
|
}
|
||||||
CudaMatrix operator/(CudaMatrix &&aMatrix, CudaMatrix &aOther){
|
CudaMatrix operator/(CudaMatrix &&aMatrix, CudaMatrix &aOther){
|
||||||
@@ -873,7 +898,13 @@ bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMat
|
|||||||
<<" and the matrix1 size is "<<aOther.getDataSize()<<std::endl;
|
<<" and the matrix1 size is "<<aOther.getDataSize()<<std::endl;
|
||||||
return CudaMatrix();
|
return CudaMatrix();
|
||||||
}
|
}
|
||||||
|
if (aMatrix.isComplex()){
|
||||||
|
unaryDivc(aMatrix.getData(),aOther.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
||||||
|
}
|
||||||
|
else{
|
||||||
unaryDiv(aMatrix.getData(),aOther.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
unaryDiv(aMatrix.getData(),aOther.getData(),aMatrix.getData(),aMatrix.getDataSize());
|
||||||
|
}
|
||||||
|
|
||||||
return aMatrix;
|
return aMatrix;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#include <CudaMatrixPrivate.cuh>
|
#include <CudaMatrixPrivate.cuh>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <thrust/transform.h>
|
#include <thrust/transform.h>
|
||||||
|
#include <thrust/complex.h>
|
||||||
#include <thrust/functional.h>
|
#include <thrust/functional.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
using namespace thrust::placeholders;
|
using namespace thrust::placeholders;
|
||||||
@@ -143,6 +144,15 @@ void unaryMul(float* in1, float* in2, float* out, unsigned long length)
|
|||||||
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
|
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void unaryMulc(float* in1, float* in2, float* out, unsigned long length)
|
||||||
|
{
|
||||||
|
thrust::complex<float>* _in1 = (thrust::complex<float>*)in1;
|
||||||
|
thrust::complex<float>* _in2 = (thrust::complex<float>*)in2;
|
||||||
|
thrust::complex<float>* _out = (thrust::complex<float>*)out;;
|
||||||
|
thrust::multiplies<thrust::complex<float>> op;
|
||||||
|
thrust::transform(thrust::device,_in1,_in1+length,_in2,_out,op);
|
||||||
|
}
|
||||||
|
|
||||||
void unaryMul(float* in1, const float& in2, float* out, unsigned long length)
|
void unaryMul(float* in1, const float& in2, float* out, unsigned long length)
|
||||||
{
|
{
|
||||||
thrust::transform(thrust::device,in1, in1+length, out, in2 * _1);
|
thrust::transform(thrust::device,in1, in1+length, out, in2 * _1);
|
||||||
@@ -163,6 +173,15 @@ void unaryDiv(float* in1, float* in2, float* out, unsigned long length){
|
|||||||
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
|
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void unaryDivc(float* in1, float* in2, float* out, unsigned long length)
|
||||||
|
{
|
||||||
|
thrust::complex<float>* _in1 = (thrust::complex<float>*)in1;
|
||||||
|
thrust::complex<float>* _in2 = (thrust::complex<float>*)in2;
|
||||||
|
thrust::complex<float>* _out = (thrust::complex<float>*)out;;
|
||||||
|
thrust::divides<thrust::complex<float>> op;
|
||||||
|
thrust::transform(thrust::device,_in1,_in1+length,_in2,_out,op);
|
||||||
|
}
|
||||||
|
|
||||||
void unarySub(const float& in1, float* in2, float* out, unsigned long length){
|
void unarySub(const float& in1, float* in2, float* out, unsigned long length){
|
||||||
thrust::transform(thrust::device,in2,in2+length,out,in1-_1);
|
thrust::transform(thrust::device,in2,in2+length,out,in1-_1);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ namespace{
|
|||||||
void unaryAdd(float* in1, float* in2, float* out, unsigned long length);
|
void unaryAdd(float* in1, float* in2, float* out, unsigned long length);
|
||||||
void unaryAdd(float* in1, const float& in2, float* out, unsigned long length);
|
void unaryAdd(float* in1, const float& in2, float* out, unsigned long length);
|
||||||
void unaryMul(float* in1, float* in2, float* out, unsigned long length);
|
void unaryMul(float* in1, float* in2, float* out, unsigned long length);
|
||||||
|
void unaryMulc(float* in1, float* in2, float* out, unsigned long length);
|
||||||
|
|
||||||
void unaryMul(float* in1, const float& in2, float* out, unsigned long length);
|
void unaryMul(float* in1, const float& in2, float* out, unsigned long length);
|
||||||
|
|
||||||
void unaryNeg(float* in1, float* out, unsigned long length);
|
void unaryNeg(float* in1, float* out, unsigned long length);
|
||||||
@@ -19,6 +21,8 @@ void unaryPow(float* in1, float N,float* out, unsigned long length);
|
|||||||
|
|
||||||
void unarySub(float* in1, float* in2, float* out, unsigned long length);
|
void unarySub(float* in1, float* in2, float* out, unsigned long length);
|
||||||
void unaryDiv(float* in1, float* in2, float* out, unsigned long length);
|
void unaryDiv(float* in1, float* in2, float* out, unsigned long length);
|
||||||
|
void unaryDivc(float* in1, float* in2, float* out, unsigned long length);
|
||||||
|
|
||||||
void unarySub(const float& in1, float* in2, float* out, unsigned long length);
|
void unarySub(const float& in1, float* in2, float* out, unsigned long length);
|
||||||
void unaryDiv(const float& in1, float* in2, float* out, unsigned long length);
|
void unaryDiv(const float& in1, float* in2, float* out, unsigned long length);
|
||||||
void unarySub(float* in1, const float& in2, float* out, unsigned long length);
|
void unarySub(float* in1, const float& in2, float* out, unsigned long length);
|
||||||
|
|||||||
Reference in New Issue
Block a user