Update ifft add arg n
This commit is contained in:
@@ -757,19 +757,39 @@ Matrix Aurora::fft(const Matrix &aMatrix, long aFFTSize) {
|
|||||||
return Matrix();
|
return Matrix();
|
||||||
}
|
}
|
||||||
|
|
||||||
Matrix Aurora::ifft(const Matrix &aMatrix) {
|
Matrix Aurora::ifft(const Matrix &aMatrix, long aFFTSize ) {
|
||||||
if (!aMatrix.isComplex()){
|
if (!aMatrix.isComplex()){
|
||||||
std::cerr<<"ifft input must be complex value"<<std::endl;
|
std::cerr<<"ifft input must be complex value"<<std::endl;
|
||||||
return Matrix();
|
return Matrix();
|
||||||
}
|
}
|
||||||
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
|
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
|
||||||
mkl_free_buffers();
|
// mkl_free_buffers();
|
||||||
auto output = malloc(aMatrix.getDataSize(),true);
|
// auto output = malloc(aMatrix.getDataSize(),true);
|
||||||
MKL_LONG status;
|
MKL_LONG status;
|
||||||
|
double *output = nullptr;
|
||||||
|
mkl_free_buffers();
|
||||||
|
MKL_LONG rowSize = (aFFTSize>0)?aFFTSize:aMatrix.getDimSize(0);
|
||||||
|
//实际需要copy赋值的非0值
|
||||||
|
MKL_LONG needCopySize = (rowSize<aMatrix.getDimSize(0))?rowSize:aMatrix.getDimSize(0);
|
||||||
|
MKL_LONG bufferSize = rowSize*aMatrix.getDimSize(1);
|
||||||
|
output = malloc(bufferSize, true);
|
||||||
|
double zero = 0.0;
|
||||||
|
//先全部置为0
|
||||||
|
cblas_dcopy(bufferSize*2, &zero, 0, output, 1);
|
||||||
|
if (!aMatrix.isComplex()) {
|
||||||
|
//按列copy原值
|
||||||
|
for (int i = 0 ; i < aMatrix.getDimSize(1); ++i) {
|
||||||
|
cblas_dcopy(needCopySize, aMatrix.getData()+i*aMatrix.getDimSize(0), 1, output+i*rowSize*2, 2);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
//按列copy原值
|
||||||
|
for (int i = 0 ; i < aMatrix.getDimSize(1); ++i) {
|
||||||
|
cblas_zcopy(needCopySize, aMatrix.getData()+i*aMatrix.getDimSize(0)*2, 1, output+i*rowSize*2, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
|
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
|
||||||
int size = aMatrix.getDimSize(0);
|
int size = aMatrix.getDimSize(0);
|
||||||
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, size);
|
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, rowSize);
|
||||||
if (status != DFTI_NO_ERROR) goto error;
|
if (status != DFTI_NO_ERROR) goto error;
|
||||||
//通过 setValue 配置Descriptor
|
//通过 setValue 配置Descriptor
|
||||||
//使用单独的输出数据缓存
|
//使用单独的输出数据缓存
|
||||||
@@ -777,14 +797,14 @@ Matrix Aurora::ifft(const Matrix &aMatrix) {
|
|||||||
if (status != DFTI_NO_ERROR) goto error;
|
if (status != DFTI_NO_ERROR) goto error;
|
||||||
|
|
||||||
//设置DFTI_BACKWARD_SCALE !!!很关键,不然值不对
|
//设置DFTI_BACKWARD_SCALE !!!很关键,不然值不对
|
||||||
status = DftiSetValue(my_desc_handle, DFTI_BACKWARD_SCALE, 1.0f / size);
|
status = DftiSetValue(my_desc_handle, DFTI_BACKWARD_SCALE, 1.0f / rowSize);
|
||||||
if (status != DFTI_NO_ERROR) goto error;
|
if (status != DFTI_NO_ERROR) goto error;
|
||||||
|
|
||||||
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,size);
|
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,rowSize);
|
||||||
if (status != DFTI_NO_ERROR) goto error;
|
if (status != DFTI_NO_ERROR) goto error;
|
||||||
|
|
||||||
//每个傅里叶变换的输出距离
|
//每个傅里叶变换的输出距离
|
||||||
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,size);
|
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,rowSize);
|
||||||
if (status != DFTI_NO_ERROR) goto error;
|
if (status != DFTI_NO_ERROR) goto error;
|
||||||
|
|
||||||
//傅里叶变换的数量
|
//傅里叶变换的数量
|
||||||
@@ -806,7 +826,7 @@ Matrix Aurora::ifft(const Matrix &aMatrix) {
|
|||||||
mkl_free_buffers();
|
mkl_free_buffers();
|
||||||
if (status != DFTI_NO_ERROR) goto error;
|
if (status != DFTI_NO_ERROR) goto error;
|
||||||
{
|
{
|
||||||
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
|
return Matrix::New(output, rowSize, aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
|
||||||
}
|
}
|
||||||
error:
|
error:
|
||||||
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;
|
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ namespace Aurora
|
|||||||
* @param aMatrix
|
* @param aMatrix
|
||||||
* @return ifft后的复数矩阵
|
* @return ifft后的复数矩阵
|
||||||
*/
|
*/
|
||||||
Matrix ifft(const Matrix &aMatrix);
|
Matrix ifft(const Matrix &aMatrix, long aFFTSize = -1);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Symmetric逆fft,支持到2维,输入必须是复数,输出必是实数
|
* Symmetric逆fft,支持到2维,输入必须是复数,输出必是实数
|
||||||
|
|||||||
Reference in New Issue
Block a user