Skip to content

Commit

Permalink
Improve drawing of annotations with matplotlib (#11855)
Browse files Browse the repository at this point in the history
  • Loading branch information
mscheltienne authored Aug 14, 2023
1 parent 522cf44 commit 6f9b03c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 35 deletions.
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ stages:
- bash: |
set -e
python -m pip install --progress-bar off --upgrade pip setuptools wheel
python -m pip install --progress-bar off mne-qt-browser[opengl] pyvista scikit-learn pytest-error-for-skips python-picard "PyQt6!=6.5.1" qtpy
python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard "PyQt6!=6.5.1" qtpy
python -m pip uninstall -yq mne
python -m pip install --progress-bar off --upgrade -e .[test]
displayName: 'Install dependencies with pip'
Expand Down
19 changes: 9 additions & 10 deletions mne/viz/_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,15 @@ def _setup_annotation_colors(self):

def _update_annotation_segments(self):
"""Update the array of annotation start/end times."""
segments = list()
raw = self.mne.inst
if len(raw.annotations):
for idx, annot in enumerate(raw.annotations):
annot_start = _sync_onset(raw, annot["onset"])
annot_end = annot_start + max(
annot["duration"], 1 / self.mne.info["sfreq"]
)
segments.append((annot_start, annot_end))
self.mne.annotation_segments = np.array(segments)
self.mne.annotation_segments = np.array([])
if len(self.mne.inst.annotations):
annot_start = _sync_onset(self.mne.inst, self.mne.inst.annotations.onset)
durations = self.mne.inst.annotations.duration.copy()
durations[durations < 1 / self.mne.info["sfreq"]] = (
1 / self.mne.info["sfreq"]
)
annot_end = annot_start + durations
self.mne.annotation_segments = np.vstack((annot_start, annot_end)).T

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# PROJECTOR & BADS
Expand Down
38 changes: 14 additions & 24 deletions mne/viz/_mpl_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,21 +1060,13 @@ def _create_annotation_fig(self):
instructions_ax = div.append_axes(
position="top", size=Fixed(1), pad=Fixed(5 * ANNOTATION_FIG_PAD)
)
# XXX when we support a newer matplotlib (something >3.0) the
# instructions can have inline bold formatting:
# instructions = '\n'.join(
# [r'$\mathbf{Left‐click~&~drag~on~plot:}$ create/modify annotation', # noqa E501
# r'$\mathbf{Right‐click~on~plot~annotation:}$ delete annotation',
# r'$\mathbf{Type~in~annotation~window:}$ modify new label name',
# r'$\mathbf{Enter~(or~click~button):}$ add new label to list',
# r'$\mathbf{Esc:}$ exit annotation mode & close this window'])
instructions = "\n".join(
[
"Left click & drag on plot: create/modify annotation",
"Right click on annotation highlight: delete annotation",
"Type in this window: modify new label name",
"Enter (or click button): add new label to list",
"Esc: exit annotation mode & close this dialog window",
r"$\mathbf{Leftclick~&~drag~on~plot:}$ create/modify annotation",
r"$\mathbf{Rightclick~on~plot~annotation:}$ delete annotation",
r"$\mathbf{Type~in~annotation~window:}$ modify new label name",
r"$\mathbf{Enter~(or~click~button):}$ add new label to list",
r"$\mathbf{Esc:}$ exit annotation mode & close this window",
]
)
instructions_ax.text(
Expand Down Expand Up @@ -1141,15 +1133,13 @@ def _create_annotation_fig(self):
else:
col = self.mne.annotation_segment_colors[self._get_annotation_labels()[0]]

# TODO: we would like useblit=True here, but it behaves oddly when the
# first span is dragged (subsequent spans seem to work OK)
rect_kw = _prop_kw("rect", dict(alpha=0.5, facecolor=col))
selector = SpanSelector(
self.mne.ax_main,
self._select_annotation_span,
"horizontal",
minspan=0.1,
useblit=False,
useblit=True,
button=1,
**rect_kw,
)
Expand Down Expand Up @@ -1342,8 +1332,14 @@ def _select_annotation_span(self, vmin, vmax):
onset = _sync_onset(self.mne.inst, vmin, True) - self.mne.first_time
duration = vmax - vmin
buttons = self.mne.fig_annotation.mne.radio_ax.buttons
labels = [label.get_text() for label in buttons.labels]
if buttons.value_selected is not None:
if buttons is None or buttons.value_selected is None:
logger.warning(
"No annotation-label exists! "
"Add one by typing the name and clicking "
'on "Add new label" in the annotation-dialog.'
)
else:
labels = [label.get_text() for label in buttons.labels]
active_idx = labels.index(buttons.value_selected)
_merge_annotations(
onset, onset + duration, labels[active_idx], self.mne.inst.annotations
Expand All @@ -1352,12 +1348,6 @@ def _select_annotation_span(self, vmin, vmax):
if not self.mne.visible_annotations[buttons.value_selected]:
self.mne.show_hide_annotation_checkboxes.set_active(active_idx)
self._redraw(update_data=False, annotations=True)
else:
logger.warning(
"No annotation-label exists! "
"Add one by typing the name and clicking "
'on "Add new label" in the annotation-dialog.'
)

def _remove_annotation_hover_line(self):
"""Remove annotation line from the plot and reactivate selector."""
Expand Down
83 changes: 83 additions & 0 deletions mne/viz/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,89 @@ def test_remove_annotations(raw, hide_which, browser_backend):
assert len(raw.annotations) == len(hide_which)


def test_merge_annotations(raw, browser_backend):
"""Test merging of annotations in the Qt backend.
Let's not bother in figuring out on which sample the _fake_click actually
dropped the annotation, especially with the 600.614 Hz weird sampling rate.
-> atol = 10 / raw.info["sfreq"]
"""
if browser_backend.name == "matplotlib":
pytest.skip("The MPL backend does not support draggable annotations.")
elif not check_version("mne_qt_browser", "0.5.3"):
pytest.xfail("mne_qt_browser < 0.5.3 does not merge annotations properly")
annot = Annotations(
onset=[1, 3, 4, 5, 7, 8],
duration=[1, 0.5, 0.8, 1, 0.5, 0.5],
description=["bad_test", "bad_test", "bad_test", "test", "test", "test"],
)
raw.set_annotations(annot)
fig = raw.plot()
fig._fake_keypress("a") # start annotation mode
assert len(raw.annotations) == 6
assert_allclose(
raw.annotations.onset,
np.array([1, 3, 4, 5, 7, 8]) + raw.first_samp / raw.info["sfreq"],
atol=10 / raw.info["sfreq"],
)
# drag edge and merge 2 annotations in focus (selected description)
fig._fake_click(
(3.5, 1.0), add_points=[(4.2, 1.0)], xform="data", button=1, kind="drag"
)
assert len(raw.annotations) == 5
assert_allclose(
raw.annotations.onset,
np.array([1, 3, 5, 7, 8]) + raw.first_samp / raw.info["sfreq"],
atol=10 / raw.info["sfreq"],
)
assert_allclose(
raw.annotations.duration,
np.array([1, 1.8, 1, 0.5, 0.5]),
atol=10 / raw.info["sfreq"],
)
# drag annotation and merge 2 annotations in focus (selected description)
fig._fake_click(
(1.5, 1.0), add_points=[(3, 1.0)], xform="data", button=1, kind="drag"
)
assert len(raw.annotations) == 4
assert_allclose(
raw.annotations.onset,
np.array([2.5, 5, 7, 8]) + raw.first_samp / raw.info["sfreq"],
atol=10 / raw.info["sfreq"],
)
assert_allclose(
raw.annotations.duration,
np.array([2.3, 1, 0.5, 0.5]),
atol=10 / raw.info["sfreq"],
)
# drag edge and merge 2 annotations not in focus
fig._fake_click(
(7.5, 1.0), add_points=[(8.2, 1.0)], xform="data", button=1, kind="drag"
)
assert len(raw.annotations) == 3
assert_allclose(
raw.annotations.onset,
np.array([2.5, 5, 7]) + raw.first_samp / raw.info["sfreq"],
atol=10 / raw.info["sfreq"],
)
assert_allclose(
raw.annotations.duration, np.array([2.3, 1, 1.5]), atol=10 / raw.info["sfreq"]
)
# drag annotation and merge 2 annotations not in focus
fig._fake_click(
(5.6, 1.0), add_points=[(7.2, 1.0)], xform="data", button=1, kind="drag"
)
assert len(raw.annotations) == 2
assert_allclose(
raw.annotations.onset,
np.array([2.5, 6.6]) + raw.first_samp / raw.info["sfreq"],
atol=10 / raw.info["sfreq"],
)
assert_allclose(
raw.annotations.duration, np.array([2.3, 1.9]), atol=10 / raw.info["sfreq"]
)


@pytest.mark.parametrize("filtorder", (0, 2)) # FIR, IIR
def test_plot_raw_filtered(filtorder, raw, browser_backend):
"""Test filtering of raw plots."""
Expand Down

0 comments on commit 6f9b03c

Please sign in to comment.