Fix repmat with complex bug.

This commit is contained in:
sunwen
2023-04-20 17:52:36 +08:00
parent ed7312992f
commit 84fee55eb8
2 changed files with 12 additions and 7 deletions

View File

@@ -237,7 +237,8 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes)
{
return Matrix();
}
int originalDataSize = aMatrix.getDataSize();
int complexStep = aMatrix.getValueType();
int originalDataSize = aMatrix.getDataSize() * complexStep;
double* resultData = Aurora::malloc(originalDataSize * aRowTimes * aColumnTimes);
int row = aMatrix.getDimSize(0);
int column = 1;
@@ -252,10 +253,10 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes)
{
for(int j=1; j<=aRowTimes; ++j)
{
std::copy(originalData, originalData+row, resultDataTemp);
resultDataTemp += row;
std::copy(originalData, originalData+row*complexStep, resultDataTemp);
resultDataTemp += row*complexStep;
}
originalData += row;
originalData += row*complexStep;
}
resultDataTemp = resultData;
@@ -273,7 +274,7 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes)
resultInfo.push_back(column);
}
return Matrix(std::shared_ptr<double>(resultData, Aurora::free),resultInfo);
return Matrix(std::shared_ptr<double>(resultData, Aurora::free),resultInfo, aMatrix.getValueType());
}
Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes)
@@ -282,8 +283,10 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int
{
return Matrix();
}
int complexStep = aMatrix.getValueType();
Matrix resultTemp = Aurora::repmat(aMatrix, aRowTimes, aColumnTimes);
int resultTempDataSize = resultTemp.getDataSize();
int resultTempDataSize = resultTemp.getDataSize() * complexStep;
double* resultData = Aurora::malloc(resultTempDataSize * aSliceTimes);
std::copy(resultTemp.getData(), resultTemp.getData() + resultTempDataSize, resultData);
for(int i=1; i<aSliceTimes; ++i)
@@ -307,5 +310,5 @@ Matrix Aurora::repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int
resultInfo.push_back(aSliceTimes);
}
return Matrix(std::shared_ptr<double>(resultData, Aurora::free),resultInfo);
return Matrix(std::shared_ptr<double>(resultData, Aurora::free), resultInfo, aMatrix.getValueType());
}

View File

@@ -46,6 +46,8 @@ namespace Aurora {
Matrix repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes);
Matrix repmat(const Matrix& aMatrix,int aRowTimes, int aColumnTimes, int aSliceTimes);
Matrix log(const Matrix& aMatrix, int aBaseNum = -1);
};