Skip to content

Commit

Permalink
Update (Epochs)SpectrumArray docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Jul 30, 2024
1 parent 5a62c43 commit 7badfbf
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 66 deletions.
62 changes: 37 additions & 25 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,8 @@ class Spectrum(BaseSpectrum):
The weights for each taper. Only present if spectra computed with
``method='multitaper'`` and ``output='complex'``.
.. versionadded:: 1.8
See Also
--------
EpochsSpectrum
Expand Down Expand Up @@ -1214,28 +1216,28 @@ def __getitem__(self, item):
return BaseRaw._getitem(self, item, return_times=False)


def _check_data_shape(data, info, freqs, dimnames, weights, is_epoched):
if data.ndim != len(dimnames):
def _check_data_shape(data, info, freqs, dim_names, weights, is_epoched):
if data.ndim != len(dim_names):
raise ValueError(
f"Expected data to have {len(dimnames)} dimensions, got {data.ndim}."
f"Expected data to have {len(dim_names)} dimensions, got {data.ndim}."
)

allowed_dims = ["epoch", "channel", "freq", "segment", "taper"]
if not is_epoched:
allowed_dims.remove("epoch")
# TODO maybe we should be nice and allow plural versions of each dimname?
for dim in dimnames:
_check_option("dimnames", dim, allowed_dims)
if "channel" not in dimnames or "freq" not in dimnames:
raise ValueError("Both 'channel' and 'freq' must be present in `dimnames`.")
for dim in dim_names:
_check_option("dim_names", dim, allowed_dims)
if "channel" not in dim_names or "freq" not in dim_names:
raise ValueError("Both 'channel' and 'freq' must be present in `dim_names`.")

if list(dimnames).index("channel") != int(is_epoched):
if list(dim_names).index("channel") != int(is_epoched):
raise ValueError(
f"'channel' must be the {'second' if is_epoched else 'first'} dimension of "
"the data."
)
want_n_chan = _pick_data_channels(info).size
got_n_chan = data.shape[list(dimnames).index("channel")]
got_n_chan = data.shape[list(dim_names).index("channel")]
if got_n_chan != want_n_chan:
raise ValueError(
f"The number of channels in `data` ({got_n_chan}) must match the number of "
Expand All @@ -1244,25 +1246,25 @@ def _check_data_shape(data, info, freqs, dimnames, weights, is_epoched):

# given we limit max array size and ensure channel & freq dims present, only one of
# taper or segment can be present
if "taper" in dimnames:
if dimnames[-2] != "taper": # _psd_from_mt assumes this (called when plotting)
if "taper" in dim_names:
if dim_names[-2] != "taper": # _psd_from_mt assumes this (called when plotting)
raise ValueError(
"'taper' must be the second to last dimension of the data."
)
# expect weights for each taper
actual = None if weights is None else weights.size
expected = data.shape[list(dimnames).index("taper")]
expected = data.shape[list(dim_names).index("taper")]
if actual != expected:
raise ValueError(
f"Expected size of `weights` to be {expected} to match 'n_tapers' in "
f"`data`, got {actual}."
)
elif "segment" in dimnames and dimnames[-1] != "segment":
elif "segment" in dim_names and dim_names[-1] != "segment":
raise ValueError("'segment' must be the last dimension of the data.")

# freq being in wrong position ruled out by above checks
want_n_freq = freqs.size
got_n_freq = data.shape[list(dimnames).index("freq")]
got_n_freq = data.shape[list(dim_names).index("freq")]
if got_n_freq != want_n_freq:
raise ValueError(
f"The number of frequencies in `data` ({got_n_freq}) must match the number "
Expand All @@ -1280,14 +1282,18 @@ class SpectrumArray(Spectrum):
The spectra for each channel.
%(info_not_none)s
%(freqs_tfr_array)s
dimnames : tuple of str
dim_names : tuple of str
The name of the dimensions in the data, in the order they occur. Must contain
``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include
either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g.,
multitaper algorithms) dimension. If including ``'taper'``, you should also pass
a ``weights`` parameter.
.. versionadded:: 1.8
weights : ndarray | None
Weights for the ``'taper'`` dimension, if present (see ``dimnames``).
Weights for the ``'taper'`` dimension, if present (see ``dim_names``).
.. versionadded:: 1.8
%(verbose)s
See Also
Expand All @@ -1310,22 +1316,22 @@ def __init__(
data,
info,
freqs,
dimnames=("channel", "freq"),
dim_names=("channel", "freq"),
weights=None,
*,
verbose=None,
):
# (channel, [taper], freq, [segment])
_check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension

_check_data_shape(data, info, freqs, dimnames, weights, is_epoched=False)
_check_data_shape(data, info, freqs, dim_names, weights, is_epoched=False)

self.__setstate__(
dict(
method="unknown",
data=data,
sfreq=info["sfreq"],
dims=dimnames,
dims=dim_names,
freqs=freqs,
inst_type_str="Array",
data_type=(
Expand Down Expand Up @@ -1376,6 +1382,8 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):
The weights for each taper. Only present if spectra computed with
``method='multitaper'`` and ``output='complex'``.
.. versionadded:: 1.8
See Also
--------
EpochsSpectrumArray
Expand Down Expand Up @@ -1554,14 +1562,18 @@ class EpochsSpectrumArray(EpochsSpectrum):
%(freqs_tfr_array)s
%(events_epochs)s
%(event_id)s
dimnames : tuple of str
dim_names : tuple of str
The name of the dimensions in the data, in the order they occur. Must contain
``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include
either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g.,
multitaper algorithms) dimension. If including ``'taper'``, you should also pass
a ``weights`` parameter.
.. versionadded:: 1.8
weights : ndarray | None
Weights for the ``'taper'`` dimension, if present (see ``dimnames``).
Weights for the ``'taper'`` dimension, if present (see ``dim_names``).
.. versionadded:: 1.8
%(verbose)s
See Also
Expand All @@ -1585,30 +1597,30 @@ def __init__(
freqs,
events=None,
event_id=None,
dimnames=("epoch", "channel", "freq"),
dim_names=("epoch", "channel", "freq"),
weights=None,
*,
verbose=None,
):
# (epoch, channel, [taper], freq, [segment])
_check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension

if list(dimnames).index("epoch") != 0:
if list(dim_names).index("epoch") != 0:
raise ValueError("'epoch' must be the first dimension of `data`.")
if events is not None and data.shape[0] != events.shape[0]:
raise ValueError(
f"The first dimension of `data` ({data.shape[0]}) must match the first "
f"dimension of `events` ({events.shape[0]})."
)

_check_data_shape(data, info, freqs, dimnames, weights, is_epoched=True)
_check_data_shape(data, info, freqs, dim_names, weights, is_epoched=True)

self.__setstate__(
dict(
method="unknown",
data=data,
sfreq=info["sfreq"],
dims=dimnames,
dims=dim_names,
freqs=freqs,
inst_type_str="Array",
data_type=(
Expand Down
73 changes: 36 additions & 37 deletions mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,6 @@ def _check_spectrum_equivalent(spect1, spect2, tmp_path):
assert_array_equal(spect1.freqs, spect2.freqs)


def _get_dimnames(kind, method, output, average):
dimnames = ("epoch", "channel") if kind == "epochs" else ("channel",)
if method == "welch":
dimnames += ("freq",) if average else ("freq", "segment")
else: # i.e. multitaper
dimnames += ("freq",) if output == "power" else ("taper", "freq")
return dimnames


def test_spectrum_array_errors():
"""Test (Epochs)SpectrumArray constructor errors."""
n_epochs = 10
Expand All @@ -457,23 +448,23 @@ def test_spectrum_array_errors():
sfreq = 100
rng = np.random.default_rng(44)
data = rng.random((n_epochs, n_chans, n_freqs))
dimnames = ("epoch", "channel", "freq")
dim_names = ("epoch", "channel", "freq")
info = create_info(n_chans, sfreq, "eeg")
# test incorrect ndims (for SpectrumArray; allows 2-3D data)
with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"):
SpectrumArray(data[0, 0, :], info, freqs, dimnames=dimnames)
SpectrumArray(data[0, 0, :], info, freqs, dim_names=dim_names)
with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"):
SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dimnames=dimnames)
SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names)
# test incorrect ndims (for EpochsSpectrumArray; allows 3-4D data)
with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"):
EpochsSpectrumArray(data[0, :, :], info, freqs, dimnames=dimnames)
EpochsSpectrumArray(data[0, :, :], info, freqs, dim_names=dim_names)
with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"):
EpochsSpectrumArray(
np.expand_dims(data, axis=(3, 4)), info, freqs, dimnames=dimnames
np.expand_dims(data, axis=(3, 4)), info, freqs, dim_names=dim_names
)
# test incorrect epochs location
with pytest.raises(ValueError, match="'epoch' must be the first dimension"):
EpochsSpectrumArray(data, info, freqs, dimnames=("channel", "epoch", "freq"))
EpochsSpectrumArray(data, info, freqs, dim_names=("channel", "epoch", "freq"))
# test mismatching events shape
events = np.vstack(
(
Expand All @@ -483,52 +474,56 @@ def test_spectrum_array_errors():
)
).T
with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"):
EpochsSpectrumArray(data, info, freqs, events, dimnames=dimnames)
EpochsSpectrumArray(data, info, freqs, events, dim_names=dim_names)
# test data-dimname mismatch
with pytest.raises(ValueError, match=r"Expected data to have.*dimensions, got.*"):
EpochsSpectrumArray(data, info, freqs, dimnames=dimnames[:-1])
# test unrecognised dimnames (for SpectrumArray; epoch not allowed)
with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"):
SpectrumArray(data[0, :, :], info, freqs, dimnames=("epoch", "channel"))
# test unrecognised dimnames (for EpochsSpectrumArray)
with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"):
EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "notfreq"))
# test missing dimnames
EpochsSpectrumArray(data, info, freqs, dim_names=dim_names[:-1])
# test unrecognised dim_names (for SpectrumArray; epoch not allowed)
with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"):
SpectrumArray(data[0, :, :], info, freqs, dim_names=("epoch", "channel"))
# test unrecognised dim_names (for EpochsSpectrumArray)
with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"):
EpochsSpectrumArray(
data, info, freqs, dim_names=("epoch", "channel", "notfreq")
)
# test missing dim_names
with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"):
EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "channel"))
EpochsSpectrumArray(
data, info, freqs, dim_names=("epoch", "channel", "channel")
)
with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"):
EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "freq", "freq"))
EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "freq"))
with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"):
EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "epoch", "epoch"))
EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "epoch", "epoch"))
# test incorrect channel location (for SpectrumArray; must be 1st dim)
with pytest.raises(ValueError, match="'channel' must be the first dimension"):
SpectrumArray(data[0, :, :], info, freqs, dimnames=("freq", "channel"))
SpectrumArray(data[0, :, :], info, freqs, dim_names=("freq", "channel"))
# test incorrect channel location (for EpochsSpectrumArray; must be 2nd dim)
with pytest.raises(ValueError, match="'channel' must be the second dimension"):
EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "freq", "channel"))
EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "channel"))
# test mismatching number of channels
with pytest.raises(ValueError, match=r"number of channels.*good data channels"):
EpochsSpectrumArray(data[:, :-1, :], info, freqs, dimnames=dimnames)
EpochsSpectrumArray(data[:, :-1, :], info, freqs, dim_names=dim_names)
# test incorrect taper position
with pytest.raises(ValueError, match="'taper' must be the second to last dim"):
EpochsSpectrumArray(
np.expand_dims(data, axis=3), info, freqs, dimnames=dimnames + ("taper",)
np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names + ("taper",)
)
# test incorrect weight size
with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"):
EpochsSpectrumArray(
np.expand_dims(data, axis=2),
info,
freqs,
dimnames=("epoch", "channel", "taper", "freq"),
dim_names=("epoch", "channel", "taper", "freq"),
weights=None,
)
with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"):
EpochsSpectrumArray(
np.expand_dims(data, axis=2),
info,
freqs,
dimnames=("epoch", "channel", "taper", "freq"),
dim_names=("epoch", "channel", "taper", "freq"),
weights=np.ones((1, 2, 1)),
)
# test incorrect segment position
Expand All @@ -537,11 +532,11 @@ def test_spectrum_array_errors():
np.expand_dims(data, axis=2),
info,
freqs,
dimnames=("epoch", "channel", "segment", "freq"),
dim_names=("epoch", "channel", "segment", "freq"),
)
# test mismatching number of frequencies
with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"):
EpochsSpectrumArray(data[:, :, :-1], info, freqs, dimnames=dimnames)
EpochsSpectrumArray(data[:, :, :-1], info, freqs, dim_names=dim_names)


@pytest.mark.parametrize(
Expand All @@ -554,7 +549,11 @@ def test_spectrum_array_errors():
)
def test_spectrum_array(kind, method, output, average, tmp_path, request):
"""Test EpochsSpectrumArray and SpectrumArray constructors."""
dimnames = _get_dimnames(kind, method, output, average)
dim_names = ("epoch", "channel") if kind == "epochs" else ("channel",)
if method == "welch":
dim_names += ("freq",) if average else ("freq", "segment")
else: # i.e. multitaper
dim_names += ("freq",) if output == "power" else ("taper", "freq")
if method == "welch" and output == "power" and average:
spectrum = request.getfixturevalue(f"{kind}_spectrum")
else:
Expand All @@ -569,7 +568,7 @@ def test_spectrum_array(kind, method, output, average, tmp_path, request):
data=data,
info=spectrum.info,
freqs=freqs,
dimnames=dimnames,
dim_names=dim_names,
weights=spectrum.weights,
)
_check_spectrum_equivalent(spectrum, spect_arr, tmp_path)
Expand Down
12 changes: 8 additions & 4 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2922,11 +2922,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
docdict["notes_plot_psd_meth"] = _notes_plot_psd.format("method")

docdict["notes_spectrum_array"] = """
It is assumed that the data passed in represent spectral *power* (not amplitude,
phase, model coefficients, etc) and downstream methods (such as
If the data passed in is real-valued, it is assumed to represent spectral *power* (not
amplitude, phase, etc), and downstream methods (such as
:meth:`~mne.time_frequency.SpectrumArray.plot`) assume power data. If you pass in
something other than power, at the very least axis labels will be inaccurate (and
other things may also not work or be incorrect).
real-valued data that is not power, axis labels will be incorrect.
If the data passed in is complex-valued, it is assumed to represent Fourier
coefficients. Downstream plotting methods will treat the data as such, attempting to
convert this to power before visualisation. If you pass in complex-valued data that is
not Fourier coefficients, axis labels will be incorrect.
"""

docdict["notes_timefreqs_tfr_plot_joint"] = """
Expand Down

0 comments on commit 7badfbf

Please sign in to comment.