From 9fe1fb609faa3ccadce5800ae1b48914f576102d Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:17:36 +0200 Subject: [PATCH 01/23] Add option to store and return tfr taper weights --- mne/time_frequency/multitaper.py | 11 +++++++ mne/time_frequency/tests/test_tfr.py | 14 ++++++--- mne/time_frequency/tfr.py | 47 ++++++++++++++++++++++++++-- 3 files changed, 65 insertions(+), 7 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 73a3308685d..f5f6f79a0b3 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -469,6 +469,7 @@ def tfr_array_multitaper( use_fft=True, decim=1, output="complex", + return_weights=False, n_jobs=None, *, verbose=None, @@ -502,6 +503,12 @@ def tfr_array_multitaper( * ``'itc'`` : inter-trial coherence. * ``'avg_power_itc'`` : average of single trial power and inter-trial coherence across trials. + + return_weights : bool, default False + If True, return the taper weights. Only applies if ``output="complex"``. + + .. versionadded:: 1.9.0 + %(n_jobs)s The parallelization is implemented across channels. %(verbose)s @@ -520,6 +527,9 @@ def tfr_array_multitaper( If ``output`` is ``'avg_power_itc'``, the real values in ``out`` contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if ``output="complex"`` and + ``return_weights=True``. See Also -------- @@ -550,6 +560,7 @@ def tfr_array_multitaper( use_fft=use_fft, decim=decim, output=output, + return_weights=return_weights, n_jobs=n_jobs, verbose=verbose, ) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index cd3a97ab90a..1799692d6ce 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -432,17 +432,21 @@ def test_tfr_morlet(): def test_dpsswavelet(): """Test DPSS tapers.""" freqs = np.arange(5, 25, 3) - Ws = _make_dpss( - 1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True + Ws, weights = _make_dpss( + 1000, + freqs=freqs, + n_cycles=freqs / 2.0, + time_bandwidth=4.0, + zero_mean=True, + return_weights=True, ) - assert len(Ws) == 3 # 3 tapers expected + assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected + assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs) # Check that zero mean is true assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5 - assert len(Ws[0]) == len(freqs) # As many wavelets as asked for - @pytest.mark.slowtest def test_tfr_multitaper(): diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index eaf173092bb..e9d028e7e50 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -264,8 +264,11 @@ def _make_dpss( ------- Ws : list of array The wavelets time series. + Cs : list of array + The concentration weights. Only returned if return_weights=True. """ Ws = list() + Cs = list() freqs = np.array(freqs) if np.any(freqs <= 0): @@ -281,6 +284,7 @@ def _make_dpss( for m in range(n_taps): Wm = list() + Cm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -302,12 +306,15 @@ def _make_dpss( real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) + Ck = np.sqrt(conc[m]) Wm.append(Wk) + Cm.append(Ck) Ws.append(Wm) + Cs.append(Cm) if return_weights: - return Ws, conc + return Ws, Cs return Ws @@ -428,6 +435,7 @@ def _compute_tfr( use_fft=True, decim=1, output="complex", + return_weights=False, n_jobs=None, *, verbose=None, @@ -479,6 +487,9 @@ def _compute_tfr( * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. + return_weights : bool, default False + Whether to return the taper weights. Only applies if method='multitaper' and + output='complex' or 'phase'. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. @@ -495,6 +506,10 @@ def _compute_tfr( n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. + + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if method='multitaper', output='complex' or + 'phase', and return_weights=True. """ # Check data epoch_data = np.asarray(epoch_data) @@ -516,6 +531,9 @@ def _compute_tfr( decim, output, ) + return_weights = ( + return_weights and method == "multitaper" and output in ["complex", "phase"] + ) decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): @@ -531,13 +549,18 @@ def _compute_tfr( Ws = [W] # to have same dimensionality as the 'multitaper' case elif method == "multitaper": - Ws = _make_dpss( + out = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, + return_weights=return_weights, ) + if return_weights: + Ws, weights = out + else: + Ws = out # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -561,6 +584,8 @@ def _compute_tfr( out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) + if return_weights: + weights = np.array(weights) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -585,6 +610,9 @@ def _compute_tfr( out = out.transpose(2, 0, 1, 3, 4) else: out = out.transpose(1, 0, 2, 3) + + if return_weights: + return out, weights return out @@ -1203,6 +1231,9 @@ def __init__( method_kw.setdefault("output", "power") self._freqs = np.asarray(freqs, dtype=np.float64) del freqs + # always store weights for per-taper outputs + if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: + method_kw["return_weights"] = True # check validity of kwargs manually to save compute time if any are invalid tfr_funcs = dict( morlet=tfr_array_morlet, @@ -1224,6 +1255,7 @@ def __init__( self._method = method self._inst_type = type(inst) self._baseline = None + self._weights = None self.preload = True # needed for __getitem__, never False for TFRs # self._dims may also get updated by child classes self._dims = ["channel", "freq", "time"] @@ -1382,6 +1414,7 @@ def __getstate__(self): info=self.info, baseline=self._baseline, decim=self._decim, + weights=self._weights, ) def __setstate__(self, state): @@ -1410,6 +1443,7 @@ def __setstate__(self, state): self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) + self._weights = state.get("weights") # objs saved before #XXX won't have # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked @@ -1516,6 +1550,10 @@ def _compute_tfr(self, data, n_jobs, verbose): if self.method == "stockwell": self._data, self._itc, freqs = result assert np.array_equal(self._freqs, freqs) + elif self.method == "multitaper" and self._tfr_func.keywords.get( + "output", "" + ) in ["complex", "phase"]: + self._data, self._weights = result elif self._tfr_func.keywords.get("output", "").endswith("_itc"): self._data, self._itc = result.real, result.imag else: @@ -1694,6 +1732,11 @@ def times(self): """The time points present in the data (in seconds).""" return self._times_readonly + @property + def weights(self): + """The weights used for each taper in the time-frequency estimates.""" + return self._weights + @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): """Crop data to a given time interval in place. From 82fc2f7fe450dee8445cb9b48993944336e2aedc Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:24:33 +0200 Subject: [PATCH 02/23] Update docstrings --- mne/time_frequency/multitaper.py | 5 ++-- mne/time_frequency/tfr.py | 50 +++++++++++++++++++------------- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index f5f6f79a0b3..fc926af4863 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -505,7 +505,8 @@ def tfr_array_multitaper( coherence across trials. return_weights : bool, default False - If True, return the taper weights. Only applies if ``output="complex"``. + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. .. versionadded:: 1.9.0 @@ -528,7 +529,7 @@ def tfr_array_multitaper( contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. weights : array of shape (n_tapers, n_freqs) - The taper weights. Only returned if ``output="complex"`` and + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and ``return_weights=True``. See Also diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index e9d028e7e50..dd64c18d9e3 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1215,9 +1215,6 @@ def __init__( f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) - # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) - if method == "morlet": - method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): @@ -2697,9 +2694,12 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = isinstance(self, EpochsTFR) + unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): + if from_epo: valid_index_args.extend(["epoch", "condition"]) valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) @@ -2707,32 +2707,42 @@ def to_data_frame( # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - axis = self._dims.index("channel") - if not isinstance(self, EpochsTFR): + ch_axis = self._dims.index("channel") + if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis - axis += 1 - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, axis, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + ch_axis += 1 + if not unagg_mt: + data = np.expand_dims(data, -3) # add singleton "tapers" axis + n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape + # reshape to (epochs*tapers*freqs*times) x signals + data = np.moveaxis(data, ch_axis, -1) + data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() + default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs) - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + times = np.tile(times, n_epochs * n_freqs * n_tapers) + freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + if from_epo: + mindex.append( + ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) + ) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + mindex.append( + ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) + ) + default_index.extend(["condition", "epoch"]) + default_index.extend(["freq", "time"]) + if unagg_mt: + name = "taper" + taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times) + mindex.append((name, taper_nums)) + default_index.append(name) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) From a49f9343f7ae1f42f19a5ff3e2ee687404fa3eda Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:38:02 +0200 Subject: [PATCH 03/23] Remove whitespace --- mne/time_frequency/tfr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index dd64c18d9e3..a7cd416cb0d 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -506,7 +506,6 @@ def _compute_tfr( n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. - weights : array of shape (n_tapers, n_freqs) The taper weights. Only returned if method='multitaper', output='complex' or 'phase', and return_weights=True. From 7c3dcfa3a38207d12adee76f85fd21efd1d176e0 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:39:51 +0200 Subject: [PATCH 04/23] Add PR num --- mne/time_frequency/tfr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index a7cd416cb0d..91eaad159f5 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1439,7 +1439,7 @@ def __setstate__(self, state): self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) - self._weights = state.get("weights") # objs saved before #XXX won't have + self._weights = state.get("weights") # objs saved before #12910 won't have # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked From 8c167168e3bfbfd2a1f2394831bcf9c4c214d6f5 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:40:56 +0200 Subject: [PATCH 05/23] Revert "Update docstrings" This reverts commit 82fc2f7fe450dee8445cb9b48993944336e2aedc. --- mne/time_frequency/multitaper.py | 5 ++-- mne/time_frequency/tfr.py | 50 +++++++++++++------------------- 2 files changed, 22 insertions(+), 33 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index fc926af4863..f5f6f79a0b3 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -505,8 +505,7 @@ def tfr_array_multitaper( coherence across trials. return_weights : bool, default False - If True, return the taper weights. Only applies if ``output='complex'`` or - ``'phase'``. + If True, return the taper weights. Only applies if ``output="complex"``. .. versionadded:: 1.9.0 @@ -529,7 +528,7 @@ def tfr_array_multitaper( contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. weights : array of shape (n_tapers, n_freqs) - The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and + The taper weights. Only returned if ``output="complex"`` and ``return_weights=True``. See Also diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 91eaad159f5..908d25662e8 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1214,6 +1214,9 @@ def __init__( f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) + # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) + if method == "morlet": + method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): @@ -2693,12 +2696,9 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa - # triage for Epoch-derived or unaggregated spectra - from_epo = isinstance(self, EpochsTFR) - unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if from_epo: + if isinstance(self, EpochsTFR): valid_index_args.extend(["epoch", "condition"]) valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) @@ -2706,42 +2706,32 @@ def to_data_frame( # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - ch_axis = self._dims.index("channel") - if not from_epo: + axis = self._dims.index("channel") + if not isinstance(self, EpochsTFR): data = data[np.newaxis] # add singleton "epochs" axis - ch_axis += 1 - if not unagg_mt: - data = np.expand_dims(data, -3) # add singleton "tapers" axis - n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape - # reshape to (epochs*tapers*freqs*times) x signals - data = np.moveaxis(data, ch_axis, -1) - data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) + axis += 1 + n_epochs, n_picks, n_freqs, n_times = data.shape + # reshape to (epochs*freqs*times) x signals + data = np.moveaxis(data, axis, -1) + data = data.reshape(n_epochs * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() - default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs * n_tapers) - freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs) + times = np.tile(times, n_epochs * n_freqs) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if from_epo: - mindex.append( - ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) - ) + if isinstance(self, EpochsTFR): + mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append( - ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) - ) - default_index.extend(["condition", "epoch"]) - default_index.extend(["freq", "time"]) - if unagg_mt: - name = "taper" - taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times) - mindex.append((name, taper_nums)) - default_index.append(name) + mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame + if isinstance(self, EpochsTFR): + default_index = ["condition", "epoch", "freq", "time"] + else: + default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) From 51b8cd0ac419401462e46619a28316095a2ecef9 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:42:03 +0200 Subject: [PATCH 06/23] Remove outdated default setting --- mne/time_frequency/tfr.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 908d25662e8..3a04c9910ce 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1214,9 +1214,6 @@ def __init__( f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) - # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) - if method == "morlet": - method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): From 2f9a4b4bc5daaa61eaa20b08f7c0918733c6cd38 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:43:40 +0200 Subject: [PATCH 07/23] Reapply "Update docstrings" This reverts commit 8c167168e3bfbfd2a1f2394831bcf9c4c214d6f5. --- mne/time_frequency/multitaper.py | 5 ++-- mne/time_frequency/tfr.py | 47 ++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index f5f6f79a0b3..fc926af4863 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -505,7 +505,8 @@ def tfr_array_multitaper( coherence across trials. return_weights : bool, default False - If True, return the taper weights. Only applies if ``output="complex"``. + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. .. versionadded:: 1.9.0 @@ -528,7 +529,7 @@ def tfr_array_multitaper( contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. weights : array of shape (n_tapers, n_freqs) - The taper weights. Only returned if ``output="complex"`` and + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and ``return_weights=True``. See Also diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 3a04c9910ce..91eaad159f5 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -2693,9 +2693,12 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = isinstance(self, EpochsTFR) + unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): + if from_epo: valid_index_args.extend(["epoch", "condition"]) valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) @@ -2703,32 +2706,42 @@ def to_data_frame( # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - axis = self._dims.index("channel") - if not isinstance(self, EpochsTFR): + ch_axis = self._dims.index("channel") + if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis - axis += 1 - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, axis, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + ch_axis += 1 + if not unagg_mt: + data = np.expand_dims(data, -3) # add singleton "tapers" axis + n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape + # reshape to (epochs*tapers*freqs*times) x signals + data = np.moveaxis(data, ch_axis, -1) + data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() + default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs) - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + times = np.tile(times, n_epochs * n_freqs * n_tapers) + freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + if from_epo: + mindex.append( + ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) + ) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + mindex.append( + ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) + ) + default_index.extend(["condition", "epoch"]) + default_index.extend(["freq", "time"]) + if unagg_mt: + name = "taper" + taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times) + mindex.append((name, taper_nums)) + default_index.append(name) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) From b4537b2d0a956da01291e7b499f60586254998ae Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:45:42 +0200 Subject: [PATCH 08/23] Update docstrings --- mne/time_frequency/multitaper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index f5f6f79a0b3..fc926af4863 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -505,7 +505,8 @@ def tfr_array_multitaper( coherence across trials. return_weights : bool, default False - If True, return the taper weights. Only applies if ``output="complex"``. + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. .. versionadded:: 1.9.0 @@ -528,7 +529,7 @@ def tfr_array_multitaper( contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. weights : array of shape (n_tapers, n_freqs) - The taper weights. Only returned if ``output="complex"`` and + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and ``return_weights=True``. See Also From 8d645bb830f550e290da4a5028672336751c5f87 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 29 Oct 2024 20:00:14 +0100 Subject: [PATCH 09/23] Enforce return_weights as named param --- mne/time_frequency/multitaper.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index fc926af4863..f7c6dc51a4c 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -469,9 +469,9 @@ def tfr_array_multitaper( use_fft=True, decim=1, output="complex", - return_weights=False, n_jobs=None, *, + return_weights=False, verbose=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. @@ -504,14 +504,13 @@ def tfr_array_multitaper( * ``'avg_power_itc'`` : average of single trial power and inter-trial coherence across trials. + %(n_jobs)s + The parallelization is implemented across channels. return_weights : bool, default False If True, return the taper weights. Only applies if ``output='complex'`` or ``'phase'``. .. versionadded:: 1.9.0 - - %(n_jobs)s - The parallelization is implemented across channels. %(verbose)s Returns From 1c02b40b69c27eee41b93f5ca907f302a75df3d4 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 9 Dec 2024 11:37:18 +0000 Subject: [PATCH 10/23] Add missing test coverage --- mne/time_frequency/tests/test_tfr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 1799692d6ce..ec50f22e38c 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -1538,7 +1538,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): """Test Epochs.compute_tfr(output="complex"/"phase").""" tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) - assert len(tfr.shape) == 5 + assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time + assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match @pytest.mark.parametrize("copy", (False, True)) From 54f2a32037b756f2d15eea1ca6ae1ebe991b9f3c Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 9 Dec 2024 11:43:38 +0000 Subject: [PATCH 11/23] Add changelog entry --- doc/changes/devel/12910.newfeature.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 doc/changes/devel/12910.newfeature.rst diff --git a/doc/changes/devel/12910.newfeature.rst b/doc/changes/devel/12910.newfeature.rst new file mode 100644 index 00000000000..d4af832923a --- /dev/null +++ b/doc/changes/devel/12910.newfeature.rst @@ -0,0 +1,3 @@ +Added the option to return taper weights from +:func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the +:class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_. \ No newline at end of file From a10799132c4c579c917a0ea422ca1e93f94cd133 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 9 Dec 2024 18:29:41 +0000 Subject: [PATCH 12/23] Begin add support for tapers in array objs --- mne/time_frequency/tfr.py | 70 +++++++++++++++++++++++++++++++++++---- mne/utils/docs.py | 10 ++++++ 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 91eaad159f5..53a908dc648 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1421,7 +1421,6 @@ def __setstate__(self, state): defaults = dict( method="unknown", - dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], baseline=None, decim=1, data_type="TFR", @@ -1445,7 +1444,7 @@ def __setstate__(self, state): unknown_class = Epochs if "epoch" in self._dims else Evoked inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] - # sanity check data/freqs/times/info agreement + # sanity check data/freqs/times/info/weights agreement self._check_state() def __repr__(self): @@ -1498,14 +1497,26 @@ def _check_compatibility(self, other): raise RuntimeError(msg.format(problem, extra)) def _check_state(self): - """Check data/freqs/times/info agreement during __setstate__.""" + """Check data/freqs/times/info/weights agreement during __setstate__.""" msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] + n_taper = ( + self._data.shape[self._dims.index("taper")] + if "taper" in self._dims + else None + ) n_freq = self._data.shape[self._dims.index("freq")] n_time = self._data.shape[self._dims.index("time")] if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) + elif n_taper is not None: + if self._weights is None: + raise RuntimeError("Taper dimension in data, but no weights found.") + if n_taper != self._weights.shape[0]: + msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) + elif n_freq != self._weights.shape[1]: + msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): @@ -2788,6 +2799,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2904,6 +2916,10 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" + if state["data"].ndim != 3: + raise ValueError(f"RawTFR data should be 3D, got {state['data'].ndim}.") + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel", "freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -3059,6 +3075,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3143,8 +3160,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack EpochsTFR from serialized format.""" - if state["data"].ndim != 4: - raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + if state["data"].ndim not in [4, 5]: + raise ValueError( + f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("epoch", "channel") + if state["data"].ndim == 5: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._metadata = state.get("metadata", None) n_epochs = self.shape[0] @@ -3248,7 +3272,16 @@ def average(self, method="mean", *, dim="epochs", copy=False): See discussion here: https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + + Averaging is not supported for data containing a taper dimension. """ + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs, frequencies, or times is " + "not supported. If averaging across epochs, consider averaging the " + "epochs before computing the complex/phase spectrum." + ) + _check_option("dim", dim, ("epochs", "freqs", "times")) axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural @@ -3620,6 +3653,7 @@ class EpochsTFRArray(EpochsTFR): %(selection)s %(drop_log)s %(metadata_epochstfr)s + %(weights_tfr_array)s Attributes ---------- @@ -3636,6 +3670,7 @@ class EpochsTFRArray(EpochsTFR): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3658,6 +3693,7 @@ def __init__( selection=None, drop_log=None, metadata=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) optional = dict( @@ -3668,6 +3704,7 @@ def __init__( selection=selection, drop_log=drop_log, metadata=metadata, + weights=weights, ) for name, value in optional.items(): if value is not None: @@ -3710,6 +3747,7 @@ class RawTFR(BaseTFR): method : str The method used to compute the spectra (``'morlet'``, ``'multitaper'`` or ``'stockwell'``). + %(weights_tfr_attr)s See Also -------- @@ -3759,6 +3797,19 @@ def __init__( **method_kw, ) + def __setstate__(self, state): + """Unpack RawTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") + super().__setstate__(state) + def __getitem__(self, item): """Get RawTFR data. @@ -3824,6 +3875,7 @@ class RawTFRArray(RawTFR): %(times)s %(freqs_tfr_array)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -3834,6 +3886,7 @@ class RawTFRArray(RawTFR): %(method_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3851,10 +3904,13 @@ def __init__( freqs, *, method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - if method is not None: - state["method"] = method + optional = dict(method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 2ff6984dce9..bc1c10a623a 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5008,6 +5008,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ +docdict["weight_tfr_array"] = """ +weights : array of shape (n_tapers, n_freqs) | None + The weights for each taper. Must be provided if ``data`` has a taper dimension, such + as for complex or phase multitaper data. +""" +docdict["weight_tfr_attr"] = """ +weights : array of shape (n_tapers, n_freqs) | None + The weights for each taper, if present in the data. +""" + docdict["window_psd"] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. From 01c486c53e27a762551aea877b8ec290dd44431a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 9 Dec 2024 18:29:41 +0000 Subject: [PATCH 13/23] Begin add support for tapers in array objs --- mne/time_frequency/tfr.py | 70 +++++++++++++++++++++++++++++++++++---- mne/utils/docs.py | 10 ++++++ 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 3a04c9910ce..e4d69ca8580 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1421,7 +1421,6 @@ def __setstate__(self, state): defaults = dict( method="unknown", - dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], baseline=None, decim=1, data_type="TFR", @@ -1445,7 +1444,7 @@ def __setstate__(self, state): unknown_class = Epochs if "epoch" in self._dims else Evoked inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] - # sanity check data/freqs/times/info agreement + # sanity check data/freqs/times/info/weights agreement self._check_state() def __repr__(self): @@ -1498,14 +1497,26 @@ def _check_compatibility(self, other): raise RuntimeError(msg.format(problem, extra)) def _check_state(self): - """Check data/freqs/times/info agreement during __setstate__.""" + """Check data/freqs/times/info/weights agreement during __setstate__.""" msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] + n_taper = ( + self._data.shape[self._dims.index("taper")] + if "taper" in self._dims + else None + ) n_freq = self._data.shape[self._dims.index("freq")] n_time = self._data.shape[self._dims.index("time")] if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) + elif n_taper is not None: + if self._weights is None: + raise RuntimeError("Taper dimension in data, but no weights found.") + if n_taper != self._weights.shape[0]: + msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) + elif n_freq != self._weights.shape[1]: + msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): @@ -2775,6 +2786,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2891,6 +2903,10 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" + if state["data"].ndim != 3: + raise ValueError(f"RawTFR data should be 3D, got {state['data'].ndim}.") + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel", "freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -3046,6 +3062,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3130,8 +3147,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack EpochsTFR from serialized format.""" - if state["data"].ndim != 4: - raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + if state["data"].ndim not in [4, 5]: + raise ValueError( + f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("epoch", "channel") + if state["data"].ndim == 5: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._metadata = state.get("metadata", None) n_epochs = self.shape[0] @@ -3235,7 +3259,16 @@ def average(self, method="mean", *, dim="epochs", copy=False): See discussion here: https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + + Averaging is not supported for data containing a taper dimension. """ + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs, frequencies, or times is " + "not supported. If averaging across epochs, consider averaging the " + "epochs before computing the complex/phase spectrum." + ) + _check_option("dim", dim, ("epochs", "freqs", "times")) axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural @@ -3607,6 +3640,7 @@ class EpochsTFRArray(EpochsTFR): %(selection)s %(drop_log)s %(metadata_epochstfr)s + %(weights_tfr_array)s Attributes ---------- @@ -3623,6 +3657,7 @@ class EpochsTFRArray(EpochsTFR): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3645,6 +3680,7 @@ def __init__( selection=None, drop_log=None, metadata=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) optional = dict( @@ -3655,6 +3691,7 @@ def __init__( selection=selection, drop_log=drop_log, metadata=metadata, + weights=weights, ) for name, value in optional.items(): if value is not None: @@ -3697,6 +3734,7 @@ class RawTFR(BaseTFR): method : str The method used to compute the spectra (``'morlet'``, ``'multitaper'`` or ``'stockwell'``). + %(weights_tfr_attr)s See Also -------- @@ -3746,6 +3784,19 @@ def __init__( **method_kw, ) + def __setstate__(self, state): + """Unpack RawTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") + super().__setstate__(state) + def __getitem__(self, item): """Get RawTFR data. @@ -3811,6 +3862,7 @@ class RawTFRArray(RawTFR): %(times)s %(freqs_tfr_array)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -3821,6 +3873,7 @@ class RawTFRArray(RawTFR): %(method_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3838,10 +3891,13 @@ def __init__( freqs, *, method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - if method is not None: - state["method"] = method + optional = dict(method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 2ff6984dce9..bc1c10a623a 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5008,6 +5008,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ +docdict["weight_tfr_array"] = """ +weights : array of shape (n_tapers, n_freqs) | None + The weights for each taper. Must be provided if ``data`` has a taper dimension, such + as for complex or phase multitaper data. +""" +docdict["weight_tfr_attr"] = """ +weights : array of shape (n_tapers, n_freqs) | None + The weights for each taper, if present in the data. +""" + docdict["window_psd"] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. From ca27179c0d76d40eceec8323148e91633e54ae92 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 9 Dec 2024 18:33:30 +0000 Subject: [PATCH 14/23] Fix docstring entries --- mne/time_frequency/tfr.py | 1 - mne/utils/docs.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index e4d69ca8580..ec5ecf49732 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -2786,7 +2786,6 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s - %(weights_tfr_attr)s See Also -------- diff --git a/mne/utils/docs.py b/mne/utils/docs.py index bc1c10a623a..893f379ba75 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5008,12 +5008,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ -docdict["weight_tfr_array"] = """ +docdict["weights_tfr_array"] = """ weights : array of shape (n_tapers, n_freqs) | None The weights for each taper. Must be provided if ``data`` has a taper dimension, such as for complex or phase multitaper data. """ -docdict["weight_tfr_attr"] = """ +docdict["weights_tfr_attr"] = """ weights : array of shape (n_tapers, n_freqs) | None The weights for each taper, if present in the data. """ From b14a100341bc7ad1c536c1d2aded0bfaf9c7c3ac Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Dec 2024 13:45:42 +0000 Subject: [PATCH 15/23] Fix faulty state check --- mne/time_frequency/tfr.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index ec5ecf49732..1829bb1bb98 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1501,26 +1501,25 @@ def _check_state(self): msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] + n_freq = self._data.shape[self._dims.index("freq")] + n_time = self._data.shape[self._dims.index("time")] n_taper = ( self._data.shape[self._dims.index("taper")] if "taper" in self._dims else None ) - n_freq = self._data.shape[self._dims.index("freq")] - n_time = self._data.shape[self._dims.index("time")] + if n_taper is not None and self._weights is None: + raise ValueError("Taper dimension in data, but no weights found.") if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) - elif n_taper is not None: - if self._weights is None: - raise RuntimeError("Taper dimension in data, but no weights found.") - if n_taper != self._weights.shape[0]: - msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) - elif n_freq != self._weights.shape[1]: - msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): msg = msg.format("Time", n_time, "times", self.times.size) + elif n_taper is not None and n_taper != self._weights.shape[0]: + msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) + elif n_taper is not None and n_freq != self._weights.shape[1]: + msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) else: return raise ValueError(msg) From 972aba235799015355465e92a07adb2e50a95709 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Dec 2024 13:46:00 +0000 Subject: [PATCH 16/23] Add weights to AverageTFR --- mne/time_frequency/tfr.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 1829bb1bb98..52ebcd2c1f3 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -2785,6 +2785,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2901,10 +2902,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" - if state["data"].ndim != 3: - raise ValueError(f"RawTFR data should be 3D, got {state['data'].ndim}.") + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) # Set dims now since optional tapers makes it difficult to disentangle later - state["dims"] = ("channel", "freq", "time") + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -2948,6 +2954,7 @@ class AverageTFRArray(AverageTFR): The number of averaged TFRs. %(comment_averagetfr_attr)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -2960,6 +2967,7 @@ class AverageTFRArray(AverageTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2970,12 +2978,22 @@ class AverageTFRArray(AverageTFR): """ def __init__( - self, info, data, times, freqs, *, nave=None, comment=None, method=None + self, + info, + data, + times, + freqs, + *, + nave=None, + comment=None, + method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - for name, optional in dict(nave=nave, comment=comment, method=method).items(): - if optional is not None: - state[name] = optional + optional = dict(nave=nave, comment=comment, method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) From e11fa2b74562380e5d4ca2bb5b7979a5968a8342 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Dec 2024 13:46:14 +0000 Subject: [PATCH 17/23] Expand test coverage --- mne/time_frequency/tests/test_tfr.py | 86 +++++++++++++++++++++++++--- 1 file changed, 78 insertions(+), 8 deletions(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index ec50f22e38c..96ae7997caa 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -668,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): with tfr.info._unlock(): tfr.info["meas_date"] = want assert tfr_loaded == tfr + # test with taper dimension and weights + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs + state = tfr.__getstate__() + state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim + state["weights"] = weights # add weights + state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims + tfr = EpochsTFR(inst=state) + tfr.save(fname, overwrite=True) + tfr_loaded = read_tfrs(fname) + assert tfr_loaded == tfr # test overwrite with pytest.raises(OSError, match="Destination file exists."): tfr.save(fname, overwrite=False) @@ -726,17 +737,31 @@ def test_average_tfr_init(full_evoked): AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) -def test_epochstfr_init_errors(epochs_tfr): - """Test __init__ for EpochsTFR.""" - state = epochs_tfr.__getstate__() - with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) +@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) +def test_tfr_init_errors(inst, request, average_tfr): + """Test __init__ for Raw/Epochs/AverageTFR.""" + # Load data + inst = _get_inst(inst, request, average_tfr=average_tfr) + state = inst.__getstate__() + # Prepare for TFRArray object instantiation + inst_name = inst.__class__.__name__ + class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR) + ndims_mapping = dict( + RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D") + ) + TFR = class_mapping[inst_name] + allowed_ndims = ndims_mapping[inst_name] + # Check errors caught + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=inst.data[..., 0])) + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1)))) with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) + TFR(inst=state | dict(data=inst.data[..., :-1, :, :])) with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): - EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) + TFR(inst=state | dict(times=inst.times[:-1])) with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): - EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) + TFR(inst=state | dict(freqs=inst.freqs[:-1])) @pytest.mark.parametrize( @@ -1158,6 +1183,15 @@ def test_averaging_epochsTFR(): ): power.average(method=np.mean) + # Check it doesn't run for taper spectra + tapered = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex" + ) + with pytest.raises( + NotImplementedError, match=r"Averaging multitaper tapers .* is not supported." + ): + tapered.average() + def test_averaging_freqsandtimes_epochsTFR(): """Test that EpochsTFR averaging freqs methods work.""" @@ -1551,6 +1585,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): assert avgs[0].comment == str(epochs_tfr.events[0, -1]) +@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) +def test_tfrarray_tapered_spectra(inst, evoked, request): + """Test Raw/Epochs/AverageTFRArray instantiation with tapered spectra.""" + # Load data object + inst = _get_inst(inst, request, evoked=evoked) + inst.pick("mag") + # Compute TFR with taper dimension (can be complex or phase output) + tfr = inst.compute_tfr( + method="multitaper", freqs=freqs_linspace, n_cycles=4, output="complex" + ) + tfr_array, weights = tfr.get_data(), tfr.weights + # Prepare for TFRArray object instantiation + defaults = dict( + info=inst.info, data=tfr_array, times=inst.times, freqs=freqs_linspace + ) + class_mapping = dict(Raw=RawTFRArray, Epochs=EpochsTFRArray, Evoked=AverageTFRArray) + TFRArray = class_mapping[inst.__class__.__name__] + # Check TFRArray instantiation runs with good data + TFRArray(**defaults, weights=weights) + # Check taper dimension but no weights caught + with pytest.raises( + ValueError, match="Taper dimension in data, but no weights found." + ): + TFRArray(**defaults) + # Check mismatching n_taper in weights caught + with pytest.raises( + ValueError, match=r"Taper axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:-1]) + # Check mismatching n_freq in weights caught + with pytest.raises( + ValueError, match=r"Frequency axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:, :-1]) + + def test_tfr_proj(epochs): """Test `compute_tfr(proj=True)`.""" epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) From 999d12232ba9e83f9778b81e94078e3eb0d0cd4c Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Dec 2024 16:02:14 +0000 Subject: [PATCH 18/23] Disallow aggregating tapers in combine_tfr --- mne/time_frequency/tests/test_tfr.py | 40 ++++++++++++++++++++++++++-- mne/time_frequency/tfr.py | 8 ++++++ mne/utils/numerics.py | 3 +++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 96ae7997caa..09c8a35defa 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -739,7 +739,7 @@ def test_average_tfr_init(full_evoked): @pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) def test_tfr_init_errors(inst, request, average_tfr): - """Test __init__ for Raw/Epochs/AverageTFR.""" + """Test __init__ for {Raw,Epochs,Average}TFR.""" # Load data inst = _get_inst(inst, request, average_tfr=average_tfr) state = inst.__getstate__() @@ -1587,7 +1587,7 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): @pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) def test_tfrarray_tapered_spectra(inst, evoked, request): - """Test Raw/Epochs/AverageTFRArray instantiation with tapered spectra.""" + """Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra.""" # Load data object inst = _get_inst(inst, request, evoked=evoked) inst.pick("mag") @@ -1802,3 +1802,39 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): assert re.match( rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() ) + + +def test_combine_tfr_error_catch(request, average_tfr): + """Test combine_tfr() catches errors.""" + # check unrecognised weights string caught + with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'): + combine_tfr([average_tfr, average_tfr], weights="foo") + # check bad weights size caught + with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"): + combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1]) + # check different channel names caught + state = average_tfr.__getstate__() + new_info = average_tfr.info.copy() + average_tfr_bad = AverageTFR( + inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"})) + ) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_tfr([average_tfr, average_tfr_bad]) + # check different times caught + average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1)) + with pytest.raises( + AssertionError, match=".* do not contain the same time instants" + ): + combine_tfr([average_tfr, average_tfr_bad]) + # check taper dim caught + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs + state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1) + state["weights"] = weights + state["dims"] = ("channel", "taper", "freq", "time") + average_tfr_taper = AverageTFR(inst=state) + with pytest.raises( + NotImplementedError, + match="Aggregating multitaper tapers across TFR datasets is not supported.", + ): + combine_tfr([average_tfr_taper, average_tfr_taper]) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 52ebcd2c1f3..6a77d174efd 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -3941,8 +3941,16 @@ def combine_tfr(all_tfr, weights="nave"): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ + if any("taper" in tfr._dims for tfr in all_tfr): + raise NotImplementedError( + "Aggregating multitaper tapers across TFR datasets is not supported." + ) + tfr = all_tfr[0].copy() if isinstance(weights, str): if weights not in ("nave", "equal"): diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index c287fb42305..4bf8d094f81 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -550,6 +550,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ # check if all elements in the given list are evoked data From e12b09a90d37d3da51c6d55dc036b7c8b0cb2a75 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Dec 2024 16:22:56 +0000 Subject: [PATCH 19/23] Updated docstrings --- mne/utils/docs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 893f379ba75..1d3338bccfb 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5009,13 +5009,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ docdict["weights_tfr_array"] = """ -weights : array of shape (n_tapers, n_freqs) | None +weights : array, shape (n_tapers, n_freqs) | None The weights for each taper. Must be provided if ``data`` has a taper dimension, such as for complex or phase multitaper data. """ docdict["weights_tfr_attr"] = """ -weights : array of shape (n_tapers, n_freqs) | None - The weights for each taper, if present in the data. +weights : array, shape (n_tapers, n_freqs) | None + The weights used for each taper in the time-frequency estimates. """ docdict["window_psd"] = """\ From 728701e0a33cb67ebfd243f59002d3289f83e26a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Dec 2024 20:54:37 +0000 Subject: [PATCH 20/23] Add placeholder versionadded tags --- mne/time_frequency/multitaper.py | 2 +- mne/utils/docs.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index f7c6dc51a4c..0fa48db49d7 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -510,7 +510,7 @@ def tfr_array_multitaper( If True, return the taper weights. Only applies if ``output='complex'`` or ``'phase'``. - .. versionadded:: 1.9.0 + .. versionadded:: 1.X.0 %(verbose)s Returns diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 1d3338bccfb..5ff7b480ddc 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5012,6 +5012,8 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): weights : array, shape (n_tapers, n_freqs) | None The weights for each taper. Must be provided if ``data`` has a taper dimension, such as for complex or phase multitaper data. + + .. versionadded:: 1.X.0 """ docdict["weights_tfr_attr"] = """ weights : array, shape (n_tapers, n_freqs) | None From 80126a701735dd73952fd32e3a34ab9ae6dfc3b7 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 11 Dec 2024 21:40:34 +0000 Subject: [PATCH 21/23] Fix to_data_frame bug with tapers --- mne/time_frequency/tests/test_tfr.py | 50 +++++++++++++++++++------- mne/time_frequency/tfr.py | 53 ++++++++++++++++++---------- 2 files changed, 73 insertions(+), 30 deletions(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 09c8a35defa..62db87f3a83 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -1292,12 +1292,15 @@ def test_to_data_frame(): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) srate = 1000.0 - freqs = np.arange(5) + freqs = np.arange(n_freqs) + tapers = np.arange(n_tapers) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 5 + n_epos) @@ -1310,6 +1313,7 @@ def test_to_data_frame(): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) # test index checking with pytest.raises(ValueError, match="options. Valid index options are"): @@ -1321,10 +1325,21 @@ def test_to_data_frame(): # test wide format df_wide = tfr.to_data_frame() assert all(np.isin(tfr.ch_names, df_wide.columns)) - assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns)) + assert all( + np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns) + ) # test long format df_long = tfr.to_data_frame(long_format=True) - expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value") + expected = ( + "condition", + "epoch", + "freq", + "time", + "channel", + "ch_type", + "value", + "taper", + ) assert set(expected) == set(df_long.columns) assert set(tfr.ch_names) == set(df_long["channel"]) assert len(df_long) == tfr.data.size @@ -1332,21 +1347,29 @@ def test_to_data_frame(): df_long = tfr.to_data_frame(long_format=True, index=["freq"]) del df_wide, df_long # test whether data is in correct shape - df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"]) + df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"]) data = tfr.data assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze()) # compare arbitrary observation: assert ( - df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0] - == data[1, 3, 1, 2] + df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0] + == data[1, 3, 1, 1, 2] ) # Check also for AverageTFR: + # (remove taper dimension before averaging) + state = tfr.__getstate__() + state["data"] = state["data"][:, :, 0] + state["dims"] = ("epoch", "channel", "freq", "time") + state["weights"] = None + tfr = EpochsTFR(inst=state) tfr = tfr.average() with pytest.raises(ValueError, match="options. Valid index options are"): tfr.to_data_frame(index=["epoch", "condition"]) with pytest.raises(ValueError, match='"epoch" is not a valid option'): tfr.to_data_frame(index="epoch") + with pytest.raises(ValueError, match='"taper" is not a valid option'): + tfr.to_data_frame(index="taper") with pytest.raises(TypeError, match="index must be `None` or a string "): tfr.to_data_frame(index=np.arange(400)) # test wide format @@ -1382,11 +1405,13 @@ def test_to_data_frame_index(index): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) - freqs = np.arange(5) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 8) @@ -1399,6 +1424,7 @@ def test_to_data_frame_index(index): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) df = tfr.to_data_frame(picks=[0, 2, 3], index=index) # test index order/hierarchy preservation @@ -1406,7 +1432,7 @@ def test_to_data_frame_index(index): index = [index] assert list(df.index.names) == index # test that non-indexed data were present as columns - non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index)) + non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index)) if len(non_index): assert all(np.isin(non_index, df.columns)) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 3555ce14963..49bc15d8833 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1837,6 +1837,7 @@ def get_data( tmax=None, return_times=False, return_freqs=False, + return_tapers=False, ): """Get time-frequency data in NumPy array format. @@ -1852,6 +1853,10 @@ def get_data( return_freqs : bool Whether to return the frequency bin values for the requested frequency range. Default is ``False``. + return_tapers : bool + Whether to return the taper numbers. Default is ``False``. + + .. versionadded:: 1.X.0 Returns ------- @@ -1863,6 +1868,9 @@ def get_data( freqs : array The frequency values for the requested data range. Only returned if ``return_freqs`` is ``True``. + tapers : array | None + The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be + ``None`` if a taper dimension is not present in the data. Notes ----- @@ -1900,7 +1908,13 @@ def get_data( if return_freqs: freqs = self._freqs[fmin_idx:fmax_idx] out.append(freqs) - if not return_times and not return_freqs: + if return_tapers: + if "taper" in self._dims: + tapers = np.arange(self.shape[self._dims.index("taper")]) + else: + tapers = None + out.append(tapers) + if not return_times and not return_freqs and not return_tapers: return out[0] return tuple(out) @@ -2676,21 +2690,21 @@ def to_data_frame( ): """Export data in tabular structure as a pandas DataFrame. - Channels are converted to columns in the DataFrame. By default, - additional columns ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` (epoch event description) are added, unless ``index`` - is not ``None`` (in which case the columns specified in ``index`` will - be used to form the DataFrame's index instead). ``'epoch'``, and - ``'condition'`` are not supported for ``AverageTFR``. + Channels are converted to columns in the DataFrame. By default, additional + columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'`` + (epoch event description) are added, unless ``index`` is not ``None`` (in which + case the columns specified in ``index`` will be used to form the DataFrame's + index instead). ``'epoch'``, and ``'condition'`` are not supported for + ``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is + present, such as for complex or phase multitaper data. Parameters ---------- %(picks_all)s %(index_df_epo)s - Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` - for ``AverageTFR``. - Defaults to ``None``. + Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, + and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and + ``'taper'`` for ``AverageTFR``. Defaults to ``None``. %(long_format_df_epo)s %(time_format_df)s @@ -2710,12 +2724,16 @@ def to_data_frame( valid_index_args = ["time", "freq"] if from_epo: valid_index_args.extend(["epoch", "condition"]) + if unagg_mt: + valid_index_args.append("taper") valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) - data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) + data, times, freqs, tapers = self.get_data( + picks, return_times=True, return_freqs=True, return_tapers=True + ) ch_axis = self._dims.index("channel") if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis @@ -2731,7 +2749,7 @@ def to_data_frame( default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) times = np.tile(times, n_epochs * n_freqs * n_tapers) - freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers) mindex.append(("time", times)) mindex.append(("freq", freqs)) if from_epo: @@ -2744,12 +2762,11 @@ def to_data_frame( ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) ) default_index.extend(["condition", "epoch"]) - default_index.extend(["freq", "time"]) if unagg_mt: - name = "taper" - taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times) - mindex.append((name, taper_nums)) - default_index.append(name) + tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times) + mindex.append(("taper", tapers)) + default_index.append("taper") + default_index.extend(["freq", "time"]) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame df = _build_data_frame( From 82dfab9c2c4a6e0959bfb049a8ff0475bda67cee Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sat, 14 Dec 2024 19:00:34 +0000 Subject: [PATCH 22/23] Fix plotting with tapers --- mne/time_frequency/tests/test_tfr.py | 64 ++++++++++---- mne/time_frequency/tfr.py | 121 +++++++++++---------------- mne/viz/tests/test_topomap.py | 25 +++++- mne/viz/topomap.py | 14 +++- 4 files changed, 136 insertions(+), 88 deletions(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 62db87f3a83..4eec02af6f4 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -859,6 +859,25 @@ def test_plot(): plt.close("all") +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_plot_multitaper_complex_phase(output): + """Test TFR plotting of data with a taper dimension.""" + # Create example data with a taper dimension + n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3) + data = np.random.rand(n_chans, n_tapers, n_freqs, n_times) + if output == "complex": + data = data + np.random.rand(*data.shape) * 1j # add imaginary data + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, weights=weights + ) + # Check that plotting works + tfr.plot() + + @pytest.mark.parametrize( "timefreqs,title,combine", ( @@ -1611,23 +1630,23 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): assert avgs[0].comment == str(epochs_tfr.events[0, -1]) -@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) -def test_tfrarray_tapered_spectra(inst, evoked, request): +@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked")) +def test_tfrarray_tapered_spectra(obj_type): """Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra.""" - # Load data object - inst = _get_inst(inst, request, evoked=evoked) - inst.pick("mag") - # Compute TFR with taper dimension (can be complex or phase output) - tfr = inst.compute_tfr( - method="multitaper", freqs=freqs_linspace, n_cycles=4, output="complex" - ) - tfr_array, weights = tfr.get_data(), tfr.weights + # Create example data with a taper dimension + n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6) + data_shape = (n_chans, n_tapers, n_freqs, n_times) + if obj_type == "epochs": + data_shape = (n_epochs,) + data_shape + data = np.random.rand(*data_shape) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") # Prepare for TFRArray object instantiation - defaults = dict( - info=inst.info, data=tfr_array, times=inst.times, freqs=freqs_linspace - ) - class_mapping = dict(Raw=RawTFRArray, Epochs=EpochsTFRArray, Evoked=AverageTFRArray) - TFRArray = class_mapping[inst.__class__.__name__] + defaults = dict(info=info, data=data, times=times, freqs=freqs) + class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray) + TFRArray = class_mapping[obj_type] # Check TFRArray instantiation runs with good data TFRArray(**defaults, weights=weights) # Check taper dimension but no weights caught @@ -1830,7 +1849,20 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): ) -def test_combine_tfr_error_catch(request, average_tfr): +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked): + """Test plot_joint/topo/topomap() for data with a taper dimension.""" + # Compute TFR with taper dimension + tfr = evoked.compute_tfr( + method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output + ) + # Check that plotting works + tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed + tfr.plot_topo() + tfr.plot_topomap() + + +def test_combine_tfr_error_catch(average_tfr): """Test combine_tfr() catches errors.""" # check unrecognised weights string caught with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'): diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 49bc15d8833..04c43f9f4d7 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1660,6 +1660,7 @@ def _onselect( fmax=fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) # average over times and freqs @@ -2026,6 +2027,7 @@ def plot( baseline=baseline, mode=mode, dB=dB, + taper_weights=self.weights, verbose=verbose, ) # shape @@ -2036,6 +2038,9 @@ def plot( want_shape[ch_axis] = len(idx_picks) if combine is None else 1 want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping + want_shape = [ + n for i, n in enumerate(want_shape) if self._dims[i] != "taper" + ] # tapers must be aggregated over by now want_shape = tuple(want_shape) # combine combine_was_none = combine is None @@ -2379,6 +2384,7 @@ def plot_joint( fmax=_fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) _data = _data.mean(axis=(-1, -2)) # avg over times and freqs @@ -2527,23 +2533,23 @@ def plot_topo( info, data = _prepare_picks(info, data, picks, axis=0) del picks - # TODO this is the only remaining call to _preproc_tfr; should be refactored - # (to use _prep_data_for_plot?) - data, times, freqs, vmin, vmax = _preproc_tfr( + # baseline, crop, convert complex to power, aggregate tapers, and dB scaling + data, times, freqs = _prep_data_for_plot( data, times, freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - info["sfreq"], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + taper_weights=self.weights, + verbose=verbose, ) + # get vlims + vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) if layout is None: from mne import find_layout @@ -4054,62 +4060,6 @@ def _centered(arr, newsize): return arr[tuple(myslice)] -def _preproc_tfr( - data, - times, - freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - sfreq, - copy=None, -): - """Aux Function to prepare tfr computation.""" - if copy is None: - copy = baseline is not None - data = rescale(data, times, baseline, mode, copy=copy) - - if np.iscomplexobj(data): - # complex amplitude → real power (for plotting); if data are - # real-valued they should already be power - data = (data * data.conj()).real - - # crop time - itmin, itmax = None, None - idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0] - if tmin is not None: - itmin = idx[0] - if tmax is not None: - itmax = idx[-1] + 1 - - times = times[itmin:itmax] - - # crop freqs - ifmin, ifmax = None, None - idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0] - if fmin is not None: - ifmin = idx[0] - if fmax is not None: - ifmax = idx[-1] + 1 - - freqs = freqs[ifmin:ifmax] - - # crop data - data = data[:, ifmin:ifmax, itmin:itmax] - - if dB: - data = 10 * np.log10(data) - - vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) - return data, times, freqs, vmin, vmax - - def _ensure_slice(decim): """Aux function checking the decim parameter.""" _validate_type(decim, ("int-like", slice), "decim") @@ -4344,6 +4294,7 @@ def _prep_data_for_plot( baseline=None, mode=None, dB=False, + taper_weights=None, verbose=None, ): # baseline @@ -4357,9 +4308,39 @@ def _prep_data_for_plot( freqs = freqs[freq_mask] # crop data data = data[..., freq_mask, :][..., time_mask] - # complex amplitude → real power; real-valued data is already power (or ITC) + # handle unaggregated multitaper (complex or phase multitaper data) + if taper_weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + if np.iscomplexobj(data): # complex coefficients → power + data = _tfr_from_mt(data, taper_weights) + else: # tapered phase data → weighted phase data + data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = (data * data.conj()).real if dB: data = 10 * np.log10(data) return data, times, freqs + + +def _tfr_from_mt(x_mt, weights): + """Aggregate complex multitaper coefficients over tapers and convert to power. + + Parameters + ---------- + x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times) + The complex-valued multitaper coefficients. + weights : array, shape (n_tapers, n_freqs) + The weights to use to combine the tapered estimates. + + Returns + ------- + tfr : array, shape (n_channels, n_freqs, n_times) + The time-frequency power estimates. + """ + weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims + tfr = weights * x_mt + tfr *= tfr.conj() + tfr = tfr.real.sum(axis=1) + tfr *= 2 / (weights * weights.conj()).real.sum(axis=1) + return tfr diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index afa9341c00e..b87d0d39f89 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -44,7 +44,7 @@ compute_bridged_electrodes, compute_current_source_density, ) -from mne.time_frequency.tfr import AverageTFRArray +from mne.time_frequency.tfr import AverageTFR, AverageTFRArray from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap from mne.viz.tests.test_raw import _proj_status from mne.viz.topomap import ( @@ -610,6 +610,29 @@ def test_plot_tfr_topomap(): ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) + # test data with taper dimension (real) + data = np.expand_dims(data, axis=1) + weights = np.random.rand(1, n_freqs) + tfr = AverageTFRArray( + info=info, + data=data, + times=times, + freqs=np.arange(n_freqs), + nave=nave, + weights=weights, + ) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # test data with taper dimension (complex) + state = tfr.__getstate__() + tfr = AverageTFR(inst=state | dict(data=data * (1 + 1j))) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # remove taper dim before proceeding + data = data[:, 0] + # test real numbers tfr = AverageTFRArray( info=info, data=data, times=times, freqs=np.arange(n_freqs), nave=nave diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 147919a9c9d..e9240b8917d 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -1882,7 +1882,7 @@ def plot_tfr_topomap( tfr, ch_type, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) - data = tfr.data[picks, :, :] + data = tfr.data[picks] # merging grads before rescaling makes ERDs visible if merge_channels: @@ -1890,6 +1890,18 @@ def plot_tfr_topomap( data = rescale(data, tfr.times, baseline, mode, copy=True) + # handle unaggregated multitaper (complex or phase multitaper data) + if tfr.weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + weights = tfr.weights[np.newaxis, :, :, np.newaxis] # add channel & time dims + data = weights * data + if np.iscomplexobj(data): # complex coefficients → power + data *= data.conj() + data = data.real.sum(axis=1) + data *= 2 / (weights * weights.conj()).real.sum(axis=1) + else: # tapered phase data → weighted phase data + data = data.mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = np.sqrt((data * data.conj()).real) From 012bd949ce426f52925ed9ddece033f498f11475 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 20 Dec 2024 11:36:48 +0000 Subject: [PATCH 23/23] Add version tag --- mne/time_frequency/multitaper.py | 2 +- mne/time_frequency/tfr.py | 2 +- mne/utils/docs.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 0fa48db49d7..463879a8860 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -510,7 +510,7 @@ def tfr_array_multitaper( If True, return the taper weights. Only applies if ``output='complex'`` or ``'phase'``. - .. versionadded:: 1.X.0 + .. versionadded:: 1.10.0 %(verbose)s Returns diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 04c43f9f4d7..38a334a657f 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1857,7 +1857,7 @@ def get_data( return_tapers : bool Whether to return the taper numbers. Default is ``False``. - .. versionadded:: 1.X.0 + .. versionadded:: 1.10.0 Returns ------- diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 082a50feb35..6b1841046d1 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5013,7 +5013,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The weights for each taper. Must be provided if ``data`` has a taper dimension, such as for complex or phase multitaper data. - .. versionadded:: 1.X.0 + .. versionadded:: 1.10.0 """ docdict["weights_tfr_attr"] = """ weights : array, shape (n_tapers, n_freqs) | None