From a6b2474dd15b0a7d0e9cecefcfa9347fa56f67f4 Mon Sep 17 00:00:00 2001 From: Krad Date: Tue, 25 Apr 2023 13:02:11 +0800 Subject: [PATCH] Add fft, ifft, hilbert and their unit test. --- src/Function2D.cpp | 125 +++++++++++++++++++++++++++++++++++++++ src/Function2D.h | 69 +++++++++++++++++---- test/Function2D_Test.cpp | 52 ++++++++-------- 3 files changed, 210 insertions(+), 36 deletions(-) diff --git a/src/Function2D.cpp b/src/Function2D.cpp index 0353dc9..f12c08c 100644 --- a/src/Function2D.cpp +++ b/src/Function2D.cpp @@ -516,3 +516,128 @@ Matrix Aurora::median(const Matrix &aMatrix) { } } +Matrix Aurora::fft(const Matrix &aMatrix) { + double *output = nullptr; + output = malloc(aMatrix.getDataSize(), true); + if (!aMatrix.isComplex()) { + cblas_dcopy(aMatrix.getDataSize(), aMatrix.getData(), 1, output, 2); + double zero = 0.0; + cblas_dcopy(aMatrix.getDataSize(), &zero, 0, output+1, 2); + } else { + cblas_zcopy(aMatrix.getDataSize(), aMatrix.getData(), 1, output, 1); + } + + DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL; + MKL_LONG status; + + //创建 Descriptor, 精度 double , 输入类型实数, 维度1 + status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, aMatrix.getDimSize(0)); + if (status != DFTI_NO_ERROR) goto error; + + //通过 setValue 配置Descriptor + //使用单独的输出数据缓存 + status = DftiSetValue(my_desc_handle, DFTI_PLACEMENT, DFTI_INPLACE); + if (status != DFTI_NO_ERROR) goto error; + + //每个傅里叶变换的输入距离 + status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,aMatrix.getDimSize(0)); + if (status != DFTI_NO_ERROR) goto error; + + //每个傅里叶变换的输出距离 + status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,aMatrix.getDimSize(0)); + if (status != DFTI_NO_ERROR) goto error; + + //傅里叶变换的数量 + status = DftiSetValue(my_desc_handle,DFTI_NUMBER_OF_TRANSFORMS,aMatrix.getDimSize(1)); + if (status != DFTI_NO_ERROR) goto error; + + //提交 修改配置后的Descriptor(实际上会进行FFT的计算初始化) + status = DftiCommitDescriptor(my_desc_handle); + if (status != DFTI_NO_ERROR) goto error; + + //执行计算 + status = DftiComputeForward(my_desc_handle, output, output); + if (status != DFTI_NO_ERROR) goto error; + + //释放资源 + status = DftiFreeDescriptor(&my_desc_handle); + if (status != DFTI_NO_ERROR) goto error; + + return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex); + error: + std::cerr<<"FFT fail, error message:"<> 31) ? 2.0 : 1.0; + h[0] = 1.0; + for (int i = 0; i < aMatrix.getDimSize(1); ++i) { + auto p = (double *)(x.getData() + aMatrix.getDimSize(0)* i*2); + vdMulI(aMatrix.getDimSize(0), p, 2, h, 1, p, 2); + vdMulI(aMatrix.getDimSize(0), p + 1, 2, h, 1, p + 1, 2); + } + auto result = ifft( x); + delete[] h; + return result; +} + diff --git a/src/Function2D.h b/src/Function2D.h index 1286f76..cc9d74e 100644 --- a/src/Function2D.h +++ b/src/Function2D.h @@ -20,7 +20,7 @@ namespace Aurora { /** * 求矩阵最小值,可按行、列、单元, 目前不支持三维,不支持复数 - * @param aMatrix 矩阵 + * @param aMatrix 目标矩阵 * @param direction 方向,Column, Row, All * @return */ @@ -30,9 +30,9 @@ namespace Aurora { /** * 求矩阵最小值,可按行、列、单元, 目前不支持三维,不支持复数 - * @param aMatrix 矩阵 + * @param aMatrix 目标矩阵 * @param direction 方向,Column, Row, All - * @return + * @return 最大值矩阵 */ Matrix max(const Matrix& aMatrix,FunctionDirection direction = Column); @@ -41,36 +41,85 @@ namespace Aurora { /** * 比较两个矩阵,求对应位置的最小值,不支持三维 * @attention 矩阵形状不一样时,如A为[MxN],则B应为标量或[1xN]的行向量 - * @param aMatrix - * @param aOther - * @return + * @param aMatrix 目标矩阵1 + * @param aOther 目标矩阵2 + * @return 最小值矩阵 */ Matrix min(const Matrix& aMatrix,const Matrix& aOther); /** * 求矩阵和,可按行、列、单元, 目前不支持三维,不支持复数 - * @param aMatrix 矩阵 + * @param aMatrix 目标矩阵 * @param direction 方向,Column, Row, All - * @return + * @return 求和结果矩阵 */ Matrix sum(const Matrix& aMatrix,FunctionDirection direction = Column); /** * 求矩阵平均值,可按行、列、单元, 目前不支持三维,不支持复数 - * @param aMatrix 矩阵 + * @param aMatrix 目标矩阵 * @param direction 方向,Column, Row, All * @param aIncludeNan 是否包含nan - * @return + * @return 平均值矩阵 */ Matrix mean(const Matrix& aMatrix,FunctionDirection direction = Column, bool aIncludeNan = true); + /** + * 矩阵排序 按列, 目前不支持三维,不支持复数 + * @param aMatrix 目标矩阵 + * @return 排序后矩阵 + */ Matrix sort(const Matrix& aMatrix); + + /** + * 矩阵排序 按列, 目前不支持三维,不支持复数 + * @param aMatrix 目标矩阵 + * @return 排序后矩阵 + */ Matrix sort(Matrix&& aMatrix); + /** + * 矩阵排序 按行, 目前不支持三维,不支持复数 + * @param aMatrix 目标矩阵 + * @return 排序后矩阵 + */ Matrix sortrows(const Matrix& aMatrix); + + /** + * 矩阵排序 按行, 目前不支持三维,不支持复数 + * @param aMatrix 目标矩阵 + * @return 排序后矩阵 + */ Matrix sortrows(Matrix&& aMatrix); + /** + * 对矩阵求中间值 按列, 目前不支持三维,不支持复数 + * @param aMatrix 目标矩阵 + * @return 中值矩阵 + */ Matrix median(const Matrix& aMatrix); + + /** + * FFT,支持到2维,输入可以是常数可以是复数,输出必是复数 + * @param aMatrix 目标矩阵 + * @return fft后的复数矩阵 + */ + Matrix fft(const Matrix& aMatrix); + + /** + * 逆fft,支持到2维,输入必须是复数,输出必是复数 + * @attention 如有需要可使用real去除虚部 + * @param aMatrix + * @return ifft后的复数矩阵 + */ + Matrix ifft(const Matrix& aMatrix); + + /** + * hilbert,支持到2维,输入必须是复数,输出必是复数 + * @param aMatrix + * @return + */ + Matrix hilbert(const Matrix& aMatrix); }; diff --git a/test/Function2D_Test.cpp b/test/Function2D_Test.cpp index 92196eb..5ae3785 100644 --- a/test/Function2D_Test.cpp +++ b/test/Function2D_Test.cpp @@ -311,35 +311,35 @@ TEST_F(Function2D_Test, median) { } TEST_F(Function2D_Test, fftAndComplexAndIfft){ -// double input[10]{1,1,0,2,2,0,1,1,0,2}; -// std::complex* complexInput = Aurora::complex(10,input); -// //复数化后,实部不变,虚部全为0 -// EXPECT_DOUBLE_EQ(complexInput[1].real(),1.0)<<" complex error"; -// EXPECT_DOUBLE_EQ(complexInput[1].imag(),0)<<" complex error"; -// std::complex* result = Aurora::fft(10,complexInput); -// delete [] complexInput; -// //检验fft结果与matlab是否对应 -// EXPECT_DOUBLE_EQ(0.0729, fourDecimalRound(result[1].real()))<<" fft result value error"; -// EXPECT_DOUBLE_EQ(2.4899, fourDecimalRound(result[2].imag()))<<" fft result value error"; -// //检验fft的结果是否共轭 -// EXPECT_DOUBLE_EQ(0, result[4].imag()+result[6].imag())<<" fft result conjugate error"; -// EXPECT_DOUBLE_EQ(0, result[4].real()-result[6].real())<<" fft result conjugate error"; -// std::complex* ifftResult = Aurora::ifft(10,result); -// EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[1].real()),1.0)<<" ifft result real value error"; -// EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[1].imag()),0)<<" ifft result imag value error"; -// delete [] result; -// delete [] ifftResult; + double *input = new double[20]{1,1,0,2,2,0,1,1,0,2,1,1,0,2,2,0,1,1,0,2}; + auto ma = Aurora::Matrix::fromRawData(input,10,2); + auto ret = Aurora::fft(ma); + std::complex* result = (std::complex*)ret.getData(); + //检验fft结果与matlab是否对应 + EXPECT_DOUBLE_EQ(0.0729, fourDecimalRound(result[1].real())); + EXPECT_DOUBLE_EQ(2.4899, fourDecimalRound(result[2].imag())); + EXPECT_DOUBLE_EQ(0.0729, fourDecimalRound(result[11].real())); + EXPECT_DOUBLE_EQ(2.4899, fourDecimalRound(result[12].imag())); + //检验fft的结果是否共轭 + EXPECT_DOUBLE_EQ(0, result[4].imag()+result[6].imag()); + EXPECT_DOUBLE_EQ(0, result[4].real()-result[6].real()); + ret= Aurora::ifft(ret); + std::complex* ifftResult = (std::complex*)ret.getData(); + EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[1].real()),1.0); + EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[3].real()),2.0); + EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[11].real()),1.0); + EXPECT_DOUBLE_EQ(fourDecimalRound(ifftResult[13].real()),2.0); } TEST_F(Function2D_Test, hilbert) { - double input[10]{1,1,0,2,2,0,1,1,0,2}; - auto result = Aurora::hilbert(10,input); - EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].real()),1.0)<<" hilbert result real value error"; - EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].imag()),0.3249)<<" hilbert result imag value error"; - delete [] result; - result = Aurora::hilbert(9,input); - EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].real()),1.0)<<" hilbert result real value error"; - EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].imag()),0.4253)<<" hilbert result imag value error"; + double *input = new double[20]{1,1,0,2,2,0,1,1,0,2,1,1,0,2,2,0,1,1,0,2}; + auto ma = Aurora::Matrix::fromRawData(input,10,2); + auto ret = Aurora::hilbert(ma); + auto result = (std::complex*)ret.getData(); + EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].real()),1.0); + EXPECT_DOUBLE_EQ(fourDecimalRound(result[1].imag()),0.3249); + EXPECT_DOUBLE_EQ(fourDecimalRound(result[11].real()),1.0); + EXPECT_DOUBLE_EQ(fourDecimalRound(result[11].imag()),0.3249); } TEST_F(Function2D_Test, interp2) {