This commit is contained in:
sunwen
2023-05-12 16:30:18 +08:00
3 changed files with 39 additions and 8 deletions

2
.gitignore vendored
View File

@@ -1,5 +1,5 @@
.idea/ .idea/
cmake-build-debug/* cmake-*/*
cmake-build-debug-wsl/* cmake-build-debug-wsl/*
build*/ build*/
.vscode/ .vscode/

View File

@@ -148,6 +148,13 @@ namespace Aurora
Matrix dot(const Matrix &aMatrix, const Matrix &aOther, FunctionDirection direction = Column); Matrix dot(const Matrix &aMatrix, const Matrix &aOther, FunctionDirection direction = Column);
/**
* 转换下标为索引值
* @attention 索引值按照其实为1与matlab对应在C++中使用需要-1
* @param aVMatrixSize
* @param aSliceIdxs
* @return
*/
Matrix sub2ind(const Matrix &aVMatrixSize, std::initializer_list<Matrix> aSliceIdxs); Matrix sub2ind(const Matrix &aVMatrixSize, std::initializer_list<Matrix> aSliceIdxs);
}; };

View File

@@ -915,23 +915,45 @@ namespace Aurora {
std::cerr <<"Assign value fail!Src data pointer is null!"; std::cerr <<"Assign value fail!Src data pointer is null!";
return *this; return *this;
} }
if (matrix.getDims()!=mSliceMode) { switch (mSliceMode) {
std::cerr <<"Assign value fail!Src matrix(dims count:"<< matrix.getDims() <<"), not match of des(dims count:"<<mSliceMode<<")!"; case 2://matrix mode
return *this; {
if(matrix.getDims()!=2 || matrix.getDimSize(1)<=1)
{
std::cerr <<"Assign value fail!Src matrix(dims count:"<< matrix.getDims() <<"), not match of des(dims count:"<<mSliceMode<<")!";
return *this;
}
if (matrix.getDimSize(1)!=mSize2) {
std::cerr <<"Assign value fail!Src slice(dim 2 size:"<< matrix.getDimSize(1) <<"), not match of des(dim 2 size:"<<mSize2<<")!";
return *this;
}
break;
}
case 1:{
if(!matrix.isVector())
{
std::cerr <<"Assign value fail!Src matrix(dims count:"<< matrix.getDims() <<"), not match of des(dims count:"<<mSliceMode<<")!";
return *this;
}
break;
}
case 0:{
if(!matrix.isScalar()){
std::cerr <<"Assign value fail!Src matrix(dims count:"<< matrix.getDims() <<"), not match of des(dims count:"<<mSliceMode<<")!";
return *this;
}
}
} }
if (matrix.getDimSize(0)!=mSize) { if (matrix.getDimSize(0)!=mSize) {
std::cerr <<"Assign value fail!Src matrix(dim 1 size:"<< matrix.getDimSize(0)<<"), not match of des(dim 1 size:"<<mSize<<")!"; std::cerr <<"Assign value fail!Src matrix(dim 1 size:"<< matrix.getDimSize(0)<<"), not match of des(dim 1 size:"<<mSize<<")!";
return *this; return *this;
} }
if (matrix.getDimSize(1)!=mSize2) {
std::cerr <<"Assign value fail!Src slice(dim 2 size:"<< matrix.getDimSize(1) <<"), not match of des(dim 2 size:"<<mSize2<<")!";
return *this;
}
if (matrix.getValueType()!=mType) { if (matrix.getValueType()!=mType) {
std::cerr <<"Assign value fail!Src slice(value type:"<< matrix.getValueType() <<"), not match of des(value type:"<<mType<<")!"; std::cerr <<"Assign value fail!Src slice(value type:"<< matrix.getValueType() <<"), not match of des(value type:"<<mType<<")!";
return *this; return *this;
} }
switch (mSliceMode) { switch (mSliceMode) {
//matrix mode
case 2:{ case 2:{
if (mType== Normal) { if (mType== Normal) {
cblas_dcopy_batch_strided(mSize, matrix.getData(), 1, matrix.getDimSize(0), mData, mStride, cblas_dcopy_batch_strided(mSize, matrix.getData(), 1, matrix.getDimSize(0), mData, mStride,
@@ -943,6 +965,7 @@ namespace Aurora {
} }
break; break;
} }
//vector mode
case 1:{ case 1:{
if (mType== Normal){ if (mType== Normal){
cblas_dcopy(mSize,matrix.getData(),1,mData,mStride); cblas_dcopy(mSize,matrix.getData(),1,mData,mStride);
@@ -953,6 +976,7 @@ namespace Aurora {
} }
break; break;
} }
//scalar mode
case 0: case 0:
default:{ default:{
mData[0] = matrix.getData()[0]; mData[0] = matrix.getData()[0];