Files
Aurora/src/Matrix.cpp
2023-04-20 11:21:11 +08:00

565 lines
22 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "Matrix.h"
#include <string>
#include <cstring>
#include <iostream>
#include <complex>
//必须在mkl.h和Eigen的头之前<complex>之后
#define MKL_Complex16 std::complex<double>
#include "mkl.h"
#include "Function.h"
namespace Aurora{
typedef void(*CalcFuncD)(const MKL_INT n, const double a[], const MKL_INT inca, const double b[],
const MKL_INT incb, double r[], const MKL_INT incr);
typedef void(*CalcFuncZ)(const MKL_INT n, const MKL_Complex16 a[], const MKL_INT inca, const MKL_Complex16 b[],
const MKL_INT incb, MKL_Complex16 r[], const MKL_INT incr);
inline Matrix operatorMxA(CalcFuncD aFunc, double aScalar, const Matrix &aMatrix) {
double *output = malloc(aMatrix.getDataSize(), aMatrix.getValueType() == Complex);
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0, output, 1);
if (aMatrix.getValueType() == Complex) {
aFunc(aMatrix.getDataSize(), aMatrix.getData() + 1, 1, &aScalar, 0, output + 1,
1);
}
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2),
aMatrix.getValueType());
}
inline Matrix operatorMxM(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Matrix &aMatrix,
const Matrix &aOther) {
if (!aMatrix.compareShape(aOther))return Matrix();
if (aMatrix.getValueType() != aOther.getValueType()) {
double *output = malloc(aMatrix.getDataSize(), true);
if (aMatrix.getValueType() == Complex) {
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1,
output, 1);
aFuncD(aMatrix.getDataSize(), aMatrix.getData() + 1, 1, aOther.getData(), 1,
output + 1,
1);
return Matrix::New(output, aMatrix);
}
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1, output,
1);
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData() + 1, 1,
output + 1, 1);
return Matrix::New(output, aOther);
} else if (aMatrix.getValueType() == Normal) {
double *output = malloc(aMatrix.getDataSize());
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1, output,
1);
return Matrix::New(output, aMatrix);
} else {
double *output = malloc(aMatrix.getDataSize(), true);
aFuncZ(aMatrix.getDataSize(), (std::complex<double> *) aMatrix.getData(), 1,
(std::complex<double> *) aOther.getData(), 1, (std::complex<double> *) output, 1);
return Matrix::New(output, aOther);
}
}
inline Matrix &operatorMxA_RR(CalcFuncD aFunc, double aScalar, Aurora::Matrix &&aMatrix) {
std::cout << "use right ref operation" << std::endl;
std::cout << "before operation" << std::endl;
aMatrix.printf();
if (aMatrix.getValueType() == Complex) {
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0,
aMatrix.getData(),
1);
aFunc(aMatrix.getDataSize(), aMatrix.getData() + 1, 1, &aScalar, 0,
aMatrix.getData() + 1, 1);
} else {
aFunc(aMatrix.getDataSize(), aMatrix.getData(), 1, &aScalar, 0,
aMatrix.getData(),
1);
}
std::cout << "after operation" << std::endl;
aMatrix.printf();
return aMatrix;
}
inline Matrix operatorMxM_RR(CalcFuncD aFuncD, CalcFuncZ aFuncZ, const Aurora::Matrix &aMatrix,
Aurora::Matrix &&aOther) {
if (!aMatrix.compareShape(aOther))return Matrix();
std::cout << "use right ref operation m" << std::endl;
if (aMatrix.getValueType() != aOther.getValueType()) {
//aOther is not a complex matrix
if (aMatrix.getValueType() == Complex) {
double *output = malloc(aMatrix.getDataSize(), true);
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1,
output, 1);
aFuncD(aMatrix.getDataSize(), aMatrix.getData() + 1,1, aOther.getData(), 1,
output + 1,
1);
return Matrix::New(output, aOther);
}
//aOther is a complex matrix, use aOther as output
aFuncD(aMatrix.getDataSize(), aMatrix.getData(),1, aOther.getData(), 1,
aOther.getData(),
1);
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData() + 1, 1,
aOther.getData() + 1, 1);
return aOther;
} else if (aMatrix.getValueType() == Normal) {
aFuncD(aMatrix.getDataSize(), aMatrix.getData(), 1, aOther.getData(), 1,
aOther.getData(),
1);
return aOther;
} else {
aFuncZ(aMatrix.getDataSize(), (std::complex<double> *) aMatrix.getData(), 1,
(std::complex<double> *) aOther.getData(), 1, (std::complex<double> *) aOther.getData(), 1);
return aOther;
}
}
};
namespace Aurora {
Matrix::Matrix(std::shared_ptr<double> aData, std::vector<int> aInfo)
: mData(aData), mInfo(aInfo) {
}
Matrix::Matrix(const Matrix::MatrixSlice& slice) {
auto temp = slice.toMatrix();
this->mData = temp.mData;
this->mInfo = temp.mInfo;
this->mValueType = temp.mValueType;
}
bool Matrix::isNull() const {
return !mData || mInfo.empty();
}
int Matrix::getDims() const {
return mInfo.size();
}
double *Matrix::getData() const {
return mData.get();
}
int Matrix::getDimSize(int aIndex) const {
if (aIndex >= 0 && aIndex < 3 && aIndex < getDims()) {
return mInfo.at(aIndex);
}
return 0;
}
size_t Matrix::getDataSize() const {
if (!mData.get())return 0;
size_t ret = 1;
for (auto v: mInfo) {
ret *= v;
}
return ret;
}
bool Matrix::compareShape(const Matrix &other) const {
if (mInfo.size() != other.mInfo.size()) return false;
for (int i = 0; i < mInfo.size(); ++i) {
if (mInfo[i] != other.mInfo[i]) return false;
}
return true;
}
Matrix Matrix::getDataFromDims2(int aColumn) {
if (2 != getDims() || aColumn > mInfo.back()) {
return Matrix();
}
int rows = mInfo.at(0);
std::shared_ptr<double> resultData = std::shared_ptr<double>(new double[rows], std::default_delete<double[]>());
std::copy(mData.get() + (aColumn - 1) * rows, mData.get() + aColumn * rows, resultData.get());
std::vector<int> resultInfo = {rows};
Matrix result(resultData, resultInfo);
return result;
}
Matrix Matrix::getDataFromDims1(int aRow) {
if (1 != getDims() || aRow > mInfo.back()) {
return Matrix();
}
std::shared_ptr<double> resultData = std::shared_ptr<double>(new double[1], std::default_delete<double[]>());
resultData.get()[0] = mData.get()[aRow - 1];
std::vector<int> resultInfo{1};
Matrix result(resultData, resultInfo);
return result;
}
Matrix Matrix::New(double *data, int rows, int cols, int slices, ValueType type) {
if (!data) return Matrix();
std::vector<int> vector;
vector.push_back(rows);
if (cols > 0)vector.push_back(cols);
if (slices > 0)vector.push_back(slices);
Matrix ret({data, free}, vector);
if (type != Normal)ret.setValueType(type);
return ret;
}
Matrix Matrix::New(double *data, const Matrix &shapeMatrix) {
return New(data,
shapeMatrix.getDimSize(0),
shapeMatrix.getDimSize(1),
shapeMatrix.getDimSize(2),
shapeMatrix.getValueType());
}
Matrix Matrix::New(const Matrix &shapeMatrix) {
double *newBuffer = malloc(shapeMatrix.getDataSize(), shapeMatrix.getValueType());
return New(newBuffer, shapeMatrix);
}
Matrix Matrix::deepCopy() const {
double *newBuffer = malloc(getDataSize(), getValueType());
memcpy(newBuffer, getData(), sizeof(double) * getDataSize() * getValueType());
return New(newBuffer,
getDimSize(0),
getDimSize(1),
getDimSize(2),
getValueType());
}
//operation +
Matrix Matrix::operator+(double aScalar) const { return operatorMxA(&vdAddI, aScalar, *this);}
Matrix operator+(double aScalar, const Matrix &matrix) {return matrix + aScalar;}
Matrix Matrix::operator+(const Matrix &matrix) const {return operatorMxM(vdAddI, vzAddI, *this, matrix);}
Matrix &operator+(double aScalar, Matrix &&matrix) {
return operatorMxA_RR(&vdAddI,aScalar, std::forward<Matrix&&>(matrix));
}
Matrix &operator+(Matrix &&matrix,double aScalar) {
return operatorMxA_RR(&vdAddI,aScalar, std::forward<Matrix&&>(matrix));
}
Matrix Matrix::operator+(Matrix &&aMatrix) const {
return operatorMxM_RR(&vdAddI,&vzAddI,*this,std::forward<Matrix&&>(aMatrix));
}
Matrix operator+(Matrix &&aMatrix, const Matrix &aOther) {
return operatorMxM_RR(&vdAddI,&vzAddI,aOther,std::forward<Matrix&&>(aMatrix));
}
//operation -
Matrix Matrix::operator-(double aScalar) const { return operatorMxA(&vdSubI, aScalar, *this);}
Matrix operator-(double aScalar, const Matrix &matrix) {return matrix - aScalar;}
Matrix Matrix::operator-(const Matrix &matrix) const {return operatorMxM(vdSubI, vzSubI, *this, matrix);}
Matrix &operator-(double aScalar, Matrix &&matrix) {
return operatorMxA_RR(&vdSubI,aScalar, std::forward<Matrix&&>(matrix));
}
Matrix &operator-(Matrix &&matrix,double aScalar) {
return operatorMxA_RR(&vdSubI,aScalar, std::forward<Matrix&&>(matrix));
}
Matrix Matrix::operator-(Matrix &&aMatrix) const {
return operatorMxM_RR(&vdSubI,&vzSubI,*this,std::forward<Matrix&&>(aMatrix));
}
Matrix operator-(Matrix &&aMatrix, const Matrix &aOther) {
return operatorMxM_RR(&vdSubI,&vzSubI,aOther,std::forward<Matrix&&>(aMatrix));
}
//operation *
Matrix Matrix::operator*(double aScalar) const { return operatorMxA(&vdMulI, aScalar, *this);}
Matrix operator*(double aScalar, const Matrix &matrix) {return matrix * aScalar;}
Matrix Matrix::operator*(const Matrix &matrix) const {return operatorMxM(vdMulI, vzMulI, *this, matrix);}
Matrix &operator*(double aScalar, Matrix &&matrix) {
return operatorMxA_RR(&vdMulI,aScalar, std::forward<Matrix&&>(matrix));
}
Matrix &operator*(Matrix &&matrix,double aScalar) {
return operatorMxA_RR(&vdMulI,aScalar, std::forward<Matrix&&>(matrix));
}
Matrix Matrix::operator*(Matrix &&aMatrix) const {
return operatorMxM_RR(&vdMulI,&vzMulI,*this,std::forward<Matrix&&>(aMatrix));
}
Matrix operator*(Matrix &&aMatrix, const Matrix &aOther) {
return operatorMxM_RR(&vdMulI,&vzMulI,aOther,std::forward<Matrix&&>(aMatrix));
}
//operation /
Matrix Matrix::operator/(double aScalar) const { return operatorMxA(&vdDivI, aScalar, *this);}
Matrix operator/(double aScalar, const Matrix &matrix) {return matrix / aScalar;}
Matrix Matrix::operator/(const Matrix &matrix) const {return operatorMxM(vdDivI, vzDivI, *this, matrix);}
Matrix &operator/(double aScalar, Matrix &&matrix) {
return operatorMxA_RR(&vdDivI,aScalar, std::forward<Matrix&&>(matrix));
}
Matrix &operator/(Matrix &&matrix,double aScalar) {
return operatorMxA_RR(&vdDivI,aScalar, std::forward<Matrix&&>(matrix));
}
Matrix Matrix::operator/(Matrix &&aMatrix) const {
return operatorMxM_RR(&vdDivI,&vzDivI,*this,std::forward<Matrix&&>(aMatrix));
}
Matrix operator/(Matrix &&aMatrix, const Matrix &aOther) {
return operatorMxM_RR(&vdDivI,&vzDivI,aOther,std::forward<Matrix&&>(aMatrix));
}
//operator ^ (pow)
Matrix Matrix::operator^(int times) const { return operatorMxA(&vdPowI, times, *this);}
Matrix operator^( Matrix &&matrix,int times) {
return operatorMxA(&vdPowI, times, std::forward<Matrix&&>(matrix));
}
void Matrix::printf() {
int k_count = getDimSize(2)==0?1:getDimSize(2);
int j_count = getDimSize(1)==0?1:getDimSize(1);
for (int k = 0; k <k_count; ++k) {
::printf("slice %d:\r\n[",k);
for (int i = 0; i < getDimSize(0); ++i) {
::printf("[");
for (int j = 0; j < j_count; ++j) {
::printf("%f2, ",getData()[k*getDimSize(1)*getDimSize(0)+j*getDimSize(0)+i]);
}
::printf("]\r\n");
}
::printf("]\r\n");
}
}
Matrix::MatrixSlice Matrix::operator()(int aRowIdx, int aColIdx, int aSliceIdx) const {
std::vector<int> dims = {aRowIdx, aColIdx, aSliceIdx};
std::vector<int> allDimIndex;
int mode = 0;
for (int j = 0; j < 3; ++j) {
if (dims[j]==$ && this->getDims()>j){
++mode;
allDimIndex.push_back(j);
}
}
int rowStride = 1;
int colStride = getDimSize(0);
int sliceStride = getDimSize(0)*getDimSize(1);
int strides[3] = {rowStride, colStride, sliceStride};
int rowOffset = aRowIdx == $ ? 0 : aRowIdx;
int colOffset = aColIdx == $ ? 0 : aColIdx;
int sliceOffset = aSliceIdx == $ ? 0 : aSliceIdx;
double *startPointer = getData() + (rowStride * rowOffset
+ colStride * colOffset
+ sliceStride * sliceOffset) * getValueType();
int size1 = allDimIndex.empty()?1:getDimSize(allDimIndex[0]);
int stride1 = allDimIndex.empty()?1:strides[allDimIndex[0]];
switch (mode) {
//matrix mode
case 2:{
int size2 = getDimSize(allDimIndex[1]);
int stride2 = strides[allDimIndex[1]];
return Matrix::MatrixSlice(size1, stride1, startPointer, getValueType(), mode, size2, stride2);
}
//vector mode & default
case 1:
{
return Matrix::MatrixSlice(size1, stride1, startPointer, getValueType(), mode);
}
//scalar mode or ALL $
case 0:
default: {
return Matrix::MatrixSlice(1 , 1, startPointer,getValueType(), mode);
}
}
}
Matrix::MatrixSlice::MatrixSlice(int aSize,int aStride, double* aData, ValueType aType, int aSliceMode,int aSize2, int aStride2):
mSliceMode(aSliceMode),mData(aData),
mSize(aSize),mSize2(aSize2),
mStride(aStride),mStride2(aStride2),
mType(aType)
{
}
Matrix::MatrixSlice &Matrix::MatrixSlice::operator=(const Matrix::MatrixSlice &slice) {
if (this==&slice) return *this;
if(!mData){
std::cerr <<"Assign value fail!Des data pointer is null!";
return *this;
}
if(!slice.mData){
std::cerr <<"Assign value fail!Src data pointer is null!";
return *this;
}
if (slice.mSliceMode!=mSliceMode) {
std::cerr <<"Assign value fail!Src slice(dims count:"<< slice.mSliceMode <<"), not match of des(dims count:"<<mSliceMode<<")!";
return *this;
}
if (slice.mSize!=mSize) {
std::cerr <<"Assign value fail!Src slice(dim 1 size:"<< slice.mSize <<"), not match of des(dim 1 size:"<<mSize<<")!";
return *this;
}
if (slice.mSize2!=mSize2) {
std::cerr <<"Assign value fail!Src slice(dim 2 size:"<< slice.mSize2 <<"), not match of des(dim 2 size:"<<mSize2<<")!";
return *this;
}
if (slice.mType!=mType) {
std::cerr <<"Assign value fail!Src slice(value type:"<< slice.mType <<"), not match of des(value type:"<<mType<<")!";
return *this;
}
switch (mSliceMode) {
case 2:{
if (mType== Normal) {
cblas_dcopy_batch_strided(mSize, slice.mData, slice.mStride, slice.mStride2, mData, mStride,
mStride2, mSize2);
}
else {
cblas_zcopy_batch_strided(mSize,(std::complex<double>*)slice.mData,slice.mStride,slice.mStride2,
(std::complex<double>*)mData,mStride,mStride2,mSize2);
}
break;
}
case 1:{
if (mType== Normal){
cblas_dcopy(mSize,slice.mData,slice.mStride,mData,mStride);
}
else {
cblas_zcopy(mSize, (std::complex<double> *) slice.mData, slice.mStride,
(std::complex<double> *) mData, mStride);
}
break;
}
case 0:
default:{
mData[0] = slice.mData[0];
if (mType != Normal)mData[1] = slice.mData[1];
}
}
return *this;
}
Matrix::MatrixSlice &Matrix::MatrixSlice::operator=(const Matrix &matrix) {
if(!mData){
std::cerr <<"Assign value fail!Des data pointer is null!";
return *this;
}
if(!matrix.getData()){
std::cerr <<"Assign value fail!Src data pointer is null!";
return *this;
}
if (matrix.getDims()!=mSliceMode) {
std::cerr <<"Assign value fail!Src matrix(dims count:"<< matrix.getDims() <<"), not match of des(dims count:"<<mSliceMode<<")!";
return *this;
}
if (matrix.getDimSize(0)!=mSize) {
std::cerr <<"Assign value fail!Src matrix(dim 1 size:"<< matrix.getDimSize(0)<<"), not match of des(dim 1 size:"<<mSize<<")!";
return *this;
}
if (matrix.getDimSize(1)!=mSize2) {
std::cerr <<"Assign value fail!Src slice(dim 2 size:"<< matrix.getDimSize(1) <<"), not match of des(dim 2 size:"<<mSize2<<")!";
return *this;
}
if (matrix.getValueType()!=mType) {
std::cerr <<"Assign value fail!Src slice(value type:"<< matrix.getValueType() <<"), not match of des(value type:"<<mType<<")!";
return *this;
}
switch (mSliceMode) {
case 2:{
if (mType== Normal) {
cblas_dcopy_batch_strided(mSize, matrix.getData(), 1, matrix.getDimSize(0), mData, mStride,
mStride2, mSize2);
}
else {
cblas_zcopy_batch_strided(mSize,(std::complex<double>*)matrix.getData(),1,matrix.getDimSize(0),
(std::complex<double>*)mData,mStride,mStride2,mSize2);
}
break;
}
case 1:{
if (mType== Normal){
cblas_dcopy(mSize,matrix.getData(),1,mData,mStride);
}
else {
cblas_zcopy(mSize, (std::complex<double> *) matrix.getData(),1,
(std::complex<double> *) mData, mStride);
}
break;
}
case 0:
default:{
mData[0] = matrix.getData()[0];
if (mType != Normal)mData[1] = matrix.getData()[1];
}
}
return *this;
}
Matrix::MatrixSlice &Matrix::MatrixSlice::operator=(double value) {
if(!mData){
std::cerr <<"Assign value fail!Des data pointer is null!";
return *this;
}
if (mSliceMode!=0) {
std::cerr <<"Assign value fail!Des slicemode is"<<mSliceMode<<", not scalar mode!";
return *this;
}
if (mSize!=1) {
std::cerr <<"Assign value fail!Des size:"<<mSize<<", not scalar mode!";
return *this;
}
if (mType!=Normal) {
std::cerr <<"Assign value fail!Des type is complex!";
return *this;
}
mData[0]=value;
return *this;
}
Matrix::MatrixSlice &Matrix::MatrixSlice::operator=(std::complex<double> value) {
if(!mData){
std::cerr <<"Assign value fail!Des data pointer is null!";
return *this;
}
if (mSliceMode!=0) {
std::cerr <<"Assign value fail!Des slicemode is"<<mSliceMode<<", not scalar mode!";
return *this;
}
if (mSize!=1) {
std::cerr <<"Assign value fail!Des size:"<<mSize<<", not scalar mode!";
return *this;
}
if (mType!=Complex) {
std::cerr <<"Assign value fail!Des type is not complex!";
return *this;
}
mData[0]=value.real();
mData[1]=value.imag();
return *this;
}
Matrix Matrix::MatrixSlice::toMatrix() const {
double * data = (double *) mkl_malloc(mSize*(mSize2>0?mSize2:1) * sizeof(double)*mType, 64);
switch (mSliceMode) {
case 2:{
if (mType== Normal) {
cblas_dcopy_batch_strided(mSize, mData, mStride,
mStride2,data, 1, mSize, mSize2);
}
else {
cblas_zcopy_batch_strided(mSize, (std::complex<double> *) mData, mStride, mStride2,
(std::complex<double> *) data, 1, mSize,
mSize2);
}
break;
}
case 1:{
if (mType== Normal){
cblas_dcopy(mSize,mData,mStride,data,1);
}
else {
cblas_zcopy(mSize, (std::complex<double> *) mData, mStride,
(std::complex<double> *) data, 1);
}
break;
}
case 0:
default:{
data[0]= mData[0];
if (mType != Normal) data[1] = mData[1];
}
}
return Matrix::New(data,mSize,mSize2,0,mType);
}
}