Add NewCuda.

This commit is contained in:
sunwen
2023-10-09 17:13:31 +08:00
parent f8f9e453b5
commit 15c0654c5c
4 changed files with 34 additions and 0 deletions

View File

@@ -13,6 +13,7 @@
#include <Eigen/Core>
#include <Eigen/Eigen>
#include <Eigen/Dense>
#include <cuda_runtime.h>
namespace Aurora {
@@ -25,4 +26,9 @@ namespace Aurora {
void free(void* ptr){
mkl_free(ptr);
}
void gpuFree(void* ptr)
{
cudaFree(ptr);
}
}

View File

@@ -10,6 +10,7 @@
namespace Aurora{
float* malloc(size_t size,bool complex = false);
void free(void* ptr);
void gpuFree(void* ptr);
};

View File

@@ -16,6 +16,8 @@
#include "Eigen/src/Core/Matrix.h"
#include "Function.h"
#include <cuda_runtime.h>
namespace Aurora{
typedef void(*CalcFuncD)(const MKL_INT n, const float a[], const MKL_INT inca, const float b[],
const MKL_INT incb, float r[], const MKL_INT incr);
@@ -391,6 +393,26 @@ namespace Aurora {
return ret;
}
Matrix Matrix::NewCuda(float *data, int rows, int cols, int slices, ValueType type)
{
if (!data) return Matrix();
std::vector<int> vector(3);
vector[0]=rows;
vector[1] = (cols > 0?cols:1);
vector[2] = (slices > 0?slices:1);
Matrix ret({data, gpuFree}, vector);
if (type != Normal)ret.setValueType(type);
ret.mCuda_Allocated = true;
return ret;
}
Matrix Matrix::toHostMatrix() const
{
float* data = new float[getDataSize()];
cudaMemcpy(data, mData.get(), sizeof(float) * getDataSize() * (mValueType == Normal ? 1 : 2), cudaMemcpyDeviceToHost);
return fromRawData(data, mInfo[0], mInfo[1], mInfo[2], getValueType());
}
Matrix Matrix::New(float *data, const Matrix &shapeMatrix) {
return New(data,
shapeMatrix.getDimSize(0),

View File

@@ -87,6 +87,7 @@ namespace Aurora {
*/
static Matrix New(float *data, const Matrix &shapeMatrix);
static Matrix NewCuda(float *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
/**
* New a mkl calculate based Matrix
* @attention Memory are allocate by Aurora:malloc function
@@ -287,11 +288,15 @@ namespace Aurora {
void forceReshape(int rows, int columns, int slices);
Matrix toHostMatrix() const;
private:
ValueType mValueType = Normal;
std::shared_ptr<float> mData;
std::vector<int> mInfo;
bool mMKL_Allocated = false;
bool mCuda_Allocated = false;
};
}