diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 4d4d3d957..5f7a08796 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -17,6 +17,8 @@ Develop branch Enhancements ~~~~~~~~~~~~ +- Increasing the version in the pre-commit config (:gh:`631` by pre-commit bot) +- Implementation of Pseudo Online framework (:gh:`641` by `Igor Carrara`_) - Update version of pyRiemann to 0.7 (:gh:`671` by `Gregoire Cattan`_) @@ -44,6 +46,7 @@ Enhancements - Add new dataset :class:`moabb.datasets.Liu2024` dataset (:gh:`619` by `Taha Habib`_) + Bugs ~~~~ - Fix caching in the workflows (:gh:`632` by `Pierre Guetschel`_) diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py new file mode 100644 index 000000000..a84cd66b5 --- /dev/null +++ b/examples/plot_pseudoonline.py @@ -0,0 +1,60 @@ +# Set up the Directory for made it run on a server. + +import numpy as np +from pyriemann.classification import MDM, FgMDM +from pyriemann.estimation import Covariances +from sklearn.pipeline import Pipeline + +from moabb.datasets import BNCI2014_001 +from moabb.evaluations import WithinSessionEvaluation +from moabb.paradigms import MotorImagery + + +sub = 1 + +# Initialize parameter for the Band Pass filter +fmin = 8 +fmax = 30 +tmax = 3 + +# Load Dataset and switch to Pseudoonline mode +dataset = BNCI2014_001() +dataset.pseudoonline = True + +# events = ["right_hand", "left_hand"] +events = list(dataset.event_id.keys()) + +paradigm = MotorImagery( + events=events, n_classes=len(events), fmin=fmin, fmax=fmax, tmax=tmax, overlap=50 +) + +X, y, meta = paradigm.get_data(dataset=dataset, subjects=[sub]) +print("Print Events_id:", y) +unique, counts = np.unique(y, return_counts=True) +print("Number of events per class:", dict(zip(unique, counts))) + + +pipelines = {} +pipelines["MDM"] = Pipeline( + steps=[ + ("Covariances", Covariances("cov")), + ("MDM", MDM(metric=dict(mean="riemann", distance="riemann"))), + ] +) + +pipelines["FgMDM"] = Pipeline( + steps=[("Covariances", Covariances("cov")), ("FgMDM", FgMDM())] +) + +dataset.subject_list = dataset.subject_list[int(sub) - 1 : int(sub)] +# Select an evaluation Within Session +evaluation_online = WithinSessionEvaluation( + paradigm=paradigm, datasets=dataset, overwrite=True, random_state=42, n_jobs=1 +) + +# Print the results +results_ALL = evaluation_online.process(pipelines) +results_pipeline = results_ALL.groupby(["pipeline"], as_index=False)["score"].mean() +results_pipeline_std = results_ALL.groupby(["pipeline"], as_index=False)["score"].std() +results_pipeline["std"] = results_pipeline_std["score"] +print(results_pipeline) diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index 0d4672482..60fc7bafc 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -319,6 +319,7 @@ def __init__( paradigm, doi=None, unit_factor=1e6, + overlap=False, ): """Initialize function for the BaseDataset.""" try: @@ -348,6 +349,7 @@ def __init__( self.paradigm = paradigm self.doi = doi self.unit_factor = unit_factor + self.overlap = overlap def _create_process_pipeline(self): return Pipeline( diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index 0cef42910..e0ae31ca7 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -1,7 +1,7 @@ """BNCI 2014-001 Motor imagery dataset.""" import numpy as np -from mne import create_info +from mne import create_info, find_events from mne.channels import make_standard_montage from mne.io import RawArray from mne.utils import verbose @@ -33,6 +33,7 @@ def load_data( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): # noqa: D301 """Get paths to local copies of a BNCI dataset files. @@ -116,6 +117,7 @@ def load_data( baseurl_list[dataset], only_filenames, verbose, + pseudoonline, ) @@ -128,6 +130,7 @@ def _load_data_001_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): """Load data for 001-2014 dataset.""" if (subject < 1) or (subject > 9): @@ -144,13 +147,21 @@ def _load_data_001_2014( sessions = {} filenames = [] + time_task = 4 + time_fix = 2 for session_idx, r in enumerate(["T", "E"]): url = "{u}001-2014/A{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) filename = data_path(url, path, force_update, update_path) filenames += filename if only_filenames: continue - runs, ev = _convert_mi(filename[0], ch_names, ch_types) + + if pseudoonline: + runs, ev = _convert_mi_pseudoonline( + filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline + ) + else: + runs, ev = _convert_mi(filename[0], ch_names, ch_types) # FIXME: deal with run with no event (1:3) and name them sessions[f"{session_idx}{_map[r]}"] = { str(ii): run for ii, run in enumerate(runs) @@ -169,12 +180,15 @@ def _load_data_002_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): """Load data for 002-2014 dataset.""" if (subject < 1) or (subject > 14): raise ValueError("Subject must be between 1 and 14. Got %d." % subject) runs = [] + time_task = 5 + time_fix = 3 filenames = [] for r in ["T", "E"]: url = "{u}002-2014/S{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -183,7 +197,12 @@ def _load_data_002_2014( if only_filenames: continue # FIXME: electrode position and name are not provided directly. - raws, _ = _convert_mi(filename, None, ["eeg"] * 15) + if pseudoonline: + raws, _ = _convert_mi_pseudoonline( + filename, time_task, time_fix, None, ["eeg"] * 15, pseudoonline + ) + else: + raws, _ = _convert_mi(filename, None, ["eeg"] * 15) runs.extend(zip([r] * len(raws), raws)) if only_filenames: return filenames @@ -200,6 +219,7 @@ def _load_data_004_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): """Load data for 004-2014 dataset.""" if (subject < 1) or (subject > 9): @@ -209,6 +229,8 @@ def _load_data_004_2014( ch_types = ["eeg"] * 3 + ["eog"] * 3 sessions = [] + time_task = 4.5 + time_fix = 3 filenames = [] for r in ["T", "E"]: url = "{u}004-2014/B{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -216,7 +238,12 @@ def _load_data_004_2014( filenames.append(filename) if only_filenames: continue - raws, _ = _convert_mi(filename, ch_names, ch_types) + if pseudoonline: + raws, _ = _convert_mi_pseudoonline( + filename, time_task, time_fix, ch_names, ch_types, pseudoonline + ) + else: + raws, _ = _convert_mi(filename, ch_names, ch_types) sessions.extend(zip([r] * len(raws), raws)) if only_filenames: @@ -234,7 +261,14 @@ def _load_data_008_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) + """Load data for 008-2014 dataset.""" if (subject < 1) or (subject > 8): raise ValueError("Subject must be between 1 and 8. Got %d." % subject) @@ -260,7 +294,12 @@ def _load_data_009_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 009-2014 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 10. Got %d." % subject) @@ -299,6 +338,7 @@ def _load_data_001_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): """Load data for 001-2015 dataset.""" if (subject < 1) or (subject > 12): @@ -318,6 +358,8 @@ def _load_data_001_2015( ch_types = ["eeg"] * 13 sessions = {} + time_task = 5 + time_fix = 0 filenames = [] for session_idx, r in ses: url = "{u}001-2015/S{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -325,7 +367,12 @@ def _load_data_001_2015( filenames += filename if only_filenames: continue - runs, ev = _convert_mi(filename[0], ch_names, ch_types) + if pseudoonline: + runs, ev = _convert_mi_pseudoonline( + filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline + ) + else: + runs, ev = _convert_mi(filename[0], ch_names, ch_types) sessions[f"{session_idx}{r}"] = {str(ii): run for ii, run in enumerate(runs)} if only_filenames: return filenames @@ -341,7 +388,12 @@ def _load_data_003_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 003-2015 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -400,7 +452,12 @@ def _load_data_004_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 004-2015 dataset.""" if (subject < 1) or (subject > 9): raise ValueError("Subject must be between 1 and 9. Got %d." % subject) @@ -434,7 +491,12 @@ def _load_data_009_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 009-2015 dataset.""" if (subject < 1) or (subject > 21): raise ValueError("Subject must be between 1 and 21. Got %d." % subject) @@ -465,7 +527,12 @@ def _load_data_010_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 010-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -497,7 +564,12 @@ def _load_data_012_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 012-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -524,7 +596,12 @@ def _load_data_013_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 013-2015 dataset.""" if (subject < 1) or (subject > 6): raise ValueError("Subject must be between 1 and 6. Got %d." % subject) @@ -584,6 +661,50 @@ def _convert_mi(filename, ch_names, ch_types): return runs, event_id +def _convert_mi_pseudoonline( + filename, time_task, time_fix, ch_names, ch_types, pseudoonline +): + """Process (Graz) motor imagery data from MAT files. + + Parameters + ---------- + filename : str + Path to the MAT file. + time_task: float + Actual duration of the task + time_fix: + Duration of Fixation Cross + ch_names : list of str + List of channel names. + ch_types : list of str + List of channel types. + + Returns + ------- + raw : instance of RawArray + returns list of recording runs.""" + runs = [] + event_id = {} + data = loadmat(filename, struct_as_record=False, squeeze_me=True) + + if isinstance(data["data"], np.ndarray): + run_array = data["data"] + else: + run_array = [data["data"]] + + for run in run_array: + raw, evd = _convert_run_pseudoonline( + run, time_task, time_fix, ch_names, ch_types, None, pseudoonline + ) + if raw is None: + continue + runs.append(raw) + event_id.update(evd) + # change labels to match rest + standardize_keys(event_id) + return runs, event_id + + def standardize_keys(d): master_list = [ ["both feet", "feet"], @@ -634,6 +755,72 @@ def _convert_run(run, ch_names=None, ch_types=None, verbose=None): return raw, event_id +def _convert_run_pseudoonline( + run, + time_task, + time_fix, + ch_names=None, + ch_types=None, + verbose=None, + pseudoonline=False, +): + """Convert one run to raw.""" + # parse eeg data + event_id = {} + n_chan = run.X.shape[1] + montage = make_standard_montage("standard_1005") + eeg_data = 1e-6 * run.X + sfreq = run.fs + + if not ch_names: + ch_names = ["EEG%d" % ch for ch in range(1, n_chan + 1)] + montage = None # no montage + + if not ch_types: + ch_types = ["eeg"] * n_chan + + trigger = np.zeros((len(eeg_data), 1)) + # some runs does not contains trials i.e baseline runs + if len(run.trial) > 0: + trigger[run.trial - 1, 0] = run.y + else: + return None, None + + eeg_data = np.c_[eeg_data, trigger] + ch_names = ch_names + ["stim"] + ch_types = ch_types + ["stim"] + event_id = {ev: (ii + 1) for ii, ev in enumerate(run.classes)} + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) + raw = RawArray(data=eeg_data.T, info=info, verbose=verbose) + raw.set_montage(montage) + + if pseudoonline: + # ================================================================================================================= + # Code to add the event Nothing with label 9 + # ================================================================================================================= + # The idea is to replace the old stim channel with a new STIM channel that locate the events at the exact time that + # start and the event also for the nothing phase. + events = find_events(raw, stim_channel="stim") + stim_data = np.zeros((1, len(raw.times))) + + # Time when the task finish + time_nothing = (sfreq * time_task) + 1 + # Time where the task actually begin, because the events of "stim" give us when the fix cross appear, but not when + # the task begin. + time_fixation_cross = sfreq * time_fix + for i in np.arange(len(events[:, 0])): + stim_data[0, int(events[i, 0] + time_fixation_cross)] = events[i, 2] + stim_data[0, int(events[i, 0] + time_fixation_cross + time_nothing)] = 9 + + info = create_info(ch_names=["STI"], ch_types=["stim"], sfreq=sfreq) + new_stim = RawArray(data=stim_data, info=info, verbose=verbose) + raw.add_channels([new_stim], force_update_info=True) + raw.drop_channels(["stim"]) # Delete old stim channel + event_id["nothing"] = 9 + + return raw, event_id + + @verbose def _convert_run_p300_sl(run, verbose=None): """Convert one p300 run from santa lucia file format.""" @@ -735,7 +922,12 @@ class MNEBNCI(BaseDataset): def _get_single_subject_data(self, subject): """Return data for a single subject.""" - sessions = load_data(subject=subject, dataset=self.code, verbose=False) + sessions = load_data( + subject=subject, + dataset=self.code, + verbose=False, + pseudoonline=self.pseudoonline, + ) return sessions def data_path( @@ -749,6 +941,7 @@ def data_path( path=path, force_update=force_update, only_filenames=True, + pseudoonline=self.pseudoonline, ) @@ -799,7 +992,13 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=2, - events={"left_hand": 1, "right_hand": 2, "feet": 3, "tongue": 4}, + events={ + "left_hand": 1, + "right_hand": 2, + "feet": 3, + "tongue": 4, + "nothing": 9, + }, code="BNCI2014-001", interval=[2, 6], paradigm="imagery", @@ -852,7 +1051,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 15)), sessions_per_subject=1, - events={"right_hand": 1, "feet": 2}, + events={"right_hand": 1, "feet": 2, "nothing": 9}, code="BNCI2014-002", interval=[3, 8], paradigm="imagery", @@ -926,7 +1125,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=5, - events={"left_hand": 1, "right_hand": 2}, + events={"left_hand": 1, "right_hand": 2, "nothing": 9}, code="BNCI2014-004", interval=[3, 7.5], paradigm="imagery", @@ -1088,7 +1287,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 13)), sessions_per_subject=2, - events={"right_hand": 1, "feet": 2}, + events={"right_hand": 1, "feet": 2, "nothing": 9}, code="BNCI2015-001", interval=[0, 5], paradigm="imagery", diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index 2bf30a0bf..6a6ac4f80 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -31,6 +31,55 @@ def _unsafe_pick_events(events, include): raise e +def _events_pseudoonline(events, tmin, tmax, sfreq, overlap): + """ + This function create new events every duration length. + :param events: Real event created during registrations of the dataset + :param tmin: Minimum time where create new events(tmin MUST be 0). Is the starting time of epoch, and we consider as starting time + the initial value of the interval in normal MOABB [2, 6] + :param tmax: Maximum time of the windows. Is the final time of epoch. + :param sfreq: Sfreq of the recorded signal + :param overlap: Percentage of overlapping that we want in the sliding windows + :return: + return the new events, ove every starting point of the sliding windows and with univocal label + """ + # Compute duration of the windows in seconds + duration_s = tmax - tmin + # Convert the duration in time point. + duration = duration_s * sfreq + # The starting point of the new windows in time point + ove = (((tmax - tmin) / 100) * (100 - overlap)) * sfreq + + # Total number of new events that need to be created + total = int((events[-1, 0] - events[0, 0]) / (100 - overlap)) + events_new = np.zeros((total, 3), dtype=int) + # Fill the first event with the same old events + events_new[0, :] = events[0, :] + + j = 0 + i = 1 + # Go on while we are at a time sample less than the last events in the data acquisition + while events_new[i - 1, 0] + duration <= events[-1, -0]: + # Assign the time stamp to the new events, so we add ove + events_new[i, 0] = events_new[i - 1, 0] + ove + # Now we have to check. If the new added events plus the duration is less then the time stamp of the new event + # we assign an univocal label. If is not we check the percentage of time stamp associate with a label is predominant in a windows. + # If we have 50/50 we assign the label as the next event since the subject want to switch in that direction. + if events_new[i, 0] + duration <= events[j + 1, 0]: + events_new[i, 2] = events[j, 2] + else: + First = abs(events[j + 1, 0] - events_new[i, 0]) + Second = abs((events_new[i, 0] + duration) - events[j + 1, 0]) + if First > Second: + events_new[i, 2] = events[j, 2] + else: + events_new[i, 2] = events[j + 1, 2] + j = j + 1 + i = i + 1 + + return events_new + + class ForkPipelines(TransformerMixin, BaseEstimator): def __init__(self, transformers: List[Tuple[str, Union[Pipeline, TransformerMixin]]]): for _, t in transformers: @@ -77,6 +126,7 @@ def transform(self, raw, y=None): events = mne.find_events(raw, shortest_event=0, verbose=False) events = _unsafe_pick_events(events, include=list(self.event_id.values())) events[:, 0] += offset + if len(events) != 0: annotations = mne.annotations_from_events( events, @@ -87,6 +137,60 @@ def transform(self, raw, y=None): ) annotations.set_durations(duration) raw.set_annotations(annotations) + # raw.plot() + # print("OK") + else: + log.warning("No events found, skipping setting annotations.") + return raw + + +class SetRawAnnotations_PseudoOnline(FixedTransformer): + """ + Always sets the annotations, even if the events list is empty + """ + + def __init__(self, event_id, interval: Tuple[float, float], tmin, tmax, overlap): + assert isinstance(event_id, dict) # not None + self.event_id = event_id + if len(set(event_id.values())) != len(event_id): + raise ValueError("Duplicate event code") + self.event_desc = dict((code, desc) for desc, code in self.event_id.items()) + self.interval = interval + self.overlap = overlap + self.tmin = tmin + self.tmax = tmax + + def transform(self, raw, y=None): + if raw.annotations: + return raw + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) + if len(stim_channels) == 0: + log.warning( + "No stim channel nor annotations found, skipping setting annotations." + ) + return raw + events_ = mne.find_events(raw, shortest_event=0, verbose=False) + events = _events_pseudoonline( + events_, + tmin=self.tmin, + tmax=self.tmax, + sfreq=raw.info["sfreq"], + overlap=self.overlap, + ) + duration = self.tmax - self.tmin + + if len(events) != 0: + annotations = mne.annotations_from_events( + events, + raw.info["sfreq"], + self.event_desc, + first_samp=raw.first_samp, + verbose=False, + ) + annotations.set_durations(duration) + raw.set_annotations(annotations) + # raw.plot() + # print("OK") else: log.warning("No events found, skipping setting annotations.") return raw @@ -125,6 +229,54 @@ def transform(self, raw, y=None): return _unsafe_pick_events(events, list(self.event_id.values())) +class RawToEvents_PseudoOnline(FixedTransformer): + """ + Always returns an array for shape (n_events, 3), even if no events found + """ + + def __init__( + self, event_id: dict[str, int], interval: Tuple[float, float], tmin, tmax, overlap + ): + assert isinstance(event_id, dict) # not None + self.event_id = event_id + self.interval = interval + self.tmin = tmin + self.tmax = tmax + self.overlap = overlap + + def _find_events(self, raw): + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) + if len(stim_channels) > 0: + # returns empty array if none found + if self.overlap is None: + events = mne.find_events(raw, shortest_event=0, verbose=False) + else: + events_ = mne.find_events(raw, shortest_event=0, verbose=False) + events = _events_pseudoonline( + events_, + tmin=self.tmin, + tmax=self.tmax, + sfreq=raw.info["sfreq"], + overlap=self.overlap, + ) + else: + try: + events, _ = mne.events_from_annotations( + raw, event_id=self.event_id, verbose=False + ) + offset = int(self.interval[0] * raw.info["sfreq"]) + events[:, 0] -= offset # return the original events onset + except ValueError as e: + if str(e) == "Could not find any of the events you specified.": + return np.zeros((0, 3), dtype="int32") + raise e + return events + + def transform(self, raw, y=None): + events = self._find_events(raw) + return _unsafe_pick_events(events, list(self.event_id.values())) + + class RawToEventsP300(RawToEvents): def transform(self, raw, y=None): events = self._find_events(raw) diff --git a/moabb/evaluations/utils.py b/moabb/evaluations/utils.py index 4a28b8d48..96d159f97 100644 --- a/moabb/evaluations/utils.py +++ b/moabb/evaluations/utils.py @@ -5,6 +5,7 @@ from typing import Sequence from numpy import argmax +from sklearn.metrics import matthews_corrcoef from sklearn.pipeline import Pipeline @@ -37,6 +38,11 @@ def _check_if_is_keras_model(model): return False +def _normalized_mcc(y_true, y_pred): + mcc = matthews_corrcoef(y_true, y_pred) + return (mcc + 1) / 2 + + def _check_if_is_pytorch_model(model): """Check if the model is a Keras model. diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 527d51c20..f73a73434 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -17,7 +17,9 @@ ForkPipelines, RawToEpochs, RawToEvents, + RawToEvents_PseudoOnline, SetRawAnnotations, + SetRawAnnotations_PseudoOnline, get_crop_pipeline, get_filter_pipeline, get_resample_pipeline, @@ -68,6 +70,7 @@ def __init__( baseline: Optional[Tuple[float, float]] = None, channels: Optional[List[str]] = None, resample: Optional[float] = None, + overlap: Optional[float] = None, ): if tmax is not None: if tmin >= tmax: @@ -79,6 +82,7 @@ def __init__( self.tmin = tmin self.tmax = tmax self.interpolate_missing_channels = False + self.overlap = overlap @property @abc.abstractmethod @@ -163,15 +167,29 @@ def make_process_pipelines( process_pipelines = [] for raw_pipeline in raw_pipelines: steps = [] - steps.append( - ( - StepType.RAW, - SetRawAnnotations( - dataset.event_id, - interval=dataset.interval, - ), + if self.overlap is not None: + steps.append( + ( + StepType.RAW, + SetRawAnnotations_PseudoOnline( + dataset.event_id, + interval=dataset.interval, + tmin=self.tmin, + tmax=self.tmax, + overlap=self.overlap, + ), + ) + ) + else: + steps.append( + ( + StepType.RAW, + SetRawAnnotations( + dataset.event_id, + interval=dataset.interval, + ), + ) ) - ) if raw_pipeline is not None: steps.append((StepType.RAW, raw_pipeline)) if epochs_pipeline is not None: @@ -513,6 +531,7 @@ def __init__( baseline=None, channels=None, resample=None, + overlap=None, ): super().__init__( filters=filters, @@ -521,6 +540,7 @@ def __init__( resample=resample, tmin=tmin, tmax=tmax, + overlap=overlap, ) self.events = events @@ -536,4 +556,16 @@ def scoring(self): def _get_events_pipeline(self, dataset): event_id = self.used_events(dataset) - return RawToEvents(event_id=event_id, interval=dataset.interval) + if self.overlap is not None: + return RawToEvents_PseudoOnline( + event_id=event_id, + interval=dataset.interval, + tmin=self.tmin, + tmax=self.tmax, + overlap=self.overlap, + ) + else: + return RawToEvents( + event_id=event_id, + interval=dataset.interval, + ) diff --git a/moabb/paradigms/motor_imagery.py b/moabb/paradigms/motor_imagery.py index 657a8e814..f7af413ba 100644 --- a/moabb/paradigms/motor_imagery.py +++ b/moabb/paradigms/motor_imagery.py @@ -3,8 +3,11 @@ import abc import logging +from sklearn.metrics import make_scorer + from moabb.datasets import utils from moabb.datasets.fake import FakeDataset +from moabb.evaluations.utils import _normalized_mcc from moabb.paradigms.base import BaseParadigm @@ -51,6 +54,8 @@ class BaseMotorImagery(BaseParadigm): resample: float | None (default None) If not None, resample the eeg data with the sampling rate provided. + + overlap: Overlap (in percentage) of the sliding windows approach for the pseudoonline evaluation """ def __init__( @@ -62,7 +67,13 @@ def __init__( baseline=None, channels=None, resample=None, + overlap=None, ): + + if overlap is not None: + print("Overlap available only for pseudo online evaluation") + tmin = 0.0 + super().__init__( filters=filters, events=events, @@ -71,6 +82,7 @@ def __init__( resample=resample, tmin=tmin, tmax=tmax, + overlap=overlap, ) def is_valid(self, dataset): @@ -102,7 +114,10 @@ def datasets(self): @property def scoring(self): - return "accuracy" + if self.overlap is None: + return "accuracy" + else: + return make_scorer(_normalized_mcc) class SinglePass(BaseMotorImagery): @@ -401,7 +416,10 @@ def scoring(self): if self.n_classes == 2: return "roc_auc" else: - return "accuracy" + if self.overlap is None: + return "accuracy" + else: + return make_scorer(_normalized_mcc) class FakeImageryParadigm(LeftRightImagery):