163 lines
5.0 KiB
C++
163 lines
5.0 KiB
C++
#ifndef MATRIX_H
|
||
#define MATRIX_H
|
||
|
||
#include <memory>
|
||
#include <complex>
|
||
#include <vector>
|
||
|
||
|
||
namespace Aurora {
|
||
enum ValueType{
|
||
Normal=1,
|
||
Complex
|
||
};
|
||
const int $ = -1;
|
||
|
||
class Matrix {
|
||
public:
|
||
/**
|
||
* 内部类MatrixSlice,用于切片操作
|
||
*/
|
||
class MatrixSlice{
|
||
public:
|
||
MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0);
|
||
MatrixSlice& operator=(const MatrixSlice& slice);
|
||
MatrixSlice& operator=(const Matrix& matrix);
|
||
MatrixSlice& operator=(double value);
|
||
MatrixSlice& operator=(std::complex<double> value);
|
||
Matrix toMatrix() const;
|
||
private:
|
||
int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix
|
||
double* mData;
|
||
int mSize=0;
|
||
int mSize2=0;
|
||
int mStride=1;
|
||
int mStride2=0;
|
||
ValueType mType;
|
||
friend class Matrix;
|
||
};
|
||
explicit Matrix(std::shared_ptr<double> aData = std::shared_ptr<double>(),
|
||
std::vector<int> aInfo = std::vector<int>());
|
||
|
||
explicit Matrix(const Matrix::MatrixSlice& slice);
|
||
|
||
static Matrix New(double *data, int rows, int cols = 0, int slices = 0, ValueType type = Normal);
|
||
|
||
static Matrix New(double *data, const Matrix &shapeMatrix);
|
||
|
||
static Matrix New(const Matrix &shapeMatrix);
|
||
|
||
Matrix getDataFromDims2(int aColumn);
|
||
|
||
Matrix getDataFromDims1(int aRow);
|
||
|
||
/**
|
||
* 深拷贝操作
|
||
* @return 深拷贝的Matrix对象
|
||
*/
|
||
Matrix deepCopy() const;
|
||
|
||
//切片,暂时不支持三维
|
||
MatrixSlice operator()(int r, int c = $, int aSliceIdx = $) const;
|
||
|
||
// add
|
||
Matrix operator+(double aScalar) const;
|
||
friend Matrix operator+(double aScalar, const Matrix &matrix);
|
||
friend Matrix& operator+(double aScalar, Matrix &&matrix);
|
||
friend Matrix& operator+(Matrix &&matrix,double aScalar);
|
||
Matrix operator+(const Matrix &matrix) const;
|
||
Matrix operator+(Matrix &&matrix) const;
|
||
friend Matrix operator+(Matrix &&aMatrix,const Matrix &aOther);
|
||
|
||
// sub
|
||
Matrix operator-(double aScalar) const;
|
||
friend Matrix operator-(double aScalar, const Matrix &matrix);
|
||
friend Matrix& operator-(double aScalar, Matrix &&matrix);
|
||
friend Matrix& operator-(Matrix &&matrix,double aScalar);
|
||
Matrix operator-(const Matrix &matrix) const;
|
||
Matrix operator-(Matrix &&matrix) const;
|
||
friend Matrix operator-(Matrix &&aMatrix,const Matrix &aOther);
|
||
|
||
// mul
|
||
Matrix operator*(double aScalar) const;
|
||
friend Matrix operator*(double aScalar, const Matrix &matrix);
|
||
friend Matrix& operator*(double aScalar, Matrix &&matrix);
|
||
friend Matrix& operator*(Matrix &&matrix,double aScalar);
|
||
Matrix operator*(const Matrix &matrix) const;
|
||
Matrix operator*(Matrix &&matrix) const;
|
||
friend Matrix operator*(Matrix &&aMatrix,const Matrix &aOther);
|
||
|
||
// div
|
||
Matrix operator/(double aScalar) const;
|
||
friend Matrix operator/(double aScalar, const Matrix &matrix);
|
||
friend Matrix& operator/(double aScalar, Matrix &&matrix);
|
||
friend Matrix& operator/(Matrix &&matrix,double aScalar);
|
||
Matrix operator/(const Matrix &matrix) const;
|
||
Matrix operator/(Matrix &&matrix) const;
|
||
friend Matrix operator/(Matrix &&aMatrix,const Matrix &aOther);
|
||
|
||
// pow
|
||
Matrix operator^(int times) const;
|
||
friend Matrix operator^(Matrix &&matrix,int times);
|
||
|
||
/**
|
||
* print matrix , only support 2d matrix now
|
||
*/
|
||
void printf();
|
||
|
||
/**
|
||
* Get is the Matrix's data is empty or size is zero.
|
||
* @return
|
||
*/
|
||
bool isNull() const;
|
||
|
||
/**
|
||
* Get dimension count of the matrix
|
||
* @return dimension count
|
||
*/
|
||
int getDims() const;
|
||
|
||
double *getData() const;
|
||
|
||
int getDimSize(int index) const;
|
||
|
||
/**
|
||
* Compare matrix shape
|
||
* @param other matrix
|
||
* @return is identity
|
||
*/
|
||
bool compareShape(const Matrix& other) const;
|
||
|
||
/**
|
||
* Get the data size
|
||
* @return the unit count of the matrix
|
||
*/
|
||
size_t getDataSize() const;
|
||
|
||
/**
|
||
* Get Value type as normal and complex,
|
||
* complex use std::complex<double>,
|
||
* it's contains two double value.
|
||
* @return
|
||
*/
|
||
ValueType getValueType() const {
|
||
return mValueType;
|
||
}
|
||
|
||
/**
|
||
* Set Value type as normal and complex,
|
||
* @return
|
||
*/
|
||
void setValueType(ValueType aValueType) {
|
||
mValueType = aValueType;
|
||
}
|
||
|
||
private:
|
||
ValueType mValueType = Normal;
|
||
std::shared_ptr<double> mData;
|
||
std::vector<int> mInfo;
|
||
};
|
||
}
|
||
|
||
#endif
|