From 5291b62d4481f7a9f27ef5ce0f5ad95c79e6e337 Mon Sep 17 00:00:00 2001 From: Matt Kinsey Date: Mon, 21 Nov 2022 14:37:33 -0500 Subject: [PATCH] fix filtering.resample output for even values of num parameter (#517) This just fixes the handling of the nyquest frequency bins when resampling using the FFT method (`filtering.resample`). Without enforcing the fourier space be hermitian symmetric like this, the resampled output had a few undesirable properties: 1. It had a different integrated power spectrum than the input 2. It was not reversible i.e. `resample(resample(y, ), ) != y` 3. The output disagreed with `scipy.signal.resample`. The max error was around 1e-2 in my application. Thanks Authors: - Matt Kinsey (https://github.com/mattkinsey) - Ray Douglass (https://github.com/raydouglass) - gpuCI (https://github.com/GPUtester) - Adam Thompson (https://github.com/awthomp) - AJ Schmidt (https://github.com/ajschmidt8) - Mike Wendt (https://github.com/mike-wendt) Approvers: - Adam Thompson (https://github.com/awthomp) URL: https://github.com/rapidsai/cusignal/pull/517 --- python/cusignal/filtering/resample.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/cusignal/filtering/resample.py b/python/cusignal/filtering/resample.py index 88d75721..eccffafd 100644 --- a/python/cusignal/filtering/resample.py +++ b/python/cusignal/filtering/resample.py @@ -262,11 +262,30 @@ def resample(x, num, t=None, axis=0, window=None, domain="time"): newshape = list(x.shape) newshape[axis] = num N = int(np.minimum(num, Nx)) + nyq = N // 2 + 1 # Slice index that includes Nyquist Y = cp.zeros(newshape, dtype=X.dtype) - sl[axis] = slice(0, (N + 1) // 2) - Y[tuple(sl)] = X[tuple(sl)] - sl[axis] = slice(-(N - 1) // 2, None) + sl[axis] = slice(0, nyq) Y[tuple(sl)] = X[tuple(sl)] + if N > 2: # avoid empty slice + sl[axis] = slice(nyq - N, None) + Y[tuple(sl)] = X[tuple(sl)] + + # symmetrize nyquest freq bins if N is even + if N % 2 == 0: + if num < Nx: + # select the component of Y at frequency +N/2, + # add the component of X at -N/2 + sl[axis] = slice(-N // 2, -N // 2 + 1) + Y[tuple(sl)] += X[tuple(sl)] + elif Nx < num: + # select the component at frequency +N/2 and halve it + sl[axis] = slice(N // 2, N // 2 + 1) + Y[tuple(sl)] *= 0.5 + temp = Y[tuple(sl)] + # set the component at -N/2 equal to the component at +N/2 + sl[axis] = slice(num - N // 2, num - N // 2 + 1) + Y[tuple(sl)] = temp + y = cp.fft.ifft(Y, axis=axis) * (float(num) / float(Nx)) if x.dtype.char not in ["F", "D"]: