From bcf42c840915b431bbd0faa5ace12fca12e2c71e Mon Sep 17 00:00:00 2001 From: zeyus Date: Mon, 25 Apr 2022 20:29:17 +0200 Subject: [PATCH 01/13] Fixed issue with epoch ID -> class label mapping. --- dn3/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dn3/data/dataset.py b/dn3/data/dataset.py index 2a397b6..e5c7dd8 100644 --- a/dn3/data/dataset.py +++ b/dn3/data/dataset.py @@ -381,7 +381,7 @@ def __init__(self, epochs: mne.Epochs, session_id=0, person_id=0, force_label=No self.epoch_codes_to_class_labels = event_mapping else: reverse_mapping = {v: k for k, v in event_mapping.items()} - self.epoch_codes_to_class_labels = {v: i for i, v in enumerate(sorted(reverse_mapping.keys()))} + self.epoch_codes_to_class_labels = {v: i for i, v in enumerate(sorted(reverse_mapping.values()))} skip_epochs = list() if skip_epochs is None else skip_epochs self._skip_map = [i for i in range(len(self.epochs.events)) if i not in skip_epochs] self._skip_map = dict(zip(range(len(self._skip_map)), self._skip_map)) From 13008828d021f0aeafdc4a3333f20e6255e00b49 Mon Sep 17 00:00:00 2001 From: zeyus Date: Tue, 26 Apr 2022 21:45:04 +0200 Subject: [PATCH 02/13] added notch filter --- dn3/configuratron/config.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dn3/configuratron/config.py b/dn3/configuratron/config.py index 10010be..182e4a3 100644 --- a/dn3/configuratron/config.py +++ b/dn3/configuratron/config.py @@ -231,7 +231,8 @@ def get_pop(key, default=None): self.dumped = get_pop('pre-dumped') self.hpf = get_pop('hpf', None) self.lpf = get_pop('lpf', None) - self.filter_data = self.hpf is not None or self.lpf is not None + self.notch_freq = get_pop('notch_freq', None) + self.filter_data = self.hpf is not None or self.lpf is not None or self.notch_freq is not None if self.filter_data: self.preload = True self.stride = get_pop('stride', 1) @@ -506,9 +507,12 @@ def _load_raw(self, path: Path): @staticmethod def _prepare_session(raw, tlen, decimate, desired_sfreq, desired_samples, picks, exclude_channels, rename_channels, - hpf, lpf): + hpf, lpf, notch_freq): if hpf is not None or lpf is not None: raw = raw.filter(hpf, lpf) + + if notch_freq is not None: + raw.notch_filter(notch_freq) lowpass = raw.info.get('lowpass', None) raw_sfreq = raw.info['sfreq'] @@ -564,7 +568,7 @@ def load_and_prepare(sess): sess = Path(sess) r = self._load_raw(sess) return (sess, *self._prepare_session(r, self.tlen, self.decimate, self._sfreq, self._samples, self.picks, - self.exclude_channels, self.rename_channels, self.hpf, self.lpf)) + self.exclude_channels, self.rename_channels, self.hpf, self.lpf, self.notch_freq)) sess, raw, tlen, picks, new_sfreq = load_and_prepare(session) # Fixme - deprecate the decimate option in favour of specifying desired sfreq's From 8785d90da164c8d7d0f69a3b70e80f60719530f8 Mon Sep 17 00:00:00 2001 From: zeyus Date: Tue, 26 Apr 2022 22:42:09 +0200 Subject: [PATCH 03/13] moved notch location --- dn3/configuratron/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dn3/configuratron/config.py b/dn3/configuratron/config.py index 182e4a3..f1325ab 100644 --- a/dn3/configuratron/config.py +++ b/dn3/configuratron/config.py @@ -508,11 +508,12 @@ def _load_raw(self, path: Path): @staticmethod def _prepare_session(raw, tlen, decimate, desired_sfreq, desired_samples, picks, exclude_channels, rename_channels, hpf, lpf, notch_freq): + if notch_freq is not None: + raw.notch_filter(notch_freq) if hpf is not None or lpf is not None: raw = raw.filter(hpf, lpf) - if notch_freq is not None: - raw.notch_filter(notch_freq) + lowpass = raw.info.get('lowpass', None) raw_sfreq = raw.info['sfreq'] From 23a97590708f1440016d03a6345b4953415f2cf9 Mon Sep 17 00:00:00 2001 From: zeyus Date: Wed, 27 Apr 2022 18:56:06 +0200 Subject: [PATCH 04/13] added option to use average reference channel --- dn3/configuratron/config.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dn3/configuratron/config.py b/dn3/configuratron/config.py index f1325ab..f9bcd80 100644 --- a/dn3/configuratron/config.py +++ b/dn3/configuratron/config.py @@ -232,7 +232,8 @@ def get_pop(key, default=None): self.hpf = get_pop('hpf', None) self.lpf = get_pop('lpf', None) self.notch_freq = get_pop('notch_freq', None) - self.filter_data = self.hpf is not None or self.lpf is not None or self.notch_freq is not None + self.create_avg_ref = get_pop('create_avg_ref', False) + self.filter_data = self.hpf is not None or self.lpf is not None or self.notch_freq is not None or self.create_avg_ref if self.filter_data: self.preload = True self.stride = get_pop('stride', 1) @@ -507,7 +508,9 @@ def _load_raw(self, path: Path): @staticmethod def _prepare_session(raw, tlen, decimate, desired_sfreq, desired_samples, picks, exclude_channels, rename_channels, - hpf, lpf, notch_freq): + hpf, lpf, notch_freq, create_avg_ref): + if create_avg_ref: + raw.set_eeg_reference(ref_channels='average') if notch_freq is not None: raw.notch_filter(notch_freq) if hpf is not None or lpf is not None: @@ -569,7 +572,7 @@ def load_and_prepare(sess): sess = Path(sess) r = self._load_raw(sess) return (sess, *self._prepare_session(r, self.tlen, self.decimate, self._sfreq, self._samples, self.picks, - self.exclude_channels, self.rename_channels, self.hpf, self.lpf, self.notch_freq)) + self.exclude_channels, self.rename_channels, self.hpf, self.lpf, self.notch_freq, self.create_avg_ref)) sess, raw, tlen, picks, new_sfreq = load_and_prepare(session) # Fixme - deprecate the decimate option in favour of specifying desired sfreq's From a2f6c0a1a1270e611f10402f54d2cceb4764cc38 Mon Sep 17 00:00:00 2001 From: zeyus Date: Wed, 27 Apr 2022 19:43:05 +0200 Subject: [PATCH 05/13] updates from dn3 upstream PRs --- dn3/trainable/models.py | 2 +- dn3/trainable/processes.py | 2 +- tests/testTrainables.py | 28 +++++++++++++++++++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/dn3/trainable/models.py b/dn3/trainable/models.py index eb2ee7d..9e53474 100644 --- a/dn3/trainable/models.py +++ b/dn3/trainable/models.py @@ -425,6 +425,6 @@ def easy_parallel(self): def features_forward(self, x): x = self.encoder(x) x = self.contextualizer(x) - return x[0] + return x[:, :, 0] diff --git a/dn3/trainable/processes.py b/dn3/trainable/processes.py index 5f2ccce..cafcd3b 100644 --- a/dn3/trainable/processes.py +++ b/dn3/trainable/processes.py @@ -569,7 +569,7 @@ def _validation(epoch, iteration=None): if callable(step_callback): step_callback(train_metrics) - if iteration % train_log_interval == 0 and pbar.total != iteration: + if iteration % train_log_interval == 0 and pbar.total >= iteration: print_training_metrics(epoch, iteration) train_metrics['epoch'] = epoch train_metrics['iteration'] = iteration diff --git a/tests/testTrainables.py b/tests/testTrainables.py index 276c11c..d0b2c73 100644 --- a/tests/testTrainables.py +++ b/tests/testTrainables.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from dn3.trainable.processes import StandardClassification -from dn3.trainable.models import EEGNetStrided +from dn3.trainable.models import * from dn3.metrics.base import balanced_accuracy from tests.dummy_data import create_dummy_dataset, retrieve_underlying_dummy_data, EVENTS @@ -89,6 +89,32 @@ def test_EvaluationMetrics(self): self.assertIn('BAC', val_metrics) self.assertIn('loss', val_metrics) +class TestIncludedModels(unittest.TestCase): + + _BATCH_SIZES = [1, 2, 4, 8, 11] + + def setUp(self) -> None: + mne.set_log_level(False) + self.dataset = create_dummy_dataset() + + def test_TIDNet(self): + model = TIDNet.from_dataset(self.dataset) + process = StandardClassification(model) + + for bs in self._BATCH_SIZES: + with self.subTest(batch_size=bs): + process.predict(self.dataset, batch_size=bs) + self.assertTrue(True) + + def test_BENDRClassifier(self): + model = BENDRClassifier.from_dataset(self.dataset) + process = StandardClassification(model) + + for bs in self._BATCH_SIZES: + with self.subTest(batch_size=bs): + process.predict(self.dataset, batch_size=bs) + self.assertTrue(True) + if __name__ == '__main__': unittest.main() From a66c14a9c48fd3bd085f76dab43805029fb00311 Mon Sep 17 00:00:00 2001 From: zeyus Date: Wed, 27 Apr 2022 19:52:24 +0200 Subject: [PATCH 06/13] renamed config option to use_avg_ref --- dn3/configuratron/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dn3/configuratron/config.py b/dn3/configuratron/config.py index f9bcd80..6d445ce 100644 --- a/dn3/configuratron/config.py +++ b/dn3/configuratron/config.py @@ -232,8 +232,8 @@ def get_pop(key, default=None): self.hpf = get_pop('hpf', None) self.lpf = get_pop('lpf', None) self.notch_freq = get_pop('notch_freq', None) - self.create_avg_ref = get_pop('create_avg_ref', False) - self.filter_data = self.hpf is not None or self.lpf is not None or self.notch_freq is not None or self.create_avg_ref + self.use_avg_ref = get_pop('use_avg_ref', False) + self.filter_data = self.hpf is not None or self.lpf is not None or self.notch_freq is not None or self.use_avg_ref if self.filter_data: self.preload = True self.stride = get_pop('stride', 1) @@ -508,8 +508,8 @@ def _load_raw(self, path: Path): @staticmethod def _prepare_session(raw, tlen, decimate, desired_sfreq, desired_samples, picks, exclude_channels, rename_channels, - hpf, lpf, notch_freq, create_avg_ref): - if create_avg_ref: + hpf, lpf, notch_freq, use_avg_ref): + if use_avg_ref: raw.set_eeg_reference(ref_channels='average') if notch_freq is not None: raw.notch_filter(notch_freq) @@ -572,7 +572,7 @@ def load_and_prepare(sess): sess = Path(sess) r = self._load_raw(sess) return (sess, *self._prepare_session(r, self.tlen, self.decimate, self._sfreq, self._samples, self.picks, - self.exclude_channels, self.rename_channels, self.hpf, self.lpf, self.notch_freq, self.create_avg_ref)) + self.exclude_channels, self.rename_channels, self.hpf, self.lpf, self.notch_freq, self.use_avg_ref)) sess, raw, tlen, picks, new_sfreq = load_and_prepare(session) # Fixme - deprecate the decimate option in favour of specifying desired sfreq's From ba73d6ba509cb69c4d7e0d02170e6bbc7768b196 Mon Sep 17 00:00:00 2001 From: zeyus Date: Wed, 27 Apr 2022 22:02:38 +0200 Subject: [PATCH 07/13] [WIP] Adding return Thinker.split() as DN3ataset for use in fit() with sampling methods. --- dn3/data/dataset.py | 52 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/dn3/data/dataset.py b/dn3/data/dataset.py index e5c7dd8..2a438b2 100644 --- a/dn3/data/dataset.py +++ b/dn3/data/dataset.py @@ -1,3 +1,4 @@ +from typing import Optional import mne import torch import copy @@ -15,6 +16,7 @@ from pathlib import Path from torch.utils.data import Dataset as TorchDataset from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data.dataset import Subset as TorchSubset class DN3ataset(TorchDataset): @@ -181,6 +183,43 @@ def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4, return loaded +class DN3Subset(DN3ataset, TorchSubset): + + def __init__(self, dataset, indices): + super().__init__() + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + return TorchSubset.__getitem__(self, idx) + + def __len__(self): + return TorchSubset.__len__(self) + + @staticmethod + def init_from_torch_subset(subset: TorchSubset): + return DN3Subset(subset.dataset, subset.indices) + + @property + def sfreq(self): + return self.dataset.sfreq + + @property + def channels(self): + return self.dataset.channels + + @property + def sequence_length(self): + return self.dataset.sequence_length + + def clone(self): + return DN3Subset(self.dataset, self.indices) + + def preprocess(self, preprocessor: Preprocessor, apply_transform=True): + return self.dataset.preprocess(preprocessor, apply_transform) + + def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4, **dataloader_kwargs): + return self.dataset.to_numpy(batch_size, batch_transforms, num_workers, **dataloader_kwargs) class _Recording(DN3ataset, ABC): """ @@ -608,7 +647,7 @@ def split(self, training_sess_ids=None, validation_sess_ids=None, testing_sess_i if len(use_sessions) > 0: print("Warning: sessions specified do not span all sessions. Skipping {} sessions.".format( len(use_sessions))) - return training, validating, testing + return self._to_dn3_or_none(training), self._to_dn3_or_none(validating), self._to_dn3_or_none(testing) # Split up the rest if there is anything left if len(use_sessions) > 0: @@ -623,7 +662,16 @@ def split(self, training_sess_ids=None, validation_sess_ids=None, testing_sess_i training = remainder if training is None else training - return training, validating, testing + return self._to_dn3_or_none(training), self._to_dn3_or_none(validating), self._to_dn3_or_none(testing) + + def _to_dn3_or_none(self, x) -> Optional[DN3ataset]: + if isinstance(x, DN3ataset): + return x + elif x is None: + return x + else: + print("type of x is {}".format(type(x))) + return DN3ataset.__init__(x) def preprocess(self, preprocessor: Preprocessor, apply_transform=True, sessions=None, **kwargs): """ From f27ea333048fc8664cef9fec3e4d0acc4fc93fd2 Mon Sep 17 00:00:00 2001 From: zeyus Date: Wed, 27 Apr 2022 22:07:41 +0200 Subject: [PATCH 08/13] [WIP] added todo to remember where i was going with this. --- dn3/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dn3/utils.py b/dn3/utils.py index 851dd40..a0a9b8f 100644 --- a/dn3/utils.py +++ b/dn3/utils.py @@ -53,6 +53,7 @@ def rand_split(dataset, frac=0.75): if frac >= 1: return dataset samples = len(dataset) + # @TODO: return DN3ataset wrapped torch.utils.data.dataset.Subset return random_split(dataset, lengths=[round(x) for x in [samples*frac, samples*(1-frac)]]) From 8808d70e726e2040aa28016114e5a3dd70e25a3d Mon Sep 17 00:00:00 2001 From: zeyus Date: Thu, 28 Apr 2022 16:06:53 +0200 Subject: [PATCH 09/13] new subset class --- dn3/data/dataset.py | 42 +++++++++++++++++++++++++++++++----------- dn3/utils.py | 1 - 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/dn3/data/dataset.py b/dn3/data/dataset.py index 2a438b2..a8eec91 100644 --- a/dn3/data/dataset.py +++ b/dn3/data/dataset.py @@ -183,6 +183,7 @@ def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4, return loaded +<<<<<<< Updated upstream class DN3Subset(DN3ataset, TorchSubset): def __init__(self, dataset, indices): @@ -221,6 +222,8 @@ def preprocess(self, preprocessor: Preprocessor, apply_transform=True): def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4, **dataloader_kwargs): return self.dataset.to_numpy(batch_size, batch_transforms, num_workers, **dataloader_kwargs) +======= +>>>>>>> Stashed changes class _Recording(DN3ataset, ABC): """ Abstract base class for any supported recording @@ -472,6 +475,26 @@ def get_targets(self): return np.apply_along_axis(lambda x: self.epoch_codes_to_class_labels[x[0]], 1, self.epochs.events[list(self._skip_map.values()), -1, np.newaxis]).squeeze() +class DN3ataSubSet(DN3ataset): + """ + Wrap a torch subset of a DN3ataset. + """ + def __init__(self, dn3ata: DN3ataset, subset: TorchSubset): + DN3ataset.__init__(self) + self.dataset = subset.dataset + self.indices = subset.indices + if not hasattr(dn3ata, 'get_targets'): + raise ValueError("dn3ata must have a get_targets method") + self.targets = dn3ata.get_targets()[subset.indices] + + def __getitem__(self, idx): + return TorchSubset.__getitem__(self, idx) + + def __len__(self): + return TorchSubset.__len__(self) + + def get_targets(self): + return self.targets class Thinker(DN3ataset, ConcatDataset): """ @@ -647,7 +670,7 @@ def split(self, training_sess_ids=None, validation_sess_ids=None, testing_sess_i if len(use_sessions) > 0: print("Warning: sessions specified do not span all sessions. Skipping {} sessions.".format( len(use_sessions))) - return self._to_dn3_or_none(training), self._to_dn3_or_none(validating), self._to_dn3_or_none(testing) + return self._dn3_or_none(training), self._dn3_or_none(validating), self._dn3_or_none(testing) # Split up the rest if there is anything left if len(use_sessions) > 0: @@ -661,17 +684,14 @@ def split(self, training_sess_ids=None, validation_sess_ids=None, testing_sess_i validating, remainder = rand_split(remainder, frac=validation_frac) training = remainder if training is None else training + + return self._dn3_or_none(training), self._dn3_or_none(validating), self._dn3_or_none(testing) - return self._to_dn3_or_none(training), self._to_dn3_or_none(validating), self._to_dn3_or_none(testing) - - def _to_dn3_or_none(self, x) -> Optional[DN3ataset]: - if isinstance(x, DN3ataset): - return x - elif x is None: - return x - else: - print("type of x is {}".format(type(x))) - return DN3ataset.__init__(x) + def _dn3_or_none(self, subset: Optional[DN3ataset]) -> Optional[DN3ataset]: + if subset is None or type(subset) is DN3ataset: + return subset + + return DN3ataSubSet(self, subset) def preprocess(self, preprocessor: Preprocessor, apply_transform=True, sessions=None, **kwargs): """ diff --git a/dn3/utils.py b/dn3/utils.py index a0a9b8f..851dd40 100644 --- a/dn3/utils.py +++ b/dn3/utils.py @@ -53,7 +53,6 @@ def rand_split(dataset, frac=0.75): if frac >= 1: return dataset samples = len(dataset) - # @TODO: return DN3ataset wrapped torch.utils.data.dataset.Subset return random_split(dataset, lengths=[round(x) for x in [samples*frac, samples*(1-frac)]]) From f3c8875884af3bb91dfad04ed39ce387b1732115 Mon Sep 17 00:00:00 2001 From: zeyus Date: Thu, 28 Apr 2022 16:10:04 +0200 Subject: [PATCH 10/13] whoops --- dn3/data/dataset.py | 42 ------------------------------------------ dn3/utils.py | 2 ++ 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/dn3/data/dataset.py b/dn3/data/dataset.py index a8eec91..d1c3476 100644 --- a/dn3/data/dataset.py +++ b/dn3/data/dataset.py @@ -182,48 +182,6 @@ def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4, loaded = [np.concatenate([loaded[i], batch[i]], axis=0) for i in range(len(batch))] return loaded - -<<<<<<< Updated upstream -class DN3Subset(DN3ataset, TorchSubset): - - def __init__(self, dataset, indices): - super().__init__() - self.dataset = dataset - self.indices = indices - - def __getitem__(self, idx): - return TorchSubset.__getitem__(self, idx) - - def __len__(self): - return TorchSubset.__len__(self) - - @staticmethod - def init_from_torch_subset(subset: TorchSubset): - return DN3Subset(subset.dataset, subset.indices) - - @property - def sfreq(self): - return self.dataset.sfreq - - @property - def channels(self): - return self.dataset.channels - - @property - def sequence_length(self): - return self.dataset.sequence_length - - def clone(self): - return DN3Subset(self.dataset, self.indices) - - def preprocess(self, preprocessor: Preprocessor, apply_transform=True): - return self.dataset.preprocess(preprocessor, apply_transform) - - def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4, **dataloader_kwargs): - return self.dataset.to_numpy(batch_size, batch_transforms, num_workers, **dataloader_kwargs) - -======= ->>>>>>> Stashed changes class _Recording(DN3ataset, ABC): """ Abstract base class for any supported recording diff --git a/dn3/utils.py b/dn3/utils.py index 851dd40..d6a5718 100644 --- a/dn3/utils.py +++ b/dn3/utils.py @@ -52,7 +52,9 @@ class DN3atasetNanFound(BaseException): def rand_split(dataset, frac=0.75): if frac >= 1: return dataset + print("type of dataset is {}".format(type(dataset))) samples = len(dataset) + return random_split(dataset, lengths=[round(x) for x in [samples*frac, samples*(1-frac)]]) From c3637df19d32ead0c2c722b9f030b4fc5631833c Mon Sep 17 00:00:00 2001 From: zeyus Date: Thu, 28 Apr 2022 16:11:03 +0200 Subject: [PATCH 11/13] remove spurious print --- dn3/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dn3/utils.py b/dn3/utils.py index d6a5718..6ab870c 100644 --- a/dn3/utils.py +++ b/dn3/utils.py @@ -52,7 +52,6 @@ class DN3atasetNanFound(BaseException): def rand_split(dataset, frac=0.75): if frac >= 1: return dataset - print("type of dataset is {}".format(type(dataset))) samples = len(dataset) return random_split(dataset, lengths=[round(x) for x in [samples*frac, samples*(1-frac)]]) From 8fd88010263acccb11eec6c934a22dbe374d52b7 Mon Sep 17 00:00:00 2001 From: zeyus Date: Sun, 22 May 2022 19:15:10 +0200 Subject: [PATCH 12/13] allow data loader to access config object. --- dn3/configuratron/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dn3/configuratron/config.py b/dn3/configuratron/config.py index 6d445ce..ddbc741 100644 --- a/dn3/configuratron/config.py +++ b/dn3/configuratron/config.py @@ -467,7 +467,7 @@ def add_custom_raw_loader(self, custom_loader): Parameters ---------- custom_loader: callable - A function that expects a single :any:`pathlib.Path()` instance as argument and returns an + A function that expects a :any:`DatasetConfig` :any:`pathlib.Path()` instance as argument and returns an instance of :any:`mne.io.Raw()`. To gracefully ignore problematic sessions, raise :any:`DN3ConfigException` within. @@ -495,7 +495,7 @@ def add_progress_callbacks(self, session_callback=None, thinker_callback=None): def _load_raw(self, path: Path): if self._custom_raw_loader is not None: - return self._custom_raw_loader(path) + return self._custom_raw_loader(self, path) if path.suffix in self._extension_handlers: return self._extension_handlers[path.suffix](str(path), preload=self.preload) print("Handler for file {} with extension {} not found.".format(str(path), path.suffix)) From b1611356500efa1a537e645846c542e1fc4621e5 Mon Sep 17 00:00:00 2001 From: zeyus Date: Sun, 29 May 2022 11:13:54 +0200 Subject: [PATCH 13/13] removed bias on classification layer by default --- dn3/trainable/models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dn3/trainable/models.py b/dn3/trainable/models.py index 9e53474..2719be2 100644 --- a/dn3/trainable/models.py +++ b/dn3/trainable/models.py @@ -185,7 +185,7 @@ def save(self, filename, ignore_classifier=False): class StrideClassifier(Classifier, metaclass=ABCMeta): - def __init__(self, targets, samples, channels, stride_width=2, return_features=False): + def __init__(self, targets, samples, channels, stride_width=2, return_features=False, bias_output=False): """ Instead of summarizing the entire temporal dimension into a single prediction, a prediction kernel is swept over the final sequence representation and generates predictions at each step. @@ -199,13 +199,15 @@ def __init__(self, targets, samples, channels, stride_width=2, return_features=F return_features """ self.stride_width = stride_width + self.biased_output = bias_output super(StrideClassifier, self).__init__(targets, samples, channels, return_features=return_features) def make_new_classification_layer(self): self.classifier = torch.nn.Conv1d(self.num_features_for_classification, self.targets, - kernel_size=self.stride_width) + kernel_size=self.stride_width, bias=self.biased_output) torch.nn.init.xavier_normal_(self.classifier.weight) - self.classifier.bias.data.zero_() + if self.biased_output: + self.classifier.bias.data.zero_() class LogRegNetwork(Classifier):