Add fft, ifft, hilbert and their unit test.

This commit is contained in:
Krad
2023-04-25 13:02:11 +08:00
parent c54471ef6a
commit a6b2474dd1
3 changed files with 210 additions and 36 deletions

View File

@@ -516,3 +516,128 @@ Matrix Aurora::median(const Matrix &aMatrix) {
} }
} }
Matrix Aurora::fft(const Matrix &aMatrix) {
double *output = nullptr;
output = malloc(aMatrix.getDataSize(), true);
if (!aMatrix.isComplex()) {
cblas_dcopy(aMatrix.getDataSize(), aMatrix.getData(), 1, output, 2);
double zero = 0.0;
cblas_dcopy(aMatrix.getDataSize(), &zero, 0, output+1, 2);
} else {
cblas_zcopy(aMatrix.getDataSize(), aMatrix.getData(), 1, output, 1);
}
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
MKL_LONG status;
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, aMatrix.getDimSize(0));
if (status != DFTI_NO_ERROR) goto error;
//通过 setValue 配置Descriptor
//使用单独的输出数据缓存
status = DftiSetValue(my_desc_handle, DFTI_PLACEMENT, DFTI_INPLACE);
if (status != DFTI_NO_ERROR) goto error;
//每个傅里叶变换的输入距离
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,aMatrix.getDimSize(0));
if (status != DFTI_NO_ERROR) goto error;
//每个傅里叶变换的输出距离
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,aMatrix.getDimSize(0));
if (status != DFTI_NO_ERROR) goto error;
//傅里叶变换的数量
status = DftiSetValue(my_desc_handle,DFTI_NUMBER_OF_TRANSFORMS,aMatrix.getDimSize(1));
if (status != DFTI_NO_ERROR) goto error;
//提交 修改配置后的Descriptor(实际上会进行FFT的计算初始化)
status = DftiCommitDescriptor(my_desc_handle);
if (status != DFTI_NO_ERROR) goto error;
//执行计算
status = DftiComputeForward(my_desc_handle, output, output);
if (status != DFTI_NO_ERROR) goto error;
//释放资源
status = DftiFreeDescriptor(&my_desc_handle);
if (status != DFTI_NO_ERROR) goto error;
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
error:
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;
return Matrix();
}
Matrix Aurora::ifft(const Matrix &aMatrix) {
if (!aMatrix.isComplex()){
std::cerr<<"ifft input must be complex value"<<std::endl;
return Matrix();
}
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
auto output = malloc(aMatrix.getDataSize(),true);
MKL_LONG status;
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, aMatrix.getDimSize(0));
if (status != DFTI_NO_ERROR) goto error;
//通过 setValue 配置Descriptor
//使用单独的输出数据缓存
status = DftiSetValue(my_desc_handle, DFTI_PLACEMENT, DFTI_NOT_INPLACE);
if (status != DFTI_NO_ERROR) goto error;
//设置DFTI_BACKWARD_SCALE !!!很关键,不然值不对
status = DftiSetValue(my_desc_handle, DFTI_BACKWARD_SCALE, 1.0f / aMatrix.getDimSize(0));
if (status != DFTI_NO_ERROR) goto error;
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,aMatrix.getDimSize(0));
if (status != DFTI_NO_ERROR) goto error;
//每个傅里叶变换的输出距离
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,aMatrix.getDimSize(0));
if (status != DFTI_NO_ERROR) goto error;
//傅里叶变换的数量
status = DftiSetValue(my_desc_handle,DFTI_NUMBER_OF_TRANSFORMS,aMatrix.getDimSize(1));
if (status != DFTI_NO_ERROR) goto error;
status = DftiCommitDescriptor(my_desc_handle);
if (status != DFTI_NO_ERROR) goto error;
//提交 修改配置后的Descriptor(实际上会进行FFT的计算初始化)
status = DftiCommitDescriptor(my_desc_handle);
if (status != DFTI_NO_ERROR) goto error;
//执行计算
status = DftiComputeBackward(my_desc_handle, aMatrix.getData(), output);
if (status != DFTI_NO_ERROR) goto error;
//释放资源
status = DftiFreeDescriptor(&my_desc_handle);
if (status != DFTI_NO_ERROR) goto error;
{
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
}
error:
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;
return Matrix();
}
Matrix Aurora::hilbert(const Matrix &aMatrix) {
auto x = fft(aMatrix);
auto h = new double[aMatrix.getDimSize(0)];
auto two = 2.0;
auto zero = 0.0;
cblas_dcopy(aMatrix.getDimSize(0), &zero, 0, h, 1);
cblas_dcopy(aMatrix.getDimSize(0) / 2, &two, 0, h, 1);
h[aMatrix.getDimSize(0) / 2] = ((aMatrix.getDimSize(0) << 31) >> 31) ? 2.0 : 1.0;
h[0] = 1.0;
for (int i = 0; i < aMatrix.getDimSize(1); ++i) {
auto p = (double *)(x.getData() + aMatrix.getDimSize(0)* i*2);
vdMulI(aMatrix.getDimSize(0), p, 2, h, 1, p, 2);
vdMulI(aMatrix.getDimSize(0), p + 1, 2, h, 1, p + 1, 2);
}
auto result = ifft( x);
delete[] h;
return result;
}

View File

@@ -20,7 +20,7 @@ namespace Aurora {
/** /**
* 求矩阵最小值,可按行、列、单元, 目前不支持三维,不支持复数 * 求矩阵最小值,可按行、列、单元, 目前不支持三维,不支持复数
* @param aMatrix 矩阵 * @param aMatrix 目标矩阵
* @param direction 方向Column, Row, All * @param direction 方向Column, Row, All
* @return * @return
*/ */
@@ -30,9 +30,9 @@ namespace Aurora {
/** /**
* 求矩阵最小值,可按行、列、单元, 目前不支持三维,不支持复数 * 求矩阵最小值,可按行、列、单元, 目前不支持三维,不支持复数
* @param aMatrix 矩阵 * @param aMatrix 目标矩阵
* @param direction 方向Column, Row, All * @param direction 方向Column, Row, All
* @return * @return 最大值矩阵
*/ */
Matrix max(const Matrix& aMatrix,FunctionDirection direction = Column); Matrix max(const Matrix& aMatrix,FunctionDirection direction = Column);
@@ -41,36 +41,85 @@ namespace Aurora {
/** /**
* 比较两个矩阵,求对应位置的最小值,不支持三维 * 比较两个矩阵,求对应位置的最小值,不支持三维
* @attention 矩阵形状不一样时如A为[MxN],则B应为标量或[1xN]的行向量 * @attention 矩阵形状不一样时如A为[MxN],则B应为标量或[1xN]的行向量
* @param aMatrix * @param aMatrix 目标矩阵1
* @param aOther * @param aOther 目标矩阵2
* @return * @return 最小值矩阵
*/ */
Matrix min(const Matrix& aMatrix,const Matrix& aOther); Matrix min(const Matrix& aMatrix,const Matrix& aOther);
/** /**
* 求矩阵和,可按行、列、单元, 目前不支持三维,不支持复数 * 求矩阵和,可按行、列、单元, 目前不支持三维,不支持复数
* @param aMatrix 矩阵 * @param aMatrix 目标矩阵
* @param direction 方向Column, Row, All * @param direction 方向Column, Row, All
* @return * @return 求和结果矩阵
*/ */
Matrix sum(const Matrix& aMatrix,FunctionDirection direction = Column); Matrix sum(const Matrix& aMatrix,FunctionDirection direction = Column);
/** /**
* 求矩阵平均值,可按行、列、单元, 目前不支持三维,不支持复数 * 求矩阵平均值,可按行、列、单元, 目前不支持三维,不支持复数
* @param aMatrix 矩阵 * @param aMatrix 目标矩阵
* @param direction 方向Column, Row, All * @param direction 方向Column, Row, All
* @param aIncludeNan 是否包含nan * @param aIncludeNan 是否包含nan
* @return * @return 平均值矩阵
*/ */
Matrix mean(const Matrix& aMatrix,FunctionDirection direction = Column, bool aIncludeNan = true); Matrix mean(const Matrix& aMatrix,FunctionDirection direction = Column, bool aIncludeNan = true);
/**
* 矩阵排序 按列, 目前不支持三维,不支持复数
* @param aMatrix 目标矩阵
* @return 排序后矩阵
*/
Matrix sort(const Matrix& aMatrix); Matrix sort(const Matrix& aMatrix);
/**
* 矩阵排序 按列, 目前不支持三维,不支持复数
* @param aMatrix 目标矩阵
* @return 排序后矩阵
*/
Matrix sort(Matrix&& aMatrix); Matrix sort(Matrix&& aMatrix);
/**
* 矩阵排序 按行, 目前不支持三维,不支持复数
* @param aMatrix 目标矩阵
* @return 排序后矩阵
*/
Matrix sortrows(const Matrix& aMatrix); Matrix sortrows(const Matrix& aMatrix);
/**
* 矩阵排序 按行, 目前不支持三维,不支持复数
* @param aMatrix 目标矩阵
* @return 排序后矩阵
*/
Matrix sortrows(Matrix&& aMatrix); Matrix sortrows(Matrix&& aMatrix);
/**
* 对矩阵求中间值 按列, 目前不支持三维,不支持复数
* @param aMatrix 目标矩阵
* @return 中值矩阵
*/
Matrix median(const Matrix& aMatrix); Matrix median(const Matrix& aMatrix);
/**
* FFT,支持到2维输入可以是常数可以是复数输出必是复数
* @param aMatrix 目标矩阵
* @return fft后的复数矩阵
*/
Matrix fft(const Matrix& aMatrix);
/**
* 逆fft支持到2维输入必须是复数输出必是复数
* @attention 如有需要可使用real去除虚部
* @param aMatrix
* @return ifft后的复数矩阵
*/
Matrix ifft(const Matrix& aMatrix);
/**
* hilbert支持到2维输入必须是复数输出必是复数
* @param aMatrix
* @return
*/
Matrix hilbert(const Matrix& aMatrix);
}; };

View File

@@ -311,35 +311,35 @@ TEST_F(Function2D_Test, median) {
} }
TEST_F(Function2D_Test, fftAndComplexAndIfft){ TEST_F(Function2D_Test, fftAndComplexAndIfft){
// double input[10]{1,1,0,2,2,0,1,1,0,2}; double *input = new double[20]{1,1,0,2,2,0,1,1,0,2,1,1,0,2,2,0,1,1,0,2};
// std::complex<double>* complexInput = Aurora::complex(10,input); auto ma = Aurora::Matrix::fromRawData(input,10,2);
// //复数化后实部不变虚部全为0 auto ret = Aurora::fft(ma);
// EXPECT_DOUBLE_EQ(complexInput[1].real(),1.0)<<" complex error"; std::complex<double>* result = (std::complex<double>*)ret.getData();
// EXPECT_DOUBLE_EQ(complexInput[1].imag(),0)<<" complex error"; //检验fft结果与matlab是否对应
// std::complex<double>* result = Aurora::fft(10,complexInput); EXPECT_DOUBLE_EQ(0.0729, fourDecimalRound(result[1].real()));
// delete [] complexInput; EXPECT_DOUBLE_EQ(2.4899, fourDecimalRound(result[2].imag()));
// //检验fft结果与matlab是否对应 EXPECT_DOUBLE_EQ(0.0729, fourDecimalRound(result[11].real()));
// EXPECT_DOUBLE_EQ(0.0729, fourDecimalRound(result[1].real()))<<" fft result value error"; EXPECT_DOUBLE_EQ(2.4899, fourDecimalRound(result[12].imag()));
// EXPECT_DOUBLE_EQ(2.4899, fourDecimalRound(result[2].imag()))<<" fft result value error"; //检验fft的结果是否共轭
// //检验fft的结果是否共轭 EXPECT_DOUBLE_EQ(0, result[4].imag()+result[6].imag());
// EXPECT_DOUBLE_EQ(0, result[4].imag()+result[6].imag())<<" fft result conjugate error"; EXPECT_DOUBLE_EQ(0, result[4].real()-result[6].real());
// EXPECT_DOUBLE_EQ(0, result[4].real()-result[6].real())<<" fft result conjugate error"; ret= Aurora::ifft(ret);
// std::complex<double>* ifftResult = Aurora::ifft(10,result); std::complex<double>* ifftResult = (std::complex<double>*)ret.getData();
// EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[1].real()),1.0)<<" ifft result real value error"; EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[1].real()),1.0);
// EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[1].imag()),0)<<" ifft result imag value error"; EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[3].real()),2.0);
// delete [] result; EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[11].real()),1.0);
// delete [] ifftResult; EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[13].real()),2.0);
} }
TEST_F(Function2D_Test, hilbert) { TEST_F(Function2D_Test, hilbert) {
double input[10]{1,1,0,2,2,0,1,1,0,2}; double *input = new double[20]{1,1,0,2,2,0,1,1,0,2,1,1,0,2,2,0,1,1,0,2};
auto result = Aurora::hilbert(10,input); auto ma = Aurora::Matrix::fromRawData(input,10,2);
EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].real()),1.0)<<" hilbert result real value error"; auto ret = Aurora::hilbert(ma);
EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].imag()),0.3249)<<" hilbert result imag value error"; auto result = (std::complex<double>*)ret.getData();
delete [] result; EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].real()),1.0);
result = Aurora::hilbert(9,input); EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].imag()),0.3249);
EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].real()),1.0)<<" hilbert result real value error"; EXPECT_DOUBLE_EQ(fourDecimalRound(result[11].real()),1.0);
EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].imag()),0.4253)<<" hilbert result imag value error"; EXPECT_DOUBLE_EQ(fourDecimalRound(result[11].imag()),0.3249);
} }
TEST_F(Function2D_Test, interp2) { TEST_F(Function2D_Test, interp2) {