Skip to content

Commit

Permalink
chore(format) reformat files to black codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
wilhelmagren committed Sep 13, 2023
1 parent 29774b6 commit 54fc58c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
33 changes: 20 additions & 13 deletions neurocode/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,33 @@

from torch.utils.data import Dataset


class RecordingDataset(Dataset):
def __init__(self, *args, **kwargs):
self._setup(*args, **kwargs)

def __len__(self):
return self._info['n_recordings']
return self._info["n_recordings"]

def __getitem__(self, indices):
recording, window = indices
return self._data[recording][window]

def __iter__(self):
for idx in range(len(self)):
yield (self._data[idx], self._labels[idx])

def _setup(self, datasets, labels, formatted=False, **kwargs):
if not formatted:
datasets = {recording: dataset for recording, dataset in enumerate(datasets)}
datasets = {
recording: dataset for recording, dataset in enumerate(datasets)
}
labels = {recording: label for recording, label in enumerate(labels)}

lengths = {recording: len(d) for recording, d in enumerate(datasets.values())}
info = {
'lengths': lengths,
'n_recordings': len(datasets),
"lengths": lengths,
"n_recordings": len(datasets),
}
info = {**info, **kwargs}

Expand All @@ -62,20 +65,20 @@ def _setup(self, datasets, labels, formatted=False, **kwargs):

def get_data(self):
return self._data

def get_labels(self):
return self._labels

def get_info(self):
return self._info

def split(self, ratio=0.6, shuffle=True):
split_idx = int(len(self) * ratio)
indices = list(range(len(self)))

if shuffle:
np.random.shuffle(indices)

train_indices = indices[:split_idx]
valid_indices = indices[split_idx:]

Expand All @@ -84,6 +87,10 @@ def split(self, ratio=0.6, shuffle=True):
X_valid = {idx: self.data[k] for idx, k in enumerate(valid_indices)}
Y_valid = {idx: self.labels[k] for idx, k in enumerate(valid_indices)}

train_dataset = RecordingDataset(X_train, Y_train, formatted=True, sfreq=self.info['sfreq'])
valid_dataset = RecordingDataset(X_valid, Y_valid, formatted=True, sfreq=self.info['sfreq'])
train_dataset = RecordingDataset(
X_train, Y_train, formatted=True, sfreq=self.info["sfreq"]
)
valid_dataset = RecordingDataset(
X_valid, Y_valid, formatted=True, sfreq=self.info["sfreq"]
)
return (train_dataset, valid_dataset)
8 changes: 4 additions & 4 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def test_constructor(self):
list(np.arange(4, 14)),
],
[
'sleeping',
'sleeping',
'awake',
]
"sleeping",
"sleeping",
"awake",
],
)
assert isinstance(ds, Dataset)

0 comments on commit 54fc58c

Please sign in to comment.