Files
Aurora/src/Matrix.h
2023-05-06 14:16:17 +08:00

251 lines
7.6 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef MATRIX_H
#define MATRIX_H
#include <memory>
#include <complex>
#include <vector>
namespace Aurora {
enum ValueType{
Normal=1,
Complex
};
const int $ = -1;
class Matrix {
public:
/**
* 内部类MatrixSlice用于切片操作
*/
class MatrixSlice{
public:
MatrixSlice(int aSize,int aStride, double* aData,ValueType aType = Normal,int SliceMode = 1,int aSize2 = 0, int aStride2 = 0);
MatrixSlice& operator=(const MatrixSlice& slice);
MatrixSlice& operator=(const Matrix& matrix);
MatrixSlice& operator=(double value);
MatrixSlice& operator=(std::complex<double> value);
Matrix toMatrix() const;
private:
int mSliceMode = 0;//0 scalar, 1 vector, 2 Matrix
double* mData;
int mSize=0;
int mSize2=0;
int mStride=1;
int mStride2=0;
ValueType mType;
friend class Matrix;
};
explicit Matrix(std::shared_ptr<double> aData = std::shared_ptr<double>(),
std::vector<int> aInfo = std::vector<int>(),
ValueType aValueType = Normal);
explicit Matrix(const Matrix::MatrixSlice& slice);
/**
* Create from a Raw data(double array).
* Use Raw data which create like new double[size]() as a data source
* and the share_ptr's deleter will be std::default_delete<double[]>
* @param data
* @param rows
* @param cols
* @param slices
* @param type
* @return
*/
static Matrix fromRawData(double *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
/**
* Create from a Raw data(double array) with copy the data to a new mkl memory.
* @param data
* @param rows
* @param cols
* @param slices
* @param type
* @return
*/
static Matrix copyFromRawData(double *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
/**
* New a mkl calculate based Matrix
* @attention Using New function, must use Aurora:malloc to get memory
* @param data
* @param rows
* @param cols
* @param slices
* @param type
* @return
*/
static Matrix New(double *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
/**
* New a mkl calculate based Matrix
* @attention Using New function, must use Aurora:malloc to get memory
* @param data
* @param shapeMatrix
* @return
*/
static Matrix New(double *data, const Matrix &shapeMatrix);
/**
* New a mkl calculate based Matrix
* @attention Memory are allocate by Aurora:malloc function
* @param shapeMatrix
* @return
*/
static Matrix New(const Matrix &shapeMatrix);
/**
* 深拷贝操作
* @return 深拷贝的Matrix对象
*/
Matrix deepCopy() const;
//切片,暂时不支持三维
MatrixSlice operator()(int r, int c = $, int aSliceIdx = $) const;
// add
Matrix operator+(double aScalar) const;
friend Matrix operator+(double aScalar, const Matrix &matrix);
friend Matrix& operator+(double aScalar, Matrix &&matrix);
friend Matrix& operator+(Matrix &&matrix,double aScalar);
Matrix operator+(const Matrix &matrix) const;
Matrix operator+(Matrix &&matrix) const;
friend Matrix operator+(Matrix &&aMatrix,Matrix &aOther);
// sub
Matrix operator-(double aScalar) const;
friend Matrix operator-(double aScalar, const Matrix &matrix);
friend Matrix& operator-(double aScalar, Matrix &&matrix);
friend Matrix& operator-(Matrix &&matrix,double aScalar);
Matrix operator-(const Matrix &matrix) const;
Matrix operator-(Matrix &&matrix) const;
friend Matrix operator-(Matrix &&aMatrix,Matrix &aOther);
//negetive
friend Matrix operator-(Matrix &&aMatrix);
friend Matrix operator-(const Matrix &aMatrix);
// mul
Matrix operator*(double aScalar) const;
friend Matrix operator*(double aScalar, const Matrix &matrix);
friend Matrix& operator*(double aScalar, Matrix &&matrix);
friend Matrix& operator*(Matrix &&matrix,double aScalar);
Matrix operator*(const Matrix &matrix) const;
Matrix operator*(Matrix &&matrix) const;
friend Matrix operator*(Matrix &&aMatrix,Matrix &aOther);
// div
Matrix operator/(double aScalar) const;
friend Matrix operator/(double aScalar, const Matrix &matrix);
friend Matrix& operator/(double aScalar, Matrix &&matrix);
friend Matrix& operator/(Matrix &&matrix,double aScalar);
Matrix operator/(const Matrix &matrix) const;
Matrix operator/(Matrix &&matrix) const;
friend Matrix operator/(Matrix &&aMatrix, Matrix &aOther);
// pow
Matrix operator^(int times) const;
friend Matrix operator^(Matrix &&matrix,int times);
//compare
Matrix operator>(double aScalar) const;
friend Matrix operator>(double aScalar, const Matrix &matrix);
Matrix operator>(const Matrix &matrix) const;
Matrix operator<(double aScalar) const;
friend Matrix operator<(double aScalar, const Matrix &matrix);
Matrix operator<(const Matrix &matrix) const;
Matrix operator==(double aScalar) const;
friend Matrix operator==(double aScalar, const Matrix &matrix);
Matrix operator==(const Matrix &matrix) const;
/**
* print matrix , only support 2d matrix now
*/
void printf();
void printfShape();
bool isScalar() const;
double getScalar() const;
/**
* Get is the Matrix's data is empty or size is zero.
* @return
*/
bool isNull() const;
bool isNan() const;
bool isVector() const;
bool isMKLAllocated() const;
/**
* Get dimension count of the matrix
* @return dimension count
*/
int getDims() const;
double *getData() const;
int getDimSize(int index) const;
/**
* Compare matrix shape
* @param other matrix
* @return is identity
*/
bool compareShape(const Matrix& other) const;
/**
* Get the data size
* @return the unit count of the matrix
*/
size_t getDataSize() const;
/**
* Get Value type as normal and complex,
* complex use std::complex<double>,
* it's contains two double value.
* @return
*/
ValueType getValueType() const {
return mValueType;
}
/**
* Set Value type as normal and complex,
* @return
*/
void setValueType(ValueType aValueType) {
mValueType = aValueType;
}
/**
* return true if the valueType is complex,
* @return
*/
bool isComplex() const {
if(mValueType == Complex)
{
return true;
}
return false;
}
void forceReshape(int rows, int columns, int slices);
private:
ValueType mValueType = Normal;
std::shared_ptr<double> mData;
std::vector<int> mInfo;
bool mMKL_Allocated = false;
};
}
#endif