2023-04-18 13:31:14 +08:00
|
|
|
|
#ifndef MATRIX_H
|
|
|
|
|
|
#define MATRIX_H
|
|
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
2023-04-20 11:21:11 +08:00
|
|
|
|
#include <complex>
|
2023-04-18 13:31:14 +08:00
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
2023-04-19 15:54:52 +08:00
|
|
|
|
|
2023-04-18 13:31:14 +08:00
|
|
|
|
namespace Aurora {
|
|
|
|
|
|
enum ValueType{
|
|
|
|
|
|
Normal=1,
|
|
|
|
|
|
Complex
|
|
|
|
|
|
};
|
2023-04-19 15:54:52 +08:00
|
|
|
|
const int $ = -1;
|
2023-04-18 13:31:14 +08:00
|
|
|
|
|
|
|
|
|
|
class Matrix {
|
|
|
|
|
|
public:
|
|
|
|
|
|
/**
|
|
|
|
|
|
* 内部类MatrixSlice,用于切片操作
|
|
|
|
|
|
*/
|
|
|
|
|
|
class MatrixSlice{
|
|
|
|
|
|
public:
|
2023-04-19 15:54:52 +08:00
|
|
|
|
MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0);
|
2023-04-18 13:31:14 +08:00
|
|
|
|
MatrixSlice& operator=(const MatrixSlice& slice);
|
|
|
|
|
|
MatrixSlice& operator=(const Matrix& matrix);
|
2023-04-20 11:21:11 +08:00
|
|
|
|
MatrixSlice& operator=(double value);
|
|
|
|
|
|
MatrixSlice& operator=(std::complex<double> value);
|
2023-04-20 09:32:58 +08:00
|
|
|
|
Matrix toMatrix() const;
|
2023-04-18 13:31:14 +08:00
|
|
|
|
private:
|
2023-04-19 15:54:52 +08:00
|
|
|
|
int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix
|
2023-04-18 13:31:14 +08:00
|
|
|
|
double* mData;
|
2023-04-20 09:32:58 +08:00
|
|
|
|
int mSize=0;
|
|
|
|
|
|
int mSize2=0;
|
|
|
|
|
|
int mStride=1;
|
|
|
|
|
|
int mStride2=0;
|
2023-04-18 13:31:14 +08:00
|
|
|
|
ValueType mType;
|
2023-04-20 09:32:58 +08:00
|
|
|
|
friend class Matrix;
|
2023-04-18 13:31:14 +08:00
|
|
|
|
};
|
2023-04-20 09:32:58 +08:00
|
|
|
|
explicit Matrix(std::shared_ptr<double> aData = std::shared_ptr<double>(),
|
|
|
|
|
|
std::vector<int> aInfo = std::vector<int>());
|
|
|
|
|
|
|
|
|
|
|
|
explicit Matrix(const Matrix::MatrixSlice& slice);
|
|
|
|
|
|
|
|
|
|
|
|
static Matrix New(double *data, int rows, int cols = 0, int slices = 0, ValueType type = Normal);
|
|
|
|
|
|
|
|
|
|
|
|
static Matrix New(double *data, const Matrix &shapeMatrix);
|
|
|
|
|
|
|
|
|
|
|
|
static Matrix New(const Matrix &shapeMatrix);
|
2023-04-18 13:31:14 +08:00
|
|
|
|
|
|
|
|
|
|
Matrix getDataFromDims2(int aColumn);
|
|
|
|
|
|
|
|
|
|
|
|
Matrix getDataFromDims1(int aRow);
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
* 深拷贝操作
|
|
|
|
|
|
* @return 深拷贝的Matrix对象
|
|
|
|
|
|
*/
|
|
|
|
|
|
Matrix deepCopy() const;
|
|
|
|
|
|
|
|
|
|
|
|
//切片,暂时不支持三维
|
2023-04-19 15:54:52 +08:00
|
|
|
|
MatrixSlice operator()(int r, int c = $, int aSliceIdx = $) const;
|
2023-04-18 13:31:14 +08:00
|
|
|
|
|
|
|
|
|
|
// add
|
|
|
|
|
|
Matrix operator+(double aScalar) const;
|
|
|
|
|
|
friend Matrix operator+(double aScalar, const Matrix &matrix);
|
|
|
|
|
|
friend Matrix& operator+(double aScalar, Matrix &&matrix);
|
|
|
|
|
|
friend Matrix& operator+(Matrix &&matrix,double aScalar);
|
|
|
|
|
|
Matrix operator+(const Matrix &matrix) const;
|
|
|
|
|
|
Matrix operator+(Matrix &&matrix) const;
|
|
|
|
|
|
friend Matrix operator+(Matrix &&aMatrix,const Matrix &aOther);
|
|
|
|
|
|
|
|
|
|
|
|
// sub
|
|
|
|
|
|
Matrix operator-(double aScalar) const;
|
|
|
|
|
|
friend Matrix operator-(double aScalar, const Matrix &matrix);
|
|
|
|
|
|
friend Matrix& operator-(double aScalar, Matrix &&matrix);
|
|
|
|
|
|
friend Matrix& operator-(Matrix &&matrix,double aScalar);
|
|
|
|
|
|
Matrix operator-(const Matrix &matrix) const;
|
|
|
|
|
|
Matrix operator-(Matrix &&matrix) const;
|
|
|
|
|
|
friend Matrix operator-(Matrix &&aMatrix,const Matrix &aOther);
|
|
|
|
|
|
|
|
|
|
|
|
// mul
|
|
|
|
|
|
Matrix operator*(double aScalar) const;
|
|
|
|
|
|
friend Matrix operator*(double aScalar, const Matrix &matrix);
|
|
|
|
|
|
friend Matrix& operator*(double aScalar, Matrix &&matrix);
|
|
|
|
|
|
friend Matrix& operator*(Matrix &&matrix,double aScalar);
|
|
|
|
|
|
Matrix operator*(const Matrix &matrix) const;
|
|
|
|
|
|
Matrix operator*(Matrix &&matrix) const;
|
|
|
|
|
|
friend Matrix operator*(Matrix &&aMatrix,const Matrix &aOther);
|
|
|
|
|
|
|
|
|
|
|
|
// div
|
|
|
|
|
|
Matrix operator/(double aScalar) const;
|
|
|
|
|
|
friend Matrix operator/(double aScalar, const Matrix &matrix);
|
|
|
|
|
|
friend Matrix& operator/(double aScalar, Matrix &&matrix);
|
|
|
|
|
|
friend Matrix& operator/(Matrix &&matrix,double aScalar);
|
|
|
|
|
|
Matrix operator/(const Matrix &matrix) const;
|
|
|
|
|
|
Matrix operator/(Matrix &&matrix) const;
|
|
|
|
|
|
friend Matrix operator/(Matrix &&aMatrix,const Matrix &aOther);
|
|
|
|
|
|
|
2023-04-19 11:30:05 +08:00
|
|
|
|
// pow
|
|
|
|
|
|
Matrix operator^(int times) const;
|
|
|
|
|
|
friend Matrix operator^(Matrix &&matrix,int times);
|
|
|
|
|
|
|
2023-04-18 13:31:14 +08:00
|
|
|
|
/**
|
|
|
|
|
|
* print matrix , only support 2d matrix now
|
|
|
|
|
|
*/
|
|
|
|
|
|
void printf();
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
* Get is the Matrix's data is empty or size is zero.
|
|
|
|
|
|
* @return
|
|
|
|
|
|
*/
|
|
|
|
|
|
bool isNull() const;
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
* Get dimension count of the matrix
|
|
|
|
|
|
* @return dimension count
|
|
|
|
|
|
*/
|
|
|
|
|
|
int getDims() const;
|
|
|
|
|
|
|
|
|
|
|
|
double *getData() const;
|
|
|
|
|
|
|
|
|
|
|
|
int getDimSize(int index) const;
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
* Compare matrix shape
|
|
|
|
|
|
* @param other matrix
|
|
|
|
|
|
* @return is identity
|
|
|
|
|
|
*/
|
|
|
|
|
|
bool compareShape(const Matrix& other) const;
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
* Get the data size
|
|
|
|
|
|
* @return the unit count of the matrix
|
|
|
|
|
|
*/
|
|
|
|
|
|
size_t getDataSize() const;
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
* Get Value type as normal and complex,
|
|
|
|
|
|
* complex use std::complex<double>,
|
|
|
|
|
|
* it's contains two double value.
|
|
|
|
|
|
* @return
|
|
|
|
|
|
*/
|
|
|
|
|
|
ValueType getValueType() const {
|
|
|
|
|
|
return mValueType;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
* Set Value type as normal and complex,
|
|
|
|
|
|
* @return
|
|
|
|
|
|
*/
|
|
|
|
|
|
void setValueType(ValueType aValueType) {
|
|
|
|
|
|
mValueType = aValueType;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
ValueType mValueType = Normal;
|
|
|
|
|
|
std::shared_ptr<double> mData;
|
|
|
|
|
|
std::vector<int> mInfo;
|
|
|
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#endif
|