Add repmat3d.

This commit is contained in:
sunwen
2023-06-08 15:23:54 +08:00
parent d6bb2dcbbf
commit 98425b3eb8
2 changed files with 34 additions and 0 deletions

View File

@@ -442,6 +442,38 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int
return Matrix(std::shared_ptr<double>(resultData, Aurora::free), resultInfo, aMatrix.getValueType());
}
Matrix Aurora::repmat3d(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes)
{
if(aRowTimes < 1 || aColumnTimes < 1 || aMatrix.getDims() < 3 || aMatrix.isNull())
{
return Matrix();
}
double* start = aMatrix.getData();
int rows = aMatrix.getDimSize(0);
int columns = aMatrix.getDimSize(1);
int slices = aMatrix.getDimSize(2);
double* extended2DimsData = Aurora::malloc(rows * columns * aRowTimes * aColumnTimes * slices);
Matrix extended2DimsMatrix = Matrix::New(extended2DimsData, aRowTimes*rows, aColumnTimes*columns, slices);
for(int i=0; i<aMatrix.getDimSize(2); ++i)
{
Matrix dim2Matrix = Matrix::copyFromRawData(start, rows, columns);
Matrix extendedTemp = repmat(dim2Matrix, aRowTimes, aColumnTimes);
cblas_dcopy(extendedTemp.getDataSize(), extendedTemp.getData(), 1, extended2DimsData, 1);
extended2DimsData += extendedTemp.getDataSize();
start += columns * rows;
}
double* extended3DimsData = Aurora::malloc(rows * columns * aRowTimes * aColumnTimes * aSliceTimes * slices);
Matrix result = Matrix::New(extended3DimsData, aRowTimes*rows, aColumnTimes*columns, slices * aSliceTimes);
for(int i=0;i<aSliceTimes;++i)
{
cblas_dcopy(extended2DimsMatrix.getDataSize(), extended2DimsMatrix.getData(), 1, extended3DimsData, 1);
extended3DimsData+=extended2DimsMatrix.getDataSize();
}
return result;
}
Matrix Aurora::polyval(const Matrix &aP, const Matrix &aX) {
auto result = malloc(aX.getDataSize());

View File

@@ -57,6 +57,8 @@ namespace Aurora {
Matrix repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes);
Matrix repmat3d(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes);
Matrix log(const Matrix& aMatrix, int aBaseNum = -1);
Matrix exp(const Matrix& aMatrix);