279 lines
8.4 KiB
C++
279 lines
8.4 KiB
C++
#ifndef MATRIX_H
|
||
#define MATRIX_H
|
||
|
||
#include <memory>
|
||
#include <complex>
|
||
#include <vector>
|
||
|
||
|
||
namespace Aurora {
|
||
enum ValueType{
|
||
Normal=1,
|
||
Complex
|
||
};
|
||
const int $ = -1;
|
||
|
||
class Matrix {
|
||
public:
|
||
/**
|
||
* 内部类MatrixSlice,用于切片操作
|
||
*/
|
||
class MatrixSlice{
|
||
public:
|
||
MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0);
|
||
MatrixSlice& operator=(const MatrixSlice& slice);
|
||
MatrixSlice& operator=(const Matrix& matrix);
|
||
MatrixSlice& operator=(double value);
|
||
MatrixSlice& operator=(std::complex<double> value);
|
||
Matrix toMatrix() const;
|
||
private:
|
||
int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix
|
||
double* mData;
|
||
int mSize=0;
|
||
int mSize2=0;
|
||
int mStride=1;
|
||
int mStride2=0;
|
||
ValueType mType;
|
||
friend class Matrix;
|
||
};
|
||
explicit Matrix(std::shared_ptr<double> aData = std::shared_ptr<double>(),
|
||
std::vector<int> aInfo = std::vector<int>(),
|
||
ValueType aValueType = Normal);
|
||
|
||
explicit Matrix(const Matrix::MatrixSlice& slice);
|
||
|
||
/**
|
||
* Create from a Raw data(double array).
|
||
* Use Raw data which create like new double[size]() as a data source
|
||
* and the share_ptr's deleter will be std::default_delete<double[]>
|
||
* @param data
|
||
* @param rows
|
||
* @param cols
|
||
* @param slices
|
||
* @param type
|
||
* @return
|
||
*/
|
||
static Matrix fromRawData(double *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
|
||
|
||
/**
|
||
* Create from a Raw data(double array) with copy the data to a new mkl memory.
|
||
* @param data
|
||
* @param rows
|
||
* @param cols
|
||
* @param slices
|
||
* @param type
|
||
* @return
|
||
*/
|
||
static Matrix copyFromRawData(double *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
|
||
|
||
/**
|
||
* New a mkl calculate based Matrix
|
||
* @attention Using New function, must use Aurora:malloc to get memory
|
||
* @param data
|
||
* @param rows
|
||
* @param cols
|
||
* @param slices
|
||
* @param type
|
||
* @return
|
||
*/
|
||
static Matrix New(double *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
|
||
|
||
/**
|
||
* New a mkl calculate based Matrix
|
||
* @attention Using New function, must use Aurora:malloc to get memory
|
||
* @param data
|
||
* @param shapeMatrix
|
||
* @return
|
||
*/
|
||
static Matrix New(double *data, const Matrix &shapeMatrix);
|
||
|
||
/**
|
||
* New a mkl calculate based Matrix
|
||
* @attention Memory are allocate by Aurora:malloc function
|
||
* @param shapeMatrix
|
||
* @return
|
||
*/
|
||
static Matrix New(const Matrix &shapeMatrix);
|
||
|
||
/**
|
||
* 深拷贝操作
|
||
* @return 深拷贝的Matrix对象
|
||
*/
|
||
Matrix deepCopy() const;
|
||
|
||
//切片,暂时不支持三维
|
||
MatrixSlice operator()(int r, int c = $, int aSliceIdx = $) const;
|
||
|
||
// add
|
||
Matrix operator+(double aScalar) const;
|
||
friend Matrix operator+(double aScalar, const Matrix &matrix);
|
||
friend Matrix& operator+(double aScalar, Matrix &&matrix);
|
||
friend Matrix& operator+(Matrix &&matrix,double aScalar);
|
||
Matrix operator+(const Matrix &matrix) const;
|
||
Matrix operator+(Matrix &&matrix) const;
|
||
friend Matrix operator+(Matrix &&aMatrix,Matrix &aOther);
|
||
|
||
// sub
|
||
Matrix operator-(double aScalar) const;
|
||
friend Matrix operator-(double aScalar, const Matrix &matrix);
|
||
friend Matrix& operator-(double aScalar, Matrix &&matrix);
|
||
friend Matrix& operator-(Matrix &&matrix,double aScalar);
|
||
Matrix operator-(const Matrix &matrix) const;
|
||
Matrix operator-(Matrix &&matrix) const;
|
||
friend Matrix operator-(Matrix &&aMatrix,Matrix &aOther);
|
||
|
||
//negetive
|
||
friend Matrix operator-(Matrix &&aMatrix);
|
||
friend Matrix operator-(const Matrix &aMatrix);
|
||
|
||
// mul
|
||
Matrix operator*(double aScalar) const;
|
||
friend Matrix operator*(double aScalar, const Matrix &matrix);
|
||
friend Matrix& operator*(double aScalar, Matrix &&matrix);
|
||
friend Matrix& operator*(Matrix &&matrix,double aScalar);
|
||
Matrix operator*(const Matrix &matrix) const;
|
||
Matrix operator*(Matrix &&matrix) const;
|
||
friend Matrix operator*(Matrix &&aMatrix,Matrix &aOther);
|
||
|
||
// div
|
||
Matrix operator/(double aScalar) const;
|
||
friend Matrix operator/(double aScalar, const Matrix &matrix);
|
||
friend Matrix& operator/(double aScalar, Matrix &&matrix);
|
||
friend Matrix& operator/(Matrix &&matrix,double aScalar);
|
||
Matrix operator/(const Matrix &matrix) const;
|
||
Matrix operator/(Matrix &&matrix) const;
|
||
friend Matrix operator/(Matrix &&aMatrix, Matrix &aOther);
|
||
|
||
// pow
|
||
Matrix operator^(int times) const;
|
||
friend Matrix operator^(Matrix &&matrix,int times);
|
||
|
||
//compare
|
||
Matrix operator>(double aScalar) const;
|
||
friend Matrix operator>(double aScalar, const Matrix &matrix);
|
||
Matrix operator>(const Matrix &matrix) const;
|
||
|
||
Matrix operator<(double aScalar) const;
|
||
friend Matrix operator<(double aScalar, const Matrix &matrix);
|
||
Matrix operator<(const Matrix &matrix) const;
|
||
|
||
Matrix operator>=(double aScalar) const;
|
||
friend Matrix operator>=(double aScalar, const Matrix &matrix);
|
||
Matrix operator>=(const Matrix &matrix) const;
|
||
|
||
Matrix operator<=(double aScalar) const;
|
||
friend Matrix operator<=(double aScalar, const Matrix &matrix);
|
||
Matrix operator<=(const Matrix &matrix) const;
|
||
|
||
Matrix operator==(double aScalar) const;
|
||
friend Matrix operator==(double aScalar, const Matrix &matrix);
|
||
Matrix operator==(const Matrix &matrix) const;
|
||
|
||
// sub
|
||
double operator[](size_t index) const;
|
||
|
||
|
||
/**
|
||
* 矩阵乘法
|
||
* @attention 目前只支持矩阵乘向量
|
||
* @param aOther
|
||
* @return Matrix
|
||
*/
|
||
Matrix mul(const Matrix& aOther) const;
|
||
|
||
/**
|
||
* 矩阵乘法
|
||
* @attention 目前只支持矩阵乘向量
|
||
* @param aOther
|
||
* @return Matrix
|
||
*/
|
||
Matrix mul(Matrix&& aOther) const;
|
||
|
||
/**
|
||
* print matrix , only support 2d matrix now
|
||
*/
|
||
void printf();
|
||
|
||
void printfShape();
|
||
|
||
bool isScalar() const;
|
||
|
||
double getScalar() const;
|
||
|
||
/**
|
||
* Get is the Matrix's data is empty or size is zero.
|
||
* @return
|
||
*/
|
||
bool isNull() const;
|
||
|
||
bool isNan() const;
|
||
|
||
bool isVector() const;
|
||
|
||
bool isMKLAllocated() const;
|
||
|
||
/**
|
||
* Get dimension count of the matrix
|
||
* @return dimension count
|
||
*/
|
||
int getDims() const;
|
||
|
||
double *getData() const;
|
||
|
||
int getDimSize(int index) const;
|
||
|
||
/**
|
||
* Compare matrix shape
|
||
* @param other matrix
|
||
* @return is identity
|
||
*/
|
||
bool compareShape(const Matrix& other) const;
|
||
|
||
/**
|
||
* Get the data size
|
||
* @return the unit count of the matrix
|
||
*/
|
||
size_t getDataSize() const;
|
||
|
||
/**
|
||
* Get Value type as normal and complex,
|
||
* complex use std::complex<double>,
|
||
* it's contains two double value.
|
||
* @return
|
||
*/
|
||
ValueType getValueType() const {
|
||
return mValueType;
|
||
}
|
||
|
||
/**
|
||
* Set Value type as normal and complex,
|
||
* @return
|
||
*/
|
||
void setValueType(ValueType aValueType) {
|
||
mValueType = aValueType;
|
||
}
|
||
|
||
/**
|
||
* return true if the valueType is complex,
|
||
* @return
|
||
*/
|
||
bool isComplex() const {
|
||
if(mValueType == Complex)
|
||
{
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
void forceReshape(int rows, int columns, int slices);
|
||
|
||
private:
|
||
ValueType mValueType = Normal;
|
||
std::shared_ptr<double> mData;
|
||
std::vector<int> mInfo;
|
||
bool mMKL_Allocated = false;
|
||
};
|
||
}
|
||
|
||
#endif
|