-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement SSL samplers
- Loading branch information
Showing
13 changed files
with
983 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,35 @@ | ||
""" -------------------------------------------------------------------- | ||
Copyright [2022] [Wilhelm Ågren] | ||
""" | ||
MIT License | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
Copyright (c) 2023 Neurocode | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
File created: 10-09-2022 | ||
Last edited: 10-09-2022 | ||
File created: 2023-09-10 | ||
Last updated: 2023-09-10 | ||
""" | ||
|
||
Initialize the samplers module and import all the required submodules. | ||
-------------------------------------------------------------------- """ | ||
from .base import PretextTaskSampler | ||
from .relative_positioning import RelativePositioningSampler | ||
from .recording_relative_positioning import RRPSampler | ||
from .scalogram_simclr import ScalogramSampler | ||
from .signal_simclr import SignalSampler | ||
from .temporal_shuffling import TSSampler | ||
from .recording_simclr import RecordingSampler | ||
from .contrastive_view import ContrastiveViewGenerator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,116 @@ | ||
""" -------------------------------------------------------------------- | ||
Copyright [2022] [Wilhelm Ågren] | ||
""" | ||
MIT License | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
Copyright (c) 2023 Neurocode | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
File created: 10-09-2022 | ||
Last edited: 10-09-2022 | ||
File created: 2023-09-10 | ||
Last updated: 2023-09-10 | ||
""" | ||
|
||
Implementation of the base PretextTaskSampler, inheriting from the | ||
PyTorch sampler object. | ||
-------------------------------------------------------------------- """ | ||
import numpy as np | ||
|
||
from torch.utils.data.sampler import Sampler | ||
|
||
|
||
class PretextTaskSampler(Sampler): | ||
def __init__(self, data, labels, *args, **kwargs): | ||
self._data = data | ||
self._labels = labels | ||
self._rng = np.random.RandomState(seed=0) | ||
def __init__(self, data, labels, info, **kwargs): | ||
self.data = data | ||
self.labels = labels | ||
self.info = info | ||
self._parameters(**kwargs) | ||
self._setup(**kwargs) | ||
|
||
def __len__(self): | ||
raise NotImplementedError(f"") | ||
return self.n_samples | ||
|
||
def __iter__(self): | ||
for i in range(self.n_samples): | ||
yield self.samples[i] if self.presample else self._sample_pair() | ||
|
||
def _setup(self, seed=1, n_samples=256, batch_size=32, presample=False, **kwargs): | ||
self.rng = np.random.RandomState(seed=seed) | ||
self.seed = seed | ||
self.n_samples = n_samples | ||
self.batch_size = batch_size | ||
self.presample = presample | ||
|
||
if presample: | ||
self._presample() | ||
|
||
def _extract_features(self, emb, device, n_samples_per_recording=None): | ||
X, Y = [], [] | ||
emb.eval() | ||
emb._return_features = True | ||
with torch.no_grad(): | ||
for reco_idx in range(len(self.data)): | ||
for idx, window in enumerate(self.data[reco_idx]): | ||
window = torch.Tensor(window[0][None]).to(device) | ||
embedding = emb(window) | ||
X.append(embedding[0, :][None]) | ||
Y.append(self.labels[reco_idx]) | ||
X = np.concatenate([x.cpu().detach().numpy() for x in X], axis=0) | ||
emb._return_features = False | ||
return (X, Y) | ||
|
||
def downstream_sample(self, emb, device): | ||
X, y = [], [] | ||
emb.eval() | ||
emb.return_feats = True | ||
|
||
def _parameters(self, *args, **kwargs): | ||
pass | ||
|
||
def _presample(self): | ||
self.samples = list(self._sample_pair() for _ in range(self.n_samples)) | ||
|
||
def _sample_recording(self): | ||
return self.rng.randint(0, high=self.info["n_recordings"]) | ||
|
||
def _sample_window(self, recording_idx=None, **kwargs): | ||
if recording_idx is None: | ||
recording_idx = self._sample_recording() | ||
return self.rng.choice(self.info["lengths"][recording_idx]) | ||
|
||
def _sample_pair(self, *args, **kwargs): | ||
raise NotImplementedError("Please implement window-pair sampling!") | ||
|
||
def _split(self): | ||
X_train = defaultdict(list) | ||
Y_train = defaultdict(list) | ||
X_valid = defaultdict(list) | ||
Y_valid = defaultdict(list) | ||
for recording in range(self.info["n_recordings"]): | ||
r_len = len(self.data[recording]) | ||
split = np.ceil(r_len * 0.7).astype(int) | ||
for window in range(split): | ||
X_train[recording].append(self.data[recording][window]) | ||
Y_train[recording].append(self.labels[recording]) | ||
for window in range(split, r_len): | ||
X_valid[recording].append(self.data[recording][window]) | ||
Y_valid[recording].append(self.labels[recording]) | ||
|
||
train_dataset = RecordingDataset( | ||
X_train, Y_train, formatted=True, sfreq=self.info["sfreq"] | ||
) | ||
valid_dataset = RecordingDataset( | ||
X_valid, Y_valid, formatted=True, sfreq=self.info["sfreq"] | ||
) | ||
return (train_dataset, valid_dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
""" | ||
MIT License | ||
Copyright (c) 2023 Neurocode | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
File created: 2023-09-10 | ||
Last updated: 2023-09-10 | ||
""" | ||
|
||
import torch | ||
|
||
|
||
class ContrastiveViewGenerator(object): | ||
"""callable object that generated n_views | ||
of the given input data x. | ||
Attributes | ||
---------- | ||
T: tuple | list | ||
A collection holding the transforms, either torchvision.transform or | ||
BaseTransform from neurocode.datautil, number of transforms should | ||
be the same as n_views. No stochastic choice on transform is made | ||
in this module, but could be implemented. | ||
n_views: int | ||
The number dictating the amount of augmentations/transformations | ||
to apply to input x, and decides the length of the resulting list | ||
after invoking __call__ on the object. | ||
""" | ||
|
||
def __init__(self, T, n_views): | ||
self.transforms = T | ||
self.n_views = n_views | ||
|
||
def __call__(self, x): | ||
return [torch.Tensor(self.transforms[t](x)) for t in range(self.n_views)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
MIT License | ||
Copyright (c) 2023 Neurocode | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
File created: 2023-09-10 | ||
Last updated: 2023-09-10 | ||
""" | ||
|
||
import torch | ||
import numpy as np | ||
|
||
from .base import PretextTaskSampler | ||
|
||
|
||
class RRPSampler(PretextTaskSampler): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def _parameters(self, tau=5, gamma=0.5, **kwargs): | ||
self.tau = tau | ||
self.gamma = gamma | ||
|
||
def _sample_pair(self): | ||
batch_anchors = list() | ||
batch_samples = list() | ||
batch_labels = list() | ||
reco_idx1 = self._sample_recording() | ||
for _ in range(self.batch_size): | ||
pair_type = self.rng.binomial(1, self.gamma) | ||
win_idx1, win_idx2, reco_idx2 = -1, -1, -1 | ||
|
||
if pair_type == 0: | ||
# Negative sample, from another recording | ||
reco_idx2 = self._sample_recording() | ||
while reco_idx2 == reco_idx1: | ||
reco_idx2 = self._sample_recording() | ||
|
||
win_idx1 = self._sample_window(recording_idx=reco_idx1) | ||
win_idx2 = self._sample_window(recording_idx=reco_idx2) | ||
|
||
if self.info["lengths"][reco_idx1] < self.info["lengths"][reco_idx2]: | ||
# The length of the first recording is shorter than the second, | ||
# so let the already sampled window from first recording remain, | ||
# then sample from second until it fits inside the context, | ||
# with respect to parameter tau | ||
while np.abs(win_idx1 - win_idx2) >= self.tau: | ||
win_idx2 = self._sample_window(recording_idx=reco_idx2) | ||
|
||
else: | ||
# The length of the second recording is shorter than the first, | ||
# so sample from recording2 first, then find a window index | ||
# from recording1 that lies inside the context | ||
while np.abs(win_idx1 - win_idx2) >= self.tau: | ||
win_idx1 = self._sample_window(recording_idx=reco_idx1) | ||
|
||
elif pair_type == 1: | ||
# Positive sample, from the same recording | ||
reco_idx2 = reco_idx1 | ||
win_idx1 = self._sample_window(recording_idx=reco_idx1) | ||
win_idx2 = self._sample_window(recording_idx=reco_idx2) | ||
while np.abs(win_idx1 - win_idx2) >= self.tau: | ||
win_idx2 = self._sample_window(recording_idx=reco_idx2) | ||
|
||
batch_anchors.append(self.data[reco_idx1][win_idx1][0][None]) | ||
batch_samples.append(self.data[reco_idx2][win_idx2][0][None]) | ||
batch_labels.append(float(pair_type)) | ||
|
||
ANCHORS = torch.Tensor(np.concatenate(batch_anchors, axis=0)) | ||
SAMPLES = torch.Tensor(np.concatenate(batch_samples, axis=0)) | ||
LABELS = torch.Tensor(np.array(batch_labels)) | ||
|
||
return (ANCHORS, SAMPLES, LABELS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
""" | ||
MIT License | ||
Copyright (c) 2023 Neurocode | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
File created: 2023-09-10 | ||
Last updated: 2023-09-10 | ||
""" | ||
|
||
import torch | ||
import numpy as np | ||
|
||
from .base import PretextTaskSampler | ||
|
||
|
||
class RecordingSampler(PretextTaskSampler): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def _parameters(self, n_channels, n_views=2, **kwargs): | ||
self.n_channels = n_channels | ||
self.n_views = n_views | ||
|
||
def _sample_pair(self): | ||
batch_anchors = [] | ||
batch_samples = [] | ||
recordings = self.rng.choice( | ||
self.info["n_recordings"], size=(self.batch_size), replace=False | ||
) | ||
for reco_idx1 in recordings: | ||
win_idx1 = self._sample_window(recording_idx=reco_idx1) | ||
win_idx2 = self._sample_window(recording_idx=reco_idx1) | ||
|
||
batch_anchors.append(self.data[reco_idx1][win_idx1][0][None]) | ||
batch_samples.append(self.data[reco_idx1][win_idx2][0][None]) | ||
|
||
ANCHORS = torch.Tensor(np.concatenate(batch_anchors, axis=0)) | ||
SAMPLES = torch.Tensor(np.concatenate(batch_samples, axis=0)) | ||
|
||
return (ANCHORS, SAMPLES) |
Oops, something went wrong.