Add std, and fix fft, ifft bug for cuda

This commit is contained in:
kradchen
2023-12-19 13:12:20 +08:00
parent ea68e6c5af
commit 81078bd69f
5 changed files with 552 additions and 313 deletions

View File

@@ -156,7 +156,6 @@ CudaMatrix Aurora::max(const CudaMatrix &aMatrix, FunctionDirection direction, l
CudaMatrix vxmMax(CudaMatrix aVec, CudaMatrix aMat) {
//col-vec x mat
if (aVec.getDimSize(1) == 1 && aVec.getDimSize(0) == aMat.getDimSize(0)) {
std::cout<<"max mat and col-vec "<<std::endl;
size_t size = aMat.getDataSize();
float* data = nullptr;
cudaMalloc((void**)&data, sizeof(float) * size);
@@ -175,7 +174,6 @@ CudaMatrix vxmMax(CudaMatrix aVec, CudaMatrix aMat) {
// row-vec x mat
else if (aVec.getDimSize(0) == 1 && aVec.getDimSize(1) == aMat.getDimSize(1))
{
std::cout<<"max mat and row-vec "<<std::endl;
size_t size = aMat.getDataSize() ;
float* data = nullptr;
cudaMalloc((void**)&data, sizeof(float) * size);
@@ -376,7 +374,6 @@ CudaMatrix Aurora::min(const CudaMatrix &aMatrix, FunctionDirection direction, l
CudaMatrix vxmMin(CudaMatrix aVec, CudaMatrix aMat) {
//col-vec x mat
if (aVec.getDimSize(1) == 1 && aVec.getDimSize(0) == aMat.getDimSize(0)) {
std::cout<<"min mat and col-vec "<<std::endl;
size_t size = aMat.getDataSize();
float* data = nullptr;
cudaMalloc((void**)&data, sizeof(float) * size);
@@ -395,7 +392,6 @@ CudaMatrix vxmMin(CudaMatrix aVec, CudaMatrix aMat) {
// row-vec x mat
else if (aVec.getDimSize(0) == 1 && aVec.getDimSize(1) == aMat.getDimSize(1))
{
std::cout<<"min mat and row-vec "<<std::endl;
size_t size = aMat.getDataSize() ;
float* data = nullptr;
cudaMalloc((void**)&data, sizeof(float) * size);
@@ -682,7 +678,6 @@ CudaMatrix Aurora::sum(const CudaMatrix &aMatrix, FunctionDirection direction ){
case Column:
default:
{
std::cout<<"Column sum"<<std::endl;
float* matData = aMatrix.getData();
float* retData = nullptr;
int colElementCount = aMatrix.getDimSize(0);
@@ -793,6 +788,21 @@ CudaMatrix Aurora::mean(const CudaMatrix &aMatrix, FunctionDirection direction )
return CudaMatrix();
}
}
CudaMatrix Aurora::std(const CudaMatrix &aMatrix){
if (aMatrix.getDimSize(2) > 1 || aMatrix.isComplex()) {
std::cerr
<< (aMatrix.getDimSize(2) > 1 ? "std() not support 3D data!" : "std() not support complex value type!")
<< std::endl;
return CudaMatrix();
}
auto src = aMatrix.isComplex() ? Aurora::abs(aMatrix) : aMatrix.deepCopy();
int calc_size = src.getDimSize(0) == 1 ? src.getDimSize(1) : src.getDimSize(0);
auto meanM = Aurora::mean(src);
return sqrt(Aurora::sum((src-meanM)^2.0)/((float)calc_size-1.0f));
}
template <typename ValueType>
class RowElementIterator:public thrust::iterator_facade<
RowElementIterator<ValueType>,
@@ -1294,12 +1304,13 @@ __global__ void complexFillKernel(float* aInputData, float* aOutput,unsigned int
for (int offset = 0; offset < aDesColEleCount; offset+=blockDim.x)
{
if(threadIdx.x + offset< aCopySize){
aOutput[2*idx_d] = aInputData[idx_s];
aOutput[2*idx_d + 1] = 0;
aOutput[2 * idx_d + offset * 2] = aInputData[idx_s + offset];
aOutput[2 * idx_d + offset * 2 + 1] = 0;
}
else if(threadIdx.x + offset< aDesColEleCount){
aOutput[2*idx_d] = 0;
aOutput[2*idx_d + 1] = 0;
aOutput[2 * idx_d + offset * 2] = 0;
aOutput[2 * idx_d + offset * 2 + 1] = 0;
}
else{
return;
@@ -1316,12 +1327,12 @@ __global__ void complexCopyKernel(float* aInputData, float* aOutput,unsigned int
for (int offset = 0; offset < aDesColEleCount; offset+=blockDim.x)
{
if(threadIdx.x + offset< aCopySize){
aOutput[2*idx_d] = aInputData[idx_s*2];
aOutput[2*idx_d + 1] = aInputData[idx_s*2+1];
aOutput[2*idx_d + offset * 2 ] = aInputData[idx_s*2 + offset*2];
aOutput[2*idx_d + offset*2+ 1] = aInputData[idx_s*2+ offset*2+1];
}
else if(threadIdx.x + offset< aDesColEleCount){
aOutput[2*idx_d] = 0;
aOutput[2*idx_d + 1] = 0;
aOutput[2*idx_d + offset*2] = 0;
aOutput[2*idx_d + offset*2+ 1] = 0;
}
else{
return;
@@ -1344,7 +1355,9 @@ if (aMatrix.isComplex()){
complexFillKernel<<<aMatrix.getDimSize(1), 256>>>(aMatrix.getData(), data, needCopySize, aMatrix.getDimSize(0),ColEleCount);
}
auto ret = Aurora::CudaMatrix::fromRawData(data,ColEleCount,aMatrix.getDimSize(1),1,Complex);
auto mm = ret.toHostMatrix();
ExecFFT(ret,0);
mm = ret.toHostMatrix();
return ret;
}