Add slice support to horncat and vertcat

This commit is contained in:
kradchen
2023-06-09 16:56:56 +08:00
parent 428d75e548
commit 461c856bd9
2 changed files with 30 additions and 14 deletions

View File

@@ -678,41 +678,45 @@ Matrix Aurora::transpose(const Matrix& aMatrix)
Matrix Aurora::horzcat(const Matrix& aMatrix1, const Matrix& aMatrix2)
{
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getDims() !=2 || aMatrix2.getDims() !=2 ||
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getDimSize(2) != aMatrix2.getDimSize(2) ||
aMatrix1.getDimSize(0) != aMatrix2.getDimSize(0) || aMatrix1.getValueType() != aMatrix2.getValueType())
{
return Matrix();
}
int column1 = aMatrix1.getDimSize(1);
int column2 = aMatrix2.getDimSize(1);
int slice = aMatrix1.getDimSize(2);
int row = aMatrix1.getDimSize(0);
size_t size1= aMatrix1.getDataSize();
size_t size2= aMatrix2.getDataSize();
double* resultData = Aurora::malloc(size1 + size2,aMatrix1.getValueType());
cblas_dcopy(size1, aMatrix1.getData(), 1, resultData, 1);
cblas_dcopy(size2, aMatrix2.getData(), 1, resultData + size1, 1);
return Matrix::New(resultData, row, column1+column2, 1, aMatrix1.getValueType());
size_t size1= row*column1;
size_t size2= row*column2;
double* resultData = Aurora::malloc(aMatrix1.getDataSize() + aMatrix2.getDataSize(),aMatrix1.getValueType());
size_t sliceStride = row*(column1+column2);
for (size_t i = 0; i < slice; i++)
{
cblas_dcopy(size1, aMatrix1.getData()+i*size1 , 1, resultData + i*sliceStride, 1);
cblas_dcopy(size2, aMatrix2.getData()+i*size2, 1, resultData + i*sliceStride + size1, 1);
}
return Matrix::New(resultData, row, column1+column2, slice, aMatrix1.getValueType());
}
Matrix Aurora::vertcat(const Matrix& aMatrix1, const Matrix& aMatrix2){
{
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getDims() !=2 || aMatrix2.getDims() !=2 ||
if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getDimSize(2) != aMatrix2.getDimSize(2) ||
aMatrix1.getDimSize(1) != aMatrix2.getDimSize(1) || aMatrix1.getValueType() != aMatrix2.getValueType())
{
return Matrix();
}
int row1 = aMatrix1.getDimSize(0);
int row2 = aMatrix2.getDimSize(0);
int slice = aMatrix1.getDimSize(2);
int column = aMatrix1.getDimSize(1);
size_t size1= aMatrix1.getDataSize();
size_t size2= aMatrix2.getDataSize();
double* resultData = Aurora::malloc(size1 + size2,aMatrix1.getValueType());
cblas_dcopy_batch_strided(row1, aMatrix1.getData(), 1,row1, resultData, 1, row1+row2, column);
cblas_dcopy_batch_strided(row2, aMatrix2.getData(), 1,row2, resultData + row1, 1, row1+row2, column);
cblas_dcopy_batch_strided(row1, aMatrix1.getData(), 1,row1, resultData, 1, row1+row2, column*slice);
cblas_dcopy_batch_strided(row2, aMatrix2.getData(), 1,row2, resultData + row1, 1, row1+row2, column*slice);
return Matrix::New(resultData, row1+row2, column, 1, aMatrix1.getValueType());
}
return Matrix::New(resultData, row1+row2, column, slice, aMatrix1.getValueType());
}
Matrix Aurora::vecnorm(const Matrix& aMatrix, NormMethod aNormMethod, int aDim)

View File

@@ -429,6 +429,18 @@ TEST_F(Function1D_Test, horzcat) {
EXPECT_DOUBLE_EQ(result.getData()[9],6);
EXPECT_DOUBLE_EQ(result.getDimSize(0),3);
EXPECT_DOUBLE_EQ(result.getDimSize(1),2);
auto A = Aurora::Matrix::New(Aurora::random(27),3,3,3);
auto B = Aurora::Matrix::New(Aurora::random(18),3,2,3);
auto C = Aurora::horzcat(A, B);
A.printf();
B.printf();
C.printf();
B = Aurora::Matrix::New(Aurora::random(18),2,3,3);
C = Aurora::vertcat(A, B);
A.printf();
B.printf();
C.printf();
}
TEST_F(Function1D_Test, vertcat) {