Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/rand split torch data #81

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
18 changes: 13 additions & 5 deletions dn3/configuratron/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def get_pop(key, default=None):
self.dumped = get_pop('pre-dumped')
self.hpf = get_pop('hpf', None)
self.lpf = get_pop('lpf', None)
self.filter_data = self.hpf is not None or self.lpf is not None
self.notch_freq = get_pop('notch_freq', None)
self.use_avg_ref = get_pop('use_avg_ref', False)
self.filter_data = self.hpf is not None or self.lpf is not None or self.notch_freq is not None or self.use_avg_ref
if self.filter_data:
self.preload = True
self.stride = get_pop('stride', 1)
Expand Down Expand Up @@ -465,7 +467,7 @@ def add_custom_raw_loader(self, custom_loader):
Parameters
----------
custom_loader: callable
A function that expects a single :any:`pathlib.Path()` instance as argument and returns an
A function that expects a :any:`DatasetConfig` :any:`pathlib.Path()` instance as argument and returns an
instance of :any:`mne.io.Raw()`. To gracefully ignore problematic sessions, raise
:any:`DN3ConfigException` within.

Expand Down Expand Up @@ -493,7 +495,7 @@ def add_progress_callbacks(self, session_callback=None, thinker_callback=None):

def _load_raw(self, path: Path):
if self._custom_raw_loader is not None:
return self._custom_raw_loader(path)
return self._custom_raw_loader(self, path)
if path.suffix in self._extension_handlers:
return self._extension_handlers[path.suffix](str(path), preload=self.preload)
print("Handler for file {} with extension {} not found.".format(str(path), path.suffix))
Expand All @@ -506,9 +508,15 @@ def _load_raw(self, path: Path):

@staticmethod
def _prepare_session(raw, tlen, decimate, desired_sfreq, desired_samples, picks, exclude_channels, rename_channels,
hpf, lpf):
hpf, lpf, notch_freq, use_avg_ref):
if use_avg_ref:
raw.set_eeg_reference(ref_channels='average')
if notch_freq is not None:
raw.notch_filter(notch_freq)
if hpf is not None or lpf is not None:
raw = raw.filter(hpf, lpf)



lowpass = raw.info.get('lowpass', None)
raw_sfreq = raw.info['sfreq']
Expand Down Expand Up @@ -564,7 +572,7 @@ def load_and_prepare(sess):
sess = Path(sess)
r = self._load_raw(sess)
return (sess, *self._prepare_session(r, self.tlen, self.decimate, self._sfreq, self._samples, self.picks,
self.exclude_channels, self.rename_channels, self.hpf, self.lpf))
self.exclude_channels, self.rename_channels, self.hpf, self.lpf, self.notch_freq, self.use_avg_ref))
sess, raw, tlen, picks, new_sfreq = load_and_prepare(session)

# Fixme - deprecate the decimate option in favour of specifying desired sfreq's
Expand Down
38 changes: 32 additions & 6 deletions dn3/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import mne
import torch
import copy
Expand All @@ -15,6 +16,7 @@
from pathlib import Path
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data.dataset import Subset as TorchSubset


class DN3ataset(TorchDataset):
Expand Down Expand Up @@ -180,8 +182,6 @@ def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4,
loaded = [np.concatenate([loaded[i], batch[i]], axis=0) for i in range(len(batch))]

return loaded


class _Recording(DN3ataset, ABC):
"""
Abstract base class for any supported recording
Expand Down Expand Up @@ -381,7 +381,7 @@ def __init__(self, epochs: mne.Epochs, session_id=0, person_id=0, force_label=No
self.epoch_codes_to_class_labels = event_mapping
else:
reverse_mapping = {v: k for k, v in event_mapping.items()}
self.epoch_codes_to_class_labels = {v: i for i, v in enumerate(sorted(reverse_mapping.keys()))}
self.epoch_codes_to_class_labels = {v: i for i, v in enumerate(sorted(reverse_mapping.values()))}
skip_epochs = list() if skip_epochs is None else skip_epochs
self._skip_map = [i for i in range(len(self.epochs.events)) if i not in skip_epochs]
self._skip_map = dict(zip(range(len(self._skip_map)), self._skip_map))
Expand Down Expand Up @@ -433,6 +433,26 @@ def get_targets(self):
return np.apply_along_axis(lambda x: self.epoch_codes_to_class_labels[x[0]], 1,
self.epochs.events[list(self._skip_map.values()), -1, np.newaxis]).squeeze()

class DN3ataSubSet(DN3ataset):
"""
Wrap a torch subset of a DN3ataset.
"""
def __init__(self, dn3ata: DN3ataset, subset: TorchSubset):
DN3ataset.__init__(self)
self.dataset = subset.dataset
self.indices = subset.indices
if not hasattr(dn3ata, 'get_targets'):
raise ValueError("dn3ata must have a get_targets method")
self.targets = dn3ata.get_targets()[subset.indices]

def __getitem__(self, idx):
return TorchSubset.__getitem__(self, idx)

def __len__(self):
return TorchSubset.__len__(self)

def get_targets(self):
return self.targets

class Thinker(DN3ataset, ConcatDataset):
"""
Expand Down Expand Up @@ -608,7 +628,7 @@ def split(self, training_sess_ids=None, validation_sess_ids=None, testing_sess_i
if len(use_sessions) > 0:
print("Warning: sessions specified do not span all sessions. Skipping {} sessions.".format(
len(use_sessions)))
return training, validating, testing
return self._dn3_or_none(training), self._dn3_or_none(validating), self._dn3_or_none(testing)

# Split up the rest if there is anything left
if len(use_sessions) > 0:
Expand All @@ -622,8 +642,14 @@ def split(self, training_sess_ids=None, validation_sess_ids=None, testing_sess_i
validating, remainder = rand_split(remainder, frac=validation_frac)

training = remainder if training is None else training

return training, validating, testing

return self._dn3_or_none(training), self._dn3_or_none(validating), self._dn3_or_none(testing)

def _dn3_or_none(self, subset: Optional[DN3ataset]) -> Optional[DN3ataset]:
if subset is None or type(subset) is DN3ataset:
return subset

return DN3ataSubSet(self, subset)

def preprocess(self, preprocessor: Preprocessor, apply_transform=True, sessions=None, **kwargs):
"""
Expand Down
10 changes: 6 additions & 4 deletions dn3/trainable/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def save(self, filename, ignore_classifier=False):

class StrideClassifier(Classifier, metaclass=ABCMeta):

def __init__(self, targets, samples, channels, stride_width=2, return_features=False):
def __init__(self, targets, samples, channels, stride_width=2, return_features=False, bias_output=False):
"""
Instead of summarizing the entire temporal dimension into a single prediction, a prediction kernel is swept over
the final sequence representation and generates predictions at each step.
Expand All @@ -199,13 +199,15 @@ def __init__(self, targets, samples, channels, stride_width=2, return_features=F
return_features
"""
self.stride_width = stride_width
self.biased_output = bias_output
super(StrideClassifier, self).__init__(targets, samples, channels, return_features=return_features)

def make_new_classification_layer(self):
self.classifier = torch.nn.Conv1d(self.num_features_for_classification, self.targets,
kernel_size=self.stride_width)
kernel_size=self.stride_width, bias=self.biased_output)
torch.nn.init.xavier_normal_(self.classifier.weight)
self.classifier.bias.data.zero_()
if self.biased_output:
self.classifier.bias.data.zero_()


class LogRegNetwork(Classifier):
Expand Down Expand Up @@ -425,6 +427,6 @@ def easy_parallel(self):
def features_forward(self, x):
x = self.encoder(x)
x = self.contextualizer(x)
return x[0]
return x[:, :, 0]


2 changes: 1 addition & 1 deletion dn3/trainable/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def _validation(epoch, iteration=None):
if callable(step_callback):
step_callback(train_metrics)

if iteration % train_log_interval == 0 and pbar.total != iteration:
if iteration % train_log_interval == 0 and pbar.total >= iteration:
print_training_metrics(epoch, iteration)
train_metrics['epoch'] = epoch
train_metrics['iteration'] = iteration
Expand Down
1 change: 1 addition & 0 deletions dn3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def rand_split(dataset, frac=0.75):
if frac >= 1:
return dataset
samples = len(dataset)

return random_split(dataset, lengths=[round(x) for x in [samples*frac, samples*(1-frac)]])


Expand Down
28 changes: 27 additions & 1 deletion tests/testTrainables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from torch.utils.data import DataLoader
from dn3.trainable.processes import StandardClassification
from dn3.trainable.models import EEGNetStrided
from dn3.trainable.models import *
from dn3.metrics.base import balanced_accuracy
from tests.dummy_data import create_dummy_dataset, retrieve_underlying_dummy_data, EVENTS

Expand Down Expand Up @@ -89,6 +89,32 @@ def test_EvaluationMetrics(self):
self.assertIn('BAC', val_metrics)
self.assertIn('loss', val_metrics)

class TestIncludedModels(unittest.TestCase):

_BATCH_SIZES = [1, 2, 4, 8, 11]

def setUp(self) -> None:
mne.set_log_level(False)
self.dataset = create_dummy_dataset()

def test_TIDNet(self):
model = TIDNet.from_dataset(self.dataset)
process = StandardClassification(model)

for bs in self._BATCH_SIZES:
with self.subTest(batch_size=bs):
process.predict(self.dataset, batch_size=bs)
self.assertTrue(True)

def test_BENDRClassifier(self):
model = BENDRClassifier.from_dataset(self.dataset)
process = StandardClassification(model)

for bs in self._BATCH_SIZES:
with self.subTest(batch_size=bs):
process.predict(self.dataset, batch_size=bs)
self.assertTrue(True)


if __name__ == '__main__':
unittest.main()
Expand Down