Update ifft add arg n
This commit is contained in:
@@ -757,19 +757,39 @@ Matrix Aurora::fft(const Matrix &aMatrix, long aFFTSize) {
|
||||
return Matrix();
|
||||
}
|
||||
|
||||
Matrix Aurora::ifft(const Matrix &aMatrix) {
|
||||
Matrix Aurora::ifft(const Matrix &aMatrix, long aFFTSize ) {
|
||||
if (!aMatrix.isComplex()){
|
||||
std::cerr<<"ifft input must be complex value"<<std::endl;
|
||||
return Matrix();
|
||||
}
|
||||
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
|
||||
mkl_free_buffers();
|
||||
auto output = malloc(aMatrix.getDataSize(),true);
|
||||
// mkl_free_buffers();
|
||||
// auto output = malloc(aMatrix.getDataSize(),true);
|
||||
MKL_LONG status;
|
||||
|
||||
double *output = nullptr;
|
||||
mkl_free_buffers();
|
||||
MKL_LONG rowSize = (aFFTSize>0)?aFFTSize:aMatrix.getDimSize(0);
|
||||
//实际需要copy赋值的非0值
|
||||
MKL_LONG needCopySize = (rowSize<aMatrix.getDimSize(0))?rowSize:aMatrix.getDimSize(0);
|
||||
MKL_LONG bufferSize = rowSize*aMatrix.getDimSize(1);
|
||||
output = malloc(bufferSize, true);
|
||||
double zero = 0.0;
|
||||
//先全部置为0
|
||||
cblas_dcopy(bufferSize*2, &zero, 0, output, 1);
|
||||
if (!aMatrix.isComplex()) {
|
||||
//按列copy原值
|
||||
for (int i = 0 ; i < aMatrix.getDimSize(1); ++i) {
|
||||
cblas_dcopy(needCopySize, aMatrix.getData()+i*aMatrix.getDimSize(0), 1, output+i*rowSize*2, 2);
|
||||
}
|
||||
} else {
|
||||
//按列copy原值
|
||||
for (int i = 0 ; i < aMatrix.getDimSize(1); ++i) {
|
||||
cblas_zcopy(needCopySize, aMatrix.getData()+i*aMatrix.getDimSize(0)*2, 1, output+i*rowSize*2, 1);
|
||||
}
|
||||
}
|
||||
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
|
||||
int size = aMatrix.getDimSize(0);
|
||||
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, size);
|
||||
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, rowSize);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
//通过 setValue 配置Descriptor
|
||||
//使用单独的输出数据缓存
|
||||
@@ -777,14 +797,14 @@ Matrix Aurora::ifft(const Matrix &aMatrix) {
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//设置DFTI_BACKWARD_SCALE !!!很关键,不然值不对
|
||||
status = DftiSetValue(my_desc_handle, DFTI_BACKWARD_SCALE, 1.0f / size);
|
||||
status = DftiSetValue(my_desc_handle, DFTI_BACKWARD_SCALE, 1.0f / rowSize);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,size);
|
||||
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,rowSize);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//每个傅里叶变换的输出距离
|
||||
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,size);
|
||||
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,rowSize);
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
|
||||
//傅里叶变换的数量
|
||||
@@ -806,7 +826,7 @@ Matrix Aurora::ifft(const Matrix &aMatrix) {
|
||||
mkl_free_buffers();
|
||||
if (status != DFTI_NO_ERROR) goto error;
|
||||
{
|
||||
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
|
||||
return Matrix::New(output, rowSize, aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
|
||||
}
|
||||
error:
|
||||
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;
|
||||
|
||||
@@ -147,7 +147,7 @@ namespace Aurora
|
||||
* @param aMatrix
|
||||
* @return ifft后的复数矩阵
|
||||
*/
|
||||
Matrix ifft(const Matrix &aMatrix);
|
||||
Matrix ifft(const Matrix &aMatrix, long aFFTSize = -1);
|
||||
|
||||
/**
|
||||
* Symmetric逆fft,支持到2维,输入必须是复数,输出必是实数
|
||||
|
||||
Reference in New Issue
Block a user