Fix repmat with complex bug.

This commit is contained in:
sunwen
2023-04-20 17:52:36 +08:00
parent ed7312992f
commit 84fee55eb8
2 changed files with 12 additions and 7 deletions

View File

@@ -237,7 +237,8 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes)
{ {
return Matrix(); return Matrix();
} }
int originalDataSize = aMatrix.getDataSize(); int complexStep = aMatrix.getValueType();
int originalDataSize = aMatrix.getDataSize() * complexStep;
double* resultData = Aurora::malloc(originalDataSize * aRowTimes * aColumnTimes); double* resultData = Aurora::malloc(originalDataSize * aRowTimes * aColumnTimes);
int row = aMatrix.getDimSize(0); int row = aMatrix.getDimSize(0);
int column = 1; int column = 1;
@@ -252,10 +253,10 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes)
{ {
for(int j=1; j<=aRowTimes; ++j) for(int j=1; j<=aRowTimes; ++j)
{ {
std::copy(originalData, originalData+row, resultDataTemp); std::copy(originalData, originalData+row*complexStep, resultDataTemp);
resultDataTemp += row; resultDataTemp += row*complexStep;
} }
originalData += row; originalData += row*complexStep;
} }
resultDataTemp = resultData; resultDataTemp = resultData;
@@ -273,7 +274,7 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes)
resultInfo.push_back(column); resultInfo.push_back(column);
} }
return Matrix(std::shared_ptr<double>(resultData, Aurora::free),resultInfo); return Matrix(std::shared_ptr<double>(resultData, Aurora::free),resultInfo, aMatrix.getValueType());
} }
Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes) Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes)
@@ -282,8 +283,10 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int
{ {
return Matrix(); return Matrix();
} }
int complexStep = aMatrix.getValueType();
Matrix resultTemp = Aurora::repmat(aMatrix, aRowTimes, aColumnTimes); Matrix resultTemp = Aurora::repmat(aMatrix, aRowTimes, aColumnTimes);
int resultTempDataSize = resultTemp.getDataSize(); int resultTempDataSize = resultTemp.getDataSize() * complexStep;
double* resultData = Aurora::malloc(resultTempDataSize * aSliceTimes); double* resultData = Aurora::malloc(resultTempDataSize * aSliceTimes);
std::copy(resultTemp.getData(), resultTemp.getData() + resultTempDataSize, resultData); std::copy(resultTemp.getData(), resultTemp.getData() + resultTempDataSize, resultData);
for(int i=1; i<aSliceTimes; ++i) for(int i=1; i<aSliceTimes; ++i)
@@ -307,5 +310,5 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int
resultInfo.push_back(aSliceTimes); resultInfo.push_back(aSliceTimes);
} }
return Matrix(std::shared_ptr<double>(resultData, Aurora::free),resultInfo); return Matrix(std::shared_ptr<double>(resultData, Aurora::free), resultInfo, aMatrix.getValueType());
} }

View File

@@ -46,6 +46,8 @@ namespace Aurora {
Matrix repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes); Matrix repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes);
Matrix repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes); Matrix repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes);
Matrix log(const Matrix& aMatrix, int aBaseNum = -1);
}; };