From 962d09c36d0acd26480ca08b448a151a30db72b7 Mon Sep 17 00:00:00 2001 From: Guillem Date: Wed, 27 Mar 2024 16:15:04 +0100 Subject: [PATCH 01/20] Implemented discrete multisession. --- cebra/data/datasets.py | 6 +- cebra/data/multi_session.py | 13 ++- cebra/distributions/multisession.py | 122 ++++++++++++++++++++++++++++ cebra/integrations/sklearn/cebra.py | 9 +- 4 files changed, 143 insertions(+), 7 deletions(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 8fc990b0..6d6097a2 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -222,9 +222,9 @@ def __init__( else: self._cindex = None if discrete: - raise NotImplementedError( - "Multisession implementation does not support discrete index yet." - ) + self._dindex = torch.cat(list( + self._iter_property("discrete_index")), + dim=0) else: self._dindex = None diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index 8cd74286..a85336c2 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -136,7 +136,7 @@ def get_indices(self, num_samples: int) -> List[BatchIndex]: ref_idx = torch.from_numpy(ref_idx) neg_idx = torch.from_numpy(neg_idx) pos_idx = torch.from_numpy(pos_idx) - + return BatchIndex( reference=ref_idx, positive=pos_idx, @@ -160,7 +160,16 @@ def index(self): @dataclasses.dataclass class DiscreteMultiSessionDataLoader(MultiSessionLoader): - pass + """Contrastive learning conditioned on a discrete behavior variable.""" + + # Overwrite sampler with the discrete implementation + # Generalize MultisessionSampler to avoid doing this? + def __post_init__(self): + self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset) + + @property + def index(self): + return self.dataset.discrete_index @dataclasses.dataclass diff --git a/cebra/distributions/multisession.py b/cebra/distributions/multisession.py index 08595d90..83a0cdee 100644 --- a/cebra/distributions/multisession.py +++ b/cebra/distributions/multisession.py @@ -259,3 +259,125 @@ def __getitem__(self, pos_idx): for i in range(self.num_sessions): pos_samples[i] = self.data[i][pos_idx[i]] return pos_samples + +class DiscreteMultisessionSampler(cebra_distr.PriorDistribution, + cebra_distr.ConditionalDistribution): + """Discrete multi-session sampling. + + Align embeddings across multiple sessions, using a discrete + index. The transitions between index samples are computed across + all sessions. + + After data processing, the dimensionality of the returned features + matches. The resulting embeddings can be concatenated, and shuffling + (across the session axis) can be applied to the reference samples, or + reversed for the positive samples. + + TODO: + * Add better CUDA support and refactor ``numpy`` implementation. + """ + + def __init__(self, dataset): + self.dataset = dataset + + # TODO(stes): implement in pytorch + self.all_data = self.dataset.discrete_index.cpu().numpy() + self.session_lengths = np.array(self.dataset.session_lengths) + + self.lengths = np.cumsum(self.session_lengths) + self.lengths[1:] = self.lengths[:-1] + self.lengths[0] = 0 + + self.index = [ + cebra_distr.DiscreteUniform( + dataset.discrete_index.int().to(_device)) + for dataset in self.dataset.iter_sessions() + ] + + @property + def num_sessions(self) -> int: + """The number of sessions in the index.""" + return len(self.lengths) + + def mix(self, array: np.ndarray, idx: np.ndarray): + """Re-order array elements according to the given index mapping. + + The given array should be of the shape ``(session, batch, ...)`` and the + indices should have length ``session x batch``, representing a mapping + between indices. + + The resulting array will be rearranged such that + ``out.reshape(session*batch, -1)[i] = array.reshape(session*batch, -1)[idx[i]]`` + + For the inverse mapping, convert the indices first using ``_invert_index`` + function. + + Args: + array: A 2D matrix containing samples for each session. + idx: A list of indexes to re-order ``array`` on. + """ + n, m = array.shape[:2] + return array.reshape(n * m, -1)[idx].reshape(array.shape) + + def sample_prior(self, num_samples): + # TODO(stes) implement empirical/uniform resampling + ref_idx = np.random.uniform(0, 1, (self.num_sessions, num_samples)) + ref_idx = (ref_idx * self.session_lengths[:, None]).astype(int) + return ref_idx + + def sample_conditional(self, idx: torch.Tensor) -> torch.Tensor: + """Sample from the conditional distribution. + + Note: + * Reference samples are sampled equally between sessions. + * In order to guarantee the same number of positive samples per + session, reference samples are randomly assigned to a session and its + corresponding positive sample is searched in that session only. + * As a result, ref/pos pairing is shuffled and can be recovered + the reverse shuffle operation. + + Args: + idx: Reference indices, with dimension ``(session, batch)``. + + Returns: + Positive indices (1st return value), which will be grouped by + session and *not* match the reference indices. + In addition, a mapping will be returned to apply the same shuffle operation + that was applied to assign a query to a session along session/batch dimension + to the reference indices (2nd return value), or reverse the shuffle operation + (3rd return value). + Returned shapes are ``(session, batch), (session, batch), (session, batch)``. + + TODO: + * re-implement in pytorch for additional speed gains + """ + + shape = idx.shape + # TODO(stes) unclear why we cannot restrict to 2d overall + # idx has shape (2, #samples per batch) + s = idx.shape[:2] + idx_all = (idx + self.lengths[:, None]).flatten() + + query = self.all_data[idx_all] + + # shuffle operation to assign each query to a session + idx = np.random.permutation(len(query)) + + # TODO this part fails in Pytorch + query = query[idx.reshape(s)] + query = torch.from_numpy(query).to(_device) + + pos_idx = torch.zeros(shape, device=_device).long() + for i in range(self.num_sessions): + pos_idx[i] = self.index[i].sample_conditional(query[i]) + pos_idx = pos_idx.cpu().numpy() + + # reverse indices to recover the ref/pos samples matching + idx_rev = _invert_index(idx) + return pos_idx, idx, idx_rev + + def __getitem__(self, pos_idx): + pos_samples = np.zeros(pos_idx.shape[:2] + (self.data.shape[2],)) + for i in range(self.num_sessions): + pos_samples[i] = self.data[i][pos_idx[i]] + return pos_samples \ No newline at end of file diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 077d3c47..5c302ecb 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -153,6 +153,9 @@ def _require_arg(key): # Discrete behavior contrastive training is selected with the default dataloader if not is_cont and is_disc: + kwargs = dict( + **shared_kwargs, + ) if is_full: if is_hybrid: raise_not_implemented_error = True @@ -162,7 +165,10 @@ def _require_arg(key): if is_hybrid: raise_not_implemented_error = True else: - raise_not_implemented_error = True + return ( + cebra.data.DiscreteMultiSessionDataLoader(**kwargs), + "multi-session", + ) # Mixed behavior contrastive training is selected with the default dataloader if is_cont and is_disc: @@ -1030,7 +1036,6 @@ def _partial_fit( if callback is None: raise ValueError( "callback_frequency requires to specify a callback.") - model.train() solver.fit( From 1af1fe419826daf11f01c4c40c5cdfea71fd3331 Mon Sep 17 00:00:00 2001 From: Guillem Date: Wed, 27 Mar 2024 16:15:56 +0100 Subject: [PATCH 02/20] Added tests for discrete multisession. --- cebra/datasets/demo.py | 11 +++++---- tests/test_distributions.py | 21 +++++++++++++++++ tests/test_loader.py | 46 +++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/cebra/datasets/demo.py b/cebra/datasets/demo.py index 3e943b07..bf0a7134 100644 --- a/cebra/datasets/demo.py +++ b/cebra/datasets/demo.py @@ -113,17 +113,20 @@ def discrete_index(self): # TODO(stes) remove this from the demo datasets until multi-session training # with discrete indices is implemented in the sklearn API. -# @register("demo-discrete-multisession") +@register("demo-discrete-multisession") class MultiDiscrete(cebra.data.DatasetCollection): """Demo dataset for testing.""" - def __init__(self, nums_neural=[3, 4, 5]): + def __init__( + self, + nums_neural=[3, 4, 5], + num_timepoints=_DEFAULT_NUM_TIMEPOINTS, + ): super().__init__(*[ - DemoDatasetDiscrete(_DEFAULT_NUM_TIMEPOINTS, num_neural) + DemoDatasetDiscrete(num_timepoints, num_neural) for num_neural in nums_neural ]) - @register("demo-continuous-multisession") class MultiContinuous(cebra.data.DatasetCollection): diff --git a/tests/test_distributions.py b/tests/test_distributions.py index f47fc630..a93031e3 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -297,6 +297,27 @@ def test_multi_session_time_contrastive(time_offset): assert (idx.flatten()[rev_idx.flatten()].all() == np.arange( len(rev_idx.flatten())).all()) +def test_multi_session_discrete(): + dataset = cebra_datasets.init("demo-discrete-multisession") + sampler = cebra_distr.DiscreteMultisessionSampler(dataset) + + num_samples = 5 + sample = sampler.sample_prior(num_samples) + assert sample.shape == (dataset.num_sessions, num_samples) + + positive, idx, rev_idx = sampler.sample_conditional(sample) + assert positive.shape == (dataset.num_sessions, num_samples) + assert idx.shape == (dataset.num_sessions * num_samples,) + assert rev_idx.shape == (dataset.num_sessions * num_samples,) + # NOTE(celia): test the private function ``_inverse_idx()``, with idx arrays flat + assert (idx.flatten()[rev_idx.flatten()].all() == np.arange( + len(rev_idx.flatten())).all()) + + # Check positive samples' labels match reference samples' labels + sample_labels = sampler.all_data[(sample + sampler.lengths[:, None]).flatten()] + sample_labels = sample_labels[idx.reshape(sample.shape[:2])].flatten() + positive_labels = sampler.all_data[(positive + sampler.lengths[:, None]).flatten()] + assert (sample_labels == positive_labels).all() class OldDeltaDistribution(cebra_distr_base.JointDistribution, cebra_distr_base.HasGenerator): diff --git a/tests/test_loader.py b/tests/test_loader.py index c7548338..51cde8a6 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -267,6 +267,51 @@ def _process(batch, feature_dim=1): assert dummy_prediction.shape == (3, 32, 6) _mix(dummy_prediction, batch[0].index) +def test_multisession_disc_loader(): + data = cebra.datasets.MultiDiscrete(nums_neural=[3, 4, 5], + num_timepoints=100) + loader = cebra.data.DiscreteMultiSessionDataLoader( + data, + num_steps=10, + batch_size=32, + ) + + # Check the sampler + assert hasattr(loader, "sampler") + ref_idx = loader.sampler.sample_prior(1000) + assert len(ref_idx) == 3 # num_sessions + + # Check sample points are in session length range + for session in range(3): + assert ref_idx[session].max() < loader.sampler.session_lengths[session] + pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx) + + assert pos_idx is not None + assert idx is not None + assert idx_rev is not None + + batch = next(iter(loader)) + + def _mix(array, idx): + shape = array.shape + n, m = shape[:2] + mixed = array.reshape(n * m, -1)[idx] + print(mixed.shape, array.shape, idx.shape) + return mixed.reshape(shape) + + def _process(batch, feature_dim=1): + """Given list_i[(N,d_i)] batch, return (#session, N, feature_dim) tensor""" + return torch.stack( + [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch], + dim=0).repeat(1, 1, feature_dim) + + assert batch[0].reference.shape == (32, 3, 10) + assert batch[1].reference.shape == (32, 4, 10) + assert batch[2].reference.shape == (32, 5, 10) + + dummy_prediction = _process(batch, feature_dim=6) + assert dummy_prediction.shape == (3, 32, 6) + _mix(dummy_prediction, batch[0].index) @parametrize_device @pytest.mark.parametrize( @@ -293,3 +338,4 @@ def test_multisession_loader(data_name, loader_initfunc, device): _check_attributes(batch, is_list=True) for session_batch in batch: assert len(session_batch.positive) == 32 + From 3d13e506337b46067770d74824d6220917245fef Mon Sep 17 00:00:00 2001 From: Guillem Date: Wed, 27 Mar 2024 17:33:14 +0100 Subject: [PATCH 03/20] Added tests for discrete multisession. --- cebra/datasets/demo.py | 11 +++++--- tests/test_distributions.py | 21 ++++++++++++++ tests/test_loader.py | 46 +++++++++++++++++++++++++++++++ tests/test_sklearn.py | 55 ++++++++++++++++++++++++++++++++++--- tests/test_solver.py | 34 ++++++++++++++++++++--- 5 files changed, 155 insertions(+), 12 deletions(-) diff --git a/cebra/datasets/demo.py b/cebra/datasets/demo.py index 3e943b07..bf0a7134 100644 --- a/cebra/datasets/demo.py +++ b/cebra/datasets/demo.py @@ -113,17 +113,20 @@ def discrete_index(self): # TODO(stes) remove this from the demo datasets until multi-session training # with discrete indices is implemented in the sklearn API. -# @register("demo-discrete-multisession") +@register("demo-discrete-multisession") class MultiDiscrete(cebra.data.DatasetCollection): """Demo dataset for testing.""" - def __init__(self, nums_neural=[3, 4, 5]): + def __init__( + self, + nums_neural=[3, 4, 5], + num_timepoints=_DEFAULT_NUM_TIMEPOINTS, + ): super().__init__(*[ - DemoDatasetDiscrete(_DEFAULT_NUM_TIMEPOINTS, num_neural) + DemoDatasetDiscrete(num_timepoints, num_neural) for num_neural in nums_neural ]) - @register("demo-continuous-multisession") class MultiContinuous(cebra.data.DatasetCollection): diff --git a/tests/test_distributions.py b/tests/test_distributions.py index f47fc630..a93031e3 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -297,6 +297,27 @@ def test_multi_session_time_contrastive(time_offset): assert (idx.flatten()[rev_idx.flatten()].all() == np.arange( len(rev_idx.flatten())).all()) +def test_multi_session_discrete(): + dataset = cebra_datasets.init("demo-discrete-multisession") + sampler = cebra_distr.DiscreteMultisessionSampler(dataset) + + num_samples = 5 + sample = sampler.sample_prior(num_samples) + assert sample.shape == (dataset.num_sessions, num_samples) + + positive, idx, rev_idx = sampler.sample_conditional(sample) + assert positive.shape == (dataset.num_sessions, num_samples) + assert idx.shape == (dataset.num_sessions * num_samples,) + assert rev_idx.shape == (dataset.num_sessions * num_samples,) + # NOTE(celia): test the private function ``_inverse_idx()``, with idx arrays flat + assert (idx.flatten()[rev_idx.flatten()].all() == np.arange( + len(rev_idx.flatten())).all()) + + # Check positive samples' labels match reference samples' labels + sample_labels = sampler.all_data[(sample + sampler.lengths[:, None]).flatten()] + sample_labels = sample_labels[idx.reshape(sample.shape[:2])].flatten() + positive_labels = sampler.all_data[(positive + sampler.lengths[:, None]).flatten()] + assert (sample_labels == positive_labels).all() class OldDeltaDistribution(cebra_distr_base.JointDistribution, cebra_distr_base.HasGenerator): diff --git a/tests/test_loader.py b/tests/test_loader.py index c7548338..51cde8a6 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -267,6 +267,51 @@ def _process(batch, feature_dim=1): assert dummy_prediction.shape == (3, 32, 6) _mix(dummy_prediction, batch[0].index) +def test_multisession_disc_loader(): + data = cebra.datasets.MultiDiscrete(nums_neural=[3, 4, 5], + num_timepoints=100) + loader = cebra.data.DiscreteMultiSessionDataLoader( + data, + num_steps=10, + batch_size=32, + ) + + # Check the sampler + assert hasattr(loader, "sampler") + ref_idx = loader.sampler.sample_prior(1000) + assert len(ref_idx) == 3 # num_sessions + + # Check sample points are in session length range + for session in range(3): + assert ref_idx[session].max() < loader.sampler.session_lengths[session] + pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx) + + assert pos_idx is not None + assert idx is not None + assert idx_rev is not None + + batch = next(iter(loader)) + + def _mix(array, idx): + shape = array.shape + n, m = shape[:2] + mixed = array.reshape(n * m, -1)[idx] + print(mixed.shape, array.shape, idx.shape) + return mixed.reshape(shape) + + def _process(batch, feature_dim=1): + """Given list_i[(N,d_i)] batch, return (#session, N, feature_dim) tensor""" + return torch.stack( + [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch], + dim=0).repeat(1, 1, feature_dim) + + assert batch[0].reference.shape == (32, 3, 10) + assert batch[1].reference.shape == (32, 4, 10) + assert batch[2].reference.shape == (32, 5, 10) + + dummy_prediction = _process(batch, feature_dim=6) + assert dummy_prediction.shape == (3, 32, 6) + _mix(dummy_prediction, batch[0].index) @parametrize_device @pytest.mark.parametrize( @@ -293,3 +338,4 @@ def test_multisession_loader(data_name, loader_initfunc, device): _check_attributes(batch, is_list=True) for session_batch in batch: assert len(session_batch.positive) == 32 + diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 649a7c93..af68b6dd 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -367,7 +367,31 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) - # multi-session behavior contrastive + # multi-session discrete behavior contrastive + cebra_model.fit([X, X_s2], [y_d, y_d_s2]) + assert cebra_model.num_sessions == 2 + + embedding = cebra_model.transform(X, session_id=0) + assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(torch.Tensor(X), session_id=0) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X.shape[0], output_dimension) + embedding = cebra_model.transform(X_s2, session_id=1) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X_s2.shape[0], output_dimension) + + with pytest.raises(ValueError, match="shape"): + embedding = cebra_model.transform(X_s2, session_id=0) + with pytest.raises(ValueError, match="shape"): + embedding = cebra_model.transform(X, session_id=1) + with pytest.raises(RuntimeError, match="No.*session_id"): + embedding = cebra_model.transform(X) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = cebra_model.transform(X, session_id=2) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = cebra_model.transform(X, session_id=-1) + + # multi-session continuous behavior contrastive cebra_model.fit([X, X_s2], [y_c1, y_c1_s2]) assert cebra_model.num_sessions == 2 @@ -397,7 +421,32 @@ def test_sklearn(model_architecture, device): [torch.Tensor(y_c1), torch.Tensor(y_c1_s2)], ) - # multi-session behavior contrastive, more than two sessions + # multi-session discrete behavior contrastive, more than two sessions + cebra_model.fit([X, X_s2, X], [y_d, y_d_s2, y_d]) + assert cebra_model.num_sessions == 3 + + embedding = cebra_model.transform(X, session_id=0) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X.shape[0], output_dimension) + embedding = cebra_model.transform(X_s2, session_id=1) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X_s2.shape[0], output_dimension) + embedding = cebra_model.transform(X, session_id=2) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X.shape[0], output_dimension) + + with pytest.raises(ValueError, match="shape"): + embedding = cebra_model.transform(X_s2, session_id=0) + with pytest.raises(ValueError, match="shape"): + embedding = cebra_model.transform(X_s2, session_id=2) + with pytest.raises(ValueError, match="shape"): + embedding = cebra_model.transform(X, session_id=1) + with pytest.raises(RuntimeError, match="No.*session_id"): + embedding = cebra_model.transform(X) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = cebra_model.transform(X, session_id=3) + + # multi-session continuous behavior contrastive, more than two sessions cebra_model.fit([X, X_s2, X], [y_c1, y_c1_s2, y_c1]) assert cebra_model.num_sessions == 3 @@ -432,8 +481,6 @@ def test_sklearn(model_architecture, device): cebra_model.fit([X, X, X_s2], [y_c1, y_c2]) with pytest.raises(ValueError, match="Invalid.*sessions"): cebra_model.fit([X, X_s2], [y_c1, y_c1, y_c2]) - with pytest.raises(NotImplementedError, match="discrete"): - cebra_model.fit([X, X_s2], [y_d, y_d_s2]) with pytest.raises(ValueError, match="Invalid.*samples"): cebra_model.fit([X, X_s2], [y_c1_s2, y_c1_s2]) diff --git a/tests/test_solver.py b/tests/test_solver.py index 46efd319..57d14db1 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -46,13 +46,13 @@ (*args, cebra.solver.SingleSessionHybridSolver)) multi_session_tests = [] -for args in [("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader)]: +for args in [ + ("demo-continuous-multisession", cebra.data.ContinuousMultiSessionDataLoader), + ("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader) +]: multi_session_tests.append((*args, cebra.solver.MultiSessionSolver)) # multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver)) -print(single_session_tests) - def _get_loader(data_name, loader_initfunc): data = cebra.datasets.init(data_name) @@ -168,3 +168,29 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc): assert isinstance(log, dict) solver.fit(loader) + +@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", + multi_session_tests) +def test_multi_session(data_name, loader_initfunc, solver_initfunc): + loader = _get_loader(data_name, loader_initfunc) + criterion = cebra.models.InfoNCE() + model = nn.ModuleList( + [_make_model(dataset) for dataset in loader.dataset.iter_sessions()]) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + solver = solver_initfunc(model=model, + criterion=criterion, + optimizer=optimizer, + tqdm_on=True) + + batch = next(iter(loader)) + for session_id, dataset in enumerate(loader.dataset.iter_sessions()): + assert batch[session_id].reference.shape == (32, + dataset.input_dimension, + 10) + assert batch[session_id].index is not None + + log = solver.step(batch) + assert isinstance(log, dict) + + solver.fit(loader) From 48a7100b41a5871971d4d9ec95cad80d7bc1c23f Mon Sep 17 00:00:00 2001 From: Guillem Date: Wed, 27 Mar 2024 17:36:48 +0100 Subject: [PATCH 04/20] Add sklearn integration test. --- tests/test_sklearn_metrics.py | 61 +++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index 58e12010..b23640b0 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -324,6 +324,67 @@ def test_sklearn_datasets_consistency(): dataset_ids=["achilles", "buddy"], between="datasets", ) + + # Example data with discrete labels + labels1 = np.random.randint(100, size=(10000,)) + labels1_invalid = np.random.randint(100, size=(10000, 3)) + labels2 = np.random.randint(100, size=(10000,)) + labels3 = np.random.randint(100, size=(8000,)) + labels4 = np.random.randint(100, size=(5000,)) + labels_datasets = [labels1, labels2, labels3, labels4] + + dataset_ids = ["achilles", "buddy", "cicero", "gatsby"] + + # random embeddings provide R2 close to 0 + scores, pairs, datasets = cebra_sklearn_metrics.consistency_score( + embeddings_datasets, + dataset_ids=dataset_ids, + labels=labels_datasets, + between="datasets", + ) + assert scores.shape == (12,) + assert pairs.shape == (12, 2) + assert len(datasets) == 4 + assert math.isclose(scores[0], 0, abs_tol=0.05) + + # no labels + scores, pairs, datasets = cebra_sklearn_metrics.consistency_score( + embeddings_datasets, labels=labels_datasets, between="datasets") + assert scores.shape == (12,) + assert pairs.shape == (12, 2) + assert len(datasets) == 4 + + # identical embeddings provide R2 close to 1 + scores, pairs, datasets = cebra_sklearn_metrics.consistency_score( + [embedding1, embedding1], + dataset_ids=["achilles", "buddy"], + labels=[labels1, labels1], + between="datasets", + ) + assert scores.shape == (2,) + assert pairs.shape == (2, 2) + assert len(datasets) == 2 + assert math.isclose(scores[0], 1, abs_tol=1e-9) + + # Tensor + scores, pairs, datasets = cebra_sklearn_metrics.consistency_score( + [torch.Tensor(embedding) for embedding in embeddings_datasets], + dataset_ids=dataset_ids, + labels=[torch.Tensor(label) for label in labels_datasets], + between="datasets", + ) + assert scores.shape == (12,) + assert pairs.shape == (12, 2) + assert len(datasets) == 4 + + with pytest.raises(ValueError, match="Labels.*value"): + _, _, _ = cebra_sklearn_metrics.consistency_score( + [embedding1, embedding2], + labels=[np.random.randint(5, size=(10000,)),np.random.randint(10, size=(10000,))], + between="datasets") + + +test_sklearn_datasets_consistency() def test_sklearn_runs_consistency(): From 0487231776f80967c0f8bf9b87385d1e7341da89 Mon Sep 17 00:00:00 2001 From: Guillem Date: Wed, 27 Mar 2024 18:10:13 +0100 Subject: [PATCH 05/20] Added comments. --- cebra/distributions/multisession.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/cebra/distributions/multisession.py b/cebra/distributions/multisession.py index 83a0cdee..62d7307c 100644 --- a/cebra/distributions/multisession.py +++ b/cebra/distributions/multisession.py @@ -264,9 +264,8 @@ class DiscreteMultisessionSampler(cebra_distr.PriorDistribution, cebra_distr.ConditionalDistribution): """Discrete multi-session sampling. - Align embeddings across multiple sessions, using a discrete - index. The transitions between index samples are computed across - all sessions. + Discrete indices don't need to be aligned. Positive pairs are found + by matching the discrete index in randomly assigned sessions. After data processing, the dimensionality of the returned features matches. The resulting embeddings can be concatenated, and shuffling @@ -343,9 +342,8 @@ def sample_conditional(self, idx: torch.Tensor) -> torch.Tensor: Positive indices (1st return value), which will be grouped by session and *not* match the reference indices. In addition, a mapping will be returned to apply the same shuffle operation - that was applied to assign a query to a session along session/batch dimension - to the reference indices (2nd return value), or reverse the shuffle operation - (3rd return value). + that was applied to assign reference samples to a session along session/batch dimension + (2nd return value), or reverse the shuffle operation (3rd return value). Returned shapes are ``(session, batch), (session, batch), (session, batch)``. TODO: @@ -358,15 +356,18 @@ def sample_conditional(self, idx: torch.Tensor) -> torch.Tensor: s = idx.shape[:2] idx_all = (idx + self.lengths[:, None]).flatten() + # get discrete indices query = self.all_data[idx_all] - # shuffle operation to assign each query to a session + # shuffle operation to assign each index to a session idx = np.random.permutation(len(query)) # TODO this part fails in Pytorch + # apply shuffle query = query[idx.reshape(s)] query = torch.from_numpy(query).to(_device) + # sample conditional for each assigned session pos_idx = torch.zeros(shape, device=_device).long() for i in range(self.num_sessions): pos_idx[i] = self.index[i].sample_conditional(query[i]) From 89d99e4fe71b8f6f32c3ee64517e766ea6016ec6 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Wed, 27 Mar 2024 19:45:25 +0100 Subject: [PATCH 06/20] Updating test setup --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index cd508ee3..da54ceb6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -105,3 +105,4 @@ dev = [bdist_wheel] universal=1 + From cf2897b66c8de038d7b2874bc71c8c423a5d3677 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Wed, 27 Mar 2024 19:55:24 +0100 Subject: [PATCH 07/20] Pin pytest-sphinx to 0.5.0 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index da54ceb6..ed6cac27 100644 --- a/setup.cfg +++ b/setup.cfg @@ -93,7 +93,7 @@ dev = pytest-benchmark pytest-xdist pytest-timeout - pytest-sphinx + pytest-sphinx==0.5.0 tables<=3.8 licenseheaders # TODO(stes) Add back once upstream issue From ccd6f76b556ff6ecf63098fd331f62439987bc3b Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Wed, 27 Mar 2024 20:01:53 +0100 Subject: [PATCH 08/20] Pin pytest to 7.4.4 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index ed6cac27..5d6899ad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -89,7 +89,7 @@ dev = isort toml coverage - pytest + pytest==7.4.4 pytest-benchmark pytest-xdist pytest-timeout From e67165a21d2ea6493518a93f169bb6e2da10ea60 Mon Sep 17 00:00:00 2001 From: Guillem Date: Thu, 28 Mar 2024 10:14:33 +0100 Subject: [PATCH 09/20] Revert unintended commit. --- tests/test_sklearn_metrics.py | 60 ----------------------------------- 1 file changed, 60 deletions(-) diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index b23640b0..ab86effe 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -325,66 +325,6 @@ def test_sklearn_datasets_consistency(): between="datasets", ) - # Example data with discrete labels - labels1 = np.random.randint(100, size=(10000,)) - labels1_invalid = np.random.randint(100, size=(10000, 3)) - labels2 = np.random.randint(100, size=(10000,)) - labels3 = np.random.randint(100, size=(8000,)) - labels4 = np.random.randint(100, size=(5000,)) - labels_datasets = [labels1, labels2, labels3, labels4] - - dataset_ids = ["achilles", "buddy", "cicero", "gatsby"] - - # random embeddings provide R2 close to 0 - scores, pairs, datasets = cebra_sklearn_metrics.consistency_score( - embeddings_datasets, - dataset_ids=dataset_ids, - labels=labels_datasets, - between="datasets", - ) - assert scores.shape == (12,) - assert pairs.shape == (12, 2) - assert len(datasets) == 4 - assert math.isclose(scores[0], 0, abs_tol=0.05) - - # no labels - scores, pairs, datasets = cebra_sklearn_metrics.consistency_score( - embeddings_datasets, labels=labels_datasets, between="datasets") - assert scores.shape == (12,) - assert pairs.shape == (12, 2) - assert len(datasets) == 4 - - # identical embeddings provide R2 close to 1 - scores, pairs, datasets = cebra_sklearn_metrics.consistency_score( - [embedding1, embedding1], - dataset_ids=["achilles", "buddy"], - labels=[labels1, labels1], - between="datasets", - ) - assert scores.shape == (2,) - assert pairs.shape == (2, 2) - assert len(datasets) == 2 - assert math.isclose(scores[0], 1, abs_tol=1e-9) - - # Tensor - scores, pairs, datasets = cebra_sklearn_metrics.consistency_score( - [torch.Tensor(embedding) for embedding in embeddings_datasets], - dataset_ids=dataset_ids, - labels=[torch.Tensor(label) for label in labels_datasets], - between="datasets", - ) - assert scores.shape == (12,) - assert pairs.shape == (12, 2) - assert len(datasets) == 4 - - with pytest.raises(ValueError, match="Labels.*value"): - _, _, _ = cebra_sklearn_metrics.consistency_score( - [embedding1, embedding2], - labels=[np.random.randint(5, size=(10000,)),np.random.randint(10, size=(10000,))], - between="datasets") - - -test_sklearn_datasets_consistency() def test_sklearn_runs_consistency(): From 78b776458a69ab5c8edfdc57d0b12f6c6b7c498c Mon Sep 17 00:00:00 2001 From: Guillem Date: Thu, 28 Mar 2024 11:46:55 +0100 Subject: [PATCH 10/20] Fixed tests. --- tests/test_integration_train.py | 8 +++++++- tests/test_loader.py | 3 ++- tests/test_sklearn.py | 15 +++++++++++---- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/test_integration_train.py b/tests/test_integration_train.py index 6e25f116..06e6da40 100644 --- a/tests/test_integration_train.py +++ b/tests/test_integration_train.py @@ -35,6 +35,11 @@ import cebra.models import cebra.solver +if torch.cuda.is_available(): + _DEVICE = "cuda" +else: + _DEVICE = "cpu" + def _init_single_session_solver(loader, args): """Train a single session CEBRA model.""" @@ -77,6 +82,7 @@ def _list_data_loaders(): cebra.data.HybridDataLoader, cebra.data.FullDataLoader, cebra.data.ContinuousMultiSessionDataLoader, + cebra.data.DiscreteMultiSessionDataLoader, ] # TODO limit this to the valid combinations---however this # requires to adapt the dataset API slightly; it is currently @@ -95,7 +101,7 @@ def _list_data_loaders(): @pytest.mark.requires_dataset @pytest.mark.parametrize("dataset_name, loader_type", _list_data_loaders()) def test_train(dataset_name, loader_type): - args = cebra.config.Config(num_steps=1, device="cuda").as_namespace() + args = cebra.config.Config(num_steps=1, device=_DEVICE).as_namespace() dataset = cebra.datasets.init(dataset_name) if loader_type not in cebra_data_helper.get_loader_options(dataset): diff --git a/tests/test_loader.py b/tests/test_loader.py index 51cde8a6..1dc1935b 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -317,7 +317,8 @@ def _process(batch, feature_dim=1): @pytest.mark.parametrize( "data_name, loader_initfunc", [ - # ('demo-discrete-multisession', cebra.data.DiscreteMultiSessionDataLoader), + ('demo-discrete-multisession', + cebra.data.DiscreteMultiSessionDataLoader), ("demo-continuous-multisession", cebra.data.ContinuousMultiSessionDataLoader) ], diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index af68b6dd..9c1c3f47 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -112,6 +112,7 @@ def test_sklearn_dataset(): # multisession num_sessions = 3 + # continuous sessions = [] for i in range(num_sessions): sessions.append(cebra_sklearn_dataset.SklearnDataset(X, (yc,))) @@ -126,11 +127,15 @@ def test_sklearn_dataset(): with pytest.raises(ValueError): cebra_data.datasets.DatasetCollection(sessions) + # discrete sessions = [] for i in range(num_sessions): sessions.append(cebra_sklearn_dataset.SklearnDataset(X, (yd,))) - with pytest.raises(NotImplementedError): - cebra_data.datasets.DatasetCollection(*sessions) + data = cebra_data.datasets.DatasetCollection(*sessions) + assert data.num_sessions == num_sessions + for i in range(num_sessions): + assert data.get_input_dimension(i) == X.shape[1] + assert len(data.get_session(i)) == len(X) @pytest.mark.parametrize("int_type", [np.uint8, np.int8, np.int32]) @@ -160,13 +165,15 @@ def test_sklearn_dataset_type_index(int_type, float_type): slow_arguments=list(itertools.product(*[[False, True]] * 5)), ) def test_init_loader(is_cont, is_disc, is_full, is_multi, is_hybrid): - if is_multi: + if is_multi and is_cont: # TODO(celia): change to a MultiDemoDataset class when discrete/mixed index implemented class __Dataset(cebra.datasets.MultiContinuous): neural = torch.zeros((50, 10), dtype=torch.float) continuous_index = torch.zeros((50, 10), dtype=torch.float) + elif is_multi and is_disc: + class __Dataset(cebra.datasets.MultiDiscrete): + neural = torch.zeros((50, 10), dtype=torch.float) discrete_index = torch.zeros((50,), dtype=torch.int) - else: class __Dataset(cebra.datasets.DemoDataset): From c7943f2855e9340323ccd5397f3ee0ab2918eb3e Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Thu, 28 Mar 2024 15:31:49 +0100 Subject: [PATCH 11/20] Limit pandas < 2.2.0 for docs build --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 5d6899ad..f0f42306 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,7 +66,7 @@ docs = nbconvert ipykernel matplotlib<=3.5.2 - pandas + pandas<2.2.0 seaborn scikit-learn<1.3 demos = From ca74f8aa010356b0785dd891959b97f0818337fd Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Thu, 28 Mar 2024 15:40:16 +0100 Subject: [PATCH 12/20] Update intersphinx mapping for pandas --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index c35ed9a5..be839ddf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -108,7 +108,7 @@ def get_years(start_year=2021): "sklearn": ("https://scikit-learn.org/stable", None), "numpy": ("https://numpy.org/doc/stable/", None), "matplotlib": ("https://matplotlib.org/stable/", None), - "pandas": ("http://pandas.pydata.org/pandas-docs/dev", None), + "pandas": ("https://pandas.pydata.org/docs/", None), "scipy": ("http://docs.scipy.org/doc/scipy/reference/", None), "joblib": ("https://joblib.readthedocs.io/en/latest/", None), "plotly": ("https://plotly.com/python-api-reference/", None) From 8fde61296da8ec23c81aa9ab92875e478b3630f4 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Thu, 28 Mar 2024 15:41:22 +0100 Subject: [PATCH 13/20] Unpin pandas --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index f0f42306..5d6899ad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,7 +66,7 @@ docs = nbconvert ipykernel matplotlib<=3.5.2 - pandas<2.2.0 + pandas seaborn scikit-learn<1.3 demos = From 91f07d99d2b6d9af0dc91e1916e0e5a33f3f03e6 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Thu, 28 Mar 2024 16:00:48 +0100 Subject: [PATCH 14/20] apply pre-commit --- cebra/data/multi_session.py | 2 +- cebra/datasets/demo.py | 9 +++++---- cebra/distributions/multisession.py | 9 +++++---- cebra/integrations/sklearn/cebra.py | 4 +--- setup.cfg | 1 - tests/test_distributions.py | 10 +++++++--- tests/test_loader.py | 14 ++++++-------- tests/test_sklearn.py | 1 + tests/test_sklearn_metrics.py | 1 - tests/test_solver.py | 9 +++++---- 10 files changed, 31 insertions(+), 29 deletions(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index a85336c2..7bf225a0 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -136,7 +136,7 @@ def get_indices(self, num_samples: int) -> List[BatchIndex]: ref_idx = torch.from_numpy(ref_idx) neg_idx = torch.from_numpy(neg_idx) pos_idx = torch.from_numpy(pos_idx) - + return BatchIndex( reference=ref_idx, positive=pos_idx, diff --git a/cebra/datasets/demo.py b/cebra/datasets/demo.py index bf0a7134..1c348219 100644 --- a/cebra/datasets/demo.py +++ b/cebra/datasets/demo.py @@ -118,15 +118,16 @@ class MultiDiscrete(cebra.data.DatasetCollection): """Demo dataset for testing.""" def __init__( - self, - nums_neural=[3, 4, 5], - num_timepoints=_DEFAULT_NUM_TIMEPOINTS, - ): + self, + nums_neural=[3, 4, 5], + num_timepoints=_DEFAULT_NUM_TIMEPOINTS, + ): super().__init__(*[ DemoDatasetDiscrete(num_timepoints, num_neural) for num_neural in nums_neural ]) + @register("demo-continuous-multisession") class MultiContinuous(cebra.data.DatasetCollection): diff --git a/cebra/distributions/multisession.py b/cebra/distributions/multisession.py index 62d7307c..647044f2 100644 --- a/cebra/distributions/multisession.py +++ b/cebra/distributions/multisession.py @@ -260,8 +260,9 @@ def __getitem__(self, pos_idx): pos_samples[i] = self.data[i][pos_idx[i]] return pos_samples + class DiscreteMultisessionSampler(cebra_distr.PriorDistribution, - cebra_distr.ConditionalDistribution): + cebra_distr.ConditionalDistribution): """Discrete multi-session sampling. Discrete indices don't need to be aligned. Positive pairs are found @@ -370,9 +371,9 @@ def sample_conditional(self, idx: torch.Tensor) -> torch.Tensor: # sample conditional for each assigned session pos_idx = torch.zeros(shape, device=_device).long() for i in range(self.num_sessions): - pos_idx[i] = self.index[i].sample_conditional(query[i]) + pos_idx[i] = self.index[i].sample_conditional(query[i]) pos_idx = pos_idx.cpu().numpy() - + # reverse indices to recover the ref/pos samples matching idx_rev = _invert_index(idx) return pos_idx, idx, idx_rev @@ -381,4 +382,4 @@ def __getitem__(self, pos_idx): pos_samples = np.zeros(pos_idx.shape[:2] + (self.data.shape[2],)) for i in range(self.num_sessions): pos_samples[i] = self.data[i][pos_idx[i]] - return pos_samples \ No newline at end of file + return pos_samples diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 5c302ecb..bf038237 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -153,9 +153,7 @@ def _require_arg(key): # Discrete behavior contrastive training is selected with the default dataloader if not is_cont and is_disc: - kwargs = dict( - **shared_kwargs, - ) + kwargs = dict(**shared_kwargs,) if is_full: if is_hybrid: raise_not_implemented_error = True diff --git a/setup.cfg b/setup.cfg index 5d6899ad..474ba8ea 100644 --- a/setup.cfg +++ b/setup.cfg @@ -105,4 +105,3 @@ dev = [bdist_wheel] universal=1 - diff --git a/tests/test_distributions.py b/tests/test_distributions.py index a93031e3..d7151fd1 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -297,6 +297,7 @@ def test_multi_session_time_contrastive(time_offset): assert (idx.flatten()[rev_idx.flatten()].all() == np.arange( len(rev_idx.flatten())).all()) + def test_multi_session_discrete(): dataset = cebra_datasets.init("demo-discrete-multisession") sampler = cebra_distr.DiscreteMultisessionSampler(dataset) @@ -312,13 +313,16 @@ def test_multi_session_discrete(): # NOTE(celia): test the private function ``_inverse_idx()``, with idx arrays flat assert (idx.flatten()[rev_idx.flatten()].all() == np.arange( len(rev_idx.flatten())).all()) - + # Check positive samples' labels match reference samples' labels - sample_labels = sampler.all_data[(sample + sampler.lengths[:, None]).flatten()] + sample_labels = sampler.all_data[(sample + + sampler.lengths[:, None]).flatten()] sample_labels = sample_labels[idx.reshape(sample.shape[:2])].flatten() - positive_labels = sampler.all_data[(positive + sampler.lengths[:, None]).flatten()] + positive_labels = sampler.all_data[(positive + + sampler.lengths[:, None]).flatten()] assert (sample_labels == positive_labels).all() + class OldDeltaDistribution(cebra_distr_base.JointDistribution, cebra_distr_base.HasGenerator): """ diff --git a/tests/test_loader.py b/tests/test_loader.py index 1dc1935b..562f64a7 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -267,9 +267,10 @@ def _process(batch, feature_dim=1): assert dummy_prediction.shape == (3, 32, 6) _mix(dummy_prediction, batch[0].index) + def test_multisession_disc_loader(): data = cebra.datasets.MultiDiscrete(nums_neural=[3, 4, 5], - num_timepoints=100) + num_timepoints=100) loader = cebra.data.DiscreteMultiSessionDataLoader( data, num_steps=10, @@ -313,15 +314,13 @@ def _process(batch, feature_dim=1): assert dummy_prediction.shape == (3, 32, 6) _mix(dummy_prediction, batch[0].index) + @parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", - [ - ('demo-discrete-multisession', - cebra.data.DiscreteMultiSessionDataLoader), - ("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader) - ], + [('demo-discrete-multisession', cebra.data.DiscreteMultiSessionDataLoader), + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader)], ) def test_multisession_loader(data_name, loader_initfunc, device): # TODO change number of timepoints across the sessions @@ -339,4 +338,3 @@ def test_multisession_loader(data_name, loader_initfunc, device): _check_attributes(batch, is_list=True) for session_batch in batch: assert len(session_batch.positive) == 32 - diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 9c1c3f47..e409c0e3 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -171,6 +171,7 @@ class __Dataset(cebra.datasets.MultiContinuous): neural = torch.zeros((50, 10), dtype=torch.float) continuous_index = torch.zeros((50, 10), dtype=torch.float) elif is_multi and is_disc: + class __Dataset(cebra.datasets.MultiDiscrete): neural = torch.zeros((50, 10), dtype=torch.float) discrete_index = torch.zeros((50,), dtype=torch.int) diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index ab86effe..58e12010 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -324,7 +324,6 @@ def test_sklearn_datasets_consistency(): dataset_ids=["achilles", "buddy"], between="datasets", ) - def test_sklearn_runs_consistency(): diff --git a/tests/test_solver.py b/tests/test_solver.py index 57d14db1..3107be30 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -46,10 +46,10 @@ (*args, cebra.solver.SingleSessionHybridSolver)) multi_session_tests = [] -for args in [ - ("demo-continuous-multisession", cebra.data.ContinuousMultiSessionDataLoader), - ("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader) -]: +for args in [("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ("demo-discrete-multisession", + cebra.data.DiscreteMultiSessionDataLoader)]: multi_session_tests.append((*args, cebra.solver.MultiSessionSolver)) # multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver)) @@ -169,6 +169,7 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc): solver.fit(loader) + @pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", multi_session_tests) def test_multi_session(data_name, loader_initfunc, solver_initfunc): From 617dd361d3be519eacdbe67ef7b8c0bd6813f694 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 26 May 2024 23:18:46 +0200 Subject: [PATCH 15/20] Remove outdated TODO statement --- cebra/datasets/demo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cebra/datasets/demo.py b/cebra/datasets/demo.py index 1c348219..90ba5367 100644 --- a/cebra/datasets/demo.py +++ b/cebra/datasets/demo.py @@ -111,8 +111,6 @@ def discrete_index(self): return self.dindex -# TODO(stes) remove this from the demo datasets until multi-session training -# with discrete indices is implemented in the sklearn API. @register("demo-discrete-multisession") class MultiDiscrete(cebra.data.DatasetCollection): """Demo dataset for testing.""" From 9e5c8c2367f3c2fd42f4a674bff4597f46d28c4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillem=20Fern=C3=A1ndez?= <68448023+introspective-swallow@users.noreply.github.com> Date: Mon, 27 May 2024 18:51:29 +0200 Subject: [PATCH 16/20] Update usage.rst --- docs/source/usage.rst | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 2e7124cd..61430f36 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -419,10 +419,10 @@ We can now fit the model in different modes. .. rubric:: Multi-session training -For multi-sesson training, lists of data are provided instead of a single dataset and eventual corresponding auxiliary variables. +For multi-session training, lists of data are provided instead of a single dataset and eventual corresponding auxiliary variables. .. warning:: - For now, multi-session training can only handle a **unique set of continuous labels**. All other combinations will raise an error. + For now, multi-session training can only handle a **unique set of continuous labels** or a **unique discrete label**. All other combinations will raise an error. For the continuous case we provide the following example: .. testcode:: @@ -450,6 +450,29 @@ Once you defined your CEBRA model, you can run: multi_cebra_model.fit([neural_session1, neural_session2], [continuous_label1, continuous_label2]) +Similarly, for the discrete case a discrete label can be provided and the CEBRA model will use the discrete multisession mode: + +.. testcode:: + + timesteps1 = 5000 + timesteps2 = 3000 + neurons1 = 50 + neurons2 = 30 + out_dim = 8 + + neural_session1 = np.random.normal(0,1,(timesteps1, neurons1)) + neural_session2 = np.random.normal(0,1,(timesteps2, neurons2)) + discrete_label1 = np.random.randint(0,10,(timesteps1, )) + discrete_label2 = np.random.randint(0,10,(timesteps2, )) + + multi_cebra_model = cebra.CEBRA(batch_size=512, + output_dimension=out_dim, + max_iterations=10, + max_adapt_iterations=10) + + + multi_cebra_model.fit([neural_session1, neural_session2], [discrete_label1, discrete_label2]) + .. admonition:: See API docs :class: dropdown From ce66b94682935451b6ada37a5491ae1b0f2d3c6d Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Mon, 27 May 2024 23:37:09 +0200 Subject: [PATCH 17/20] Update usage.rst old spelling error, caught by new tests :) --- docs/source/usage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 61430f36..ff59d665 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -12,7 +12,7 @@ Firstly, why use CEBRA? CEBRA is primarily designed for producing robust, consistent extractions of latent factors from time-series data. It supports three modes, and is a self-supervised representation learning algorithm that uses our modified contrastive learning approach designed for multi-modal time-series data. In short, it is a type of non-linear dimensionality reduction, like `tSNE `_ and `UMAP `_. We show in our original paper that it outperforms tSNE and UMAP at producing closer-to-ground-truth latents and is more consistent. -That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for toplogical exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023). +That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023). The CEBRA workflow ------------------ From 0a6c107a0d3ad915345c5fe8fe838490fd066492 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Mon, 27 May 2024 23:40:43 +0200 Subject: [PATCH 18/20] Update CODE_OF_CONDUCT.md - spelling --- CODE_OF_CONDUCT.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index ef0a8c30..737c54f2 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -6,7 +6,7 @@ In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, -level of experience, education, socio-economic status, nationality, personal +level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards From 7a47e77742b58527b9aa67369c569c589e37a769 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Mon, 27 May 2024 23:42:21 +0200 Subject: [PATCH 19/20] Update make_neuropixel.py - spelling --- cebra/datasets/make_neuropixel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cebra/datasets/make_neuropixel.py b/cebra/datasets/make_neuropixel.py index 27745ee3..7c097f38 100644 --- a/cebra/datasets/make_neuropixel.py +++ b/cebra/datasets/make_neuropixel.py @@ -171,7 +171,7 @@ def read_neuropixel( ): """Load 120Hz Neuropixels data recorded in the specified cortex during the movie1 stimulus. - The Neuropixels recordin is filtered and transformed to spike counts in a bin size specified by the sampling rat. + The Neuropixels recording is filtered and transformed to spike counts in a bin size specified by the sampling rat. Args: path: The wildcard file path where the neuropixels .nwb files are located. From e2682357ab99c77ef3b4713c8927bf63f37df0ad Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Tue, 28 May 2024 14:09:24 +0200 Subject: [PATCH 20/20] Update usage.rst