From 46b4e5c04d8b021b264ddae15a44b5adb299a433 Mon Sep 17 00:00:00 2001 From: kradchen Date: Tue, 13 Jun 2023 14:36:01 +0800 Subject: [PATCH] Add complex support to block function --- src/Matrix.cpp | 106 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/src/Matrix.cpp b/src/Matrix.cpp index aad91f2..005e208 100644 --- a/src/Matrix.cpp +++ b/src/Matrix.cpp @@ -610,7 +610,7 @@ namespace Aurora { } int dimLength = std::abs(aEndIndex-aBeginIndex)+1; int dataSize = getDataSize()/getDimSize(aDim)*dimLength; - double * dataOutput = malloc(dataSize); + double * dataOutput = malloc(dataSize,isComplex()); int colStride = getDimSize(0); int sliceStride = getDimSize(0)*getDimSize(1); switch (aDim) { @@ -623,12 +623,21 @@ namespace Aurora { { cblas_dcopy( dimLength, - getData() + aBeginIndex + j * colStride + - i * sliceStride, - 1, dataOutput + colStride2 * j + i * sliceStride2, 1); + getData() + (aBeginIndex + j * colStride + + i * sliceStride)*getValueType(), + getValueType(), dataOutput + (colStride2 * j + i * sliceStride2)*getValueType(), + getValueType()); + if(isComplex()){ + cblas_dcopy( + dimLength, + getData()+1 + (aBeginIndex + j * colStride + + i * sliceStride)*getValueType(), + getValueType(), dataOutput+1 + (colStride2 * j + i * sliceStride2)*getValueType(), + getValueType()); + } } } - return Matrix::New(dataOutput,dimLength,getDimSize(1),getDimSize(2)); + return Matrix::New(dataOutput,dimLength,getDimSize(1),getDimSize(2),getValueType()); } case 1:{ int colStride2 = getDimSize(0); @@ -637,16 +646,25 @@ namespace Aurora { for (size_t i = 0; i < getDimSize(2); i++) { cblas_dcopy(copySize, - getData() + aBeginIndex * colStride + - i * sliceStride, - 1, dataOutput + i * copySize, 1); + getData() + getValueType()*(aBeginIndex * colStride + + i * sliceStride), + getValueType(), dataOutput + getValueType()*(i * copySize), + getValueType()); + if (isComplex()) + { + cblas_dcopy(copySize, + getData()+1 + getValueType()*(aBeginIndex * colStride + + i * sliceStride), + getValueType(), dataOutput+1 + getValueType()*(i * copySize), + getValueType()); + } } - return Matrix::New(dataOutput,getDimSize(0),dimLength,getDimSize(2)); + return Matrix::New(dataOutput,getDimSize(0),dimLength,getDimSize(2),getValueType()); } case 2:{ int copySize = dimLength*sliceStride; - cblas_dcopy(copySize, getData() + aBeginIndex * sliceStride ,1, dataOutput, 1); - return Matrix::New(dataOutput,getDimSize(0),getDimSize(1),dimLength); + cblas_dcopy(copySize*getValueType(), getData() + aBeginIndex * sliceStride*getValueType() ,1, dataOutput, 1); + return Matrix::New(dataOutput,getDimSize(0),getDimSize(1),dimLength,getValueType()); } } @@ -748,12 +766,20 @@ namespace Aurora { { for (size_t j = 0; j < getDimSize(1); j++) { - cblas_zcopy( + double real = value.real(); + double imag = value.imag(); + cblas_dcopy( dimLength, - &value,0, + &real,0, getData() + aBeginIndex*2 + j * colStride*2 + i * sliceStride*2, - 1); + 2); + cblas_dcopy( + dimLength, + &imag,0, + getData() +1 + aBeginIndex*2 + j * colStride*2 + + i * sliceStride*2, + 2); } } return true; @@ -764,16 +790,25 @@ namespace Aurora { int copySize = sliceStride2; for (size_t i = 0; i < getDimSize(2); i++) { - cblas_zcopy(copySize, - &value, 0 , + double real = value.real(); + double imag = value.imag(); + cblas_dcopy(copySize, + &real, 0 , getData() + aBeginIndex * colStride*2 + - i * sliceStride*2, 1); + i * sliceStride*2,2); + cblas_dcopy(copySize, + &imag, 0 , + getData() +1 + aBeginIndex * colStride*2 + + i * sliceStride*2, 2); } return true; } case 2:{ int copySize = dimLength*sliceStride; - cblas_zcopy(copySize, &value, 0 ,getData() + aBeginIndex * sliceStride*2 ,1); + double real = value.real(); + double imag = value.imag(); + cblas_dcopy(copySize, &real, 0 ,getData() + aBeginIndex * sliceStride*2 ,2); + cblas_dcopy(copySize, &imag, 0 ,getData() +1 + aBeginIndex * sliceStride*2 ,2); return true; } } @@ -847,15 +882,26 @@ namespace Aurora { { for (size_t j = 0; j < getDimSize(1); j++) { - cblas_dcopy(dimLength * getValueType(), + cblas_dcopy(dimLength, src.getData() + j * colStride2 * getValueType() + i * sliceStride2 * getValueType(), - 1, + getValueType(), getData() + aBeginIndex * getValueType() + j * colStride * getValueType() + i * sliceStride * getValueType(), - 1); + getValueType()); + if (isComplex()){ + cblas_dcopy(dimLength , + src.getData() + 1 + + j * colStride2 * getValueType() + + i * sliceStride2 * getValueType(), + getValueType(), + getData() + 1 + aBeginIndex * getValueType() + + j * colStride * getValueType() + + i * sliceStride * getValueType(), + getValueType()); + } } } return true; @@ -867,17 +913,27 @@ namespace Aurora { int copySize = sliceStride2; for (size_t i = 0; i < getDimSize(2); i++) { - cblas_dcopy(copySize*getValueType(), - src.getData()+i * sliceStride2*getValueType(),1, + cblas_dcopy(copySize, + src.getData()+i * sliceStride2*getValueType(),getValueType(), getData() + aBeginIndex * colStride*getValueType() + - i * sliceStride*getValueType(), 1); + i * sliceStride*getValueType(), getValueType()); + if (isComplex()) + { + cblas_dcopy(copySize, + src.getData()+1 +i * sliceStride2*getValueType(),getValueType(), + getData()+1 + aBeginIndex * colStride*getValueType() + + i * sliceStride*getValueType(), getValueType()); + } } return true; } //copy slice case 2:{ int copySize = dimLength*sliceStride; - cblas_dcopy(copySize*getValueType(),src.getData(),1,getData() + aBeginIndex * sliceStride*getValueType() ,1); + cblas_dcopy(copySize,src.getData(),getValueType(),getData() + aBeginIndex * sliceStride*getValueType() ,getValueType()); + if(isComplex()){ + cblas_dcopy(copySize,src.getData()+1,getValueType(),getData()+1 + aBeginIndex * sliceStride*getValueType() ,getValueType()); + } return true; } }