Add Matrix-Vector and Matrix-Matrix product

function mul to Matrix class.
This commit is contained in:
kradchen
2023-05-16 10:00:53 +08:00
parent 0148966137
commit b2bacd8894
3 changed files with 125 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
#include "Matrix.h"
#include <mkl_cblas.h>
#include <string>
#include <cstring>
#include <iostream>
@@ -833,6 +834,85 @@ namespace Aurora {
return Matrix::New(newBuffer,aMatrix);
}
Matrix Matrix::mul(const Matrix& aOther){
if (isNull() || aOther.isNull()){
std::cerr<<"Mul operation fail, Matrix is null!"<<std::endl;
return Matrix();
}
if (getDimSize(2)>1 || aOther.getDimSize(2)>1){
std::cerr<<"In matrix mul operation not support 3d data!"<<std::endl;
return Matrix();
}
// V x ?
if (isVector())
{
// row vector
if (getDimSize(0) > 1){
std::cerr<<"In matrix mul operation Left vector must be a column vector!"<<std::endl;
return Matrix();
}
//V x Scalar
if (aOther.isScalar())
{
return (*this)*aOther.getScalar();
}
else{
//right size
if(getDimSize(1) == aOther.getDimSize(0))
{
auto result = deepCopy();
cblas_dgemv(CblasColMajor,CblasTrans,aOther.getDimSize(0),aOther.getDimSize(1),1.0,
aOther.getData(),aOther.getDimSize(0),getData(),1,0,result.getData(),1);
return result;
}
std::cerr << "Matrix mul operation fail, can't do Matrix("
<< getDimSize(0) << "," << getDimSize(1)
<< ") * Matrix(" << aOther.getDimSize(0) << ","
<< aOther.getDimSize(1) << ")" << std::endl;
}
}
else if(isScalar()){
return getScalar()*aOther;
}
// M x ?
else if(aOther.isScalar()){
return aOther.getScalar()*(*this);
} else if (aOther.isVector()) {
// right size
if (getDimSize(1) == aOther.getDimSize(0)) {
auto result = deepCopy();
cblas_dgemv(CblasColMajor, CblasNoTrans, getDimSize(0),
getDimSize(1), 1.0, getData(),
getDimSize(0), aOther.getData(), 1, 0,
result.getData(), 1);
return result;
}
std::cerr << "Matrix mul operation fail, can't do Matrix("
<< getDimSize(0) << "," << getDimSize(1) << ") * Matrix("
<< aOther.getDimSize(0) << "," << aOther.getDimSize(1)
<< ")" << std::endl;
return Matrix();
}
//MxM
else {
// fit size
if (getDimSize(1) == aOther.getDimSize(0))
{
double * output = malloc(getDimSize(0)*aOther.getDimSize(1));
int M = getDimSize(0);
int N = aOther.getDimSize(1);
int K = getDimSize(1);
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, 1.0, getData(), M, aOther.getData(),K, 0, output, M);
return Matrix::New(output,M,N);
}
std::cerr << "Matrix mul operation fail, can't do Matrix("
<< getDimSize(0) << "," << getDimSize(1) << ") * Matrix("
<< aOther.getDimSize(0) << "," << aOther.getDimSize(1)
<< ")" << std::endl;
return Matrix();
}
}
Matrix::MatrixSlice::MatrixSlice(int aSize,int aStride, double* aData, ValueType aType, int aSliceMode,int aSize2, int aStride2):
mSliceMode(aSliceMode),mData(aData),
mSize(aSize),mSize2(aSize2),

View File

@@ -169,6 +169,23 @@ namespace Aurora {
friend Matrix operator==(double aScalar, const Matrix &matrix);
Matrix operator==(const Matrix &matrix) const;
/**
* 矩阵乘法
* @attention 目前只支持矩阵乘向量
* @param aOther
* @return Matrix
*/
Matrix mul(const Matrix& aOther);
/**
* 矩阵乘法
* @attention 目前只支持矩阵乘向量
* @param aOther
* @return Matrix
*/
Matrix mul(Matrix&& aOther);
/**
* print matrix , only support 2d matrix now
*/

View File

@@ -357,4 +357,32 @@ TEST_F(Matrix_Test, matrixCompare){
EXPECT_EQ(C.getData()[2], 1);
EXPECT_EQ(C.getData()[3], 0);
}
TEST_F(Matrix_Test, matrixfunction){
double *dataA = new double[9]{1,6,9,4,1,0,5,8,1};
double *dataB = new double[3]{1,2,3};
Aurora::Matrix A = Aurora::Matrix::fromRawData(dataA, 3, 3);
Aurora::Matrix B = Aurora::Matrix::fromRawData(dataB, 1, 3);
auto C = B.mul(A);
EXPECT_DOUBLE_EQ(C.getData()[0], 40);
EXPECT_DOUBLE_EQ(C.getData()[1], 6);
EXPECT_DOUBLE_EQ(C.getData()[2], 24);
B.forceReshape(3, 1, 1);
C = A.mul(B);
EXPECT_DOUBLE_EQ(C.getData()[0], 24);
EXPECT_DOUBLE_EQ(C.getData()[1], 32);
EXPECT_DOUBLE_EQ(C.getData()[2], 12);
double *dataD = new double[9]{2.1,3,9,-3,1,0,51,-8,1};
Aurora::Matrix D = Aurora::Matrix::fromRawData(dataD, 3, 3);
C = A.mul(D);
EXPECT_DOUBLE_EQ(C.getData()[0], 59.1);
EXPECT_DOUBLE_EQ(C.getData()[1], 87.6);
EXPECT_DOUBLE_EQ(C.getData()[2], 27.9);
EXPECT_DOUBLE_EQ(C.getData()[3], 1);
EXPECT_DOUBLE_EQ(C.getData()[4], -17);
EXPECT_DOUBLE_EQ(C.getData()[5], -27);
EXPECT_DOUBLE_EQ(C.getData()[6], 24);
EXPECT_DOUBLE_EQ(C.getData()[7], 306);
EXPECT_DOUBLE_EQ(C.getData()[8], 460);
}