diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index b33f40d1..93804a99 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -935,7 +935,7 @@ def _adapt_model( dataset, is_multisession = self._prepare_data(X, y) - if is_multisession: + if is_multisession or isinstance(self.model_, nn.ModuleList): raise NotImplementedError( "The adapt option with a multisession training is not handled. Please use adapt=True for single-trained estimators only." ) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index f8b4b0ad..36ffd2fd 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -675,6 +675,48 @@ def check_first_layer_dim(model, X): cebra_model.fit([X, X_s2], [y_c1, y_c1_s2], adapt=True) +@_util.parametrize_slow( + arg_names="model_architecture,device", + fast_arguments=list( + itertools.islice( + itertools.product( + cebra_sklearn_cebra.CEBRA.supported_model_architectures(), + _DEVICES), + 2, + )), + slow_arguments=list( + itertools.product( + cebra_sklearn_cebra.CEBRA.supported_model_architectures(), + _DEVICES)), +) +def test_sklearn_adapt_multisession(model_architecture, device): + num_hidden_units = 32 + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture=model_architecture, + time_offsets=10, + learning_rate=3e-4, + max_iterations=5, + max_adapt_iterations=1, + device=device, + output_dimension=4, + num_hidden_units=num_hidden_units, + batch_size=42, + verbose=True, + ) + + # example dataset + Xs = [np.random.uniform(0, 1, (1000, 50)) for i in range(3)] + ys = [np.random.uniform(0, 1, (1000, 5)) for i in range(3)] + + X_new = np.random.uniform(0, 1, (1000, 50)) + y_new = np.random.uniform(0, 1, (1000, 5)) + + cebra_model.fit(Xs, ys) + + with pytest.raises(NotImplementedError, match=".*multisession.*"): + cebra_model.fit(X_new, y_new, adapt=True) + + @_util.parametrize_slow( arg_names="model_architecture,device,pad_before_transform", fast_arguments=list(