Skip to content

Commit

Permalink
fix(dataset) change name of recording dataset file
Browse files Browse the repository at this point in the history
  • Loading branch information
wilhelmagren committed Sep 23, 2023
1 parent 0217aa3 commit 911c825
Showing 1 changed file with 177 additions and 0 deletions.
177 changes: 177 additions & 0 deletions neurocode/datasets/recording.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
MIT License
Copyright (c) 2023 Neurocode
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
File created: 2022-09-10
Last updated: 2023-09-23
"""

from __future__ import annotations

import logging
import mne
import numpy as np

from torch.utils.data import Dataset
from neurocode.datasets.simulated import SimulatedDataset

from collections import OrderedDict
from typing import (
Any,
Union,
)

logger = logging.getLogger(__name__)


class RecordingDataset(Dataset):
def __init__(
self,
data: Union[list[mne.io.Raw], dict[str, mne.io.Raw]],
labels: Union[list[list[mne.label.Label]], dict[str, list[mne.io.Label]]],
**kwargs: dict,
):
""" """
super(RecordingDataset, self).__init__()

if isinstance(data, list) and isinstance(labels, list):
raise ValueError(

Check warning on line 57 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L57

Added line #L57 was not covered by tests
f"Can not infer recording names when both `data` and `labels` are of "
f"type `list`. At least one of them have to be of type `dict`."
)

self._info = {}
self._format_data_and_labels(data, labels, **kwargs)

def __len__(self) -> int:
""" """
return self._info["n_recordings"]

Check warning on line 67 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L67

Added line #L67 was not covered by tests

def __getitem__(
self,
indices: tuple[Union[int, str], int],
) -> Union[int, float, np.ndarray]:
""" """
recording, window = indices

Check warning on line 74 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L74

Added line #L74 was not covered by tests

if isinstance(recording, int):
recording = self._data.keys().index(recording)

Check warning on line 77 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L76-L77

Added lines #L76 - L77 were not covered by tests

return self._data[recording][window]

Check warning on line 79 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L79

Added line #L79 was not covered by tests

def __iter__(self) -> tuple[mne.io.Raw, list]:
for name in range(len(self)):
yield (self._data[name], self._labels[name])

Check warning on line 83 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L82-L83

Added lines #L82 - L83 were not covered by tests

def _format_data_and_labels(
self,
data: Union[list[mne.io.Raw], dict[str, mne.io.Raw]],
labels: Union[list[list[mne.label.Label]], dict[str, list[mne.io.Label]]],
**kwargs,
):
""" """

if isinstance(data, list):
data = OrderedDict((name, raw) for name, raw in zip(labels.keys(), data))

Check warning on line 94 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L94

Added line #L94 was not covered by tests

if isinstance(labels, list):
labels = OrderedDict(
(name, label) for name, label in zip(data.keys(), labels)
)

info = {}
info = {**info, **kwargs}
info["n_recordings"] = len(data)
info["lengths"] = {name: len(raw) for name, raw in data.items()}

self._data = data
self._labels = labels
self._info = info

def data(self) -> dict[str, mne.io.Raw]:
""" """
return self._data

Check warning on line 112 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L112

Added line #L112 was not covered by tests

def labels(self) -> dict[str, list[mne.label.Label]]:
""" """
return self._labels

Check warning on line 116 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L116

Added line #L116 was not covered by tests

def info(self) -> dict[str, Any]:
""" """
return self._info

Check warning on line 120 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L120

Added line #L120 was not covered by tests

def train_valid_split(
self,
*,
ratio: float = 0.6,
shuffle: bool = True,
) -> tuple[RecordingDataset, RecordingDataset]:
""" """
split_idx = int(len(self) * ratio)
indices = np.arange(len(self))

Check warning on line 130 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L129-L130

Added lines #L129 - L130 were not covered by tests

if shuffle:
np.random.shuffle(indices)

Check warning on line 133 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L132-L133

Added lines #L132 - L133 were not covered by tests

train_indices = indices[:split_idx]
valid_indices = indices[split_idx:]

Check warning on line 136 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L135-L136

Added lines #L135 - L136 were not covered by tests

X_train = {}
Y_train = {}
for i, name in enumerate(self._data.keys()):
if i in train_indices:
X_train[name] = self._data[name]
Y_train[name] = self._labels[name]

Check warning on line 143 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L138-L143

Added lines #L138 - L143 were not covered by tests

X_valid = {}
Y_valid = {}
for i, name in enumerate(self._data.keys()):
if i in valid_indices:
X_valid[name] = self._data[name]
Y_valid[name] = self._labels[name]

Check warning on line 150 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L145-L150

Added lines #L145 - L150 were not covered by tests

train_dataset = RecordingDataset(

Check warning on line 152 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L152

Added line #L152 was not covered by tests
data=X_train,
labels=Y_train,
**self._info,
)

valid_dataset = RecordingDataset(

Check warning on line 158 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L158

Added line #L158 was not covered by tests
data=X_valid,
labels=Y_valid,
**self._info,
)

return (train_dataset, valid_dataset)

Check warning on line 164 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L164

Added line #L164 was not covered by tests

@classmethod
def from_simulated(
cls,
dataset: SimulatedDataset,
**kwargs: dict,
) -> RecordingDataset:
""" """
cls(

Check warning on line 173 in neurocode/datasets/recording.py

View check run for this annotation

Codecov / codecov/patch

neurocode/datasets/recording.py#L173

Added line #L173 was not covered by tests
data=dataset.data(),
labels=dataset.labels(),
**kwargs,
)

0 comments on commit 911c825

Please sign in to comment.