Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Butterworth supports sosfiltfilt filter_function #234

Merged
merged 2 commits into from
Jul 10, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 64 additions & 46 deletions elephant/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,15 @@ def butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
order : int
Order of Butterworth filter. Default is 4.
filter_function : string
Filtering function to be used. Either 'filtfilt'
(`scipy.signal.filtfilt()`) or 'lfilter' (`scipy.signal.lfilter()`). In
most applications 'filtfilt' should be used, because it doesn't bring
about phase shift due to filtering. Default is 'filtfilt'.
Filtering function to be used. Available filters:
* 'filtfilt': `scipy.signal.filtfilt()`;
* 'lfilter': `scipy.signal.lfilter()`;
* 'sosfiltfilt': `scipy.signal.sosfiltfilt()`.
In most applications 'filtfilt' should be used, because it doesn't
bring about phase shift due to filtering. For numerically stable
filtering, use 'sosfiltfilt' (see issue
https://github.com/NeuralEnsemble/elephant/issues/220).
dizcza marked this conversation as resolved.
Show resolved Hide resolved
Default is 'filtfilt'.
fs : Quantity or float
The sampling frequency of the input time series. When given as float,
its value is taken as frequency in Hz. When the input is given as neo
Expand All @@ -322,42 +327,53 @@ def butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
Filtered input data. The shape and type is identical to those of the
input.

"""

def _design_butterworth_filter(Fs, hpfreq=None, lpfreq=None, order=4):
# set parameters for filter design
Fn = Fs / 2.
# - filter type is determined according to the values of cut-off
# frequencies
if lpfreq and hpfreq:
if hpfreq < lpfreq:
Wn = (hpfreq / Fn, lpfreq / Fn)
btype = 'bandpass'
else:
Wn = (lpfreq / Fn, hpfreq / Fn)
btype = 'bandstop'
elif lpfreq:
Wn = lpfreq / Fn
btype = 'lowpass'
elif hpfreq:
Wn = hpfreq / Fn
btype = 'highpass'
else:
raise ValueError(
"Either highpass_freq or lowpass_freq must be given"
)

# return filter coefficients
return scipy.signal.butter(order, Wn, btype=btype)
Raises
------
ValueError
If `filter_function` is not one of 'lfilter', 'filtfilt',
or 'sosfiltfilt'.
When both `highpass_freq` and `lowpass_freq` are None.

"""
available_filters = 'lfilter', 'filtfilt', 'sosfiltfilt'
if filter_function not in available_filters:
raise ValueError("Invalid `filter_function`: {filter_function}. "
"Available filters: {available_filters}".format(
filter_function=filter_function,
available_filters=available_filters))
# design filter
Fs = signal.sampling_rate.rescale(pq.Hz).magnitude \
if hasattr(signal, 'sampling_rate') else fs
Fh = highpass_freq.rescale(pq.Hz).magnitude \
if isinstance(highpass_freq, pq.quantity.Quantity) else highpass_freq
Fl = lowpass_freq.rescale(pq.Hz).magnitude \
if isinstance(lowpass_freq, pq.quantity.Quantity) else lowpass_freq
b, a = _design_butterworth_filter(Fs, Fh, Fl, order)
if hasattr(signal, 'sampling_rate'):
fs = signal.sampling_rate.rescale(pq.Hz).magnitude
if isinstance(highpass_freq, pq.quantity.Quantity):
highpass_freq = highpass_freq.rescale(pq.Hz).magnitude
if isinstance(lowpass_freq, pq.quantity.Quantity):
lowpass_freq = lowpass_freq.rescale(pq.Hz).magnitude
Fn = fs / 2.
# filter type is determined according to the values of cut-off
# frequencies
if lowpass_freq and highpass_freq:
if highpass_freq < lowpass_freq:
Wn = (highpass_freq / Fn, lowpass_freq / Fn)
btype = 'bandpass'
else:
Wn = (lowpass_freq / Fn, highpass_freq / Fn)
btype = 'bandstop'
elif lowpass_freq:
Wn = lowpass_freq / Fn
btype = 'lowpass'
elif highpass_freq:
Wn = highpass_freq / Fn
btype = 'highpass'
else:
raise ValueError(
"Either highpass_freq or lowpass_freq must be given"
)
if filter_function == 'sosfiltfilt':
output = 'sos'
else:
output = 'ba'
designed_filter = scipy.signal.butter(order, Wn, btype=btype,
output=output)

# When the input is AnalogSignal, the axis for time index (i.e. the
# first axis) needs to be rolled to the last
Expand All @@ -366,17 +382,19 @@ def _design_butterworth_filter(Fs, hpfreq=None, lpfreq=None, order=4):
data = np.rollaxis(data, 0, len(data.shape))

# apply filter
if filter_function is 'lfilter':
filtered_data = scipy.signal.lfilter(b, a, data, axis=axis)
elif filter_function is 'filtfilt':
filtered_data = scipy.signal.filtfilt(b, a, data, axis=axis)
if filter_function == 'lfilter':
b, a = designed_filter
filtered_data = scipy.signal.lfilter(b=b, a=a, x=data, axis=axis)
elif filter_function == 'filtfilt':
b, a = designed_filter
filtered_data = scipy.signal.filtfilt(b=b, a=a, x=data, axis=axis)
else:
raise ValueError(
"filter_func must to be either 'filtfilt' or 'lfilter'"
)
filtered_data = scipy.signal.sosfiltfilt(sos=designed_filter,
x=data, axis=axis)

if isinstance(signal, neo.AnalogSignal):
return signal.duplicate_with_new_data(np.rollaxis(filtered_data, -1, 0))
filtered_data = np.rollaxis(filtered_data, -1, 0)
return signal.duplicate_with_new_data(filtered_data)
elif isinstance(signal, pq.quantity.Quantity):
return filtered_data * signal.units
else:
Expand Down
6 changes: 6 additions & 0 deletions elephant/test/test_signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,13 @@ def test_butter_filter_function(self):
_, psd_lfilter = spsig.welch(
dizcza marked this conversation as resolved.
Show resolved Hide resolved
filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x)

kwds['filter_function'] = 'sosfiltfilt'
filtered_noise = elephant.signal_processing.butter(**kwds)
_, psd_sosfiltfilt = spsig.welch(
filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x)

self.assertAlmostEqual(psd_filtfilt[0, 0], psd_lfilter[0, 0])
self.assertAlmostEqual(psd_filtfilt[0, 0], psd_sosfiltfilt[0, 0])

def test_butter_invalid_filter_function(self):
# generate a dummy AnalogSignal
Expand Down