Files
Aurora/src/Matrix.h
2023-04-20 11:21:11 +08:00

163 lines
5.0 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef MATRIX_H
#define MATRIX_H
#include <memory>
#include <complex>
#include <vector>
namespace Aurora {
enum ValueType{
Normal=1,
Complex
};
const int $ = -1;
class Matrix {
public:
/**
* 内部类MatrixSlice用于切片操作
*/
class MatrixSlice{
public:
MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0);
MatrixSlice& operator=(const MatrixSlice& slice);
MatrixSlice& operator=(const Matrix& matrix);
MatrixSlice& operator=(double value);
MatrixSlice& operator=(std::complex<double> value);
Matrix toMatrix() const;
private:
int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix
double* mData;
int mSize=0;
int mSize2=0;
int mStride=1;
int mStride2=0;
ValueType mType;
friend class Matrix;
};
explicit Matrix(std::shared_ptr<double> aData = std::shared_ptr<double>(),
std::vector<int> aInfo = std::vector<int>());
explicit Matrix(const Matrix::MatrixSlice& slice);
static Matrix New(double *data, int rows, int cols = 0, int slices = 0, ValueType type = Normal);
static Matrix New(double *data, const Matrix &shapeMatrix);
static Matrix New(const Matrix &shapeMatrix);
Matrix getDataFromDims2(int aColumn);
Matrix getDataFromDims1(int aRow);
/**
* 深拷贝操作
* @return 深拷贝的Matrix对象
*/
Matrix deepCopy() const;
//切片,暂时不支持三维
MatrixSlice operator()(int r, int c = $, int aSliceIdx = $) const;
// add
Matrix operator+(double aScalar) const;
friend Matrix operator+(double aScalar, const Matrix &matrix);
friend Matrix& operator+(double aScalar, Matrix &&matrix);
friend Matrix& operator+(Matrix &&matrix,double aScalar);
Matrix operator+(const Matrix &matrix) const;
Matrix operator+(Matrix &&matrix) const;
friend Matrix operator+(Matrix &&aMatrix,const Matrix &aOther);
// sub
Matrix operator-(double aScalar) const;
friend Matrix operator-(double aScalar, const Matrix &matrix);
friend Matrix& operator-(double aScalar, Matrix &&matrix);
friend Matrix& operator-(Matrix &&matrix,double aScalar);
Matrix operator-(const Matrix &matrix) const;
Matrix operator-(Matrix &&matrix) const;
friend Matrix operator-(Matrix &&aMatrix,const Matrix &aOther);
// mul
Matrix operator*(double aScalar) const;
friend Matrix operator*(double aScalar, const Matrix &matrix);
friend Matrix& operator*(double aScalar, Matrix &&matrix);
friend Matrix& operator*(Matrix &&matrix,double aScalar);
Matrix operator*(const Matrix &matrix) const;
Matrix operator*(Matrix &&matrix) const;
friend Matrix operator*(Matrix &&aMatrix,const Matrix &aOther);
// div
Matrix operator/(double aScalar) const;
friend Matrix operator/(double aScalar, const Matrix &matrix);
friend Matrix& operator/(double aScalar, Matrix &&matrix);
friend Matrix& operator/(Matrix &&matrix,double aScalar);
Matrix operator/(const Matrix &matrix) const;
Matrix operator/(Matrix &&matrix) const;
friend Matrix operator/(Matrix &&aMatrix,const Matrix &aOther);
// pow
Matrix operator^(int times) const;
friend Matrix operator^(Matrix &&matrix,int times);
/**
* print matrix , only support 2d matrix now
*/
void printf();
/**
* Get is the Matrix's data is empty or size is zero.
* @return
*/
bool isNull() const;
/**
* Get dimension count of the matrix
* @return dimension count
*/
int getDims() const;
double *getData() const;
int getDimSize(int index) const;
/**
* Compare matrix shape
* @param other matrix
* @return is identity
*/
bool compareShape(const Matrix& other) const;
/**
* Get the data size
* @return the unit count of the matrix
*/
size_t getDataSize() const;
/**
* Get Value type as normal and complex,
* complex use std::complex<double>,
* it's contains two double value.
* @return
*/
ValueType getValueType() const {
return mValueType;
}
/**
* Set Value type as normal and complex,
* @return
*/
void setValueType(ValueType aValueType) {
mValueType = aValueType;
}
private:
ValueType mValueType = Normal;
std::shared_ptr<double> mData;
std::vector<int> mInfo;
};
}
#endif