Add setBlock method to Matrix.

This commit is contained in:
kradchen
2023-05-29 14:28:23 +08:00
parent ce91b7a868
commit eb98e532b9
3 changed files with 148 additions and 8 deletions

View File

@@ -620,7 +620,7 @@ namespace Aurora {
bool Matrix::setBlockValue(int aDim,int aBeginIndex, int aEndIndex, double value) {
if(aDim>2 ){
std::cerr<<"block only support 1D-3D data!"<<std::endl;
std::cerr<<"setBlockValue only support 1D-3D data!"<<std::endl;
return false;
}
//横向vector切面为0为1都强制设置aDim为1来处理
@@ -628,22 +628,19 @@ namespace Aurora {
aDim = 1;
}
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0){
std::cerr<<"block BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
std::cerr<<"setBlockValue BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
return false;
}
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0){
std::cerr<<"block EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
std::cerr<<"setBlockValue EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
return false;
}
int dimLength = std::abs(aEndIndex-aBeginIndex)+1;
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
double * dataOutput = malloc(dataSize);
int colStride = getDimSize(0);
int sliceStride = getDimSize(0)*getDimSize(1);
switch (aDim) {
case 0:{
int colStride2 = dimLength;
int sliceStride2 = dimLength*getDimSize(1);
for (size_t i = 0; i < getDimSize(2); i++)
{
for (size_t j = 0; j < getDimSize(1); j++)
@@ -681,6 +678,104 @@ namespace Aurora {
}
bool Matrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const Matrix& src){
if(aDim>2 ){
std::cerr<<"setBlock only support 1D-3D data!"<<std::endl;
return false;
}
//横向vector切面为0为1都强制设置aDim为1来处理
if (isVector() && aDim == 0 && getDimSize(1)>1){
aDim = 1;
}
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0){
std::cerr<<"setBlock BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
return false;
}
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0){
std::cerr<<"block EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
return false;
}
int dimLength = std::abs(aEndIndex-aBeginIndex)+1;
size_t newdims[3]{0};
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
int colStride = getDimSize(0);
int sliceStride = getDimSize(0)*getDimSize(1);
switch (aDim) {
case 0:
{
newdims[0] = dimLength;
newdims[1] = getDimSize(1);
newdims[2] = getDimSize(2);
break;
}
case 1:
{
newdims[0] = getDimSize(0);
newdims[1] = dimLength;
newdims[2] = getDimSize(2);
break;
}
case 2:
{
newdims[0] = getDimSize(0);
newdims[1] = getDimSize(1);
newdims[2] = dimLength;
break;
}
}
if (src.getDimSize(0)!= newdims[0]
|| src.getDimSize(1)!= newdims[1]
|| src.getDimSize(2)!= newdims[2])
{
std::cerr << "setBlock src Matrix(" << src.getDimSize(0) << ","
<< src.getDimSize(1) << "," << src.getDimSize(2)
<< ") not match the des shape(" << newdims[0] << ","
<< newdims[1] << "," << newdims[2] << ")"
<< std::endl;
return false;
}
switch (aDim) {
//copy row
case 0:{
int colStride2 = dimLength;
int sliceStride2 = dimLength*getDimSize(1);
for (size_t i = 0; i < getDimSize(2); i++)
{
for (size_t j = 0; j < getDimSize(1); j++)
{
cblas_dcopy(
dimLength,
src.getData()+j * colStride2 + i * sliceStride2, 1,
getData() + aBeginIndex + j * colStride + i * sliceStride,
1);
}
}
return true;
}
// copy column
case 1:{
int colStride2 = getDimSize(0);
int sliceStride2 = dimLength*getDimSize(0);
int copySize = sliceStride2;
for (size_t i = 0; i < getDimSize(2); i++)
{
cblas_dcopy(copySize,
src.getData()+i * sliceStride2,1,
getData() + aBeginIndex * colStride +
i * sliceStride, 1);
}
return true;
}
//copy slice
case 2:{
int copySize = dimLength*sliceStride;
cblas_dcopy(copySize,src.getData(),1,getData() + aBeginIndex * sliceStride ,1);
return true;
}
}
}
void Matrix::printf() {
if(isNull())
{

View File

@@ -185,6 +185,8 @@ namespace Aurora {
bool setBlockValue(int aDim,int aBeginIndx, int aEndIndex,double value);
bool setBlock(int aDim,int aBeginIndx, int aEndIndex,const Matrix& src);
/**
* 矩阵乘法
* @attention 目前只支持矩阵乘向量