Skip to content

Commit

Permalink
Add shape assertions, update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Aug 3, 2022
1 parent 7485da3 commit fc52cfd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
38 changes: 32 additions & 6 deletions ott/core/bar_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BarycenterProblem:
the points within the measures that define the barycenter problem.
Similarly as ``y``, segmented array of weights of shape
``[num_measures, max_measure_size]`` can be passed.
If ``y`` is already pre-segmented, this array must be passed.
If ``y`` is already pre-segmented, this array must be always specified.
weights: Array of shape ``[num_measures,]`` containing the weights of the
measures.
cost_fn: Cost function used. If `None`,
Expand All @@ -53,8 +53,8 @@ class BarycenterProblem:
needs to be smaller than the maximum measure size for parallelization to
operate efficiently.
kwargs: Keyword arguments :func:`~ott.core.segment.segment_point_cloud`.
Only used when ``y`` is not already segmented.
For jitting, 2 arguments must be specified:
Only used when ``y`` is not already segmented. When passing
``segment_ids``, 2 arguments must be specified for jitting to work:
- ``num_segments`` - the total number of measures.
- ``max_measure_size`` - maximum of support sizes of these measures.
Expand All @@ -72,14 +72,23 @@ def __init__(
):
self._y = y
if y.ndim == 3 and b is None:
raise ValueError("Specify weights if `y` is segmented.")
raise ValueError("Specify weights if `y` is already segmented.")
self._b = b
self._weights = weights
self.cost_fn = costs.Euclidean() if cost_fn is None else cost_fn
self.epsilon = epsilon
self.debiased = debiased
self._kwargs = kwargs

if self._is_segmented:
# (num_measures, max_measure_size, ndim)
# (num_measures, max_measure_size)
assert self._y.shape[:2] == self._b.shape
else:
# (num_total_points, ndim)
# (num_total_points,)
assert self._b is None or self._y.shape[0] == self._b.shape[0]

@property
def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Tuple of arrays containing the segmented measures and weights.
Expand Down Expand Up @@ -107,9 +116,9 @@ def _add_slice_for_debiased(
self, y: jnp.ndarray, b: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
y, b = self._y, self._b
n = y.shape[1] # (num_measures, max_measure_size, dim)
_, n, ndim = y.shape # (num_measures, max_measure_size, ndim)
# yapf: disable
y = jnp.concatenate((y, jnp.zeros((1, n, self.ndim))), axis=0)
y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0)
b = jnp.concatenate((b, jnp.zeros((1, n))), axis=0)
# yapf: enable
return y, b
Expand Down Expand Up @@ -230,6 +239,23 @@ def __init__(
self.scale_cost = scale_cost
self._y_as_costs = costs is not None

if self._y_as_costs:
# (num_measures, max_measure_size, max_measure_size)
_, n, m = self._y.shape
assert n == m, "Cost matrices must be square."
if self.is_fused:
seg_y = self._is_segmented
seg_fused = self._y_fused.ndim == 3
if seg_y and seg_fused:
# (num_measures, max_measure_size, ndim_fused)
# (num_measures, max_measure_size, ndim)
assert self._y_fused.shape[:2] == self._y.shape[:2]
if not seg_y and not seg_fused:
# (num_total_points, ndim_fused), (num_total_points, ndim)
assert self._y_fused.shape[0] == self._y.shape[0]
# TODO(michalk8): in the future, consider checking the other 2 cases
# using `segmented_y` and `segmented_y_fused`?

def update_barycenter(
self, transports: jnp.ndarray, a: jnp.ndarray
) -> jnp.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions ott/core/continuous_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def solve_linear_ot(
pointcloud.PointCloud(
x,
y,
src_mask=a != 0.0,
tgt_mask=b != 0.0,
src_mask=a > 0.,
tgt_mask=b > 0.,
cost_fn=bar_prob.cost_fn,
epsilon=bar_prob.epsilon
), a, b
Expand Down

0 comments on commit fc52cfd

Please sign in to comment.