Add CudaMatrix block, setBlock, setBlockValue.
This commit is contained in:
@@ -160,10 +160,6 @@ CudaMatrix CudaMatrix::block(int aDim,int aBeginIndex, int aEndIndex) const
|
||||
std::cerr<<"CudaMatrix block only support 1D-3D data!"<<std::endl;
|
||||
return CudaMatrix();
|
||||
}
|
||||
if (isVector() && aDim == 0 && getDimSize(1)>1)
|
||||
{
|
||||
aDim = 1;
|
||||
}
|
||||
|
||||
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0)
|
||||
{
|
||||
@@ -220,7 +216,6 @@ CudaMatrix CudaMatrix::block(int aDim,int aBeginIndex, int aEndIndex) const
|
||||
return CudaMatrix::fromRawData(dataOutput,getDimSize(0),dimLength,getDimSize(2),getValueType());
|
||||
}
|
||||
case 2:
|
||||
default:
|
||||
{
|
||||
int copySize = dimLength*sliceStride;
|
||||
cudaMemcpy(dataOutput,
|
||||
@@ -228,19 +223,273 @@ CudaMatrix CudaMatrix::block(int aDim,int aBeginIndex, int aEndIndex) const
|
||||
sizeof(float) * copySize*getValueType(), cudaMemcpyDeviceToDevice);
|
||||
return CudaMatrix::fromRawData(dataOutput,getDimSize(0),getDimSize(1),dimLength,getValueType());
|
||||
}
|
||||
default:
|
||||
{
|
||||
return CudaMatrix();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CudaMatrix::setBlockValue(int aDim,int aBeginIndx, int aEndIndex,float value)
|
||||
bool CudaMatrix::setBlockValue(int aDim,int aBeginIndex, int aEndIndex,float aValue)
|
||||
{
|
||||
if(aDim>2 )
|
||||
{
|
||||
std::cerr<<"CudaMatrix block only support 1D-3D data!"<<std::endl;
|
||||
std::cerr<<"CudaMatrix setblockValue only support 1D-3D data!"<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setblockValue BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setblockValue EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aEndIndex < aBeginIndex)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setblockValue EndIndex can not less than BeginIndex ! BeginIndex:"<<aBeginIndex <<", EndIndex:"<<aEndIndex<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
int dimLength = aEndIndex - aBeginIndex + 1;
|
||||
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
|
||||
int colStride = getDimSize(0);
|
||||
int sliceStride = getDimSize(0)*getDimSize(1);
|
||||
|
||||
switch (aDim)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
int colStride2 = dimLength;
|
||||
int sliceStride2 = dimLength*getDimSize(1);
|
||||
for (size_t i = 0; i < getDimSize(2); i++)
|
||||
{
|
||||
for (size_t j = 0; j < getDimSize(1); j++)
|
||||
{
|
||||
float* begin = mData.get() + (aBeginIndex + j * colStride + i * sliceStride)*getValueType();
|
||||
thrustFill(begin, begin + colStride2*getValueType(), aValue);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
int colStride2 = getDimSize(0);
|
||||
int copySize = dimLength*getDimSize(0);
|
||||
for (size_t i = 0; i < getDimSize(2); i++)
|
||||
{
|
||||
float* begin = mData.get() + getValueType()*(aBeginIndex * colStride + i * sliceStride);
|
||||
thrustFill(begin, begin + copySize * getValueType(), aValue);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
int copySize = dimLength*sliceStride;
|
||||
float* begin = mData.get() + aBeginIndex * sliceStride*getValueType();
|
||||
thrustFill(begin, begin + copySize *getValueType(), aValue);
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CudaMatrix::setBlockComplexValue(int aDim,int aBeginIndex, int aEndIndex, std::complex<float> aValue)
|
||||
{
|
||||
if(getValueType() != Complex)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setBlockComplexValue only support complex matrix"<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(aDim>2 )
|
||||
{
|
||||
std::cerr<<"CudaMatrix setBlockComplexValue only support 1D-3D data!"<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setBlockComplexValue BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setBlockComplexValue EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aEndIndex < aBeginIndex)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setBlockComplexValue EndIndex can not less than BeginIndex ! BeginIndex:"<<aBeginIndex <<", EndIndex:"<<aEndIndex<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
int dimLength = aEndIndex - aBeginIndex + 1;
|
||||
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
|
||||
int colStride = getDimSize(0);
|
||||
int sliceStride = getDimSize(0)*getDimSize(1);
|
||||
|
||||
switch (aDim)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
int colStride2 = dimLength;
|
||||
int sliceStride2 = dimLength*getDimSize(1);
|
||||
for (size_t i = 0; i < getDimSize(2); i++)
|
||||
{
|
||||
for (size_t j = 0; j < getDimSize(1); j++)
|
||||
{
|
||||
float* begin = mData.get() + (aBeginIndex + j * colStride + i * sliceStride)*getValueType();
|
||||
thrustFill(begin, begin + colStride2*getValueType(), aValue);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
int colStride2 = getDimSize(0);
|
||||
int copySize = dimLength*getDimSize(0);
|
||||
for (size_t i = 0; i < getDimSize(2); i++)
|
||||
{
|
||||
float* begin = mData.get() + getValueType()*(aBeginIndex * colStride + i * sliceStride);
|
||||
thrustFill(begin, begin + copySize * getValueType(), aValue);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
int copySize = dimLength*sliceStride;
|
||||
float* begin = mData.get() + aBeginIndex * sliceStride*getValueType();
|
||||
thrustFill(begin, begin + copySize *getValueType(), aValue);
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CudaMatrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const CudaMatrix& aMatrix)
|
||||
{
|
||||
if( aDim>2 )
|
||||
{
|
||||
std::cerr<<"CudaMatrix setBlock only support 1D-3D data!"<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setBlock BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0)
|
||||
{
|
||||
std::cerr<<"CudaMatrix setBlock EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
int dimLength = std::abs(aEndIndex-aBeginIndex)+1;
|
||||
size_t newdims[3]{0};
|
||||
|
||||
switch (aDim)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
newdims[0] = dimLength;
|
||||
newdims[1] = getDimSize(1);
|
||||
newdims[2] = getDimSize(2);
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
newdims[0] = getDimSize(0);
|
||||
newdims[1] = dimLength;
|
||||
newdims[2] = getDimSize(2);
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
newdims[0] = getDimSize(0);
|
||||
newdims[1] = getDimSize(1);
|
||||
newdims[2] = dimLength;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (aMatrix.getDimSize(0)!= newdims[0]
|
||||
|| aMatrix.getDimSize(1)!= newdims[1]
|
||||
|| aMatrix.getDimSize(2)!= newdims[2])
|
||||
{
|
||||
std::cerr << "CudaMatrix setBlock src Matrix(" << aMatrix.getDimSize(0) << ","
|
||||
<< aMatrix.getDimSize(1) << "," << aMatrix.getDimSize(2)
|
||||
<< ") not match the des shape(" << newdims[0] << ","
|
||||
<< newdims[1] << "," << newdims[2] << ")"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
|
||||
int colStride = getDimSize(0);
|
||||
int sliceStride = getDimSize(0)*getDimSize(1);
|
||||
switch (aDim)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
int colStride2 = dimLength;
|
||||
int sliceStride2 = dimLength*getDimSize(1);
|
||||
for (size_t i = 0; i < getDimSize(2); i++)
|
||||
{
|
||||
for (size_t j = 0; j < getDimSize(1); j++)
|
||||
{
|
||||
cudaMemcpy(mData.get() + (aBeginIndex + j * colStride + i * sliceStride)*getValueType(),
|
||||
aMatrix.getData() + (colStride2 * j + i * sliceStride2)*getValueType(),
|
||||
sizeof(float) * colStride2*getValueType(), cudaMemcpyDeviceToDevice);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
int colStride2 = getDimSize(0);
|
||||
int copySize = dimLength*getDimSize(0);
|
||||
for (size_t i = 0; i < getDimSize(2); i++)
|
||||
{
|
||||
cudaMemcpy(mData.get() + getValueType()*(aBeginIndex * colStride + i * sliceStride),
|
||||
aMatrix.getData() + getValueType()*(i * copySize),
|
||||
sizeof(float) * copySize*getValueType(), cudaMemcpyDeviceToDevice);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
int copySize = dimLength*sliceStride;
|
||||
cudaMemcpy(mData.get() + aBeginIndex * sliceStride*getValueType(),
|
||||
aMatrix.getData(),
|
||||
sizeof(float) * copySize*getValueType(), cudaMemcpyDeviceToDevice);
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CudaMatrix CudaMatrix::operator+(float aScalar) const{
|
||||
if (isComplex())
|
||||
|
||||
@@ -90,3 +90,13 @@ void unaryPow(float* in1, float N,float* out, unsigned long length){
|
||||
thrust::transform(thrust::device,in1,in1+length,out,PowOperator(N));
|
||||
}
|
||||
|
||||
void thrustFill(float* aBegin, float* aEnd, float aValue)
|
||||
{
|
||||
thrust::fill(thrust::device, aBegin, aEnd, aValue);
|
||||
}
|
||||
|
||||
void thrustFill(float* aBegin, float* aEnd, std::complex<float> aValue)
|
||||
{
|
||||
thrust::fill(thrust::device, (std::complex<float>*)aBegin, (std::complex<float>*)aEnd, aValue);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#ifndef __CUDAMATRIX_CUH__
|
||||
#define __CUDAMATRIX_CUH__
|
||||
|
||||
#include <complex>
|
||||
|
||||
void unaryAdd(float* in1, float* in2, float* out, unsigned long length);
|
||||
void unaryAdd(float* in1, const float& in2, float* out, unsigned long length);
|
||||
void unaryMul(float* in1, float* in2, float* out, unsigned long length);
|
||||
@@ -17,5 +19,8 @@ void unaryDiv(const float& in1, float* in2, float* out, unsigned long length);
|
||||
void unarySub(float* in1, const float& in2, float* out, unsigned long length);
|
||||
void unaryDiv(float* in1, const float& in2, float* out, unsigned long length);
|
||||
|
||||
void thrustFill(float* aBegin, float* aEnd, float aValue);
|
||||
void thrustFill(float* aBegin, float* aEnd, std::complex<float> aValue);
|
||||
|
||||
|
||||
#endif // __CUDAMATRIX_H__
|
||||
Reference in New Issue
Block a user