Skip to content

Commit

Permalink
Merge #18 Implement SSL samplers
Browse files Browse the repository at this point in the history
Implement SSL samplers
  • Loading branch information
wilhelmagren authored Sep 10, 2023
2 parents 86a0f0e + 816c4f5 commit 0c18ae4
Show file tree
Hide file tree
Showing 13 changed files with 983 additions and 85 deletions.
43 changes: 28 additions & 15 deletions neurocode/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
""" --------------------------------------------------------------------
Copyright [2022] [Wilhelm Ågren]
"""
MIT License
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Copyright (c) 2023 Neurocode
http://www.apache.org/licenses/LICENSE-2.0
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
File created: 10-09-2022
Last edited: 10-09-2022
File created: 2023-09-10
Last updated: 2023-09-10
"""

Initialize the samplers module and import all the required submodules.
-------------------------------------------------------------------- """
from .base import PretextTaskSampler
from .relative_positioning import RelativePositioningSampler
from .recording_relative_positioning import RRPSampler
from .scalogram_simclr import ScalogramSampler
from .signal_simclr import SignalSampler
from .temporal_shuffling import TSSampler
from .recording_simclr import RecordingSampler
from .contrastive_view import ContrastiveViewGenerator
124 changes: 103 additions & 21 deletions neurocode/samplers/base.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,116 @@
""" --------------------------------------------------------------------
Copyright [2022] [Wilhelm Ågren]
"""
MIT License
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Copyright (c) 2023 Neurocode
http://www.apache.org/licenses/LICENSE-2.0
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
File created: 10-09-2022
Last edited: 10-09-2022
File created: 2023-09-10
Last updated: 2023-09-10
"""

Implementation of the base PretextTaskSampler, inheriting from the
PyTorch sampler object.
-------------------------------------------------------------------- """
import numpy as np

from torch.utils.data.sampler import Sampler


class PretextTaskSampler(Sampler):
def __init__(self, data, labels, *args, **kwargs):
self._data = data
self._labels = labels
self._rng = np.random.RandomState(seed=0)
def __init__(self, data, labels, info, **kwargs):
self.data = data
self.labels = labels
self.info = info
self._parameters(**kwargs)
self._setup(**kwargs)

def __len__(self):
raise NotImplementedError(f"")
return self.n_samples

def __iter__(self):
for i in range(self.n_samples):
yield self.samples[i] if self.presample else self._sample_pair()

def _setup(self, seed=1, n_samples=256, batch_size=32, presample=False, **kwargs):
self.rng = np.random.RandomState(seed=seed)
self.seed = seed
self.n_samples = n_samples
self.batch_size = batch_size
self.presample = presample

if presample:
self._presample()

def _extract_features(self, emb, device, n_samples_per_recording=None):
X, Y = [], []
emb.eval()
emb._return_features = True
with torch.no_grad():
for reco_idx in range(len(self.data)):
for idx, window in enumerate(self.data[reco_idx]):
window = torch.Tensor(window[0][None]).to(device)
embedding = emb(window)
X.append(embedding[0, :][None])
Y.append(self.labels[reco_idx])
X = np.concatenate([x.cpu().detach().numpy() for x in X], axis=0)
emb._return_features = False
return (X, Y)

def downstream_sample(self, emb, device):
X, y = [], []
emb.eval()
emb.return_feats = True

def _parameters(self, *args, **kwargs):
pass

def _presample(self):
self.samples = list(self._sample_pair() for _ in range(self.n_samples))

def _sample_recording(self):
return self.rng.randint(0, high=self.info["n_recordings"])

def _sample_window(self, recording_idx=None, **kwargs):
if recording_idx is None:
recording_idx = self._sample_recording()
return self.rng.choice(self.info["lengths"][recording_idx])

def _sample_pair(self, *args, **kwargs):
raise NotImplementedError("Please implement window-pair sampling!")

def _split(self):
X_train = defaultdict(list)
Y_train = defaultdict(list)
X_valid = defaultdict(list)
Y_valid = defaultdict(list)
for recording in range(self.info["n_recordings"]):
r_len = len(self.data[recording])
split = np.ceil(r_len * 0.7).astype(int)
for window in range(split):
X_train[recording].append(self.data[recording][window])
Y_train[recording].append(self.labels[recording])
for window in range(split, r_len):
X_valid[recording].append(self.data[recording][window])
Y_valid[recording].append(self.labels[recording])

train_dataset = RecordingDataset(
X_train, Y_train, formatted=True, sfreq=self.info["sfreq"]
)
valid_dataset = RecordingDataset(
X_valid, Y_valid, formatted=True, sfreq=self.info["sfreq"]
)
return (train_dataset, valid_dataset)
54 changes: 54 additions & 0 deletions neurocode/samplers/contrastive_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
MIT License
Copyright (c) 2023 Neurocode
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
File created: 2023-09-10
Last updated: 2023-09-10
"""

import torch


class ContrastiveViewGenerator(object):
"""callable object that generated n_views
of the given input data x.
Attributes
----------
T: tuple | list
A collection holding the transforms, either torchvision.transform or
BaseTransform from neurocode.datautil, number of transforms should
be the same as n_views. No stochastic choice on transform is made
in this module, but could be implemented.
n_views: int
The number dictating the amount of augmentations/transformations
to apply to input x, and decides the length of the resulting list
after invoking __call__ on the object.
"""

def __init__(self, T, n_views):
self.transforms = T
self.n_views = n_views

def __call__(self, x):
return [torch.Tensor(self.transforms[t](x)) for t in range(self.n_views)]
91 changes: 91 additions & 0 deletions neurocode/samplers/recording_relative_positioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
MIT License
Copyright (c) 2023 Neurocode
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
File created: 2023-09-10
Last updated: 2023-09-10
"""

import torch
import numpy as np

from .base import PretextTaskSampler


class RRPSampler(PretextTaskSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _parameters(self, tau=5, gamma=0.5, **kwargs):
self.tau = tau
self.gamma = gamma

def _sample_pair(self):
batch_anchors = list()
batch_samples = list()
batch_labels = list()
reco_idx1 = self._sample_recording()
for _ in range(self.batch_size):
pair_type = self.rng.binomial(1, self.gamma)
win_idx1, win_idx2, reco_idx2 = -1, -1, -1

if pair_type == 0:
# Negative sample, from another recording
reco_idx2 = self._sample_recording()
while reco_idx2 == reco_idx1:
reco_idx2 = self._sample_recording()

win_idx1 = self._sample_window(recording_idx=reco_idx1)
win_idx2 = self._sample_window(recording_idx=reco_idx2)

if self.info["lengths"][reco_idx1] < self.info["lengths"][reco_idx2]:
# The length of the first recording is shorter than the second,
# so let the already sampled window from first recording remain,
# then sample from second until it fits inside the context,
# with respect to parameter tau
while np.abs(win_idx1 - win_idx2) >= self.tau:
win_idx2 = self._sample_window(recording_idx=reco_idx2)

else:
# The length of the second recording is shorter than the first,
# so sample from recording2 first, then find a window index
# from recording1 that lies inside the context
while np.abs(win_idx1 - win_idx2) >= self.tau:
win_idx1 = self._sample_window(recording_idx=reco_idx1)

elif pair_type == 1:
# Positive sample, from the same recording
reco_idx2 = reco_idx1
win_idx1 = self._sample_window(recording_idx=reco_idx1)
win_idx2 = self._sample_window(recording_idx=reco_idx2)
while np.abs(win_idx1 - win_idx2) >= self.tau:
win_idx2 = self._sample_window(recording_idx=reco_idx2)

batch_anchors.append(self.data[reco_idx1][win_idx1][0][None])
batch_samples.append(self.data[reco_idx2][win_idx2][0][None])
batch_labels.append(float(pair_type))

ANCHORS = torch.Tensor(np.concatenate(batch_anchors, axis=0))
SAMPLES = torch.Tensor(np.concatenate(batch_samples, axis=0))
LABELS = torch.Tensor(np.array(batch_labels))

return (ANCHORS, SAMPLES, LABELS)
58 changes: 58 additions & 0 deletions neurocode/samplers/recording_simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
MIT License
Copyright (c) 2023 Neurocode
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
File created: 2023-09-10
Last updated: 2023-09-10
"""

import torch
import numpy as np

from .base import PretextTaskSampler


class RecordingSampler(PretextTaskSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _parameters(self, n_channels, n_views=2, **kwargs):
self.n_channels = n_channels
self.n_views = n_views

def _sample_pair(self):
batch_anchors = []
batch_samples = []
recordings = self.rng.choice(
self.info["n_recordings"], size=(self.batch_size), replace=False
)
for reco_idx1 in recordings:
win_idx1 = self._sample_window(recording_idx=reco_idx1)
win_idx2 = self._sample_window(recording_idx=reco_idx1)

batch_anchors.append(self.data[reco_idx1][win_idx1][0][None])
batch_samples.append(self.data[reco_idx1][win_idx2][0][None])

ANCHORS = torch.Tensor(np.concatenate(batch_anchors, axis=0))
SAMPLES = torch.Tensor(np.concatenate(batch_samples, axis=0))

return (ANCHORS, SAMPLES)
Loading

0 comments on commit 0c18ae4

Please sign in to comment.