Files
Aurora/src/CudaMatrix.h
2023-10-30 10:28:24 +08:00

240 lines
7.6 KiB
C++

#ifndef CUDAMATRIX_H
#define CUDAMATRIX_H
#include "Matrix.h"
namespace Aurora
{
class CudaMatrix
{
public:
explicit CudaMatrix(std::shared_ptr<float> aData = std::shared_ptr<float>(),
std::vector<int> aInfo = std::vector<int>(),
ValueType aValueType = Normal);
/**
* Create from a Raw data(float array).
* Use Raw data which create like new float[size]() as a data source
* and the share_ptr's deleter will be std::default_delete<float[]>
* @param data
* @param rows
* @param cols
* @param slices
* @param type
* @return
*/
static CudaMatrix fromRawData(float *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
/**
* Create from a Raw data(float array) with copy the data to a new mkl memory.
* @param data
* @param rows
* @param cols
* @param slices
* @param type
* @return
*/
static CudaMatrix copyFromRawData(float *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
/**
* 深拷贝操作
* @return 深拷贝的Matrix对象
*/
CudaMatrix deepCopy() const;
// add
CudaMatrix operator+(float aScalar) const;
friend CudaMatrix operator+(float aScalar, const CudaMatrix &aMatrix);
friend CudaMatrix& operator+(float aScalar, CudaMatrix &&aMatrix);
friend CudaMatrix& operator+(CudaMatrix &&aMatrix,float aScalar);
CudaMatrix operator+(const CudaMatrix &aMatrix) const;
CudaMatrix operator+(CudaMatrix &&aMatrix) const;
friend CudaMatrix operator+(CudaMatrix &&aMatrix,CudaMatrix &aOther);
// sub
CudaMatrix operator-(float aScalar) const;
friend CudaMatrix operator-(float aScalar, const CudaMatrix &aMatrix);
friend CudaMatrix& operator-(float aScalar, CudaMatrix &&aMatrix);
friend CudaMatrix& operator-(CudaMatrix &&aMatrix,float aScalar);
CudaMatrix operator-(const CudaMatrix &aMatrix) const;
CudaMatrix operator-(CudaMatrix &&aMatrix) const;
friend CudaMatrix operator-(CudaMatrix &&aMatrix,CudaMatrix &aOther);
//negetive
friend CudaMatrix operator-(CudaMatrix &&aMatrix);
friend CudaMatrix operator-(const CudaMatrix &aMatrix);
// mul
CudaMatrix operator*(float aScalar) const;
friend CudaMatrix operator*(float aScalar, const CudaMatrix &aMatrix);
friend CudaMatrix& operator*(float aScalar, CudaMatrix &&aMatrix);
friend CudaMatrix& operator*(CudaMatrix &&aMatrix,float aScalar);
CudaMatrix operator*(const CudaMatrix &aMatrix) const;
CudaMatrix operator*(CudaMatrix &&aMatrix) const;
friend CudaMatrix operator*(CudaMatrix &&aMatrix,CudaMatrix &aOther);
// div
CudaMatrix operator/(float aScalar) const;
friend CudaMatrix operator/(float aScalar, const CudaMatrix &aMatrix);
friend CudaMatrix& operator/(float aScalar, CudaMatrix &&aMatrix);
friend CudaMatrix& operator/(CudaMatrix &&aMatrix,float aScalar);
CudaMatrix operator/(const CudaMatrix &aMatrix) const;
CudaMatrix operator/(CudaMatrix &&aMatrix) const;
friend CudaMatrix operator/(CudaMatrix &&aMatrix, CudaMatrix &aOther);
// pow
CudaMatrix operator^(int times) const;
friend CudaMatrix operator^(CudaMatrix &&aMatrix,int times);
//compare
CudaMatrix operator>(float aScalar) const;
friend CudaMatrix operator>(float aScalar, const CudaMatrix &aMatrix);
CudaMatrix operator>(const CudaMatrix &aMatrix) const;
CudaMatrix operator<(float aScalar) const;
friend CudaMatrix operator<(float aScalar, const CudaMatrix &aMatrix);
CudaMatrix operator<(const CudaMatrix &aMatrix) const;
CudaMatrix operator>=(float aScalar) const;
friend CudaMatrix operator>=(float aScalar, const CudaMatrix &aMatrix);
CudaMatrix operator>=(const CudaMatrix &aMatrix) const;
CudaMatrix operator<=(float aScalar) const;
friend CudaMatrix operator<=(float aScalar, const CudaMatrix &aMatrix);
CudaMatrix operator<=(const CudaMatrix &aMatrix) const;
CudaMatrix operator==(float aScalar) const;
friend CudaMatrix operator==(float aScalar, const CudaMatrix &aMatrix);
CudaMatrix operator==(const CudaMatrix &aMatrix) const;
CudaMatrix operator!=(float aScalar) const;
friend CudaMatrix operator!=(float aScalar, const CudaMatrix &aMatrix);
CudaMatrix operator!=(const CudaMatrix &aMatrix) const;
// sub
float& operator[](size_t index);
float operator[](size_t index) const;
/**
* 切块操作
*
* @param aDim 需要切块的维度,
* @param aBeginIndx 起始索引,包含
* @param aEndIndex 终止索引,包含
* @return Matrix 返回矩阵
*/
CudaMatrix block(int aDim,int aBeginIndx, int aEndIndex) const;
bool setBlockValue(int aDim,int aBeginIndx, int aEndIndex,float value);
bool setBlockComplexValue(int aDim,int aBeginIndx, int aEndIndex,std::complex<float> value);
bool setBlock(int aDim,int aBeginIndx, int aEndIndex,const CudaMatrix& src);
/**
* 矩阵乘法
* @attention 目前只支持矩阵乘向量
* @param aOther
* @return Matrix
*/
CudaMatrix mul(const CudaMatrix& aOther) const;
/**
* 矩阵乘法
* @attention 目前只支持矩阵乘向量
* @param aOther
* @return Matrix
*/
CudaMatrix mul(CudaMatrix&& aOther) const;
/**
* print matrix , only support 2d matrix now
*/
void printf();
void printfShape();
bool isScalar() const;
float getScalar() const;
/**
* Get is the Matrix's data is empty or size is zero.
* @return
*/
bool isNull() const;
bool isNan() const;
bool isVector() const;
bool isMKLAllocated() const;
/**
* Get dimension count of the matrix
* @return dimension count
*/
int getDims() const;
float *getData() const;
int getDimSize(int index) const;
/**
* Compare matrix shape
* @param other matrix
* @return is identity
*/
bool compareShape(const CudaMatrix& other) const;
/**
* Get the data size
* @return the unit count of the matrix
*/
size_t getDataSize() const;
/**
* Get Value type as normal and complex,
* complex use std::complex<float>,
* it's contains two float value.
* @return
*/
ValueType getValueType() const {
return mValueType;
}
/**
* Set Value type as normal and complex,
* @return
*/
void setValueType(ValueType aValueType) {
mValueType = aValueType;
}
/**
* return true if the valueType is complex,
* @return
*/
bool isComplex() const {
if(mValueType == Complex)
{
return true;
}
return false;
}
void forceReshape(int rows, int columns, int slices);
Matrix toHostMatrix() const;
private:
ValueType mValueType = Normal;
std::shared_ptr<float> mData;
std::vector<int> mInfo;
};
}
#endif