Add Matrix-Vector and Matrix-Matrix product
function mul to Matrix class.
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
#include "Matrix.h"
|
#include "Matrix.h"
|
||||||
|
|
||||||
|
#include <mkl_cblas.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
@@ -833,6 +834,85 @@ namespace Aurora {
|
|||||||
return Matrix::New(newBuffer,aMatrix);
|
return Matrix::New(newBuffer,aMatrix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Matrix Matrix::mul(const Matrix& aOther){
|
||||||
|
if (isNull() || aOther.isNull()){
|
||||||
|
std::cerr<<"Mul operation fail, Matrix is null!"<<std::endl;
|
||||||
|
return Matrix();
|
||||||
|
}
|
||||||
|
if (getDimSize(2)>1 || aOther.getDimSize(2)>1){
|
||||||
|
std::cerr<<"In matrix mul operation not support 3d data!"<<std::endl;
|
||||||
|
return Matrix();
|
||||||
|
}
|
||||||
|
// V x ?
|
||||||
|
if (isVector())
|
||||||
|
{
|
||||||
|
// row vector
|
||||||
|
if (getDimSize(0) > 1){
|
||||||
|
std::cerr<<"In matrix mul operation Left vector must be a column vector!"<<std::endl;
|
||||||
|
return Matrix();
|
||||||
|
}
|
||||||
|
//V x Scalar
|
||||||
|
if (aOther.isScalar())
|
||||||
|
{
|
||||||
|
return (*this)*aOther.getScalar();
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
//right size
|
||||||
|
if(getDimSize(1) == aOther.getDimSize(0))
|
||||||
|
{
|
||||||
|
auto result = deepCopy();
|
||||||
|
cblas_dgemv(CblasColMajor,CblasTrans,aOther.getDimSize(0),aOther.getDimSize(1),1.0,
|
||||||
|
aOther.getData(),aOther.getDimSize(0),getData(),1,0,result.getData(),1);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
std::cerr << "Matrix mul operation fail, can't do Matrix("
|
||||||
|
<< getDimSize(0) << "," << getDimSize(1)
|
||||||
|
<< ") * Matrix(" << aOther.getDimSize(0) << ","
|
||||||
|
<< aOther.getDimSize(1) << ")" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(isScalar()){
|
||||||
|
return getScalar()*aOther;
|
||||||
|
}
|
||||||
|
// M x ?
|
||||||
|
else if(aOther.isScalar()){
|
||||||
|
return aOther.getScalar()*(*this);
|
||||||
|
} else if (aOther.isVector()) {
|
||||||
|
// right size
|
||||||
|
if (getDimSize(1) == aOther.getDimSize(0)) {
|
||||||
|
auto result = deepCopy();
|
||||||
|
cblas_dgemv(CblasColMajor, CblasNoTrans, getDimSize(0),
|
||||||
|
getDimSize(1), 1.0, getData(),
|
||||||
|
getDimSize(0), aOther.getData(), 1, 0,
|
||||||
|
result.getData(), 1);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
std::cerr << "Matrix mul operation fail, can't do Matrix("
|
||||||
|
<< getDimSize(0) << "," << getDimSize(1) << ") * Matrix("
|
||||||
|
<< aOther.getDimSize(0) << "," << aOther.getDimSize(1)
|
||||||
|
<< ")" << std::endl;
|
||||||
|
return Matrix();
|
||||||
|
}
|
||||||
|
//MxM
|
||||||
|
else {
|
||||||
|
// fit size
|
||||||
|
if (getDimSize(1) == aOther.getDimSize(0))
|
||||||
|
{
|
||||||
|
double * output = malloc(getDimSize(0)*aOther.getDimSize(1));
|
||||||
|
int M = getDimSize(0);
|
||||||
|
int N = aOther.getDimSize(1);
|
||||||
|
int K = getDimSize(1);
|
||||||
|
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, 1.0, getData(), M, aOther.getData(),K, 0, output, M);
|
||||||
|
return Matrix::New(output,M,N);
|
||||||
|
}
|
||||||
|
std::cerr << "Matrix mul operation fail, can't do Matrix("
|
||||||
|
<< getDimSize(0) << "," << getDimSize(1) << ") * Matrix("
|
||||||
|
<< aOther.getDimSize(0) << "," << aOther.getDimSize(1)
|
||||||
|
<< ")" << std::endl;
|
||||||
|
return Matrix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Matrix::MatrixSlice::MatrixSlice(int aSize,int aStride, double* aData, ValueType aType, int aSliceMode,int aSize2, int aStride2):
|
Matrix::MatrixSlice::MatrixSlice(int aSize,int aStride, double* aData, ValueType aType, int aSliceMode,int aSize2, int aStride2):
|
||||||
mSliceMode(aSliceMode),mData(aData),
|
mSliceMode(aSliceMode),mData(aData),
|
||||||
mSize(aSize),mSize2(aSize2),
|
mSize(aSize),mSize2(aSize2),
|
||||||
|
|||||||
17
src/Matrix.h
17
src/Matrix.h
@@ -169,6 +169,23 @@ namespace Aurora {
|
|||||||
friend Matrix operator==(double aScalar, const Matrix &matrix);
|
friend Matrix operator==(double aScalar, const Matrix &matrix);
|
||||||
Matrix operator==(const Matrix &matrix) const;
|
Matrix operator==(const Matrix &matrix) const;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 矩阵乘法
|
||||||
|
* @attention 目前只支持矩阵乘向量
|
||||||
|
* @param aOther
|
||||||
|
* @return Matrix
|
||||||
|
*/
|
||||||
|
Matrix mul(const Matrix& aOther);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 矩阵乘法
|
||||||
|
* @attention 目前只支持矩阵乘向量
|
||||||
|
* @param aOther
|
||||||
|
* @return Matrix
|
||||||
|
*/
|
||||||
|
Matrix mul(Matrix&& aOther);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* print matrix , only support 2d matrix now
|
* print matrix , only support 2d matrix now
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -357,4 +357,32 @@ TEST_F(Matrix_Test, matrixCompare){
|
|||||||
EXPECT_EQ(C.getData()[2], 1);
|
EXPECT_EQ(C.getData()[2], 1);
|
||||||
EXPECT_EQ(C.getData()[3], 0);
|
EXPECT_EQ(C.getData()[3], 0);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(Matrix_Test, matrixfunction){
|
||||||
|
double *dataA = new double[9]{1,6,9,4,1,0,5,8,1};
|
||||||
|
double *dataB = new double[3]{1,2,3};
|
||||||
|
Aurora::Matrix A = Aurora::Matrix::fromRawData(dataA, 3, 3);
|
||||||
|
Aurora::Matrix B = Aurora::Matrix::fromRawData(dataB, 1, 3);
|
||||||
|
auto C = B.mul(A);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[0], 40);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[1], 6);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[2], 24);
|
||||||
|
B.forceReshape(3, 1, 1);
|
||||||
|
C = A.mul(B);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[0], 24);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[1], 32);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[2], 12);
|
||||||
|
double *dataD = new double[9]{2.1,3,9,-3,1,0,51,-8,1};
|
||||||
|
Aurora::Matrix D = Aurora::Matrix::fromRawData(dataD, 3, 3);
|
||||||
|
C = A.mul(D);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[0], 59.1);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[1], 87.6);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[2], 27.9);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[3], 1);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[4], -17);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[5], -27);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[6], 24);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[7], 306);
|
||||||
|
EXPECT_DOUBLE_EQ(C.getData()[8], 460);
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user