Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes numerical errors in Bures barycenter, and sqrtm, due to low default precision. #205

Merged
merged 11 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import abc
import functools
import math
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Mapping, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -304,34 +304,46 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
def covariance_fixpoint_iter(
self,
covs: jnp.ndarray,
lambdas: jnp.ndarray,
rtol: float = 1e-2
weights: jnp.ndarray,
tolerance: float = 1e-4,
kwargs_sqrtm: Optional[Mapping[str, Any]] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use **kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at that moment, this was used to avoid mixing threshold (which is used for sqrtm fixed point) and tolerance (for barycenter. Both can be defined though. However, we'll be running into issues if we want to open other parameters (such as min_iterations) since we have 2 imbricated fixed point loops (sqrtm and cov. barycenter).

) -> jnp.ndarray:
"""Iterate fix-point updates to compute barycenter of Gaussians."""
"""Iterate fix-point updates to compute barycenter of Gaussians.

Args:
covs: [batch, d^2] covariance matrices
weights: simplicial weights (nonnegative, sum to 1)
tolerance: tolerance of the overall fixed-point procedure
kwargs_sqrtm: parameters passed on to the sqrtm (Newton-Schulz)
algorithm to compute matrix square roots.

Returns:
a covariance matrix, the weighted Bures average of the covs matrices.
"""
kwargs_sqrtm = {} if kwargs_sqrtm is None else kwargs_sqrtm

@functools.partial(jax.vmap, in_axes=[None, 0, 0])
def scale_covariances(
cov_sqrt: jnp.ndarray, cov_i: jnp.ndarray, lambda_i: jnp.ndarray
cov_sqrt: jnp.ndarray, cov: jnp.ndarray, weight: jnp.ndarray
) -> jnp.ndarray:
"""Iterate update needed to compute barycenter of covariances."""
return lambda_i * matrix_square_root.sqrtm_only(
(cov_sqrt @ cov_i) @ cov_sqrt
)
"""Rescale covariance in barycenter step."""
return weight * matrix_square_root.sqrtm_only((cov_sqrt @ cov) @ cov_sqrt,
**kwargs_sqrtm)

def cond_fn(iteration: int, constants: Tuple[Any, ...], state) -> bool:
del iteration, constants
_, diff = state
return diff > rtol
return diff > tolerance

def body_fn(
iteration: int, constants: Tuple[Any, ...],
state: Tuple[jnp.ndarray, float], compute_error: bool
) -> Tuple[jnp.ndarray, float]:
del iteration, constants, compute_error
cov, _ = state
cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov)
cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov, **kwargs_sqrtm)
scaled_cov = jnp.linalg.matrix_power(
jnp.sum(scale_covariances(cov_sqrt, covs, lambdas), axis=0), 2
jnp.sum(scale_covariances(cov_sqrt, covs, weights), axis=0), 2
)
next_cov = (cov_inv_sqrt @ scaled_cov) @ cov_inv_sqrt
diff = jnp.sum((next_cov - cov) ** 2) / jnp.prod(jnp.array(cov.shape))
Expand All @@ -353,7 +365,9 @@ def init_state() -> Tuple[jnp.ndarray, float]:
)
return cov

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
def barycenter(
self, weights: jnp.ndarray, xs: jnp.ndarray, **kwargs
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
) -> jnp.ndarray:
"""Compute the Bures barycenter of weighted Gaussian distributions.

Implements the fixed point approach proposed in :cite:`alvarez-esteban:16`
Expand All @@ -365,6 +379,7 @@ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
xs: The points to be used in the computation of the barycenter, where
each point is described by a concatenation of the mean and the
covariance (raveled).
kwargs: Passed on to :meth:`covariance_fixpoint_iter`

Returns:
A concatenation of the mean and the raveled covariance of the barycenter.
Expand All @@ -373,7 +388,9 @@ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
weights = weights / jnp.sum(weights)
mus, covs = x_to_means_and_covs(xs, self._dimension)
mu_bary = jnp.sum(weights[:, None] * mus, axis=0)
cov_bary = self.covariance_fixpoint_iter(covs=covs, lambdas=weights)
cov_bary = self.covariance_fixpoint_iter(
covs=covs, weights=weights, **kwargs
)
barycenter = mean_and_cov_to_x(mu_bary, cov_bary, self._dimension)
return barycenter

Expand Down
78 changes: 60 additions & 18 deletions src/ott/math/matrix_square_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
@functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5))
def sqrtm(
x: jnp.ndarray,
threshold: float = 1e-3,
threshold: float = 1e-6,
min_iterations: int = 0,
inner_iterations: int = 10,
max_iterations: int = 1000,
regularization: float = 1e-3
regularization: float = 1e-6
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Higham algorithm to compute matrix square root of p.d. matrix.

Expand Down Expand Up @@ -235,18 +235,37 @@ def sqrtm_bwd(
# These functions have lower complexity gradients than sqrtm.


@jax.custom_vjp
def sqrtm_only(x: jnp.ndarray) -> jnp.ndarray:
return sqrtm(x)[0]


def sqrtm_only_fwd(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
sqrt_x = sqrtm(x)[0]
@functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5))
def sqrtm_only(
x: jnp.ndarray,
threshold: float = 1e-6,
min_iterations: int = 0,
inner_iterations: int = 10,
max_iterations: int = 1000,
regularization: float = 1e-6
) -> jnp.ndarray:
return sqrtm(
x, threshold, min_iterations, inner_iterations, max_iterations,
regularization
)[0]


def sqrtm_only_fwd(
x: jnp.ndarray, threshold: float, min_iterations: int,
inner_iterations: int, max_iterations: int, regularization: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
sqrt_x = sqrtm(
x, threshold, min_iterations, inner_iterations, max_iterations,
regularization
)[0]
return sqrt_x, sqrt_x


def sqrtm_only_bwd(sqrt_x: jnp.ndarray,
cotangent: jnp.ndarray) -> Tuple[jnp.ndarray]:
def sqrtm_only_bwd(
threshold: float, min_iterations: int, inner_iterations: int,
max_iterations: int, regularization: float, sqrt_x: jnp.ndarray,
cotangent: jnp.ndarray
) -> Tuple[jnp.ndarray]:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
vjp = jnp.swapaxes(
solve_sylvester_bartels_stewart(
a=sqrt_x, b=-sqrt_x, c=jnp.swapaxes(cotangent, axis1=-2, axis2=-1)
Expand All @@ -260,18 +279,41 @@ def sqrtm_only_bwd(sqrt_x: jnp.ndarray,
sqrtm_only.defvjp(sqrtm_only_fwd, sqrtm_only_bwd)


@jax.custom_vjp
def inv_sqrtm_only(x: jnp.ndarray) -> jnp.ndarray:
return sqrtm(x)[1]
@functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5))
def inv_sqrtm_only(
x: jnp.ndarray,
threshold: float = 1e-6,
min_iterations: int = 0,
inner_iterations: int = 10,
max_iterations: int = 1000,
regularization: float = 1e-6
) -> jnp.ndarray:
return sqrtm(
x, threshold, min_iterations, inner_iterations, max_iterations,
regularization
)[1]


def inv_sqrtm_only_fwd(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
inv_sqrt_x = sqrtm(x)[1]
def inv_sqrtm_only_fwd(
x: jnp.ndarray,
threshold: float,
min_iterations: int,
inner_iterations: int,
max_iterations: int,
regularization: float,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
inv_sqrt_x = sqrtm(
x, threshold, min_iterations, inner_iterations, max_iterations,
regularization
)[1]
return inv_sqrt_x, inv_sqrt_x


def inv_sqrtm_only_bwd(residual: jnp.ndarray,
cotangent: jnp.ndarray) -> Tuple[jnp.ndarray]:
def inv_sqrtm_only_bwd(
threshold: float, min_iterations: int, inner_iterations: int,
max_iterations: int, regularization: float, residual: jnp.ndarray,
cotangent: jnp.ndarray
) -> Tuple[jnp.ndarray]:
inv_sqrt_x = residual
inv_x = jnp.matmul(inv_sqrt_x, inv_sqrt_x)
vjp = jnp.swapaxes(
Expand Down
2 changes: 1 addition & 1 deletion tests/tools/gaussian_mixture/gaussian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_w2_dist(self, rng: jnp.ndarray):
delta_mean = jnp.sum((loc1 - loc0) ** 2., axis=-1)
delta_sigma = jnp.sum((jnp.sqrt(diag0) - jnp.sqrt(diag1)) ** 2.)
expected = delta_mean + delta_sigma
np.testing.assert_allclose(expected, w2)
np.testing.assert_allclose(expected, w2, rtol=1e-6, atol=1e-6)

def test_transport(self, rng: jnp.ndarray):
diag0 = jnp.array([1.])
Expand Down