Add fft, ifft, hilbert and their unit test.
This commit is contained in:
@@ -516,3 +516,128 @@ Matrix Aurora::median(const Matrix &aMatrix) {
|
||||
}
|
||||
}
|
||||
|
||||
Matrix Aurora::fft(const Matrix &aMatrix) {
|
||||
double *output = nullptr;
|
||||
output = malloc(aMatrix.getDataSize(), true);
|
||||
if (!aMatrix.isComplex()) {
|
||||
cblas_dcopy(aMatrix.getDataSize(), aMatrix.getData(), 1, output, 2);
|
||||
double zero = 0.0;
|
||||
cblas_dcopy(aMatrix.getDataSize(), &zero, 0, output+1, 2);
|
||||
} else {
|
||||
cblas_zcopy(aMatrix.getDataSize(), aMatrix.getData(), 1, output, 1);
|
||||
}
|
||||
|
||||
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
|
||||
MKL_LONG status;
|
||||
|
||||
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
|
||||
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, aMatrix.getDimSize(0));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//通过 setValue 配置Descriptor
|
||||
//使用单独的输出数据缓存
|
||||
status = DftiSetValue(my_desc_handle, DFTI_PLACEMENT, DFTI_INPLACE);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//每个傅里叶变换的输入距离
|
||||
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,aMatrix.getDimSize(0));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//每个傅里叶变换的输出距离
|
||||
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,aMatrix.getDimSize(0));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//傅里叶变换的数量
|
||||
status = DftiSetValue(my_desc_handle,DFTI_NUMBER_OF_TRANSFORMS,aMatrix.getDimSize(1));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//提交 修改配置后的Descriptor(实际上会进行FFT的计算初始化)
|
||||
status = DftiCommitDescriptor(my_desc_handle);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//执行计算
|
||||
status = DftiComputeForward(my_desc_handle, output, output);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//释放资源
|
||||
status = DftiFreeDescriptor(&my_desc_handle);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
|
||||
error:
|
||||
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;
|
||||
return Matrix();
|
||||
}
|
||||
|
||||
Matrix Aurora::ifft(const Matrix &aMatrix) {
|
||||
if (!aMatrix.isComplex()){
|
||||
std::cerr<<"ifft input must be complex value"<<std::endl;
|
||||
return Matrix();
|
||||
}
|
||||
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
|
||||
auto output = malloc(aMatrix.getDataSize(),true);
|
||||
MKL_LONG status;
|
||||
|
||||
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
|
||||
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, aMatrix.getDimSize(0));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
//通过 setValue 配置Descriptor
|
||||
//使用单独的输出数据缓存
|
||||
status = DftiSetValue(my_desc_handle, DFTI_PLACEMENT, DFTI_NOT_INPLACE);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//设置DFTI_BACKWARD_SCALE !!!很关键,不然值不对
|
||||
status = DftiSetValue(my_desc_handle, DFTI_BACKWARD_SCALE, 1.0f / aMatrix.getDimSize(0));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,aMatrix.getDimSize(0));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//每个傅里叶变换的输出距离
|
||||
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,aMatrix.getDimSize(0));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//傅里叶变换的数量
|
||||
status = DftiSetValue(my_desc_handle,DFTI_NUMBER_OF_TRANSFORMS,aMatrix.getDimSize(1));
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
status = DftiCommitDescriptor(my_desc_handle);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//提交 修改配置后的Descriptor(实际上会进行FFT的计算初始化)
|
||||
status = DftiCommitDescriptor(my_desc_handle);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//执行计算
|
||||
status = DftiComputeBackward(my_desc_handle, aMatrix.getData(), output);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
//释放资源
|
||||
status = DftiFreeDescriptor(&my_desc_handle);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
{
|
||||
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
|
||||
}
|
||||
error:
|
||||
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;
|
||||
return Matrix();
|
||||
}
|
||||
|
||||
Matrix Aurora::hilbert(const Matrix &aMatrix) {
|
||||
auto x = fft(aMatrix);
|
||||
auto h = new double[aMatrix.getDimSize(0)];
|
||||
auto two = 2.0;
|
||||
auto zero = 0.0;
|
||||
cblas_dcopy(aMatrix.getDimSize(0), &zero, 0, h, 1);
|
||||
cblas_dcopy(aMatrix.getDimSize(0) / 2, &two, 0, h, 1);
|
||||
h[aMatrix.getDimSize(0) / 2] = ((aMatrix.getDimSize(0) << 31) >> 31) ? 2.0 : 1.0;
|
||||
h[0] = 1.0;
|
||||
for (int i = 0; i < aMatrix.getDimSize(1); ++i) {
|
||||
auto p = (double *)(x.getData() + aMatrix.getDimSize(0)* i*2);
|
||||
vdMulI(aMatrix.getDimSize(0), p, 2, h, 1, p, 2);
|
||||
vdMulI(aMatrix.getDimSize(0), p + 1, 2, h, 1, p + 1, 2);
|
||||
}
|
||||
auto result = ifft( x);
|
||||
delete[] h;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace Aurora {
|
||||
|
||||
/**
|
||||
* 求矩阵最小值,可按行、列、单元, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 矩阵
|
||||
* @param aMatrix 目标矩阵
|
||||
* @param direction 方向,Column, Row, All
|
||||
* @return
|
||||
*/
|
||||
@@ -30,9 +30,9 @@ namespace Aurora {
|
||||
|
||||
/**
|
||||
* 求矩阵最小值,可按行、列、单元, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 矩阵
|
||||
* @param aMatrix 目标矩阵
|
||||
* @param direction 方向,Column, Row, All
|
||||
* @return
|
||||
* @return 最大值矩阵
|
||||
*/
|
||||
Matrix max(const Matrix& aMatrix,FunctionDirection direction = Column);
|
||||
|
||||
@@ -41,36 +41,85 @@ namespace Aurora {
|
||||
/**
|
||||
* 比较两个矩阵,求对应位置的最小值,不支持三维
|
||||
* @attention 矩阵形状不一样时,如A为[MxN],则B应为标量或[1xN]的行向量
|
||||
* @param aMatrix
|
||||
* @param aOther
|
||||
* @return
|
||||
* @param aMatrix 目标矩阵1
|
||||
* @param aOther 目标矩阵2
|
||||
* @return 最小值矩阵
|
||||
*/
|
||||
Matrix min(const Matrix& aMatrix,const Matrix& aOther);
|
||||
|
||||
/**
|
||||
* 求矩阵和,可按行、列、单元, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 矩阵
|
||||
* @param aMatrix 目标矩阵
|
||||
* @param direction 方向,Column, Row, All
|
||||
* @return
|
||||
* @return 求和结果矩阵
|
||||
*/
|
||||
Matrix sum(const Matrix& aMatrix,FunctionDirection direction = Column);
|
||||
|
||||
/**
|
||||
* 求矩阵平均值,可按行、列、单元, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 矩阵
|
||||
* @param aMatrix 目标矩阵
|
||||
* @param direction 方向,Column, Row, All
|
||||
* @param aIncludeNan 是否包含nan
|
||||
* @return
|
||||
* @return 平均值矩阵
|
||||
*/
|
||||
Matrix mean(const Matrix& aMatrix,FunctionDirection direction = Column, bool aIncludeNan = true);
|
||||
|
||||
/**
|
||||
* 矩阵排序 按列, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 目标矩阵
|
||||
* @return 排序后矩阵
|
||||
*/
|
||||
Matrix sort(const Matrix& aMatrix);
|
||||
|
||||
/**
|
||||
* 矩阵排序 按列, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 目标矩阵
|
||||
* @return 排序后矩阵
|
||||
*/
|
||||
Matrix sort(Matrix&& aMatrix);
|
||||
|
||||
/**
|
||||
* 矩阵排序 按行, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 目标矩阵
|
||||
* @return 排序后矩阵
|
||||
*/
|
||||
Matrix sortrows(const Matrix& aMatrix);
|
||||
|
||||
/**
|
||||
* 矩阵排序 按行, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 目标矩阵
|
||||
* @return 排序后矩阵
|
||||
*/
|
||||
Matrix sortrows(Matrix&& aMatrix);
|
||||
|
||||
/**
|
||||
* 对矩阵求中间值 按列, 目前不支持三维,不支持复数
|
||||
* @param aMatrix 目标矩阵
|
||||
* @return 中值矩阵
|
||||
*/
|
||||
Matrix median(const Matrix& aMatrix);
|
||||
|
||||
/**
|
||||
* FFT,支持到2维,输入可以是常数可以是复数,输出必是复数
|
||||
* @param aMatrix 目标矩阵
|
||||
* @return fft后的复数矩阵
|
||||
*/
|
||||
Matrix fft(const Matrix& aMatrix);
|
||||
|
||||
/**
|
||||
* 逆fft,支持到2维,输入必须是复数,输出必是复数
|
||||
* @attention 如有需要可使用real去除虚部
|
||||
* @param aMatrix
|
||||
* @return ifft后的复数矩阵
|
||||
*/
|
||||
Matrix ifft(const Matrix& aMatrix);
|
||||
|
||||
/**
|
||||
* hilbert,支持到2维,输入必须是复数,输出必是复数
|
||||
* @param aMatrix
|
||||
* @return
|
||||
*/
|
||||
Matrix hilbert(const Matrix& aMatrix);
|
||||
};
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user