Add complex support to block function

This commit is contained in:
kradchen
2023-06-13 14:36:01 +08:00
parent 9919b09cf0
commit 46b4e5c04d

View File

@@ -610,7 +610,7 @@ namespace Aurora {
} }
int dimLength = std::abs(aEndIndex-aBeginIndex)+1; int dimLength = std::abs(aEndIndex-aBeginIndex)+1;
int dataSize = getDataSize()/getDimSize(aDim)*dimLength; int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
double * dataOutput = malloc(dataSize); double * dataOutput = malloc(dataSize,isComplex());
int colStride = getDimSize(0); int colStride = getDimSize(0);
int sliceStride = getDimSize(0)*getDimSize(1); int sliceStride = getDimSize(0)*getDimSize(1);
switch (aDim) { switch (aDim) {
@@ -623,12 +623,21 @@ namespace Aurora {
{ {
cblas_dcopy( cblas_dcopy(
dimLength, dimLength,
getData() + aBeginIndex + j * colStride + getData() + (aBeginIndex + j * colStride +
i * sliceStride, i * sliceStride)*getValueType(),
1, dataOutput + colStride2 * j + i * sliceStride2, 1); getValueType(), dataOutput + (colStride2 * j + i * sliceStride2)*getValueType(),
getValueType());
if(isComplex()){
cblas_dcopy(
dimLength,
getData()+1 + (aBeginIndex + j * colStride +
i * sliceStride)*getValueType(),
getValueType(), dataOutput+1 + (colStride2 * j + i * sliceStride2)*getValueType(),
getValueType());
}
} }
} }
return Matrix::New(dataOutput,dimLength,getDimSize(1),getDimSize(2)); return Matrix::New(dataOutput,dimLength,getDimSize(1),getDimSize(2),getValueType());
} }
case 1:{ case 1:{
int colStride2 = getDimSize(0); int colStride2 = getDimSize(0);
@@ -637,16 +646,25 @@ namespace Aurora {
for (size_t i = 0; i < getDimSize(2); i++) for (size_t i = 0; i < getDimSize(2); i++)
{ {
cblas_dcopy(copySize, cblas_dcopy(copySize,
getData() + aBeginIndex * colStride + getData() + getValueType()*(aBeginIndex * colStride +
i * sliceStride, i * sliceStride),
1, dataOutput + i * copySize, 1); getValueType(), dataOutput + getValueType()*(i * copySize),
getValueType());
if (isComplex())
{
cblas_dcopy(copySize,
getData()+1 + getValueType()*(aBeginIndex * colStride +
i * sliceStride),
getValueType(), dataOutput+1 + getValueType()*(i * copySize),
getValueType());
}
} }
return Matrix::New(dataOutput,getDimSize(0),dimLength,getDimSize(2)); return Matrix::New(dataOutput,getDimSize(0),dimLength,getDimSize(2),getValueType());
} }
case 2:{ case 2:{
int copySize = dimLength*sliceStride; int copySize = dimLength*sliceStride;
cblas_dcopy(copySize, getData() + aBeginIndex * sliceStride ,1, dataOutput, 1); cblas_dcopy(copySize*getValueType(), getData() + aBeginIndex * sliceStride*getValueType() ,1, dataOutput, 1);
return Matrix::New(dataOutput,getDimSize(0),getDimSize(1),dimLength); return Matrix::New(dataOutput,getDimSize(0),getDimSize(1),dimLength,getValueType());
} }
} }
@@ -748,12 +766,20 @@ namespace Aurora {
{ {
for (size_t j = 0; j < getDimSize(1); j++) for (size_t j = 0; j < getDimSize(1); j++)
{ {
cblas_zcopy( double real = value.real();
double imag = value.imag();
cblas_dcopy(
dimLength, dimLength,
&value,0, &real,0,
getData() + aBeginIndex*2 + j * colStride*2 + getData() + aBeginIndex*2 + j * colStride*2 +
i * sliceStride*2, i * sliceStride*2,
1); 2);
cblas_dcopy(
dimLength,
&imag,0,
getData() +1 + aBeginIndex*2 + j * colStride*2 +
i * sliceStride*2,
2);
} }
} }
return true; return true;
@@ -764,16 +790,25 @@ namespace Aurora {
int copySize = sliceStride2; int copySize = sliceStride2;
for (size_t i = 0; i < getDimSize(2); i++) for (size_t i = 0; i < getDimSize(2); i++)
{ {
cblas_zcopy(copySize, double real = value.real();
&value, 0 , double imag = value.imag();
cblas_dcopy(copySize,
&real, 0 ,
getData() + aBeginIndex * colStride*2 + getData() + aBeginIndex * colStride*2 +
i * sliceStride*2, 1); i * sliceStride*2,2);
cblas_dcopy(copySize,
&imag, 0 ,
getData() +1 + aBeginIndex * colStride*2 +
i * sliceStride*2, 2);
} }
return true; return true;
} }
case 2:{ case 2:{
int copySize = dimLength*sliceStride; int copySize = dimLength*sliceStride;
cblas_zcopy(copySize, &value, 0 ,getData() + aBeginIndex * sliceStride*2 ,1); double real = value.real();
double imag = value.imag();
cblas_dcopy(copySize, &real, 0 ,getData() + aBeginIndex * sliceStride*2 ,2);
cblas_dcopy(copySize, &imag, 0 ,getData() +1 + aBeginIndex * sliceStride*2 ,2);
return true; return true;
} }
} }
@@ -847,15 +882,26 @@ namespace Aurora {
{ {
for (size_t j = 0; j < getDimSize(1); j++) for (size_t j = 0; j < getDimSize(1); j++)
{ {
cblas_dcopy(dimLength * getValueType(), cblas_dcopy(dimLength,
src.getData() + src.getData() +
j * colStride2 * getValueType() + j * colStride2 * getValueType() +
i * sliceStride2 * getValueType(), i * sliceStride2 * getValueType(),
1, getValueType(),
getData() + aBeginIndex * getValueType() + getData() + aBeginIndex * getValueType() +
j * colStride * getValueType() + j * colStride * getValueType() +
i * sliceStride * getValueType(), i * sliceStride * getValueType(),
1); getValueType());
if (isComplex()){
cblas_dcopy(dimLength ,
src.getData() + 1 +
j * colStride2 * getValueType() +
i * sliceStride2 * getValueType(),
getValueType(),
getData() + 1 + aBeginIndex * getValueType() +
j * colStride * getValueType() +
i * sliceStride * getValueType(),
getValueType());
}
} }
} }
return true; return true;
@@ -867,17 +913,27 @@ namespace Aurora {
int copySize = sliceStride2; int copySize = sliceStride2;
for (size_t i = 0; i < getDimSize(2); i++) for (size_t i = 0; i < getDimSize(2); i++)
{ {
cblas_dcopy(copySize*getValueType(), cblas_dcopy(copySize,
src.getData()+i * sliceStride2*getValueType(),1, src.getData()+i * sliceStride2*getValueType(),getValueType(),
getData() + aBeginIndex * colStride*getValueType() + getData() + aBeginIndex * colStride*getValueType() +
i * sliceStride*getValueType(), 1); i * sliceStride*getValueType(), getValueType());
if (isComplex())
{
cblas_dcopy(copySize,
src.getData()+1 +i * sliceStride2*getValueType(),getValueType(),
getData()+1 + aBeginIndex * colStride*getValueType() +
i * sliceStride*getValueType(), getValueType());
}
} }
return true; return true;
} }
//copy slice //copy slice
case 2:{ case 2:{
int copySize = dimLength*sliceStride; int copySize = dimLength*sliceStride;
cblas_dcopy(copySize*getValueType(),src.getData(),1,getData() + aBeginIndex * sliceStride*getValueType() ,1); cblas_dcopy(copySize,src.getData(),getValueType(),getData() + aBeginIndex * sliceStride*getValueType() ,getValueType());
if(isComplex()){
cblas_dcopy(copySize,src.getData()+1,getValueType(),getData()+1 + aBeginIndex * sliceStride*getValueType() ,getValueType());
}
return true; return true;
} }
} }