diff --git a/src/Function1D.cpp b/src/Function1D.cpp index a620f68..a886db4 100644 --- a/src/Function1D.cpp +++ b/src/Function1D.cpp @@ -614,7 +614,6 @@ Matrix Aurora::intersect(const Matrix& aMatrix1, const Matrix& aMatrix2, Matrix& iaResult[i] = j + 1; break; } - } } @@ -622,3 +621,12 @@ Matrix Aurora::intersect(const Matrix& aMatrix1, const Matrix& aMatrix2, Matrix& return result; } +Matrix Aurora::reshape(const Matrix& aMatrix, int aRows, int aColumns, int aSlices) +{ + if(aMatrix.isNull() || (aMatrix.getDataSize() != aRows * aColumns * aSlices)) + { + return Matrix(); + } + return Matrix::copyFromRawData(aMatrix.getData(),aRows,aColumns,aSlices); +} + diff --git a/src/Function1D.h b/src/Function1D.h index 485bb7d..4f40fa8 100644 --- a/src/Function1D.h +++ b/src/Function1D.h @@ -77,6 +77,8 @@ namespace Aurora { Matrix auroraUnion(const Matrix& aMatrix1, const Matrix& aMatrix2); Matrix intersect(const Matrix& aMatrix1, const Matrix& aMatrix2); + + Matrix reshape(const Matrix& aMatrix, int aRows, int aColumns, int aSlices); /** * 并集 * @param aIa, [C,ia,~] = intersect(A,B)用法中ia的返回值 diff --git a/test/Function1D_Test.cpp b/test/Function1D_Test.cpp index b4908c7..8e16f9b 100644 --- a/test/Function1D_Test.cpp +++ b/test/Function1D_Test.cpp @@ -476,3 +476,20 @@ TEST_F(Function1D_Test, intersect) { EXPECT_DOUBLE_AE(ia.getData()[1],3); EXPECT_DOUBLE_AE(ia.getData()[2],9); } + +TEST_F(Function1D_Test, reshape) { + double* data = new double[9]{3,3,2,2,2,1,4,4,7}; + auto matrix = Aurora::Matrix::fromRawData(data, 9,1,1); + auto result = Aurora::reshape(matrix,3,3,1); + EXPECT_DOUBLE_AE(result.getDimSize(0),3); + EXPECT_DOUBLE_AE(result.getDimSize(1),3); + EXPECT_DOUBLE_AE(result.getDimSize(2),1); + result = Aurora::reshape(matrix,3,1,3); + EXPECT_DOUBLE_AE(result.getDimSize(0),3); + EXPECT_DOUBLE_AE(result.getDimSize(1),1); + EXPECT_DOUBLE_AE(result.getDimSize(2),3); + result = Aurora::reshape(matrix,1,3,3); + EXPECT_DOUBLE_AE(result.getDimSize(0),1); + EXPECT_DOUBLE_AE(result.getDimSize(1),3); + EXPECT_DOUBLE_AE(result.getDimSize(2),3); +}