Add setBlock method to Matrix.
This commit is contained in:
107
src/Matrix.cpp
107
src/Matrix.cpp
@@ -620,7 +620,7 @@ namespace Aurora {
|
|||||||
|
|
||||||
bool Matrix::setBlockValue(int aDim,int aBeginIndex, int aEndIndex, double value) {
|
bool Matrix::setBlockValue(int aDim,int aBeginIndex, int aEndIndex, double value) {
|
||||||
if(aDim>2 ){
|
if(aDim>2 ){
|
||||||
std::cerr<<"block only support 1D-3D data!"<<std::endl;
|
std::cerr<<"setBlockValue only support 1D-3D data!"<<std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
//横向vector,切面为0,为1,都强制设置aDim为1来处理
|
//横向vector,切面为0,为1,都强制设置aDim为1来处理
|
||||||
@@ -628,22 +628,19 @@ namespace Aurora {
|
|||||||
aDim = 1;
|
aDim = 1;
|
||||||
}
|
}
|
||||||
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0){
|
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0){
|
||||||
std::cerr<<"block BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
|
std::cerr<<"setBlockValue BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0){
|
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0){
|
||||||
std::cerr<<"block EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
|
std::cerr<<"setBlockValue EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int dimLength = std::abs(aEndIndex-aBeginIndex)+1;
|
int dimLength = std::abs(aEndIndex-aBeginIndex)+1;
|
||||||
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
|
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
|
||||||
double * dataOutput = malloc(dataSize);
|
|
||||||
int colStride = getDimSize(0);
|
int colStride = getDimSize(0);
|
||||||
int sliceStride = getDimSize(0)*getDimSize(1);
|
int sliceStride = getDimSize(0)*getDimSize(1);
|
||||||
switch (aDim) {
|
switch (aDim) {
|
||||||
case 0:{
|
case 0:{
|
||||||
int colStride2 = dimLength;
|
|
||||||
int sliceStride2 = dimLength*getDimSize(1);
|
|
||||||
for (size_t i = 0; i < getDimSize(2); i++)
|
for (size_t i = 0; i < getDimSize(2); i++)
|
||||||
{
|
{
|
||||||
for (size_t j = 0; j < getDimSize(1); j++)
|
for (size_t j = 0; j < getDimSize(1); j++)
|
||||||
@@ -681,6 +678,104 @@ namespace Aurora {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Matrix::setBlock(int aDim,int aBeginIndex, int aEndIndex, const Matrix& src){
|
||||||
|
if(aDim>2 ){
|
||||||
|
std::cerr<<"setBlock only support 1D-3D data!"<<std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
//横向vector,切面为0,为1,都强制设置aDim为1来处理
|
||||||
|
if (isVector() && aDim == 0 && getDimSize(1)>1){
|
||||||
|
aDim = 1;
|
||||||
|
}
|
||||||
|
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0){
|
||||||
|
std::cerr<<"setBlock BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0){
|
||||||
|
std::cerr<<"block EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int dimLength = std::abs(aEndIndex-aBeginIndex)+1;
|
||||||
|
size_t newdims[3]{0};
|
||||||
|
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
|
||||||
|
int colStride = getDimSize(0);
|
||||||
|
int sliceStride = getDimSize(0)*getDimSize(1);
|
||||||
|
switch (aDim) {
|
||||||
|
case 0:
|
||||||
|
{
|
||||||
|
newdims[0] = dimLength;
|
||||||
|
newdims[1] = getDimSize(1);
|
||||||
|
newdims[2] = getDimSize(2);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
{
|
||||||
|
newdims[0] = getDimSize(0);
|
||||||
|
newdims[1] = dimLength;
|
||||||
|
newdims[2] = getDimSize(2);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
{
|
||||||
|
newdims[0] = getDimSize(0);
|
||||||
|
newdims[1] = getDimSize(1);
|
||||||
|
newdims[2] = dimLength;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (src.getDimSize(0)!= newdims[0]
|
||||||
|
|| src.getDimSize(1)!= newdims[1]
|
||||||
|
|| src.getDimSize(2)!= newdims[2])
|
||||||
|
{
|
||||||
|
std::cerr << "setBlock src Matrix(" << src.getDimSize(0) << ","
|
||||||
|
<< src.getDimSize(1) << "," << src.getDimSize(2)
|
||||||
|
<< ") not match the des shape(" << newdims[0] << ","
|
||||||
|
<< newdims[1] << "," << newdims[2] << ")"
|
||||||
|
<< std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
switch (aDim) {
|
||||||
|
//copy row
|
||||||
|
case 0:{
|
||||||
|
int colStride2 = dimLength;
|
||||||
|
int sliceStride2 = dimLength*getDimSize(1);
|
||||||
|
for (size_t i = 0; i < getDimSize(2); i++)
|
||||||
|
{
|
||||||
|
for (size_t j = 0; j < getDimSize(1); j++)
|
||||||
|
{
|
||||||
|
cblas_dcopy(
|
||||||
|
dimLength,
|
||||||
|
src.getData()+j * colStride2 + i * sliceStride2, 1,
|
||||||
|
getData() + aBeginIndex + j * colStride + i * sliceStride,
|
||||||
|
1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// copy column
|
||||||
|
case 1:{
|
||||||
|
int colStride2 = getDimSize(0);
|
||||||
|
int sliceStride2 = dimLength*getDimSize(0);
|
||||||
|
int copySize = sliceStride2;
|
||||||
|
for (size_t i = 0; i < getDimSize(2); i++)
|
||||||
|
{
|
||||||
|
cblas_dcopy(copySize,
|
||||||
|
src.getData()+i * sliceStride2,1,
|
||||||
|
getData() + aBeginIndex * colStride +
|
||||||
|
i * sliceStride, 1);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
//copy slice
|
||||||
|
case 2:{
|
||||||
|
int copySize = dimLength*sliceStride;
|
||||||
|
cblas_dcopy(copySize,src.getData(),1,getData() + aBeginIndex * sliceStride ,1);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Matrix::printf() {
|
void Matrix::printf() {
|
||||||
if(isNull())
|
if(isNull())
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -185,6 +185,8 @@ namespace Aurora {
|
|||||||
|
|
||||||
bool setBlockValue(int aDim,int aBeginIndx, int aEndIndex,double value);
|
bool setBlockValue(int aDim,int aBeginIndx, int aEndIndex,double value);
|
||||||
|
|
||||||
|
bool setBlock(int aDim,int aBeginIndx, int aEndIndex,const Matrix& src);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 矩阵乘法
|
* 矩阵乘法
|
||||||
* @attention 目前只支持矩阵乘向量
|
* @attention 目前只支持矩阵乘向量
|
||||||
|
|||||||
@@ -422,9 +422,52 @@ TEST_F(Matrix_Test, matrixfunction){
|
|||||||
{
|
{
|
||||||
for (size_t k = 1; k < E.getDimSize(0); k++)
|
for (size_t k = 1; k < E.getDimSize(0); k++)
|
||||||
{
|
{
|
||||||
auto index1 = k+j*E.getDimSize(0)+i*E.getDimSize(1)*E.getDimSize(0);
|
EXPECT_DOUBLE_AE(E(k,i,j).toMatrix().getScalar(),-1);
|
||||||
EXPECT_DOUBLE_AE(E[index1],-1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
auto z = Aurora::zeros(1,4,5);
|
||||||
|
EXPECT_TRUE(E.setBlock(0, 0, 0,z));
|
||||||
|
for (size_t i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
for (size_t j = 0; j < 5; j++)
|
||||||
|
{
|
||||||
|
EXPECT_DOUBLE_AE(E(0,i,j).toMatrix().getScalar(), 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto z2 = Aurora::zeros(3,2,5);
|
||||||
|
EXPECT_FALSE(E.setBlock(1, 1, 2,z));
|
||||||
|
EXPECT_TRUE(E.setBlock(1, 1, 2,z2));
|
||||||
|
for (size_t j = 0; j < 5; j++)
|
||||||
|
{
|
||||||
|
for (size_t i = 1; i < 3; i++)
|
||||||
|
{
|
||||||
|
for (size_t k = 0; k < 3; k++)
|
||||||
|
{
|
||||||
|
EXPECT_DOUBLE_AE(E(k,i,j).toMatrix().getScalar(), 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto o2 = Aurora::ones(3,4,2);
|
||||||
|
EXPECT_FALSE(E.setBlock(2, 1, 2,z));
|
||||||
|
EXPECT_FALSE(E.setBlock(2, 1, 2,z2));
|
||||||
|
EXPECT_TRUE(E.setBlock(2, 1, 2,o2));
|
||||||
|
for (size_t j = 1; j < 3; j++)
|
||||||
|
{
|
||||||
|
for (size_t i = 1; i < 3; i++)
|
||||||
|
{
|
||||||
|
for (size_t k = 0; k < 3; k++)
|
||||||
|
{
|
||||||
|
EXPECT_DOUBLE_AE(E(k,i,j).toMatrix().getScalar(), 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto ZZ = Aurora::zeros(1,20);
|
||||||
|
auto OO = Aurora::ones(1,10);
|
||||||
|
EXPECT_TRUE(ZZ.setBlock(1, 10, 19, OO));
|
||||||
|
for (size_t i = 0; i < ZZ.getDataSize(); i++)
|
||||||
|
{
|
||||||
|
EXPECT_DOUBLE_AE(ZZ[i], i>9?1.0:0.0);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user