Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine grained control of Bures barycenters #366

Merged
merged 9 commits into from
Jun 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 81 additions & 44 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import abc
import functools
import math
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -69,15 +69,18 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
The cost.
"""

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
def barycenter(self, weights: jnp.ndarray,
xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]:
"""Barycentric operator.

Args:
weights: Convex set of weights.
xs: Points.

Returns:
The barycenter of `xs` using `weights` coefficients.
A list, whose first element is the barycenter of `xs` using `weights`
coefficients, followed by auxiliary information on the convergence of
the algorithm.
"""
raise NotImplementedError("Barycenter is not implemented.")

Expand Down Expand Up @@ -268,9 +271,10 @@ def h(self, z: jnp.ndarray) -> float: # noqa: D102
def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102
return 0.25 * jnp.sum(z ** 2)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
def barycenter(self, weights: jnp.ndarray,
xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]:
"""Output barycenter of vectors when using squared-Euclidean distance."""
return jnp.average(xs, weights=weights, axis=0)
return jnp.average(xs, weights=weights, axis=0), None


@jax.tree_util.register_pytree_node_class
Expand Down Expand Up @@ -496,13 +500,14 @@ class Bures(CostFn):

Args:
dimension: Dimensionality of the data.
kwargs: Keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm`.
sqrtm_kw: Dictionary of keyword arguments to control the
behavior of inner calls to :func:`~ott.math.matrix_square_root.sqrtm`.
"""

def __init__(self, dimension: int, **kwargs: Any):
def __init__(self, dimension: int, sqrtm_kw: Optional[Dict[str, Any]] = None):
super().__init__()
self._dimension = dimension
self._sqrtm_kw = kwargs
self._sqrtm_kw = {} if sqrtm_kw is None else sqrtm_kw

def norm(self, x: jnp.ndarray) -> jnp.ndarray:
"""Compute norm of Gaussian, sq. 2-norm of mean + trace of covariance."""
Expand All @@ -528,71 +533,87 @@ def covariance_fixpoint_iter(
covs: jnp.ndarray,
weights: jnp.ndarray,
tolerance: float = 1e-4,
sqrtm_kw: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> jnp.ndarray:
"""Iterate fix-point updates to compute barycenter of Gaussians.

Args:
covs: [batch, d^2] covariance matrices
weights: simplicial weights (non-negative, sum to 1)
tolerance: tolerance of the overall fixed-point procedure
kwargs: keyword arguments for :func:`ott.math.matrix_square_root.sqrtm`.
tolerance: tolerance of the fixed-point procedure. That tolerance is
applied to the Frobenius norm (normalized by total size)
of two successive iterations of the algorithm
sqrtm_kw: keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm`
kwargs: keyword arguments for the outer fixed-point iteration

Returns:
Weighted Bures average of the covariance matrices.
List containing Weighted Bures average of the covariance matrices, and
vector of (normalized) 2-norms of successive differences between iterates,
to monitor convergence.
"""
sqrtm_kw = {} if sqrtm_kw is None else sqrtm_kw
# Pop values or set defaults for fixed-point loop.
min_iterations = kwargs.pop("min_iterations", 1)
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
max_iterations = kwargs.pop("max_iterations", 100)
inner_iterations = kwargs.pop("inner_iterations", 5)
dtype = covs.dtype

@functools.partial(jax.vmap, in_axes=[None, 0, 0])
def scale_covariances(
cov_sqrt: jnp.ndarray, cov: jnp.ndarray, weight: jnp.ndarray
) -> jnp.ndarray:
"""Rescale covariance in barycenter step."""
return weight * matrix_square_root.sqrtm_only((cov_sqrt @ cov) @ cov_sqrt,
**kwargs)
**sqrtm_kw)

def cond_fn(iteration: int, constants: Tuple[Any, ...], state) -> bool:
del iteration, constants
_, diff = state
return diff > tolerance
del constants
_, diffs = state
return diffs[iteration // inner_iterations] > tolerance

def body_fn(
iteration: int, constants: Tuple[Any, ...],
state: Tuple[jnp.ndarray, float], compute_error: bool
) -> Tuple[jnp.ndarray, float]:
del iteration, constants, compute_error
cov, _ = state
cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov, **kwargs)
del constants, compute_error
cov, diffs = state
cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov, **sqrtm_kw)
scaled_cov = jnp.linalg.matrix_power(
jnp.sum(scale_covariances(cov_sqrt, covs, weights), axis=0), 2
)
next_cov = (cov_inv_sqrt @ scaled_cov) @ cov_inv_sqrt
diff = jnp.sum((next_cov - cov) ** 2) / jnp.prod(jnp.array(cov.shape))
return next_cov, diff
diffs = diffs.at[iteration // inner_iterations].set(diff)
return next_cov, diffs

def init_state() -> Tuple[jnp.ndarray, float]:
cov_init = jnp.eye(self._dimension)
diff = jnp.inf
return cov_init, diff

# TODO(marcocuturi): ideally the integer parameters below should be passed
# by user, if one wants more fine grained control. This could clash with the
# parameters passed on to :func:`ott.math.matrix_square_root.sqrtm` by the
# barycenter call. At the moment, only `tolerance` can be used to control
# computational effort.
cov, _ = fixed_point_loop.fixpoint_iter(
diffs = -jnp.ones(
(np.ceil(max_iterations / inner_iterations).astype(int),),
dtype=dtype
)
return cov_init, diffs

cov, diffs = fixed_point_loop.fixpoint_iter(
cond_fn=cond_fn,
body_fn=body_fn,
min_iterations=1,
max_iterations=500,
inner_iterations=1,
min_iterations=min_iterations,
max_iterations=max_iterations,
inner_iterations=inner_iterations,
constants=(),
state=init_state()
state=init_state(),
)
return cov
return cov, diffs

def barycenter(
self, weights: jnp.ndarray, xs: jnp.ndarray, **kwargs: Any
) -> jnp.ndarray:
self,
weights: jnp.ndarray,
xs: jnp.ndarray,
tolerance: float = 1e-4,
sqrtm_kw: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Compute the Bures barycenter of weighted Gaussian distributions.

Implements the fixed point approach proposed in :cite:`alvarez-esteban:16`
Expand All @@ -604,22 +625,38 @@ def barycenter(
xs: The points to be used in the computation of the barycenter, where
each point is described by a concatenation of the mean and the
covariance (raveled).
kwargs: Passed on to :meth:`covariance_fixpoint_iter`, and by extension to
:func:`ott.math.matrix_square_root.sqrtm`. Note that `tolerance` is used
for the fixed-point iteration of the barycenter, whereas `threshold`
will apply to the fixed point iteration of Newton-Schulz iterations.
tolerance: convergence tolerance to control the termination of the
algorithm.
sqrtm_kw: Arguments passed on to the
:func:`~ott.math.matrix_square_root.sqrtm` function used within
:meth:`covariance_fixpoint_iter`. This defines the precision
(in terms of convergence threshold, and number of iterations) of the
matrix square root calls that are used at each outer iteration of
the computation of Gaussian barycenters. These values are, by default,
the same as those used to define the Bures cost object itself.
kwargs: Passed on to :meth:`covariance_fixpoint_iter`, to specify the
number of iterations and tolerance of the fixed-point iteration of the
barycenter routine, by parameterizing `tolerance` and other relevant
arguments passed on to :func:`~ott.math.fixed_point_loop.fixpoint_iter`,
namely `min_iterations`, `max_iterations` and `inner_iterations`.

Returns:
A concatenation of the mean and the raveled covariance of the barycenter.
A list holding a concatenation of the mean and the raveled covariance
of the barycenter as its first element, followed by a vector of
norms of successive differences in iterates.
"""
# Ensure that barycentric weights sum to 1.
weights = weights / jnp.sum(weights)
mus, covs = x_to_means_and_covs(xs, self._dimension)
mu_bary = jnp.sum(weights[:, None] * mus, axis=0)
cov_bary = self.covariance_fixpoint_iter(
covs=covs, weights=weights, **kwargs
cov_bary, diffs = self.covariance_fixpoint_iter(
covs=covs,
weights=weights,
tolerance=tolerance,
sqrtm_kw=sqrtm_kw if sqrtm_kw is not None else self._sqrtm_kw,
**kwargs
)
return mean_and_cov_to_x(mu_bary, cov_bary, self._dimension)
return mean_and_cov_to_x(mu_bary, cov_bary, self._dimension), diffs

@classmethod
def _padder(cls, dim: int) -> jnp.ndarray:
Expand All @@ -635,7 +672,7 @@ def tree_flatten(self): # noqa: D102
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(aux_data[0], **aux_data[1])
return cls(*aux_data)


@jax.tree_util.register_pytree_node_class
Expand Down
2 changes: 1 addition & 1 deletion src/ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def finalize(i: int):

def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray:
"""Compute barycenter of points in self.x using weights."""
return self.cost_fn.barycenter(self.x, weights)
return self.cost_fn.barycenter(self.x, weights)[0]

@classmethod
def prepare_divergences(
Expand Down
4 changes: 3 additions & 1 deletion src/ott/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def barycentric_projection(
Returns:
a vector of shape (n,) containing the barycentric projection of matrix.
"""
return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y)
return jax.vmap(
lambda m, y: cost_fn.barycenter(m, y)[0], in_axes=[0, None]
)(matrix, y)


softmin.defvjp(
Expand Down
2 changes: 1 addition & 1 deletion src/ott/solvers/linear/continuous_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def solve_linear_ot(
)

x_new = jax.vmap(
bar_prob.cost_fn.barycenter, in_axes=[None, 1]
lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[None, 1]
)(bar_prob.weights, barycenters_per_measure)

return self.set(
Expand Down
32 changes: 26 additions & 6 deletions tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,45 @@ class TestBuresBarycenter:

def test_bures(self, rng: jax.random.PRNGKeyArray):
d = 5
r = jnp.array([0.3206, 0.8825, 0.1113, 0.00052, 0.9454])
r = jnp.array([1.2036, 0.2825, 0.013, 0.00052, 0.1454])
Sigma1 = r * jnp.eye(d)
s = jnp.array([0.3075, 0.8545, 0.1110, 0.0054, 0.9206])
s = jnp.array([3.3075, 0.8545, 0.1110, 0.54, 0.9206])
Sigma2 = s * jnp.eye(d)
# initializing Bures cost function
weights = jnp.array([.3, .7])
bures = costs.Bures(d)
tolerance = 1e-6
min_iterations = 13
inner_iterations = 1
max_iterations = 123
bures = costs.Bures(d, sqrtm_kw={"max_iterations": 134, "threshold": 1e-8})
# stacking parameter values
xs = jnp.vstack((
costs.mean_and_cov_to_x(jnp.zeros((d,)), Sigma1, d),
costs.mean_and_cov_to_x(jnp.zeros((d,)), Sigma2, d)
))

output = bures.barycenter(weights, xs, tolerance=1e-4, threshold=1e-6)
_, sigma = costs.x_to_means_and_covs(output, 5)
cov, diffs = bures.barycenter(
weights,
xs,
tolerance=tolerance,
min_iterations=min_iterations,
max_iterations=max_iterations,
inner_iterations=inner_iterations
)

_, sigma = costs.x_to_means_and_covs(cov, 5)
ground_truth = (weights[0] * jnp.sqrt(r) + weights[1] * jnp.sqrt(s)) ** 2
np.testing.assert_allclose(
ground_truth, jnp.diag(sigma), rtol=1e-5, atol=1e-5
ground_truth, jnp.diag(sigma), rtol=1e-4, atol=1e-4
)
# Check that outer loop ran for at leat min_iterations
np.testing.assert_array_less(
0, diffs[min_iterations // inner_iterations - 1]
)
# Check converged
np.testing.assert_array_less((diffs[diffs > -1])[-1], tolerance)
# Check right output size of difference vectors
np.testing.assert_equal(diffs.shape[0], max_iterations // inner_iterations)


@pytest.mark.fast()
Expand Down
4 changes: 3 additions & 1 deletion tests/solvers/linear/sinkhorn_misc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def test_bures_point_cloud(
cost_fn = costs.UnbalancedBures(dimension=self.dim, gamma=0.9, sigma=0.98)
else:
x, y = self.x, self.y
cost_fn = costs.Bures(dimension=self.dim, regularization=1e-4)
cost_fn = costs.Bures(
dimension=self.dim, sqrtm_kw={"regularization": 1e-4}
)

geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=self.eps)
prob = linear_problem.LinearProblem(geom, self.a, self.b)
Expand Down