Fix Matrix free bug(with mkl_malloc)

This commit is contained in:
Krad
2023-04-20 09:32:58 +08:00
parent fedf13d5d0
commit 044de7be5f
2 changed files with 36 additions and 26 deletions

View File

@@ -118,6 +118,14 @@ namespace Aurora {
Matrix::Matrix(std::shared_ptr<double> aData, std::vector<int> aInfo) Matrix::Matrix(std::shared_ptr<double> aData, std::vector<int> aInfo)
: mData(aData), mInfo(aInfo) { : mData(aData), mInfo(aInfo) {
} }
Matrix::Matrix(const Matrix::MatrixSlice& slice) {
auto temp = slice.toMatrix();
this->mData = temp.mData;
this->mInfo = temp.mInfo;
this->mValueType = temp.mValueType;
}
bool Matrix::isNull() const { bool Matrix::isNull() const {
return !mData || mInfo.empty(); return !mData || mInfo.empty();
} }
@@ -185,7 +193,7 @@ namespace Aurora {
vector.push_back(rows); vector.push_back(rows);
if (cols > 0)vector.push_back(cols); if (cols > 0)vector.push_back(cols);
if (slices > 0)vector.push_back(slices); if (slices > 0)vector.push_back(slices);
Matrix ret({data, std::default_delete<double[]>()}, vector); Matrix ret({data, free}, vector);
if (type != Normal)ret.setValueType(type); if (type != Normal)ret.setValueType(type);
return ret; return ret;
} }
@@ -471,39 +479,40 @@ namespace Aurora {
return *this; return *this;
} }
Matrix Matrix::MatrixSlice::toMatrix() { Matrix Matrix::MatrixSlice::toMatrix() const {
auto data = (double *) mkl_malloc(mSize*(mSize2>0?mSize2:1) * sizeof(double)*mType, 64); double * data = (double *) mkl_malloc(mSize*(mSize2>0?mSize2:1) * sizeof(double)*mType, 64);
auto matrix = Matrix::New(data,mSize,mSize2,0,mType);
switch (mSliceMode) { switch (mSliceMode) {
case 2:{ case 2:{
if (mType== Normal) { if (mType== Normal) {
cblas_dcopy_batch_strided(mSize, mData, mStride, cblas_dcopy_batch_strided(mSize, mData, mStride,
mStride2,matrix.getData(), 1, matrix.getDimSize(0), mSize2); mStride2,data, 1, mSize, mSize2);
} }
else { else {
cblas_zcopy_batch_strided(mSize, (std::complex<double> *) mData, mStride, mStride2, cblas_zcopy_batch_strided(mSize, (std::complex<double> *) mData, mStride, mStride2,
(std::complex<double> *) matrix.getData(), 1, matrix.getDimSize(0), (std::complex<double> *) data, 1, mSize,
mSize2); mSize2);
} }
break; break;
} }
case 1:{ case 1:{
if (mType== Normal){ if (mType== Normal){
cblas_dcopy(mSize,mData,mStride,matrix.getData(),1); cblas_dcopy(mSize,mData,mStride,data,1);
} }
else { else {
cblas_zcopy(mSize, (std::complex<double> *) mData, mStride, cblas_zcopy(mSize, (std::complex<double> *) mData, mStride,
(std::complex<double> *) matrix.getData(), 1); (std::complex<double> *) data, 1);
} }
break; break;
} }
case 0: case 0:
default:{ default:{
matrix.getData()[0]= mData[0]; data[0]= mData[0];
if (mType != Normal) matrix.getData()[1] = mData[1]; if (mType != Normal) data[1] = mData[1];
} }
} }
return matrix;
return Matrix::New(data,mSize,mSize2,0,mType);
} }
} }

View File

@@ -14,16 +14,6 @@ namespace Aurora {
class Matrix { class Matrix {
public: public:
explicit Matrix(std::shared_ptr<double> aData = std::shared_ptr<double>(),
std::vector<int> aInfo = std::vector<int>());
static Matrix New(double *data, int rows, int cols = 0, int slices = 0, ValueType type = Normal);
static Matrix New(double *data, const Matrix &shapeMatrix);
static Matrix New(const Matrix &shapeMatrix);
/** /**
* 内部类MatrixSlice用于切片操作 * 内部类MatrixSlice用于切片操作
*/ */
@@ -32,16 +22,27 @@ namespace Aurora {
MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0); MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0);
MatrixSlice& operator=(const MatrixSlice& slice); MatrixSlice& operator=(const MatrixSlice& slice);
MatrixSlice& operator=(const Matrix& matrix); MatrixSlice& operator=(const Matrix& matrix);
Matrix toMatrix(); Matrix toMatrix() const;
private: private:
int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix
double* mData; double* mData;
int mSize; int mSize=0;
int mSize2; int mSize2=0;
int mStride; int mStride=1;
int mStride2; int mStride2=0;
ValueType mType; ValueType mType;
friend class Matrix;
}; };
explicit Matrix(std::shared_ptr<double> aData = std::shared_ptr<double>(),
std::vector<int> aInfo = std::vector<int>());
explicit Matrix(const Matrix::MatrixSlice& slice);
static Matrix New(double *data, int rows, int cols = 0, int slices = 0, ValueType type = Normal);
static Matrix New(double *data, const Matrix &shapeMatrix);
static Matrix New(const Matrix &shapeMatrix);
Matrix getDataFromDims2(int aColumn); Matrix getDataFromDims2(int aColumn);