Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add option to store and return TFR taper weights #12910

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

tsbinns
Copy link
Contributor

@tsbinns tsbinns commented Oct 22, 2024

Reference issue (if any)

PR for #12851

What does this implement/fix?

Adds an option to return taper weights for complex and phase outputs of the multitaper method in tfr_array_multitaper(), and also ensures taper weights are stored in TFR objects.

Additional information

When working on this, I discovered a couple of other issues with the per-taper TFR implementations (#12851 (comment)), including the fact that the TFR object plotting methods and to_data_frame methods do not account for a taper dimension, leading to errors. Wasn't sure if people want me to also address these here or in a separate PR.

@@ -302,12 +306,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I am somewhat unsure on. The existing implementation is to just use conc as-is, however in the MNE-Connectivity implementation that sqrt is taken: https://github.com/mne-tools/mne-connectivity/blob/97147a57eefb36a5c9680e539fdc6343a1183f20/mne_connectivity/spectral/time.py#L825

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also unsure on this point. We should ask @ruuskas (who wrote the implementation in MNE-Connectivity) and @larsoner (who wrote the SciPy DPSS implementation) to weigh in.

@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 22, 2024

I'm also somewhat confused about the design of the _make_dpss function:

for m in range(n_taps):
Wm = list()
Cm = list()
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
else:
this_n_cycles = n_cycles[0]
t_win = this_n_cycles / float(f)
t = np.arange(0.0, t_win, 1.0 / sfreq)
# Making sure wavelets are centered before tapering
oscillation = np.exp(2.0 * 1j * np.pi * f * (t - t_win / 2.0))
# Get dpss tapers
tapers, conc = dpss_windows(
t.shape[0], time_bandwidth / 2.0, n_taps, sym=False
)
Wk = oscillation * tapers[m]
if zero_mean: # to make it zero mean
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Wm.append(Wk)
Cm.append(Ck)
Ws.append(Wm)
Cs.append(Cm)

It is looping over tapers, and then over frequencies. However, the dpss_windows function it calls internally provides the tapers and their weights for all tapers of a given frequency.

Would it not be more efficient to only loop over frequencies and take advantage of the fact that this will also return information for each taper?

@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 22, 2024

I also have a question regarding testing: for the I/O tests, we're reading TFR objects that do not have a weights property (just gets assigned to None) when loaded. Do I need to create new TFR objects that actually have some weights, or is the current test sufficient?

Apart from this there are still some tests I need to expand.

mne/time_frequency/multitaper.py Outdated Show resolved Hide resolved
mne/time_frequency/tfr.py Show resolved Hide resolved
@@ -302,12 +306,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also unsure on this point. We should ask @ruuskas (who wrote the implementation in MNE-Connectivity) and @larsoner (who wrote the SciPy DPSS implementation) to weigh in.

mne/time_frequency/tfr.py Show resolved Hide resolved
@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 29, 2024

Thanks for the review @drammock! I will sort out those remaining tests, although I'm in the process of moving at the moment so it might not be for some days.

Regarding those issues I came across with TFR multitapers and converting to dataframes / plotting: would you like me to incorporate that into this PR?

@tsbinns
Copy link
Contributor Author

tsbinns commented Dec 9, 2024

Sorry for the lack of work on this, had to organise things for my PhD defence.

Everything new added here has test coverage now.

@drammock, just a couple points I would appreciate your input on:

  • For the I/O tests, we're reading TFR objects that do not have a weights property (just gets assigned to None) when loaded. Do I need to create new TFR objects that actually have some weights, or is the current test sufficient?
  • Regarding those issues I came across with TFR multitapers and converting to dataframes / plotting (Store n_cycles and time_bandwidth params in *TFR objects #12851 (comment)): would you like me to open a separate PR to keep individual diffs small, or just add the changes here?

Also tagging @larsoner and @ruuskas in case they can help clarify an outstanding point: #12910 (comment)

@tsbinns
Copy link
Contributor Author

tsbinns commented Dec 9, 2024

Currently working on support for a tapers dimension in ...TFRArray objects.

@drammock
Copy link
Member

drammock commented Dec 9, 2024

Do I need to create new TFR objects that actually have some weights, or is the current test sufficient?

Yes I think we should. most (all?) of them are created by pytest fixtures at present. I see 3 options:

  1. tweak the fixtures to always return TFRs that have weights.
  2. when you want to test something specific to weights, monkey-patch some weights (and a taper dim) into the object at the start of the test
  3. write a new fixture (or parametrize an existing one) so that you can get TFRs with/without weights at need.

To really test thoroughly, option (2) is probably best, because then you can also patch in things that are expected to fail, and test that they do fail in the expected way.

@@ -1392,7 +1421,6 @@ def __setstate__(self, state):

defaults = dict(
method="unknown",
dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have removed dims being set in BaseTFR since the possibility of the optional epoch and taper dimensions makes it really difficult to disentangle here. It's much easier to handle this in the individual RawTFR, EpochsTFR, and AverageTFR classes.

Comment on lines +2909 to +2913
# Set dims now since optional tapers makes it difficult to disentangle later
state["dims"] = ("channel",)
if state["data"].ndim == 4:
state["dims"] += ("taper",)
state["dims"] += ("freq", "time")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example of handling dims in the AverageTFR class where only one dimension (taper) is optional.

Comment on lines +3278 to +3286

Averaging is not supported for data containing a taper dimension.
"""
if "taper" in self._dims:
raise NotImplementedError(
"Averaging multitaper tapers across epochs, frequencies, or times is "
"not supported. If averaging across epochs, consider averaging the "
"epochs before computing the complex/phase spectrum."
)
Copy link
Contributor Author

@tsbinns tsbinns Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of averaging for data with tapers, I went for the same approach we're using for Spectrum and just disallowing this.

I don't think this is an API change requiring a deprecation cycle since:

  1. the docstring expects the data to not have a taper dimension, e.g. If callable, must take a NumPy array of shape (n_epochs, n_channels, n_freqs, n_times).
  2. trying to call this method on an object with a taper dimension would raise an uncaught error: n_epochs, n_channels, n_freqs, n_times = self.data.shape (wouldn't be able to unpack this properly).

So explicitly preventing this method being called with a taper dimension doesn't change current behaviour, it just gives a nicer error as to why this can't be done.

Comment on lines 3942 to +3953
Notes
-----
Aggregating multitaper TFR datasets with a taper dimension such as for complex or
phase data is not supported.

.. versionadded:: 0.11.0
"""
if any("taper" in tfr._dims for tfr in all_tfr):
raise NotImplementedError(
"Aggregating multitaper tapers across TFR datasets is not supported."
)

Copy link
Contributor Author

@tsbinns tsbinns Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a similar case to averaging for the time_frequency.combine_tfr() function (which also gets called by the grand_average() function).

However, unlike the EpochsTFR.average() method, this could be considered an API change since combine_tfr() should currently run with taper data. Does preventing this use case require a deprecation cycle?

On a side note, I noticed that while a public function, combine_tfr() is not listed in the API (the equivalent combine_evoked() is). Is this an oversight or an intended omission?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants