Skip to content

Commit

Permalink
add positive sampling options for MixedDataLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Oct 29, 2023
1 parent 0378db0 commit f1894a1
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,27 +261,47 @@ class MixedDataLoader(cebra_data.Loader):
1. Positive pairs always share their discrete variable.
2. Positive pairs are drawn only based on their conditional,
not discrete variable.
Args:
conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional`
time_offset (int): :py:attr:`cebra.CEBRA.time_offsets`
positive_sampling (str): either "discrete_variable" (default) or "conditional"
discrete_sampling_prior (str): either "empirical" (default) or "uniform"
"""

conditional: str = dataclasses.field(default="time_delta")
time_offset: int = dataclasses.field(default=10)
positive_sampling: str = dataclasses.field(default="discrete_variable")
discrete_sampling_prior: str = dataclasses.field(default="uniform")

@property
def dindex(self):
# TODO(stes) rename to discrete_index
def discrete_index(self):
return self.dataset.discrete_index

@property
def cindex(self):
# TODO(stes) rename to continuous_index
def continuous_index(self):
return self.dataset.continuous_index

def __post_init__(self):
super().__post_init__()
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
discrete=self.dindex,
continuous=self.cindex,
time_delta=self.time_offset)
if self.positive_sampling == "conditional":
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
discrete=self.discrete_index,
continuous=self.continuous_index,
time_delta=self.time_offset)
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical":
self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index)
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform":
self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index)
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior not in ["empirical", "uniform"]:
raise ValueError(
f"Invalid choice of prior distribution. Got '{self.discrete_sampling_prior}', but "
f"only accept 'uniform' or 'empirical' as potential values.")
else:
raise ValueError(
f"Invalid positive sampling mode: "
f"{self.positive_sampling} valid options are "
f"'conditional' or 'discrete_variable'.")

def get_indices(self, num_samples: int) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.
Expand All @@ -306,12 +326,23 @@ def get_indices(self, num_samples: int) -> BatchIndex:
class.
- Sample the negatives with matching discrete variable
"""
reference_idx = self.distribution.sample_prior(num_samples)
return BatchIndex(
reference=reference_idx,
negative=self.distribution.sample_prior(num_samples),
positive=self.distribution.sample_conditional(reference_idx),
)
if self.positive_sampling == "conditional":
reference_idx = self.distribution.sample_prior(num_samples)
return BatchIndex(
reference=reference_idx,
negative=self.distribution.sample_prior(num_samples),
positive=self.distribution.sample_conditional(reference_idx),
)
else:
# taken from the DiscreteDataLoader get_indices function
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference = self.discrete_index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
negative=negative_idx)


@dataclasses.dataclass
Expand Down

0 comments on commit f1894a1

Please sign in to comment.