Add negative to Matrix

This commit is contained in:
kradchen
2023-04-28 13:19:12 +08:00
parent e9e1bacdc8
commit 3f12f0e0c1
2 changed files with 24 additions and 8 deletions

View File

@@ -305,10 +305,10 @@ namespace Aurora {
Matrix Matrix::New(double *data, int rows, int cols, int slices, ValueType type) { Matrix Matrix::New(double *data, int rows, int cols, int slices, ValueType type) {
if (!data) return Matrix(); if (!data) return Matrix();
std::vector<int> vector; std::vector<int> vector(3);
vector.push_back(rows); vector[0]=rows;
if (cols > 0)vector.push_back(cols); vector[1] = (cols > 0?cols:1);
if (slices > 0)vector.push_back(slices); vector[2] = (slices > 0?slices:1);
Matrix ret({data, free}, vector); Matrix ret({data, free}, vector);
if (type != Normal)ret.setValueType(type); if (type != Normal)ret.setValueType(type);
return ret; return ret;
@@ -349,7 +349,10 @@ namespace Aurora {
Matrix Matrix::deepCopy() const { Matrix Matrix::deepCopy() const {
double *newBuffer = malloc(getDataSize(), getValueType()==Complex); double *newBuffer = malloc(getDataSize(), getValueType()==Complex);
// size_t data_size = getDataSize() * getValueType();
cblas_dcopy(getDataSize() * getValueType(),getData(),1,newBuffer,1); cblas_dcopy(getDataSize() * getValueType(),getData(),1,newBuffer,1);
// std::memcpy(newBuffer,getData(),data_size*sizeof(double));
return New(newBuffer, return New(newBuffer,
getDimSize(0), getDimSize(0),
getDimSize(1), getDimSize(1),
@@ -679,7 +682,17 @@ namespace Aurora {
} }
Matrix operator-(const Matrix &aMatrix) { Matrix operator-(const Matrix &aMatrix) {
return -(std::forward<Matrix&&>(aMatrix.deepCopy())); double *newBuffer = malloc(aMatrix.getDataSize(), aMatrix.getValueType()==Complex);
double zero = 0.0;
if (aMatrix.isComplex()){
vdSubI( aMatrix.getDataSize(),&zero,0,aMatrix.getData(),2,newBuffer,2);
vdSubI( aMatrix.getDataSize(),&zero,0,aMatrix.getData()+1,2,newBuffer+1,2);
}
else{
vdSubI( aMatrix.getDataSize(),&zero,0,aMatrix.getData(),1,newBuffer,1);
}
return Matrix::New(newBuffer,aMatrix);
} }
Matrix::MatrixSlice::MatrixSlice(int aSize,int aStride, double* aData, ValueType aType, int aSliceMode,int aSize2, int aStride2): Matrix::MatrixSlice::MatrixSlice(int aSize,int aStride, double* aData, ValueType aType, int aSliceMode,int aSize2, int aStride2):
@@ -742,8 +755,7 @@ namespace Aurora {
cblas_dcopy(mSize,slice.mData,slice.mStride,mData,mStride); cblas_dcopy(mSize,slice.mData,slice.mStride,mData,mStride);
} }
else { else {
cblas_zcopy(mSize, (std::complex<double> *) slice.mData, slice.mStride, cblas_dcopy(mSize*2, slice.mData, slice.mStride, mData, mStride);
(std::complex<double> *) mData, mStride);
} }
break; break;
} }

View File

@@ -280,5 +280,9 @@ TEST_F(Matrix_Test, matrixOpertaor) {
DISPLAY_MATRIX(C); DISPLAY_MATRIX(C);
EXPECT_EQ(C.getData()[2], 0); EXPECT_EQ(C.getData()[2], 0);
EXPECT_EQ(C.getData()[3], 2); EXPECT_EQ(C.getData()[3], 2);
C= -B;
EXPECT_EQ(C.getData()[0], -2);
EXPECT_EQ(C.getData()[1], -2);
} }
} }