Merge branch 'dtof' of http://192.168.1.9:3000/Bug/Aurora into dtof
This commit is contained in:
@@ -1054,6 +1054,110 @@ CudaMatrix Aurora::inv(const CudaMatrix &aMatrix)
|
||||
return CudaMatrix::fromRawData(data, n, n);
|
||||
}
|
||||
|
||||
CudaMatrix Aurora::inv(CudaMatrix &&aMatrix)
|
||||
{
|
||||
if (aMatrix.getDims() != 2)
|
||||
{
|
||||
std::cerr << "Fail! cuda inv args must be 2d matrix!";
|
||||
return aMatrix;
|
||||
}
|
||||
if (aMatrix.getDimSize(0) != aMatrix.getDimSize(1))
|
||||
{
|
||||
std::cerr << "Fail! cuda inv args must be square matrix!";
|
||||
return aMatrix;
|
||||
}
|
||||
if (aMatrix.getValueType() != Normal)
|
||||
{
|
||||
std::cerr << "Fail! cuda inv args must be normal value type!";
|
||||
return aMatrix;
|
||||
}
|
||||
unsigned int n = aMatrix.getDimSize(0);
|
||||
unsigned int size = aMatrix.getDataSize();
|
||||
cublasHandle_t handle;
|
||||
cublasCreate(&handle);
|
||||
float* data;
|
||||
cudaMalloc((void**)&data, sizeof(float) * size);
|
||||
|
||||
float** deviceInputPinter;
|
||||
cudaMalloc((void**)&deviceInputPinter, sizeof(float *));
|
||||
|
||||
float **deviceOutputPointer;
|
||||
cudaMalloc((void**)&deviceOutputPointer, sizeof(float *));
|
||||
|
||||
invKernel<<<1, 1>>>(deviceInputPinter, aMatrix.getData(), deviceOutputPointer, data);
|
||||
cudaDeviceSynchronize();
|
||||
int* devicePivotArray;
|
||||
cudaMalloc((void**)&devicePivotArray, n * sizeof(int));
|
||||
int* deviceInfoArray;
|
||||
cudaMalloc((void**)&deviceInfoArray, sizeof(int));
|
||||
cublasSgetrfBatched(handle, n, deviceInputPinter, n, devicePivotArray, deviceInfoArray, 1);
|
||||
cublasSgetriBatched(handle, n, deviceInputPinter, n, devicePivotArray, deviceOutputPointer, n, deviceInfoArray, 1);
|
||||
cudaFree(devicePivotArray);
|
||||
cudaFree(deviceInfoArray);
|
||||
cudaFree(deviceOutputPointer);
|
||||
cudaFree(deviceInputPinter);
|
||||
cublasDestroy(handle);
|
||||
|
||||
return CudaMatrix::fromRawData(data, n, n);
|
||||
}
|
||||
|
||||
__global__ void dotColumnKernel(float* aInputData1, float* aInputData2, float* aOutputData, unsigned int aInputRowSize)
|
||||
{
|
||||
__shared__ float sharedValue[THREADS_PER_BLOCK];
|
||||
sharedValue[threadIdx.x] = 0;
|
||||
|
||||
for(unsigned int i=0; i<=aInputRowSize/blockDim.x; ++i)
|
||||
{
|
||||
unsigned int indexByRows = i*blockDim.x + threadIdx.x;
|
||||
if(indexByRows < aInputRowSize)
|
||||
{
|
||||
sharedValue[threadIdx.x] += aInputData1[blockIdx.x*aInputRowSize + indexByRows] * aInputData2[blockIdx.x*aInputRowSize + indexByRows];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for(unsigned int i = blockDim.x/2; i>0; i >>= 1)
|
||||
{
|
||||
if(threadIdx.x < i)
|
||||
{
|
||||
sharedValue[threadIdx.x] += sharedValue[threadIdx.x + i];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
aOutputData[blockIdx.x] = sharedValue[0];
|
||||
}
|
||||
|
||||
CudaMatrix Aurora::dot(const CudaMatrix &aMatrix, const CudaMatrix &aOther, FunctionDirection direction)
|
||||
{
|
||||
if ( direction != Aurora::Column)
|
||||
{
|
||||
std::cerr<< "cuda dot() only support column data!"<< std::endl;
|
||||
return CudaMatrix();
|
||||
}
|
||||
if (!aMatrix.compareShape(aOther))
|
||||
{
|
||||
std::cerr<< "cuda dot() matrix must be same shape!"<< std::endl;
|
||||
return CudaMatrix();
|
||||
}
|
||||
if(aMatrix.getValueType() == Aurora::Complex || aOther.getValueType() == Aurora::Complex)
|
||||
{
|
||||
std::cerr<< "cuda dot() do not support complex data!"<< std::endl;
|
||||
return CudaMatrix();
|
||||
}
|
||||
|
||||
unsigned int column = aMatrix.getDimSize(1);
|
||||
unsigned int row = aMatrix.getDimSize(0);
|
||||
if(aMatrix.getDimSize(0) == 1 || aMatrix.getDimSize(1) == 1)
|
||||
{
|
||||
column = 1;
|
||||
row = aMatrix.getDataSize();
|
||||
}
|
||||
float* data = nullptr;
|
||||
cudaMalloc((void**)&data, sizeof(float) * column);
|
||||
dotColumnKernel<<<column, THREADS_PER_BLOCK>>>(aMatrix.getData(), aOther.getData(), data, row);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
return CudaMatrix::fromRawData(data, 1, column);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief
|
||||
|
||||
@@ -45,6 +45,8 @@ namespace Aurora
|
||||
|
||||
CudaMatrix inv(CudaMatrix &&aMatrix);
|
||||
|
||||
CudaMatrix dot(const CudaMatrix &aMatrix, const CudaMatrix &aOther, FunctionDirection direction = Column);
|
||||
|
||||
CudaMatrix fft(const CudaMatrix &aMatrix, long aFFTSize = -1);
|
||||
CudaMatrix ifft(const CudaMatrix &aMatrix, long aFFTSize = -1);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user