#include "Function1D.h" #include #include #include //必须在mkl.h和Eigen的头之前,之后 #define MKL_Complex16 std::complex #include "mkl.h" #include #include #include namespace { const int COMPLEX_STRIDE = 2; const int REAL_STRIDE = 1; const int SAME_STRIDE = 1; const double VALUE_ONE = 1.0; } Aurora::Matrix Aurora::complex(const Aurora::Matrix &matrix) { if (matrix.getValueType() == Complex) { std::cerr<<"complex not support complex value type"< *) mkl_malloc(matrix.getDataSize() * sizeof(std::complex), 64); memset(output, 0, (matrix.getDataSize() * sizeof(std::complex))); cblas_dcopy(matrix.getDataSize(), matrix.getData(), REAL_STRIDE, (double *) output, COMPLEX_STRIDE); return Aurora::Matrix::New((double *) output, matrix.getDimSize(0), matrix.getDimSize(1), matrix.getDimSize(2), Complex); } Aurora::Matrix Aurora::real(const Aurora::Matrix &matrix) { if (matrix.getValueType() == Normal) { std::cerr<<"real only support complex value type"< *)matrix.getData(), SAME_STRIDE,output, SAME_STRIDE); } return Aurora::Matrix::New(output, matrix); } Aurora::Matrix Aurora::abs(const Aurora::Matrix&& matrix) { std::cout<<"RR abs"< *)matrix.getData(), SAME_STRIDE,output, SAME_STRIDE); return Aurora::Matrix::New(output, matrix); } } Aurora::Matrix Aurora::sign(const Aurora::Matrix &matrix) { if (matrix.getValueType()==Normal){ auto ret = matrix.deepCopy(); Eigen::Map retV(ret.getData(),ret.getDataSize()); retV = retV.array().sign(); return ret; } else{ //sign(x) = x./abs(x),前提是 x 为复数。 auto output = (double *) mkl_malloc(matrix.getDataSize() * sizeof(std::complex), 64); Matrix absMatrix = abs(matrix); vdDivI(matrix.getDataSize(), matrix.getData(),COMPLEX_STRIDE, absMatrix.getData(), REAL_STRIDE,output,COMPLEX_STRIDE); vdDivI(matrix.getDataSize(), matrix.getData()+1,COMPLEX_STRIDE, absMatrix.getData(), REAL_STRIDE,output+1,COMPLEX_STRIDE); return Aurora::Matrix::New(output, matrix); } } Aurora::Matrix Aurora::sign(const Aurora::Matrix&& matrix) { std::cout<<"RR sign"< retV(matrix.getData(),matrix.getDataSize()); retV = retV.array().sign(); return matrix; } else{ //sign(x) = x./abs(x),前提是 x 为复数。 Matrix absMatrix = abs(matrix); vdDivI(matrix.getDataSize(), matrix.getData(),COMPLEX_STRIDE, absMatrix.getData(), REAL_STRIDE,matrix.getData(),COMPLEX_STRIDE); vdDivI(matrix.getDataSize(), matrix.getData()+1,COMPLEX_STRIDE, absMatrix.getData(), REAL_STRIDE,matrix.getData()+1,COMPLEX_STRIDE); return matrix; } }