Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/multitaper psd estimate #417

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
e0796e2
Initial commit multitaper psd estimate
ackurth Jan 11, 2021
9d87185
Fix wrong variable
ackurth Jan 11, 2021
eb2d501
Fix minor bug
ackurth Jan 11, 2021
a38b17a
Fix variable name
ackurth Jan 11, 2021
e63e6bb
Fix typos
ackurth Jan 11, 2021
96a9d93
Add functionality to generate test data
ackurth Jan 17, 2021
ef624e0
Use rfftfreq to generate frequencies instead of fftfreq
ackurth Jan 17, 2021
fecd726
Add example for comparison for multitaper and welch
ackurth Jan 17, 2021
6be6c53
Refactor NW -> nw
rjurkus Jan 25, 2021
43d3de1
Refactor NW -> nw in the test section
rjurkus Jan 25, 2021
4d64623
Add functionality to multitaper to handle frequency resolution
ackurth Feb 12, 2021
efd6f11
Remove subsampling
ackurth Feb 12, 2021
3526e5f
Initial commit for multitaper_psd tests
rjurkus Feb 12, 2021
99bcdd1
Merge branch 'master' into feature/multitaper_psd_estimate
rjurkus Feb 15, 2021
841ab51
Run all tests in test_spectral.py
rjurkus Feb 15, 2021
54998dc
Tests for multitaper_psd (copied/adapted from welch_psd)
rjurkus Feb 15, 2021
37458ef
Added quantities and their handling to multitaper psd
ackurth Feb 22, 2021
3d570a2
Remove code doubling for one/multi dim case, minor edits
ackurth Feb 22, 2021
5731f4e
Clean up examples
ackurth Feb 22, 2021
a25b2c9
Correct test, remove wrong test
ackurth Feb 22, 2021
e08aba0
Fixed PEP8 issues
rjurkus Feb 22, 2021
855e607
Removed if __name__ == '__main__'
rjurkus Feb 22, 2021
72e03b4
multitaper_psd -> psd refactor
rjurkus Feb 22, 2021
4b7559b
Refactor frequency_resolution -> peak_resolution
rjurkus Apr 16, 2021
63be87b
Refactor frequency_resolution -> peak_resolution in test_spectral.py
rjurkus Apr 16, 2021
8350d20
Added a test comparing multitaper_psd against nitime multitaper imple…
rjurkus Apr 16, 2021
4ac9e68
Removed the line regarding None for peak_resolution
rjurkus Jun 18, 2021
01d71c9
Added a note regarding parameter hierarchy and a line of code to ensu…
rjurkus Jun 18, 2021
760d318
Refactor slepain_fcts -> slepian_fcts
rjurkus Jun 18, 2021
3afd1ca
Added a comment that sym=False in dpss is used for spectral analysis
rjurkus Jun 18, 2021
3bd81bc
Changed elif to if for the second num_tapers check
rjurkus Jun 18, 2021
6cd853f
Added a comment regarding broadcasting if data has more than 1 dim
rjurkus Jun 18, 2021
7fea31a
Adjusted comparison to nitime to use both rtol and atol
rjurkus Jun 18, 2021
cfc914c
Changed num_tapers to be of type int (not int64)
rjurkus Jun 18, 2021
787485f
Added tests for peak_resolution, num_tapers and nw parameter hierarchy
rjurkus Jun 18, 2021
e0bb3a7
Initial commit multitaper psd estimate
ackurth Jan 11, 2021
9b6ec18
Fix wrong variable
ackurth Jan 11, 2021
cb7f6fd
Fix minor bug
ackurth Jan 11, 2021
3ae4664
Fix variable name
ackurth Jan 11, 2021
5f45d54
Fix typos
ackurth Jan 11, 2021
975a36f
Improve memory efficiency of _create_sparse_matrix in BinnedSpikeTrai…
morales-gregorio Jan 12, 2021
b774956
Add functionality to generate test data
ackurth Jan 17, 2021
b3fff9f
Use rfftfreq to generate frequencies instead of fftfreq
ackurth Jan 17, 2021
abbb548
Add example for comparison for multitaper and welch
ackurth Jan 17, 2021
25c746a
Joint-ISI dithering: fixed a bug regarding first ISI bin (#396)
pbouss Jan 21, 2021
358e30a
Refactor NW -> nw
rjurkus Jan 25, 2021
178b8b9
Refactor NW -> nw in the test section
rjurkus Jan 25, 2021
305fc52
CUDA accelerated ASSET (#351)
dizcza Jan 29, 2021
ba331ee
Bin-Shuffling: reimplemented the continuos time version (#397)
pbouss Feb 1, 2021
ebc0a97
Account for unidirectional spiketrain->segment links in synchrofact d…
Kleinjohann Feb 1, 2021
b46f80e
Speed up bin shuffling (#400)
pbouss Feb 3, 2021
6c3ea44
Memory efficient and faster implementation of ASSET pmat analytical (…
dizcza Feb 3, 2021
73bdb3c
missing comma in BibTeX entry (#401)
apdavison Feb 8, 2021
019187a
Add functionality to multitaper to handle frequency resolution
ackurth Feb 12, 2021
aa72902
Remove subsampling
ackurth Feb 12, 2021
3943be4
Initial commit for multitaper_psd tests
rjurkus Feb 12, 2021
6dc83e9
Run all tests in test_spectral.py
rjurkus Feb 15, 2021
172a4f0
Tests for multitaper_psd (copied/adapted from welch_psd)
rjurkus Feb 15, 2021
196e016
Added quantities and their handling to multitaper psd
ackurth Feb 22, 2021
b6dad55
Remove code doubling for one/multi dim case, minor edits
ackurth Feb 22, 2021
7fd2992
Clean up examples
ackurth Feb 22, 2021
2e0ddaf
Correct test, remove wrong test
ackurth Feb 22, 2021
2c1bc51
Fixed PEP8 issues
rjurkus Feb 22, 2021
94e9468
Removed if __name__ == '__main__'
rjurkus Feb 22, 2021
45f5df7
multitaper_psd -> psd refactor
rjurkus Feb 22, 2021
36f6e2e
Refactor frequency_resolution -> peak_resolution
rjurkus Apr 16, 2021
a755af8
Refactor frequency_resolution -> peak_resolution in test_spectral.py
rjurkus Apr 16, 2021
f987a75
Added a test comparing multitaper_psd against nitime multitaper imple…
rjurkus Apr 16, 2021
a2dcc13
Removed the line regarding None for peak_resolution
rjurkus Jun 18, 2021
a81dc26
Added a note regarding parameter hierarchy and a line of code to ensu…
rjurkus Jun 18, 2021
35f5e08
Refactor slepain_fcts -> slepian_fcts
rjurkus Jun 18, 2021
4b36d5e
Added a comment that sym=False in dpss is used for spectral analysis
rjurkus Jun 18, 2021
d3084b1
Changed elif to if for the second num_tapers check
rjurkus Jun 18, 2021
0b534e3
Added a comment regarding broadcasting if data has more than 1 dim
rjurkus Jun 18, 2021
281b151
Adjusted comparison to nitime to use both rtol and atol
rjurkus Jun 18, 2021
a7d7078
Changed num_tapers to be of type int (not int64)
rjurkus Jun 18, 2021
85e4868
Added tests for peak_resolution, num_tapers and nw parameter hierarchy
rjurkus Jun 18, 2021
8ae206e
Merge branch 'feature/multitaper_psd_estimate' of github.com:INM-6/el…
ackurth Oct 1, 2021
ed2a0c4
Add segmentation of signal to multitaper_psd
ackurth Oct 7, 2021
5c33635
Correct adding new axis
ackurth Oct 7, 2021
1816bc0
Remove printing number of tapers
ackurth Oct 7, 2021
075d4c1
Add frequency resolution to multiatper_psd
ackurth Oct 8, 2021
6e4b24b
Code factor issue fix (else is not necessary)
rjurkus Nov 3, 2021
fd6da43
Fixed a typo in multitaper_psd docstring
rjurkus Nov 3, 2021
026c5e0
Consistent language and formatting in Raises section of multitaper_ps…
rjurkus Nov 3, 2021
b3b906a
Fixed a typo in a comment of multitaper_psd, also:
rjurkus Nov 3, 2021
f3e7dc5
Update docstring peak_resolution
ackurth Nov 5, 2021
b2df694
Removed basic Raises section entries
rjurkus Nov 5, 2021
a36e7d4
Updated test_spectral.py test_multitaper_psd_against_nitime to downlo…
rjurkus Dec 10, 2021
44f0817
Merge remote-tracking branch 'upstream/master' into feature/multitape…
rjurkus Jan 26, 2022
d0a8601
fixed typos in inline comments
Moritz-Alexander-Kern Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions elephant/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,252 @@ def welch_psd(signal, n_segments=8, len_segment=None,
return freqs, psd


def multitaper_psd(signal, n_segments=1, len_segment=None,
frequency_resolution=None, overlap=0.5, fs=1,
nw=4, num_tapers=None, peak_resolution=None, axis=-1):
"""
Estimates power spectrum density (PSD) of a given 'neo.AnalogSignal'
using Multitaper method

The PSD is obtained through the following steps:

1. Cut the given data into several overlapping segments. The degree of
overlap can be specified by parameter `overlap` (default is 0.5,
i.e. segments are overlapped by the half of their length).
The number and the length of the segments are determined according
to the parameters `n_segments`, `len_segment` or `frequency_resolution`.
By default, the data is cut into 8 segments;

2. Calculate 'num_tapers' approximately independent estimates of the
spectrum by multiplying the signal with the discrete prolate spheroidal
functions (also known as Slepian function) and calculate the PSD of each
tapered segment

3. Average the approximately independent estimates of each segment to
decrease overall variance of the estimates

4. Average the obtained estimates for each segment

Parameters
----------
signal : neo.AnalogSignal
Time series data of which PSD is estimated. When `signal` is np.ndarray
sampling frequency should be given through keyword argument `fs`.
Signal should be passed as (n_channels, n_samples)
fs : float, optional
Specifies the sampling frequency of the input time series
Default: 1.0.
n_segments : int, optional
Number of segments. The length of segments is adjusted so that
overlapping segments cover the entire stretch of the given data. This
parameter is ignored if `len_segment` or `frequency_resolution` is
given.
Default: 8.
len_segment : int, optional
Length of segments. This parameter is ignored if `frequency_resolution`
is given. If None, it will be determined from other parameters.
Default: None.
frequency_resolution : pq.Quantity or float, optional
Desired frequency resolution of the obtained PSD estimate in terms of
the interval between adjacent frequency bins. When given as a `float`,
it is taken as frequency in Hz.
If None, it will be determined from other parameters.
Default: None.
overlap : float, optional
Overlap between segments represented as a float number between 0 (no
overlap) and 1 (complete overlap).
Default: 0.5 (half-overlapped).
nw : float, optional
Time bandwidth product
Default: 4.0.
num_tapers : int, optional
Number of tapers used in 1. to obtain estimate of PSD. By default
[2*nw] - 1 is chosen.
Default: None.
peak_resolution : pq.Quantity float, optional
Quantity in Hz determining the number of tapers used for analysis.
Fine peak resolution --> low numerical value --> low number of tapers
High peak resolution --> high numerical value --> high number of tapers
When given as a `float`, it is taken as frequency in Hz.
Default: None.
axis : int, optional
Axis along which the periodogram is computed.
See Notes [2].
Default: last axis (-1).

Notes
-----
1. There is a parameter hierarchy regarding n_segments and len_segment. The
former parameter is ignored if the latter one is passed.

2. There is a parameter hierarchy regarding nw, num_tapers and
peak_resolution. If peak_resolution is provided, it determines both nw
and the num_tapers. Specifying num_tapers has an effect only if
peak_resolution is not provided.

Returns
-------
freqs : np.ndarray
Frequencies associated with power estimate in `psd`
psd : np.ndarray
PSD estimate of the time series in `signal`

Raises
------
ValueError
If `peak_resolution` is None and `num_tapers` is not a positive number.

If `frequency_resolution` is too high for the given data size.

If `frequency_resolution` is None and `len_segment` is not a positive
number.

If `frequency_resolution` is None and `len_segment` is greater than the
length of data at `axis`.

If both `frequency_resolution` and `len_segment` are None and
`n_segments` is not a positive number.

If both `frequency_resolution` and `len_segment` are None and
`n_segments` is greater than the length of data at `axis`.

TypeError
If `peak_resolution` is None and `num_tapers` is not an int.
"""

# When the input is AnalogSignal, the data is added after rolling the axis
# for time index to the last
data = np.asarray(signal)
if isinstance(signal, neo.AnalogSignal):
data = np.rollaxis(data, 0, len(data.shape))

# Number of data points in time series
if data.ndim == 1:
length_signal = np.shape(data)[0]
else:
length_signal = np.shape(data)[1]

# If the data is given as AnalogSignal, use its attribute to specify the
# sampling frequency
if hasattr(signal, 'sampling_rate'):
fs = signal.sampling_rate.rescale('Hz').magnitude

# If fs and peak resolution is pq.Quantity, get magnitude
if isinstance(fs, pq.quantity.Quantity):
fs = fs.rescale('Hz').magnitude

# Determine length per segment - n_per_seg
if frequency_resolution is not None:
if frequency_resolution <= 0:
raise ValueError("frequency_resolution must be positive")
if isinstance(frequency_resolution, pq.quantity.Quantity):
dF = frequency_resolution.rescale('Hz').magnitude
else:
dF = frequency_resolution
n_per_seg = int(fs / dF)
if n_per_seg > data.shape[axis]:
raise ValueError("frequency_resolution is too high for the given "
"data size")
elif len_segment is not None:
if len_segment <= 0:
raise ValueError("len_seg must be a positive number")
elif data.shape[axis] < len_segment:
raise ValueError("len_seg must be shorter than the data length")
n_per_seg = len_segment
else:
if n_segments <= 0:
raise ValueError("n_segments must be a positive number")
elif data.shape[axis] < n_segments:
raise ValueError("n_segments must be smaller than the data length")
# when only *n_segments* is given, *n_per_seg* is determined by solving
# the following equation:
# n_segments * n_per_seg - (n_segments-1) * overlap * n_per_seg =
# data.shape[-1]
# -------------------- =============================== ^^^^^^^^^^^
# summed segment lengths total overlap data length
n_per_seg = int(data.shape[axis] /
(n_segments - overlap * (n_segments - 1)))

n_overlap = int(n_per_seg * overlap)
n_segments = int((length_signal - n_overlap) / (n_per_seg - n_overlap))

if isinstance(peak_resolution, pq.quantity.Quantity):
peak_resolution = peak_resolution.rescale('Hz').magnitude

# Determine time-halfbandwidth product from given parameters
if peak_resolution is not None:
if peak_resolution <= 0:
raise ValueError("peak_resolution must be positive")
nw = n_per_seg / fs * peak_resolution / 2
num_tapers = int(np.floor(2*nw) - 1)

if num_tapers is None:
num_tapers = int(np.floor(2*nw) - 1)
else:
if not isinstance(num_tapers, int):
raise TypeError("num_tapers must be integer")
if num_tapers <= 0:
raise ValueError("num_tapers must be positive")

# Generate frequencies of PSD estimate
freqs = np.fft.rfftfreq(n_per_seg, d=1/fs)

# Zero-pad signal to fit segment length
remainder = length_signal % n_per_seg

if data.ndim == 1:
data = np.pad(data, pad_width=(0, remainder),
mode='constant', constant_values=0)
# Generate array for storing PSD estimates of segments
psd_estimates = np.zeros((n_segments, len(freqs)))
else:
data = np.pad(data, [(0, 0), (0, remainder)],
mode='constant', constant_values=0)
# Generate array for storing PSD estimates of segments
psd_estimates = np.zeros((n_segments, data.shape[0], len(freqs)))

# Determine the number of samples given overlap
n_overlap_step = n_per_seg - n_overlap

for i in range(n_segments):
# Get slepian functions (sym=False used for spectral analysis)
slepian_fcts = scipy.signal.windows.dpss(M=n_per_seg,
NW=nw,
Kmax=num_tapers,
sym=False)

# Calculate approximately independent spectrum estimates
if data.ndim == 1:
tapered_signal = (data[i * n_overlap_step:
i * n_overlap_step + n_per_seg]
* slepian_fcts)
else:
# Use broadcasting to match dim for point-wise multiplication
tapered_signal = (data[:,
np.newaxis,
i * n_overlap_step:
i * n_overlap_step + n_per_seg]
* slepian_fcts)

# Determine Fourier transform of tapered signal
spectrum_estimates = np.abs(np.fft.rfft(tapered_signal, axis=-1))**2
spectrum_estimates[..., 1:] *= 2

# Average Fourier transform windowed signal
psd_segment = np.mean(spectrum_estimates, axis=-2) / fs

psd_estimates[i] = psd_segment

psd = np.mean(np.asarray(psd_estimates), axis=0)

# Attach proper units to return values
if isinstance(signal, pq.quantity.Quantity):
psd = psd * signal.units * signal.units / pq.Hz
freqs = freqs * pq.Hz

return freqs, psd


@deprecated_alias(x='signal_i', y='signal_j', num_seg='n_segments',
len_seg='len_segment', freq_res='frequency_resolution')
def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None,
Expand Down
Loading