Add Matrix-Vector and Matrix-Matrix product

function mul to Matrix class.
This commit is contained in:
kradchen
2023-05-16 10:00:53 +08:00
parent 0148966137
commit b2bacd8894
3 changed files with 125 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
#include "Matrix.h"
#include <mkl_cblas.h>
#include <string>
#include <cstring>
#include <iostream>
@@ -833,6 +834,85 @@ namespace Aurora {
return Matrix::New(newBuffer,aMatrix);
}
Matrix Matrix::mul(const Matrix& aOther){
if (isNull() || aOther.isNull()){
std::cerr<<"Mul operation fail, Matrix is null!"<<std::endl;
return Matrix();
}
if (getDimSize(2)>1 || aOther.getDimSize(2)>1){
std::cerr<<"In matrix mul operation not support 3d data!"<<std::endl;
return Matrix();
}
// V x ?
if (isVector())
{
// row vector
if (getDimSize(0) > 1){
std::cerr<<"In matrix mul operation Left vector must be a column vector!"<<std::endl;
return Matrix();
}
//V x Scalar
if (aOther.isScalar())
{
return (*this)*aOther.getScalar();
}
else{
//right size
if(getDimSize(1) == aOther.getDimSize(0))
{
auto result = deepCopy();
cblas_dgemv(CblasColMajor,CblasTrans,aOther.getDimSize(0),aOther.getDimSize(1),1.0,
aOther.getData(),aOther.getDimSize(0),getData(),1,0,result.getData(),1);
return result;
}
std::cerr << "Matrix mul operation fail, can't do Matrix("
<< getDimSize(0) << "," << getDimSize(1)
<< ") * Matrix(" << aOther.getDimSize(0) << ","
<< aOther.getDimSize(1) << ")" << std::endl;
}
}
else if(isScalar()){
return getScalar()*aOther;
}
// M x ?
else if(aOther.isScalar()){
return aOther.getScalar()*(*this);
} else if (aOther.isVector()) {
// right size
if (getDimSize(1) == aOther.getDimSize(0)) {
auto result = deepCopy();
cblas_dgemv(CblasColMajor, CblasNoTrans, getDimSize(0),
getDimSize(1), 1.0, getData(),
getDimSize(0), aOther.getData(), 1, 0,
result.getData(), 1);
return result;
}
std::cerr << "Matrix mul operation fail, can't do Matrix("
<< getDimSize(0) << "," << getDimSize(1) << ") * Matrix("
<< aOther.getDimSize(0) << "," << aOther.getDimSize(1)
<< ")" << std::endl;
return Matrix();
}
//MxM
else {
// fit size
if (getDimSize(1) == aOther.getDimSize(0))
{
double * output = malloc(getDimSize(0)*aOther.getDimSize(1));
int M = getDimSize(0);
int N = aOther.getDimSize(1);
int K = getDimSize(1);
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, 1.0, getData(), M, aOther.getData(),K, 0, output, M);
return Matrix::New(output,M,N);
}
std::cerr << "Matrix mul operation fail, can't do Matrix("
<< getDimSize(0) << "," << getDimSize(1) << ") * Matrix("
<< aOther.getDimSize(0) << "," << aOther.getDimSize(1)
<< ")" << std::endl;
return Matrix();
}
}
Matrix::MatrixSlice::MatrixSlice(int aSize,int aStride, double* aData, ValueType aType, int aSliceMode,int aSize2, int aStride2):
mSliceMode(aSliceMode),mData(aData),
mSize(aSize),mSize2(aSize2),

View File

@@ -169,6 +169,23 @@ namespace Aurora {
friend Matrix operator==(double aScalar, const Matrix &matrix);
Matrix operator==(const Matrix &matrix) const;
/**
* 矩阵乘法
* @attention 目前只支持矩阵乘向量
* @param aOther
* @return Matrix
*/
Matrix mul(const Matrix& aOther);
/**
* 矩阵乘法
* @attention 目前只支持矩阵乘向量
* @param aOther
* @return Matrix
*/
Matrix mul(Matrix&& aOther);
/**
* print matrix , only support 2d matrix now
*/