240 lines
7.6 KiB
C++
240 lines
7.6 KiB
C++
#ifndef CUDAMATRIX_H
|
|
#define CUDAMATRIX_H
|
|
|
|
#include "Matrix.h"
|
|
|
|
|
|
namespace Aurora
|
|
{
|
|
class CudaMatrix
|
|
{
|
|
public:
|
|
explicit CudaMatrix(std::shared_ptr<float> aData = std::shared_ptr<float>(),
|
|
std::vector<int> aInfo = std::vector<int>(),
|
|
ValueType aValueType = Normal);
|
|
|
|
/**
|
|
* Create from a Raw data(float array).
|
|
* Use Raw data which create like new float[size]() as a data source
|
|
* and the share_ptr's deleter will be std::default_delete<float[]>
|
|
* @param data
|
|
* @param rows
|
|
* @param cols
|
|
* @param slices
|
|
* @param type
|
|
* @return
|
|
*/
|
|
static CudaMatrix fromRawData(float *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
|
|
|
|
/**
|
|
* Create from a Raw data(float array) with copy the data to a new mkl memory.
|
|
* @param data
|
|
* @param rows
|
|
* @param cols
|
|
* @param slices
|
|
* @param type
|
|
* @return
|
|
*/
|
|
static CudaMatrix copyFromRawData(float *data, int rows, int cols = 1, int slices = 1, ValueType type = Normal);
|
|
|
|
/**
|
|
* 深拷贝操作
|
|
* @return 深拷贝的Matrix对象
|
|
*/
|
|
CudaMatrix deepCopy() const;
|
|
|
|
// add
|
|
CudaMatrix operator+(float aScalar) const;
|
|
friend CudaMatrix operator+(float aScalar, const CudaMatrix &aMatrix);
|
|
friend CudaMatrix& operator+(float aScalar, CudaMatrix &&aMatrix);
|
|
friend CudaMatrix& operator+(CudaMatrix &&aMatrix,float aScalar);
|
|
CudaMatrix operator+(const CudaMatrix &aMatrix) const;
|
|
CudaMatrix operator+(CudaMatrix &&aMatrix) const;
|
|
friend CudaMatrix operator+(CudaMatrix &&aMatrix,CudaMatrix &aOther);
|
|
|
|
// sub
|
|
CudaMatrix operator-(float aScalar) const;
|
|
friend CudaMatrix operator-(float aScalar, const CudaMatrix &aMatrix);
|
|
friend CudaMatrix& operator-(float aScalar, CudaMatrix &&aMatrix);
|
|
friend CudaMatrix& operator-(CudaMatrix &&aMatrix,float aScalar);
|
|
CudaMatrix operator-(const CudaMatrix &aMatrix) const;
|
|
CudaMatrix operator-(CudaMatrix &&aMatrix) const;
|
|
friend CudaMatrix operator-(CudaMatrix &&aMatrix,CudaMatrix &aOther);
|
|
|
|
//negetive
|
|
friend CudaMatrix operator-(CudaMatrix &&aMatrix);
|
|
friend CudaMatrix operator-(const CudaMatrix &aMatrix);
|
|
|
|
// mul
|
|
CudaMatrix operator*(float aScalar) const;
|
|
friend CudaMatrix operator*(float aScalar, const CudaMatrix &aMatrix);
|
|
friend CudaMatrix& operator*(float aScalar, CudaMatrix &&aMatrix);
|
|
friend CudaMatrix& operator*(CudaMatrix &&aMatrix,float aScalar);
|
|
CudaMatrix operator*(const CudaMatrix &aMatrix) const;
|
|
CudaMatrix operator*(CudaMatrix &&aMatrix) const;
|
|
friend CudaMatrix operator*(CudaMatrix &&aMatrix,CudaMatrix &aOther);
|
|
|
|
// div
|
|
CudaMatrix operator/(float aScalar) const;
|
|
friend CudaMatrix operator/(float aScalar, const CudaMatrix &aMatrix);
|
|
friend CudaMatrix& operator/(float aScalar, CudaMatrix &&aMatrix);
|
|
friend CudaMatrix& operator/(CudaMatrix &&aMatrix,float aScalar);
|
|
CudaMatrix operator/(const CudaMatrix &aMatrix) const;
|
|
CudaMatrix operator/(CudaMatrix &&aMatrix) const;
|
|
friend CudaMatrix operator/(CudaMatrix &&aMatrix, CudaMatrix &aOther);
|
|
|
|
// pow
|
|
CudaMatrix operator^(int times) const;
|
|
friend CudaMatrix operator^(CudaMatrix &&aMatrix,int times);
|
|
|
|
//compare
|
|
CudaMatrix operator>(float aScalar) const;
|
|
friend CudaMatrix operator>(float aScalar, const CudaMatrix &aMatrix);
|
|
CudaMatrix operator>(const CudaMatrix &aMatrix) const;
|
|
|
|
CudaMatrix operator<(float aScalar) const;
|
|
friend CudaMatrix operator<(float aScalar, const CudaMatrix &aMatrix);
|
|
CudaMatrix operator<(const CudaMatrix &aMatrix) const;
|
|
|
|
CudaMatrix operator>=(float aScalar) const;
|
|
friend CudaMatrix operator>=(float aScalar, const CudaMatrix &aMatrix);
|
|
CudaMatrix operator>=(const CudaMatrix &aMatrix) const;
|
|
|
|
CudaMatrix operator<=(float aScalar) const;
|
|
friend CudaMatrix operator<=(float aScalar, const CudaMatrix &aMatrix);
|
|
CudaMatrix operator<=(const CudaMatrix &aMatrix) const;
|
|
|
|
CudaMatrix operator==(float aScalar) const;
|
|
friend CudaMatrix operator==(float aScalar, const CudaMatrix &aMatrix);
|
|
CudaMatrix operator==(const CudaMatrix &aMatrix) const;
|
|
|
|
CudaMatrix operator!=(float aScalar) const;
|
|
friend CudaMatrix operator!=(float aScalar, const CudaMatrix &aMatrix);
|
|
CudaMatrix operator!=(const CudaMatrix &aMatrix) const;
|
|
|
|
// sub
|
|
float& operator[](size_t index);
|
|
float operator[](size_t index) const;
|
|
|
|
/**
|
|
* 切块操作
|
|
*
|
|
* @param aDim 需要切块的维度,
|
|
* @param aBeginIndx 起始索引,包含
|
|
* @param aEndIndex 终止索引,包含
|
|
* @return Matrix 返回矩阵
|
|
*/
|
|
CudaMatrix block(int aDim,int aBeginIndx, int aEndIndex) const;
|
|
|
|
bool setBlockValue(int aDim,int aBeginIndx, int aEndIndex,float value);
|
|
bool setBlockComplexValue(int aDim,int aBeginIndx, int aEndIndex,std::complex<float> value);
|
|
|
|
|
|
bool setBlock(int aDim,int aBeginIndx, int aEndIndex,const CudaMatrix& src);
|
|
|
|
/**
|
|
* 矩阵乘法
|
|
* @attention 目前只支持矩阵乘向量
|
|
* @param aOther
|
|
* @return Matrix
|
|
*/
|
|
CudaMatrix mul(const CudaMatrix& aOther) const;
|
|
|
|
/**
|
|
* 矩阵乘法
|
|
* @attention 目前只支持矩阵乘向量
|
|
* @param aOther
|
|
* @return Matrix
|
|
*/
|
|
CudaMatrix mul(CudaMatrix&& aOther) const;
|
|
|
|
/**
|
|
* print matrix , only support 2d matrix now
|
|
*/
|
|
void printf();
|
|
|
|
void printfShape();
|
|
|
|
bool isScalar() const;
|
|
|
|
float getScalar() const;
|
|
|
|
/**
|
|
* Get is the Matrix's data is empty or size is zero.
|
|
* @return
|
|
*/
|
|
bool isNull() const;
|
|
|
|
bool isNan() const;
|
|
|
|
bool isVector() const;
|
|
|
|
bool isMKLAllocated() const;
|
|
|
|
/**
|
|
* Get dimension count of the matrix
|
|
* @return dimension count
|
|
*/
|
|
int getDims() const;
|
|
|
|
float *getData() const;
|
|
|
|
int getDimSize(int index) const;
|
|
|
|
/**
|
|
* Compare matrix shape
|
|
* @param other matrix
|
|
* @return is identity
|
|
*/
|
|
bool compareShape(const CudaMatrix& other) const;
|
|
|
|
/**
|
|
* Get the data size
|
|
* @return the unit count of the matrix
|
|
*/
|
|
size_t getDataSize() const;
|
|
|
|
/**
|
|
* Get Value type as normal and complex,
|
|
* complex use std::complex<float>,
|
|
* it's contains two float value.
|
|
* @return
|
|
*/
|
|
ValueType getValueType() const {
|
|
return mValueType;
|
|
}
|
|
|
|
/**
|
|
* Set Value type as normal and complex,
|
|
* @return
|
|
*/
|
|
void setValueType(ValueType aValueType) {
|
|
mValueType = aValueType;
|
|
}
|
|
|
|
/**
|
|
* return true if the valueType is complex,
|
|
* @return
|
|
*/
|
|
bool isComplex() const {
|
|
if(mValueType == Complex)
|
|
{
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void forceReshape(int rows, int columns, int slices);
|
|
|
|
Matrix toHostMatrix() const;
|
|
|
|
|
|
private:
|
|
ValueType mValueType = Normal;
|
|
std::shared_ptr<float> mData;
|
|
std::vector<int> mInfo;
|
|
};
|
|
}
|
|
|
|
#endif
|