Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation for discrete multisession #135

Merged
merged 23 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
level of experience, education, socioeconomic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards
Expand Down
6 changes: 3 additions & 3 deletions cebra/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ def __init__(
else:
self._cindex = None
if discrete:
raise NotImplementedError(
"Multisession implementation does not support discrete index yet."
)
self._dindex = torch.cat(list(
self._iter_property("discrete_index")),
dim=0)
else:
self._dindex = None

Expand Down
11 changes: 10 additions & 1 deletion cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,16 @@ def index(self):

@dataclasses.dataclass
class DiscreteMultiSessionDataLoader(MultiSessionLoader):
pass
"""Contrastive learning conditioned on a discrete behavior variable."""

# Overwrite sampler with the discrete implementation
# Generalize MultisessionSampler to avoid doing this?
def __post_init__(self):
self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset)

@property
def index(self):
return self.dataset.discrete_index


@dataclasses.dataclass
Expand Down
12 changes: 7 additions & 5 deletions cebra/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,17 @@ def discrete_index(self):
return self.dindex


# TODO(stes) remove this from the demo datasets until multi-session training
# with discrete indices is implemented in the sklearn API.
# @register("demo-discrete-multisession")
@register("demo-discrete-multisession")
class MultiDiscrete(cebra.data.DatasetCollection):
"""Demo dataset for testing."""

def __init__(self, nums_neural=[3, 4, 5]):
def __init__(
self,
nums_neural=[3, 4, 5],
num_timepoints=_DEFAULT_NUM_TIMEPOINTS,
):
super().__init__(*[
DemoDatasetDiscrete(_DEFAULT_NUM_TIMEPOINTS, num_neural)
DemoDatasetDiscrete(num_timepoints, num_neural)
for num_neural in nums_neural
])

Expand Down
2 changes: 1 addition & 1 deletion cebra/datasets/make_neuropixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def read_neuropixel(
):
"""Load 120Hz Neuropixels data recorded in the specified cortex during the movie1 stimulus.

The Neuropixels recordin is filtered and transformed to spike counts in a bin size specified by the sampling rat.
The Neuropixels recording is filtered and transformed to spike counts in a bin size specified by the sampling rat.

Args:
path: The wildcard file path where the neuropixels .nwb files are located.
Expand Down
124 changes: 124 additions & 0 deletions cebra/distributions/multisession.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,127 @@ def __getitem__(self, pos_idx):
for i in range(self.num_sessions):
pos_samples[i] = self.data[i][pos_idx[i]]
return pos_samples


class DiscreteMultisessionSampler(cebra_distr.PriorDistribution,
cebra_distr.ConditionalDistribution):
"""Discrete multi-session sampling.

Discrete indices don't need to be aligned. Positive pairs are found
by matching the discrete index in randomly assigned sessions.

After data processing, the dimensionality of the returned features
matches. The resulting embeddings can be concatenated, and shuffling
(across the session axis) can be applied to the reference samples, or
reversed for the positive samples.

TODO:
* Add better CUDA support and refactor ``numpy`` implementation.
"""

def __init__(self, dataset):
self.dataset = dataset

# TODO(stes): implement in pytorch
self.all_data = self.dataset.discrete_index.cpu().numpy()
self.session_lengths = np.array(self.dataset.session_lengths)

self.lengths = np.cumsum(self.session_lengths)
self.lengths[1:] = self.lengths[:-1]
self.lengths[0] = 0

self.index = [
cebra_distr.DiscreteUniform(
dataset.discrete_index.int().to(_device))
for dataset in self.dataset.iter_sessions()
]

@property
def num_sessions(self) -> int:
"""The number of sessions in the index."""
return len(self.lengths)

def mix(self, array: np.ndarray, idx: np.ndarray):
"""Re-order array elements according to the given index mapping.

The given array should be of the shape ``(session, batch, ...)`` and the
indices should have length ``session x batch``, representing a mapping
between indices.

The resulting array will be rearranged such that
``out.reshape(session*batch, -1)[i] = array.reshape(session*batch, -1)[idx[i]]``

For the inverse mapping, convert the indices first using ``_invert_index``
function.

Args:
array: A 2D matrix containing samples for each session.
idx: A list of indexes to re-order ``array`` on.
"""
n, m = array.shape[:2]
return array.reshape(n * m, -1)[idx].reshape(array.shape)

def sample_prior(self, num_samples):
# TODO(stes) implement empirical/uniform resampling
ref_idx = np.random.uniform(0, 1, (self.num_sessions, num_samples))
ref_idx = (ref_idx * self.session_lengths[:, None]).astype(int)
return ref_idx

def sample_conditional(self, idx: torch.Tensor) -> torch.Tensor:
"""Sample from the conditional distribution.

Note:
* Reference samples are sampled equally between sessions.
* In order to guarantee the same number of positive samples per
session, reference samples are randomly assigned to a session and its
corresponding positive sample is searched in that session only.
* As a result, ref/pos pairing is shuffled and can be recovered
the reverse shuffle operation.

Args:
idx: Reference indices, with dimension ``(session, batch)``.

Returns:
Positive indices (1st return value), which will be grouped by
session and *not* match the reference indices.
In addition, a mapping will be returned to apply the same shuffle operation
that was applied to assign reference samples to a session along session/batch dimension
(2nd return value), or reverse the shuffle operation (3rd return value).
Returned shapes are ``(session, batch), (session, batch), (session, batch)``.

TODO:
* re-implement in pytorch for additional speed gains
"""

shape = idx.shape
# TODO(stes) unclear why we cannot restrict to 2d overall
# idx has shape (2, #samples per batch)
s = idx.shape[:2]
idx_all = (idx + self.lengths[:, None]).flatten()

# get discrete indices
query = self.all_data[idx_all]

# shuffle operation to assign each index to a session
idx = np.random.permutation(len(query))

# TODO this part fails in Pytorch
# apply shuffle
query = query[idx.reshape(s)]
query = torch.from_numpy(query).to(_device)

# sample conditional for each assigned session
pos_idx = torch.zeros(shape, device=_device).long()
for i in range(self.num_sessions):
pos_idx[i] = self.index[i].sample_conditional(query[i])
pos_idx = pos_idx.cpu().numpy()

# reverse indices to recover the ref/pos samples matching
idx_rev = _invert_index(idx)
return pos_idx, idx, idx_rev

def __getitem__(self, pos_idx):
pos_samples = np.zeros(pos_idx.shape[:2] + (self.data.shape[2],))
for i in range(self.num_sessions):
pos_samples[i] = self.data[i][pos_idx[i]]
return pos_samples
7 changes: 5 additions & 2 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _require_arg(key):

# Discrete behavior contrastive training is selected with the default dataloader
if not is_cont and is_disc:
kwargs = dict(**shared_kwargs,)
if is_full:
if is_hybrid:
raise_not_implemented_error = True
Expand All @@ -162,7 +163,10 @@ def _require_arg(key):
if is_hybrid:
raise_not_implemented_error = True
else:
raise_not_implemented_error = True
return (
cebra.data.DiscreteMultiSessionDataLoader(**kwargs),
"multi-session",
)

# Mixed behavior contrastive training is selected with the default dataloader
if is_cont and is_disc:
Expand Down Expand Up @@ -1030,7 +1034,6 @@ def _partial_fit(
if callback is None:
raise ValueError(
"callback_frequency requires to specify a callback.")

model.train()

solver.fit(
Expand Down
29 changes: 26 additions & 3 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Firstly, why use CEBRA?

CEBRA is primarily designed for producing robust, consistent extractions of latent factors from time-series data. It supports three modes, and is a self-supervised representation learning algorithm that uses our modified contrastive learning approach designed for multi-modal time-series data. In short, it is a type of non-linear dimensionality reduction, like `tSNE <https://www.jmlr.org/papers/v9/vandermaaten08a.html>`_ and `UMAP <https://arxiv.org/abs/1802.03426>`_. We show in our original paper that it outperforms tSNE and UMAP at producing closer-to-ground-truth latents and is more consistent.

That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for toplogical exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023).
That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023).

The CEBRA workflow
------------------
Expand Down Expand Up @@ -419,10 +419,10 @@ We can now fit the model in different modes.

.. rubric:: Multi-session training

For multi-sesson training, lists of data are provided instead of a single dataset and eventual corresponding auxiliary variables.
For multi-session training, lists of data are provided instead of a single dataset and eventual corresponding auxiliary variables.

.. warning::
For now, multi-session training can only handle a **unique set of continuous labels**. All other combinations will raise an error.
For now, multi-session training can only handle a **unique set of continuous labels** or a **unique discrete label**. All other combinations will raise an error. For the continuous case we provide the following example:


.. testcode::
Expand Down Expand Up @@ -450,6 +450,29 @@ Once you defined your CEBRA model, you can run:
multi_cebra_model.fit([neural_session1, neural_session2], [continuous_label1, continuous_label2])


Similarly, for the discrete case a discrete label can be provided and the CEBRA model will use the discrete multisession mode:

.. testcode::

timesteps1 = 5000
timesteps2 = 3000
neurons1 = 50
neurons2 = 30
out_dim = 8

neural_session1 = np.random.normal(0,1,(timesteps1, neurons1))
neural_session2 = np.random.normal(0,1,(timesteps2, neurons2))
discrete_label1 = np.random.randint(0,10,(timesteps1, ))
discrete_label2 = np.random.randint(0,10,(timesteps2, ))

multi_cebra_model = cebra.CEBRA(batch_size=512,
output_dimension=out_dim,
max_iterations=10,
max_adapt_iterations=10)


multi_cebra_model.fit([neural_session1, neural_session2], [discrete_label1, discrete_label2])

.. admonition:: See API docs
:class: dropdown

Expand Down
25 changes: 25 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,31 @@ def test_multi_session_time_contrastive(time_offset):
len(rev_idx.flatten())).all())


def test_multi_session_discrete():
dataset = cebra_datasets.init("demo-discrete-multisession")
sampler = cebra_distr.DiscreteMultisessionSampler(dataset)

num_samples = 5
sample = sampler.sample_prior(num_samples)
assert sample.shape == (dataset.num_sessions, num_samples)

positive, idx, rev_idx = sampler.sample_conditional(sample)
assert positive.shape == (dataset.num_sessions, num_samples)
assert idx.shape == (dataset.num_sessions * num_samples,)
assert rev_idx.shape == (dataset.num_sessions * num_samples,)
# NOTE(celia): test the private function ``_inverse_idx()``, with idx arrays flat
assert (idx.flatten()[rev_idx.flatten()].all() == np.arange(
len(rev_idx.flatten())).all())

# Check positive samples' labels match reference samples' labels
sample_labels = sampler.all_data[(sample +
sampler.lengths[:, None]).flatten()]
sample_labels = sample_labels[idx.reshape(sample.shape[:2])].flatten()
positive_labels = sampler.all_data[(positive +
sampler.lengths[:, None]).flatten()]
assert (sample_labels == positive_labels).all()


class OldDeltaDistribution(cebra_distr_base.JointDistribution,
cebra_distr_base.HasGenerator):
"""
Expand Down
8 changes: 7 additions & 1 deletion tests/test_integration_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
import cebra.models
import cebra.solver

if torch.cuda.is_available():
_DEVICE = "cuda"
else:
_DEVICE = "cpu"


def _init_single_session_solver(loader, args):
"""Train a single session CEBRA model."""
Expand Down Expand Up @@ -77,6 +82,7 @@ def _list_data_loaders():
cebra.data.HybridDataLoader,
cebra.data.FullDataLoader,
cebra.data.ContinuousMultiSessionDataLoader,
cebra.data.DiscreteMultiSessionDataLoader,
]
# TODO limit this to the valid combinations---however this
# requires to adapt the dataset API slightly; it is currently
Expand All @@ -95,7 +101,7 @@ def _list_data_loaders():
@pytest.mark.requires_dataset
@pytest.mark.parametrize("dataset_name, loader_type", _list_data_loaders())
def test_train(dataset_name, loader_type):
args = cebra.config.Config(num_steps=1, device="cuda").as_namespace()
args = cebra.config.Config(num_steps=1, device=_DEVICE).as_namespace()

dataset = cebra.datasets.init(dataset_name)
if loader_type not in cebra_data_helper.get_loader_options(dataset):
Expand Down
Loading
Loading