Update ifft add arg n

This commit is contained in:
kradchen
2023-06-09 14:31:47 +08:00
parent 1a29baeef8
commit 4077843d13
2 changed files with 30 additions and 10 deletions

View File

@@ -757,19 +757,39 @@ Matrix Aurora::fft(const Matrix &aMatrix, long aFFTSize) {
return Matrix();
}
Matrix Aurora::ifft(const Matrix &aMatrix) {
Matrix Aurora::ifft(const Matrix &aMatrix, long aFFTSize ) {
if (!aMatrix.isComplex()){
std::cerr<<"ifft input must be complex value"<<std::endl;
return Matrix();
}
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
mkl_free_buffers();
auto output = malloc(aMatrix.getDataSize(),true);
// mkl_free_buffers();
// auto output = malloc(aMatrix.getDataSize(),true);
MKL_LONG status;
double *output = nullptr;
mkl_free_buffers();
MKL_LONG rowSize = (aFFTSize>0)?aFFTSize:aMatrix.getDimSize(0);
//实际需要copy赋值的非0值
MKL_LONG needCopySize = (rowSize<aMatrix.getDimSize(0))?rowSize:aMatrix.getDimSize(0);
MKL_LONG bufferSize = rowSize*aMatrix.getDimSize(1);
output = malloc(bufferSize, true);
double zero = 0.0;
//先全部置为0
cblas_dcopy(bufferSize*2, &zero, 0, output, 1);
if (!aMatrix.isComplex()) {
//按列copy原值
for (int i = 0 ; i < aMatrix.getDimSize(1); ++i) {
cblas_dcopy(needCopySize, aMatrix.getData()+i*aMatrix.getDimSize(0), 1, output+i*rowSize*2, 2);
}
} else {
//按列copy原值
for (int i = 0 ; i < aMatrix.getDimSize(1); ++i) {
cblas_zcopy(needCopySize, aMatrix.getData()+i*aMatrix.getDimSize(0)*2, 1, output+i*rowSize*2, 1);
}
}
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
int size = aMatrix.getDimSize(0);
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, size);
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, rowSize);
if (status != DFTI_NO_ERROR) goto error;
//通过 setValue 配置Descriptor
//使用单独的输出数据缓存
@@ -777,14 +797,14 @@ Matrix Aurora::ifft(const Matrix &aMatrix) {
if (status != DFTI_NO_ERROR) goto error;
//设置DFTI_BACKWARD_SCALE !!!很关键,不然值不对
status = DftiSetValue(my_desc_handle, DFTI_BACKWARD_SCALE, 1.0f / size);
status = DftiSetValue(my_desc_handle, DFTI_BACKWARD_SCALE, 1.0f / rowSize);
if (status != DFTI_NO_ERROR) goto error;
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,size);
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,rowSize);
if (status != DFTI_NO_ERROR) goto error;
//每个傅里叶变换的输出距离
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,size);
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,rowSize);
if (status != DFTI_NO_ERROR) goto error;
//傅里叶变换的数量
@@ -806,7 +826,7 @@ Matrix Aurora::ifft(const Matrix &aMatrix) {
mkl_free_buffers();
if (status != DFTI_NO_ERROR) goto error;
{
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
return Matrix::New(output, rowSize, aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
}
error:
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;

View File

@@ -147,7 +147,7 @@ namespace Aurora
* @param aMatrix
* @return ifft后的复数矩阵
*/
Matrix ifft(const Matrix &aMatrix);
Matrix ifft(const Matrix &aMatrix, long aFFTSize = -1);
/**
* Symmetric逆fft支持到2维输入必须是复数输出必是实数