diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index e0712379a..2657ad80f 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -14,7 +14,7 @@ import abc import functools import math -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -69,7 +69,8 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: The cost. """ - def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: + def barycenter(self, weights: jnp.ndarray, + xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: """Barycentric operator. Args: @@ -77,7 +78,9 @@ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: xs: Points. Returns: - The barycenter of `xs` using `weights` coefficients. + A list, whose first element is the barycenter of `xs` using `weights` + coefficients, followed by auxiliary information on the convergence of + the algorithm. """ raise NotImplementedError("Barycenter is not implemented.") @@ -268,9 +271,10 @@ def h(self, z: jnp.ndarray) -> float: # noqa: D102 def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.25 * jnp.sum(z ** 2) - def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: + def barycenter(self, weights: jnp.ndarray, + xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: """Output barycenter of vectors when using squared-Euclidean distance.""" - return jnp.average(xs, weights=weights, axis=0) + return jnp.average(xs, weights=weights, axis=0), None @jax.tree_util.register_pytree_node_class @@ -496,13 +500,14 @@ class Bures(CostFn): Args: dimension: Dimensionality of the data. - kwargs: Keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm`. + sqrtm_kw: Dictionary of keyword arguments to control the + behavior of inner calls to :func:`~ott.math.matrix_square_root.sqrtm`. """ - def __init__(self, dimension: int, **kwargs: Any): + def __init__(self, dimension: int, sqrtm_kw: Optional[Dict[str, Any]] = None): super().__init__() self._dimension = dimension - self._sqrtm_kw = kwargs + self._sqrtm_kw = {} if sqrtm_kw is None else sqrtm_kw def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian, sq. 2-norm of mean + trace of covariance.""" @@ -528,6 +533,7 @@ def covariance_fixpoint_iter( covs: jnp.ndarray, weights: jnp.ndarray, tolerance: float = 1e-4, + sqrtm_kw: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> jnp.ndarray: """Iterate fix-point updates to compute barycenter of Gaussians. @@ -535,12 +541,23 @@ def covariance_fixpoint_iter( Args: covs: [batch, d^2] covariance matrices weights: simplicial weights (non-negative, sum to 1) - tolerance: tolerance of the overall fixed-point procedure - kwargs: keyword arguments for :func:`ott.math.matrix_square_root.sqrtm`. + tolerance: tolerance of the fixed-point procedure. That tolerance is + applied to the Frobenius norm (normalized by total size) + of two successive iterations of the algorithm + sqrtm_kw: keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm` + kwargs: keyword arguments for the outer fixed-point iteration Returns: - Weighted Bures average of the covariance matrices. + List containing Weighted Bures average of the covariance matrices, and + vector of (normalized) 2-norms of successive differences between iterates, + to monitor convergence. """ + sqrtm_kw = {} if sqrtm_kw is None else sqrtm_kw + # Pop values or set defaults for fixed-point loop. + min_iterations = kwargs.pop("min_iterations", 1) + max_iterations = kwargs.pop("max_iterations", 100) + inner_iterations = kwargs.pop("inner_iterations", 5) + dtype = covs.dtype @functools.partial(jax.vmap, in_axes=[None, 0, 0]) def scale_covariances( @@ -548,51 +565,55 @@ def scale_covariances( ) -> jnp.ndarray: """Rescale covariance in barycenter step.""" return weight * matrix_square_root.sqrtm_only((cov_sqrt @ cov) @ cov_sqrt, - **kwargs) + **sqrtm_kw) def cond_fn(iteration: int, constants: Tuple[Any, ...], state) -> bool: - del iteration, constants - _, diff = state - return diff > tolerance + del constants + _, diffs = state + return diffs[iteration // inner_iterations] > tolerance def body_fn( iteration: int, constants: Tuple[Any, ...], state: Tuple[jnp.ndarray, float], compute_error: bool ) -> Tuple[jnp.ndarray, float]: - del iteration, constants, compute_error - cov, _ = state - cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov, **kwargs) + del constants, compute_error + cov, diffs = state + cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov, **sqrtm_kw) scaled_cov = jnp.linalg.matrix_power( jnp.sum(scale_covariances(cov_sqrt, covs, weights), axis=0), 2 ) next_cov = (cov_inv_sqrt @ scaled_cov) @ cov_inv_sqrt diff = jnp.sum((next_cov - cov) ** 2) / jnp.prod(jnp.array(cov.shape)) - return next_cov, diff + diffs = diffs.at[iteration // inner_iterations].set(diff) + return next_cov, diffs def init_state() -> Tuple[jnp.ndarray, float]: cov_init = jnp.eye(self._dimension) - diff = jnp.inf - return cov_init, diff - - # TODO(marcocuturi): ideally the integer parameters below should be passed - # by user, if one wants more fine grained control. This could clash with the - # parameters passed on to :func:`ott.math.matrix_square_root.sqrtm` by the - # barycenter call. At the moment, only `tolerance` can be used to control - # computational effort. - cov, _ = fixed_point_loop.fixpoint_iter( + diffs = -jnp.ones( + (np.ceil(max_iterations / inner_iterations).astype(int),), + dtype=dtype + ) + return cov_init, diffs + + cov, diffs = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, - min_iterations=1, - max_iterations=500, - inner_iterations=1, + min_iterations=min_iterations, + max_iterations=max_iterations, + inner_iterations=inner_iterations, constants=(), - state=init_state() + state=init_state(), ) - return cov + return cov, diffs def barycenter( - self, weights: jnp.ndarray, xs: jnp.ndarray, **kwargs: Any - ) -> jnp.ndarray: + self, + weights: jnp.ndarray, + xs: jnp.ndarray, + tolerance: float = 1e-4, + sqrtm_kw: Optional[Dict[str, Any]] = None, + **kwargs: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute the Bures barycenter of weighted Gaussian distributions. Implements the fixed point approach proposed in :cite:`alvarez-esteban:16` @@ -604,22 +625,38 @@ def barycenter( xs: The points to be used in the computation of the barycenter, where each point is described by a concatenation of the mean and the covariance (raveled). - kwargs: Passed on to :meth:`covariance_fixpoint_iter`, and by extension to - :func:`ott.math.matrix_square_root.sqrtm`. Note that `tolerance` is used - for the fixed-point iteration of the barycenter, whereas `threshold` - will apply to the fixed point iteration of Newton-Schulz iterations. + tolerance: convergence tolerance to control the termination of the + algorithm. + sqrtm_kw: Arguments passed on to the + :func:`~ott.math.matrix_square_root.sqrtm` function used within + :meth:`covariance_fixpoint_iter`. This defines the precision + (in terms of convergence threshold, and number of iterations) of the + matrix square root calls that are used at each outer iteration of + the computation of Gaussian barycenters. These values are, by default, + the same as those used to define the Bures cost object itself. + kwargs: Passed on to :meth:`covariance_fixpoint_iter`, to specify the + number of iterations and tolerance of the fixed-point iteration of the + barycenter routine, by parameterizing `tolerance` and other relevant + arguments passed on to :func:`~ott.math.fixed_point_loop.fixpoint_iter`, + namely `min_iterations`, `max_iterations` and `inner_iterations`. Returns: - A concatenation of the mean and the raveled covariance of the barycenter. + A list holding a concatenation of the mean and the raveled covariance + of the barycenter as its first element, followed by a vector of + norms of successive differences in iterates. """ # Ensure that barycentric weights sum to 1. weights = weights / jnp.sum(weights) mus, covs = x_to_means_and_covs(xs, self._dimension) mu_bary = jnp.sum(weights[:, None] * mus, axis=0) - cov_bary = self.covariance_fixpoint_iter( - covs=covs, weights=weights, **kwargs + cov_bary, diffs = self.covariance_fixpoint_iter( + covs=covs, + weights=weights, + tolerance=tolerance, + sqrtm_kw=sqrtm_kw if sqrtm_kw is not None else self._sqrtm_kw, + **kwargs ) - return mean_and_cov_to_x(mu_bary, cov_bary, self._dimension) + return mean_and_cov_to_x(mu_bary, cov_bary, self._dimension), diffs @classmethod def _padder(cls, dim: int) -> jnp.ndarray: @@ -635,7 +672,7 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 del children - return cls(aux_data[0], **aux_data[1]) + return cls(*aux_data) @jax.tree_util.register_pytree_node_class diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index a6f3a2383..2e49a47dc 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -529,7 +529,7 @@ def finalize(i: int): def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray: """Compute barycenter of points in self.x using weights.""" - return self.cost_fn.barycenter(self.x, weights) + return self.cost_fn.barycenter(self.x, weights)[0] @classmethod def prepare_divergences( diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index 620d9886e..84fa164fa 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -146,7 +146,9 @@ def barycentric_projection( Returns: a vector of shape (n,) containing the barycentric projection of matrix. """ - return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) + return jax.vmap( + lambda m, y: cost_fn.barycenter(m, y)[0], in_axes=[0, None] + )(matrix, y) softmin.defvjp( diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 57b824ad0..6b1bf043b 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -116,7 +116,7 @@ def solve_linear_ot( ) x_new = jax.vmap( - bar_prob.cost_fn.barycenter, in_axes=[None, 1] + lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[None, 1] )(bar_prob.weights, barycenters_per_measure) return self.set( diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 80525ed63..4be9b5ad7 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -80,25 +80,45 @@ class TestBuresBarycenter: def test_bures(self, rng: jax.random.PRNGKeyArray): d = 5 - r = jnp.array([0.3206, 0.8825, 0.1113, 0.00052, 0.9454]) + r = jnp.array([1.2036, 0.2825, 0.013, 0.00052, 0.1454]) Sigma1 = r * jnp.eye(d) - s = jnp.array([0.3075, 0.8545, 0.1110, 0.0054, 0.9206]) + s = jnp.array([3.3075, 0.8545, 0.1110, 0.54, 0.9206]) Sigma2 = s * jnp.eye(d) # initializing Bures cost function weights = jnp.array([.3, .7]) - bures = costs.Bures(d) + tolerance = 1e-6 + min_iterations = 13 + inner_iterations = 1 + max_iterations = 123 + bures = costs.Bures(d, sqrtm_kw={"max_iterations": 134, "threshold": 1e-8}) # stacking parameter values xs = jnp.vstack(( costs.mean_and_cov_to_x(jnp.zeros((d,)), Sigma1, d), costs.mean_and_cov_to_x(jnp.zeros((d,)), Sigma2, d) )) - output = bures.barycenter(weights, xs, tolerance=1e-4, threshold=1e-6) - _, sigma = costs.x_to_means_and_covs(output, 5) + cov, diffs = bures.barycenter( + weights, + xs, + tolerance=tolerance, + min_iterations=min_iterations, + max_iterations=max_iterations, + inner_iterations=inner_iterations + ) + + _, sigma = costs.x_to_means_and_covs(cov, 5) ground_truth = (weights[0] * jnp.sqrt(r) + weights[1] * jnp.sqrt(s)) ** 2 np.testing.assert_allclose( - ground_truth, jnp.diag(sigma), rtol=1e-5, atol=1e-5 + ground_truth, jnp.diag(sigma), rtol=1e-4, atol=1e-4 + ) + # Check that outer loop ran for at leat min_iterations + np.testing.assert_array_less( + 0, diffs[min_iterations // inner_iterations - 1] ) + # Check converged + np.testing.assert_array_less((diffs[diffs > -1])[-1], tolerance) + # Check right output size of difference vectors + np.testing.assert_equal(diffs.shape[0], max_iterations // inner_iterations) @pytest.mark.fast() diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index 2787ca73f..58dd320b8 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -145,7 +145,9 @@ def test_bures_point_cloud( cost_fn = costs.UnbalancedBures(dimension=self.dim, gamma=0.9, sigma=0.98) else: x, y = self.x, self.y - cost_fn = costs.Bures(dimension=self.dim, regularization=1e-4) + cost_fn = costs.Bures( + dimension=self.dim, sqrtm_kw={"regularization": 1e-4} + ) geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=self.eps) prob = linear_problem.LinearProblem(geom, self.a, self.b)