From fc52cfdf81e49349357619a571535338d1d136d5 Mon Sep 17 00:00:00 2001 From: Michal Klein Date: Wed, 3 Aug 2022 12:17:46 +0200 Subject: [PATCH] Add shape assertions, update docs --- ott/core/bar_problems.py | 38 ++++++++++++++++++++++++++----- ott/core/continuous_barycenter.py | 4 ++-- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/ott/core/bar_problems.py b/ott/core/bar_problems.py index 0139e0cc9..ee214ca6d 100644 --- a/ott/core/bar_problems.py +++ b/ott/core/bar_problems.py @@ -38,7 +38,7 @@ class BarycenterProblem: the points within the measures that define the barycenter problem. Similarly as ``y``, segmented array of weights of shape ``[num_measures, max_measure_size]`` can be passed. - If ``y`` is already pre-segmented, this array must be passed. + If ``y`` is already pre-segmented, this array must be always specified. weights: Array of shape ``[num_measures,]`` containing the weights of the measures. cost_fn: Cost function used. If `None`, @@ -53,8 +53,8 @@ class BarycenterProblem: needs to be smaller than the maximum measure size for parallelization to operate efficiently. kwargs: Keyword arguments :func:`~ott.core.segment.segment_point_cloud`. - Only used when ``y`` is not already segmented. - For jitting, 2 arguments must be specified: + Only used when ``y`` is not already segmented. When passing + ``segment_ids``, 2 arguments must be specified for jitting to work: - ``num_segments`` - the total number of measures. - ``max_measure_size`` - maximum of support sizes of these measures. @@ -72,7 +72,7 @@ def __init__( ): self._y = y if y.ndim == 3 and b is None: - raise ValueError("Specify weights if `y` is segmented.") + raise ValueError("Specify weights if `y` is already segmented.") self._b = b self._weights = weights self.cost_fn = costs.Euclidean() if cost_fn is None else cost_fn @@ -80,6 +80,15 @@ def __init__( self.debiased = debiased self._kwargs = kwargs + if self._is_segmented: + # (num_measures, max_measure_size, ndim) + # (num_measures, max_measure_size) + assert self._y.shape[:2] == self._b.shape + else: + # (num_total_points, ndim) + # (num_total_points,) + assert self._b is None or self._y.shape[0] == self._b.shape[0] + @property def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: """Tuple of arrays containing the segmented measures and weights. @@ -107,9 +116,9 @@ def _add_slice_for_debiased( self, y: jnp.ndarray, b: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: y, b = self._y, self._b - n = y.shape[1] # (num_measures, max_measure_size, dim) + _, n, ndim = y.shape # (num_measures, max_measure_size, ndim) # yapf: disable - y = jnp.concatenate((y, jnp.zeros((1, n, self.ndim))), axis=0) + y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0) b = jnp.concatenate((b, jnp.zeros((1, n))), axis=0) # yapf: enable return y, b @@ -230,6 +239,23 @@ def __init__( self.scale_cost = scale_cost self._y_as_costs = costs is not None + if self._y_as_costs: + # (num_measures, max_measure_size, max_measure_size) + _, n, m = self._y.shape + assert n == m, "Cost matrices must be square." + if self.is_fused: + seg_y = self._is_segmented + seg_fused = self._y_fused.ndim == 3 + if seg_y and seg_fused: + # (num_measures, max_measure_size, ndim_fused) + # (num_measures, max_measure_size, ndim) + assert self._y_fused.shape[:2] == self._y.shape[:2] + if not seg_y and not seg_fused: + # (num_total_points, ndim_fused), (num_total_points, ndim) + assert self._y_fused.shape[0] == self._y.shape[0] + # TODO(michalk8): in the future, consider checking the other 2 cases + # using `segmented_y` and `segmented_y_fused`? + def update_barycenter( self, transports: jnp.ndarray, a: jnp.ndarray ) -> jnp.ndarray: diff --git a/ott/core/continuous_barycenter.py b/ott/core/continuous_barycenter.py index caa9129b3..8a328264c 100644 --- a/ott/core/continuous_barycenter.py +++ b/ott/core/continuous_barycenter.py @@ -67,8 +67,8 @@ def solve_linear_ot( pointcloud.PointCloud( x, y, - src_mask=a != 0.0, - tgt_mask=b != 0.0, + src_mask=a > 0., + tgt_mask=b > 0., cost_fn=bar_prob.cost_fn, epsilon=bar_prob.epsilon ), a, b