Files
Aurora/src/Matrix.h

163 lines
5.0 KiB
C
Raw Normal View History

2023-04-18 13:31:14 +08:00
#ifndef MATRIX_H
#define MATRIX_H
#include <memory>
2023-04-20 11:21:11 +08:00
#include <complex>
2023-04-18 13:31:14 +08:00
#include <vector>
2023-04-19 15:54:52 +08:00
2023-04-18 13:31:14 +08:00
namespace Aurora {
enum ValueType{
Normal=1,
Complex
};
2023-04-19 15:54:52 +08:00
const int $ = -1;
2023-04-18 13:31:14 +08:00
class Matrix {
public:
/**
* MatrixSlice
*/
class MatrixSlice{
public:
2023-04-19 15:54:52 +08:00
MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0);
2023-04-18 13:31:14 +08:00
MatrixSlice& operator=(const MatrixSlice& slice);
MatrixSlice& operator=(const Matrix& matrix);
2023-04-20 11:21:11 +08:00
MatrixSlice& operator=(double value);
MatrixSlice& operator=(std::complex<double> value);
2023-04-20 09:32:58 +08:00
Matrix toMatrix() const;
2023-04-18 13:31:14 +08:00
private:
2023-04-19 15:54:52 +08:00
int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix
2023-04-18 13:31:14 +08:00
double* mData;
2023-04-20 09:32:58 +08:00
int mSize=0;
int mSize2=0;
int mStride=1;
int mStride2=0;
2023-04-18 13:31:14 +08:00
ValueType mType;
2023-04-20 09:32:58 +08:00
friend class Matrix;
2023-04-18 13:31:14 +08:00
};
2023-04-20 09:32:58 +08:00
explicit Matrix(std::shared_ptr<double> aData = std::shared_ptr<double>(),
std::vector<int> aInfo = std::vector<int>());
explicit Matrix(const Matrix::MatrixSlice& slice);
static Matrix New(double *data, int rows, int cols = 0, int slices = 0, ValueType type = Normal);
static Matrix New(double *data, const Matrix &shapeMatrix);
static Matrix New(const Matrix &shapeMatrix);
2023-04-18 13:31:14 +08:00
Matrix getDataFromDims2(int aColumn);
Matrix getDataFromDims1(int aRow);
/**
*
* @return Matrix对象
*/
Matrix deepCopy() const;
//切片,暂时不支持三维
2023-04-19 15:54:52 +08:00
MatrixSlice operator()(int r, int c = $, int aSliceIdx = $) const;
2023-04-18 13:31:14 +08:00
// add
Matrix operator+(double aScalar) const;
friend Matrix operator+(double aScalar, const Matrix &matrix);
friend Matrix& operator+(double aScalar, Matrix &&matrix);
friend Matrix& operator+(Matrix &&matrix,double aScalar);
Matrix operator+(const Matrix &matrix) const;
Matrix operator+(Matrix &&matrix) const;
friend Matrix operator+(Matrix &&aMatrix,const Matrix &aOther);
// sub
Matrix operator-(double aScalar) const;
friend Matrix operator-(double aScalar, const Matrix &matrix);
friend Matrix& operator-(double aScalar, Matrix &&matrix);
friend Matrix& operator-(Matrix &&matrix,double aScalar);
Matrix operator-(const Matrix &matrix) const;
Matrix operator-(Matrix &&matrix) const;
friend Matrix operator-(Matrix &&aMatrix,const Matrix &aOther);
// mul
Matrix operator*(double aScalar) const;
friend Matrix operator*(double aScalar, const Matrix &matrix);
friend Matrix& operator*(double aScalar, Matrix &&matrix);
friend Matrix& operator*(Matrix &&matrix,double aScalar);
Matrix operator*(const Matrix &matrix) const;
Matrix operator*(Matrix &&matrix) const;
friend Matrix operator*(Matrix &&aMatrix,const Matrix &aOther);
// div
Matrix operator/(double aScalar) const;
friend Matrix operator/(double aScalar, const Matrix &matrix);
friend Matrix& operator/(double aScalar, Matrix &&matrix);
friend Matrix& operator/(Matrix &&matrix,double aScalar);
Matrix operator/(const Matrix &matrix) const;
Matrix operator/(Matrix &&matrix) const;
friend Matrix operator/(Matrix &&aMatrix,const Matrix &aOther);
2023-04-19 11:30:05 +08:00
// pow
Matrix operator^(int times) const;
friend Matrix operator^(Matrix &&matrix,int times);
2023-04-18 13:31:14 +08:00
/**
* print matrix , only support 2d matrix now
*/
void printf();
/**
* Get is the Matrix's data is empty or size is zero.
* @return
*/
bool isNull() const;
/**
* Get dimension count of the matrix
* @return dimension count
*/
int getDims() const;
double *getData() const;
int getDimSize(int index) const;
/**
* Compare matrix shape
* @param other matrix
* @return is identity
*/
bool compareShape(const Matrix& other) const;
/**
* Get the data size
* @return the unit count of the matrix
*/
size_t getDataSize() const;
/**
* Get Value type as normal and complex,
* complex use std::complex<double>,
* it's contains two double value.
* @return
*/
ValueType getValueType() const {
return mValueType;
}
/**
* Set Value type as normal and complex,
* @return
*/
void setValueType(ValueType aValueType) {
mValueType = aValueType;
}
private:
ValueType mValueType = Normal;
std::shared_ptr<double> mData;
std::vector<int> mInfo;
};
}
#endif