From 461c856bd93a63093c3284dadddd62d73fa452af Mon Sep 17 00:00:00 2001 From: kradchen Date: Fri, 9 Jun 2023 16:56:56 +0800 Subject: [PATCH] Add slice support to horncat and vertcat --- src/Function1D.cpp | 32 ++++++++++++++++++-------------- test/Function1D_Test.cpp | 12 ++++++++++++ 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/Function1D.cpp b/src/Function1D.cpp index f7ca533..7c364a1 100644 --- a/src/Function1D.cpp +++ b/src/Function1D.cpp @@ -678,41 +678,45 @@ Matrix Aurora::transpose(const Matrix& aMatrix) Matrix Aurora::horzcat(const Matrix& aMatrix1, const Matrix& aMatrix2) { - if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getDims() !=2 || aMatrix2.getDims() !=2 || + if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getDimSize(2) != aMatrix2.getDimSize(2) || aMatrix1.getDimSize(0) != aMatrix2.getDimSize(0) || aMatrix1.getValueType() != aMatrix2.getValueType()) { return Matrix(); } int column1 = aMatrix1.getDimSize(1); int column2 = aMatrix2.getDimSize(1); + int slice = aMatrix1.getDimSize(2); int row = aMatrix1.getDimSize(0); - size_t size1= aMatrix1.getDataSize(); - size_t size2= aMatrix2.getDataSize(); - double* resultData = Aurora::malloc(size1 + size2,aMatrix1.getValueType()); - cblas_dcopy(size1, aMatrix1.getData(), 1, resultData, 1); - cblas_dcopy(size2, aMatrix2.getData(), 1, resultData + size1, 1); - - return Matrix::New(resultData, row, column1+column2, 1, aMatrix1.getValueType()); + size_t size1= row*column1; + size_t size2= row*column2; + double* resultData = Aurora::malloc(aMatrix1.getDataSize() + aMatrix2.getDataSize(),aMatrix1.getValueType()); + size_t sliceStride = row*(column1+column2); + for (size_t i = 0; i < slice; i++) + { + cblas_dcopy(size1, aMatrix1.getData()+i*size1 , 1, resultData + i*sliceStride, 1); + cblas_dcopy(size2, aMatrix2.getData()+i*size2, 1, resultData + i*sliceStride + size1, 1); + } + return Matrix::New(resultData, row, column1+column2, slice, aMatrix1.getValueType()); } Matrix Aurora::vertcat(const Matrix& aMatrix1, const Matrix& aMatrix2){ - { - if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getDims() !=2 || aMatrix2.getDims() !=2 || + + if(aMatrix1.isNull() || aMatrix2.isNull() || aMatrix1.getDimSize(2) != aMatrix2.getDimSize(2) || aMatrix1.getDimSize(1) != aMatrix2.getDimSize(1) || aMatrix1.getValueType() != aMatrix2.getValueType()) { return Matrix(); } int row1 = aMatrix1.getDimSize(0); int row2 = aMatrix2.getDimSize(0); + int slice = aMatrix1.getDimSize(2); int column = aMatrix1.getDimSize(1); size_t size1= aMatrix1.getDataSize(); size_t size2= aMatrix2.getDataSize(); double* resultData = Aurora::malloc(size1 + size2,aMatrix1.getValueType()); - cblas_dcopy_batch_strided(row1, aMatrix1.getData(), 1,row1, resultData, 1, row1+row2, column); - cblas_dcopy_batch_strided(row2, aMatrix2.getData(), 1,row2, resultData + row1, 1, row1+row2, column); + cblas_dcopy_batch_strided(row1, aMatrix1.getData(), 1,row1, resultData, 1, row1+row2, column*slice); + cblas_dcopy_batch_strided(row2, aMatrix2.getData(), 1,row2, resultData + row1, 1, row1+row2, column*slice); - return Matrix::New(resultData, row1+row2, column, 1, aMatrix1.getValueType()); -} + return Matrix::New(resultData, row1+row2, column, slice, aMatrix1.getValueType()); } Matrix Aurora::vecnorm(const Matrix& aMatrix, NormMethod aNormMethod, int aDim) diff --git a/test/Function1D_Test.cpp b/test/Function1D_Test.cpp index be24e2b..d30a1be 100644 --- a/test/Function1D_Test.cpp +++ b/test/Function1D_Test.cpp @@ -429,6 +429,18 @@ TEST_F(Function1D_Test, horzcat) { EXPECT_DOUBLE_EQ(result.getData()[9],6); EXPECT_DOUBLE_EQ(result.getDimSize(0),3); EXPECT_DOUBLE_EQ(result.getDimSize(1),2); + + auto A = Aurora::Matrix::New(Aurora::random(27),3,3,3); + auto B = Aurora::Matrix::New(Aurora::random(18),3,2,3); + auto C = Aurora::horzcat(A, B); + A.printf(); + B.printf(); + C.printf(); + B = Aurora::Matrix::New(Aurora::random(18),2,3,3); + C = Aurora::vertcat(A, B); + A.printf(); + B.printf(); + C.printf(); } TEST_F(Function1D_Test, vertcat) {