Add fftshift ,fft with size ,and their test.

This commit is contained in:
kradchen
2023-05-06 15:49:19 +08:00
parent 89ba667107
commit e7317d0ade
3 changed files with 66 additions and 14 deletions

View File

@@ -578,23 +578,34 @@ Matrix Aurora::median(const Matrix &aMatrix) {
}
}
Matrix Aurora::fft(const Matrix &aMatrix) {
Matrix Aurora::fft(const Matrix &aMatrix, long aFFTSize) {
double *output = nullptr;
mkl_free_buffers();
output = malloc(aMatrix.getDataSize(), true);
MKL_LONG rowSize = (aFFTSize>0)?aFFTSize:aMatrix.getDimSize(0);
//实际需要copy赋值的非0值
MKL_LONG needCopySize = (rowSize<aMatrix.getDimSize(0))?rowSize:aMatrix.getDimSize(0);
MKL_LONG bufferSize = rowSize*aMatrix.getDimSize(1);
output = malloc(bufferSize, true);
double zero = 0.0;
//先全部置为0
cblas_dcopy(bufferSize*2, &zero, 0, output, 1);
if (!aMatrix.isComplex()) {
cblas_dcopy(aMatrix.getDataSize(), aMatrix.getData(), 1, output, 2);
double zero = 0.0;
cblas_dcopy(aMatrix.getDataSize(), &zero, 0, output+1, 2);
//按列copy原值
for (int i = 0 ; i < aMatrix.getDimSize(1); ++i) {
cblas_dcopy(needCopySize, aMatrix.getData()+i*aMatrix.getDimSize(0), 1, output+i*rowSize*2, 2);
}
} else {
cblas_zcopy(aMatrix.getDataSize(), aMatrix.getData(), 1, output, 1);
//按列copy原值
for (int i = 0 ; i < aMatrix.getDimSize(1); ++i) {
cblas_zcopy(needCopySize, aMatrix.getData()+i*aMatrix.getDimSize(0)*2, 1, output+i*rowSize*2, 1);
}
}
DFTI_DESCRIPTOR_HANDLE my_desc_handle = NULL;
MKL_LONG status;
//创建 Descriptor, 精度 double , 输入类型实数, 维度1
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, aMatrix.getDimSize(0));
status = DftiCreateDescriptor(&my_desc_handle, DFTI_DOUBLE, DFTI_COMPLEX, 1, rowSize);
if (status != DFTI_NO_ERROR) goto error;
//通过 setValue 配置Descriptor
@@ -603,11 +614,11 @@ Matrix Aurora::fft(const Matrix &aMatrix) {
if (status != DFTI_NO_ERROR) goto error;
//每个傅里叶变换的输入距离
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,aMatrix.getDimSize(0));
status = DftiSetValue(my_desc_handle,DFTI_INPUT_DISTANCE,rowSize);
if (status != DFTI_NO_ERROR) goto error;
//每个傅里叶变换的输出距离
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,aMatrix.getDimSize(0));
status = DftiSetValue(my_desc_handle,DFTI_OUTPUT_DISTANCE,rowSize);
if (status != DFTI_NO_ERROR) goto error;
//傅里叶变换的数量
@@ -626,7 +637,7 @@ Matrix Aurora::fft(const Matrix &aMatrix) {
status = DftiFreeDescriptor(&my_desc_handle);
if (status != DFTI_NO_ERROR) goto error;
mkl_free_buffers();
return Matrix::New(output, aMatrix.getDimSize(0), aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
return Matrix::New(output, rowSize, aMatrix.getDimSize(1), aMatrix.getDimSize(2), Complex);
error:
std::cerr<<"FFT fail, error message:"<<DftiErrorMessage(status)<<std::endl;
return Matrix();
@@ -710,6 +721,19 @@ Matrix Aurora::ifft_symmetric(const Matrix &aMatrix,long length)
return real(ifft(Matrix::New(calcData,length,1,1,Complex)));
}
void Aurora::fftshift(Matrix &aMatrix){
int backwardLength = aMatrix.getDimSize(0)/2;
int forwardLength = aMatrix.getDimSize(0) - backwardLength;
double* buffer = malloc(forwardLength,true);
for (int i = 0; i<aMatrix.getDimSize(1); ++i) {
double* dataPtr = aMatrix.getData()+aMatrix.getDimSize(0)*i*2;
cblas_dcopy(forwardLength*2, dataPtr+backwardLength*2, 1, buffer, 1);
cblas_dcopy(backwardLength*2, dataPtr, 1, dataPtr+forwardLength*2, 1);
cblas_dcopy(forwardLength*2, buffer, 1, dataPtr, 1);
}
Aurora::free(buffer);
}
Matrix Aurora::hilbert(const Matrix &aMatrix) {
auto x = fft(aMatrix);
auto h = malloc(aMatrix.getDimSize(0));

View File

@@ -103,9 +103,16 @@ namespace Aurora
/**
* FFT,支持到2维输入可以是常数可以是复数输出必是复数
* @param aMatrix 目标矩阵
* @param aFFTSize 目标矩阵需要处理的长度,默认为-1即全部
* @return fft后的复数矩阵
*/
Matrix fft(const Matrix &aMatrix);
Matrix fft(const Matrix &aMatrix, long aFFTSize = -1);
/**
* fftshift,在原有数据上进行修改将fft的数据的前半部分和后半部分交换支持2D数据
* @param aMatrix
*/
void fftshift(Matrix &aMatrix);
/**
* 逆fft支持到2维输入必须是复数输出必是复数