Add block function to Matrix.

This commit is contained in:
kradchen
2023-05-17 10:39:19 +08:00
parent f9e7df3f1c
commit 7babd216de
3 changed files with 86 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
#include "Matrix.h"
#include <cmath>
#include <cstddef>
#include <mkl_cblas.h>
#include <string>
@@ -11,6 +12,8 @@
#include <Eigen/Core>
#include <Eigen/Dense>
#include "Eigen/src/Core/Map.h"
#include "Eigen/src/Core/Matrix.h"
#include "Function.h"
namespace Aurora{
@@ -473,6 +476,64 @@ namespace Aurora {
double& Matrix::operator[](size_t index) { return getData()[index];}
double Matrix::operator[](size_t index) const { return getData()[index];}
Matrix Matrix::block(int aDim,int aBeginIndex, int aEndIndex){
if(aDim>2 ){
std::cerr<<"block only support 1D-3D data!"<<std::endl;
return Matrix();
}
if (aBeginIndex>=getDimSize(aDim) || aBeginIndex<0){
std::cerr<<"block BeginIndx error!BeginIndx:"<<aBeginIndex<<std::endl;
return Matrix();
}
if (aEndIndex>=getDimSize(aDim) || aEndIndex<0){
std::cerr<<"block EndIndex error!EndIndex:"<<aEndIndex<<std::endl;
return Matrix();
}
int dimLength = std::abs(aEndIndex-aBeginIndex)+1;
int dataSize = getDataSize()/getDimSize(aDim)*dimLength;
double * dataOutput = malloc(dataSize);
int colStride = getDimSize(0);
int sliceStride = getDimSize(0)*getDimSize(1);
switch (aDim) {
case 0:{
int colStride2 = dimLength;
int sliceStride2 = dimLength*getDimSize(1);
for (size_t i = 0; i < getDimSize(2); i++)
{
for (size_t j = 0; j < getDimSize(1); j++)
{
cblas_dcopy(
dimLength,
getData() + aBeginIndex + j * colStride +
i * sliceStride,
1, dataOutput + colStride2 * j + i * sliceStride2, 1);
}
}
return Matrix::New(dataOutput,dimLength,getDimSize(1),getDimSize(2));
}
case 1:{
int colStride2 = getDimSize(0);
int sliceStride2 = dimLength*getDimSize(0);
int copySize = sliceStride2;
for (size_t i = 0; i < getDimSize(2); i++)
{
cblas_dcopy(copySize,
getData() + aBeginIndex * colStride +
i * sliceStride,
1, dataOutput + i * copySize, 1);
}
return Matrix::New(dataOutput,getDimSize(0),dimLength,getDimSize(2));
}
case 2:{
int copySize = dimLength*sliceStride;
cblas_dcopy(copySize, getData() + aBeginIndex * sliceStride ,1, dataOutput, 1);
return Matrix::New(dataOutput,getDimSize(0),getDimSize(1),dimLength);
}
}
}
void Matrix::printf() {
if(isNull())

View File

@@ -173,7 +173,15 @@ namespace Aurora {
double& operator[](size_t index);
double operator[](size_t index) const;
/**
* 切块操作
*
* @param aDim 需要切块的维度,
* @param aBeginIndx 起始索引,包含
* @param aEndIndex 终止索引,包含
* @return Matrix 返回矩阵
*/
Matrix block(int aDim,int aBeginIndx, int aEndIndex);
/**
* 矩阵乘法

View File

@@ -158,6 +158,22 @@ TEST_F(Matrix_Test, matrixSlice) {
auto D = C(0, 0, 0).toMatrix();
EXPECT_EQ(1, D.getDataSize());
EXPECT_EQ(9, D.getData()[0]);
double *dataD = Aurora::malloc(27);
for (int i = 0; i < 27; ++i) {
dataD[i] = (double) (i);
}
Aurora::Matrix D1 = Aurora::Matrix::New(dataD, 3, 3, 3);
auto r1 = D1.block(0, 0, 1);
EXPECT_EQ(2,r1.getDimSize(0));
EXPECT_EQ(3,r1.getData()[2]);
auto r2 = D1.block(1, 0, 0);
EXPECT_EQ(1,r2.getDimSize(1));
EXPECT_EQ(10,r2.getData()[4]);
auto r3 = D1.block(2, 1, 2);
EXPECT_EQ(2,r3.getDimSize(2));
EXPECT_EQ(9,r3.getData()[0]);
}
TEST_F(Matrix_Test, matrixOpertaor) {