Skip to content

Commit

Permalink
A few changes to increase Coverage (#390)
Browse files Browse the repository at this point in the history
* utils + fix potentials' distance.

* remove stale debiased flag in continuous bary

* increase coverage of fixed barycenter

* fix pydocs

* reformat pydoc + clean up assumptions for corr

* fix

* fix
  • Loading branch information
marcocuturi authored Jul 7, 2023
1 parent db9f58a commit 172ba77
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 67 deletions.
5 changes: 4 additions & 1 deletion src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:

@jax.tree_util.register_pytree_node_class
class SqEuclidean(TICost):
"""Squared Euclidean distance."""
r"""Squared Euclidean distance.
Implemented as a translation invariant cost, :math:`h(z) = \|z\|^2`.
"""

def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:
"""Compute squared Euclidean norm for vector."""
Expand Down
43 changes: 7 additions & 36 deletions src/ott/problems/linear/barycenter_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ class FreeBarycenterProblem:
cost_fn: Cost function used. If `None`,
use the :class:`~ott.geometry.costs.SqEuclidean` cost.
epsilon: Epsilon regularization used to solve reg-OT problems.
debiased: **Currently not implemented.**
Whether the problem is debiased, in the sense that
the regularized transportation cost of barycenter to itself will
be considered when computing gradient. Note that if the debiased option
is used, the barycenter size needs to be smaller than the maximum measure
size for parallelization to operate efficiently.
kwargs: Keyword arguments :func:`~ott.geometry.segment.segment_point_cloud`.
Only used when ``y`` is not already segmented. When passing
``segment_ids``, 2 arguments must be specified for jitting to work:
Expand All @@ -61,7 +55,6 @@ def __init__(
weights: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
debiased: bool = False,
**kwargs: Any,
):
self._y = y
Expand All @@ -71,7 +64,6 @@ def __init__(
self._weights = weights
self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn
self.epsilon = epsilon
self.debiased = debiased
self._kwargs = kwargs

if self._is_segmented:
Expand All @@ -87,10 +79,8 @@ def __init__(
def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Tuple of arrays containing the segmented measures and weights.
Additional segment may be added when the problem is debiased.
- Segmented measures of shape ``[num_measures, max_measure_size, ndim]``.
- Segmented weights of shape ``[num_measures, max_measure_size]``.
- Segmented measures of shape ``[num_measures, max_measure_size, ndim]``.
- Segmented weights of shape ``[num_measures, max_measure_size]``.
"""
if self._is_segmented:
y, b = self._y, self._b
Expand All @@ -101,20 +91,6 @@ def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
padding_vector=self.cost_fn._padder(self.ndim),
**self._kwargs
)

if self.debiased:
return self._add_slice_for_debiased(y, b)
return y, b

@staticmethod
def _add_slice_for_debiased(
y: jnp.ndarray, b: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
_, n, ndim = y.shape # (num_measures, max_measure_size, ndim)
# yapf: disable
y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0)
b = jnp.concatenate((b, jnp.zeros((1, n))), axis=0)
# yapf: enable
return y, b

@property
Expand Down Expand Up @@ -148,15 +124,11 @@ def ndim(self) -> int:
def weights(self) -> jnp.ndarray:
"""Barycenter weights of shape ``[num_measures,]`` that sum to 1."""
if self._weights is None:
weights = jnp.ones((self.num_measures,)) / self.num_measures
else:
# Check that the number of measures coincides with the weights' size.
assert self._weights.shape[0] == self.num_measures
# By default, we assume that weights sum to 1, and enforce this if needed.
weights = self._weights / jnp.sum(self._weights)
if self.debiased:
return jnp.concatenate((weights, jnp.array([-0.5])))
return weights
return jnp.ones((self.num_measures,)) / self.num_measures
# Check that the number of measures coincides with the weights' size.
assert self._weights.shape[0] == self.num_measures
# By default, we assume that weights sum to 1, and enforce this if needed.
return self._weights / jnp.sum(self._weights)

@property
def _is_segmented(self) -> bool:
Expand All @@ -166,7 +138,6 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return ([self._y, self._b, self._weights], {
"cost_fn": self.cost_fn,
"epsilon": self.epsilon,
"debiased": self.debiased,
**self._kwargs,
})

Expand Down
42 changes: 22 additions & 20 deletions src/ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -28,11 +27,9 @@
import jax.tree_util as jtu
import numpy as np

from ott.geometry import costs
from ott.problems.linear import linear_problem

if TYPE_CHECKING:
from ott.geometry import costs

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -64,11 +61,14 @@ def __init__(
f: Potential_t,
g: Potential_t,
*,
cost_fn: "costs.CostFn",
cost_fn: costs.CostFn,
corr: bool = False
):
self._f = f
self._g = g
assert (
not corr or type(cost_fn) == costs.SqEuclidean
), "Duals in `corr` form can only be used with a squared-Euclidean cost."
self.cost_fn = cost_fn
self._corr = corr

Expand Down Expand Up @@ -106,32 +106,34 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
return vec - self._grad_h_inv(self._grad_g(vec))

def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
"""Evaluate 2-Wasserstein distance between samples using dual potentials.
r"""Evaluate Wasserstein distance between samples using dual potentials.
This uses direct estimation of potentials against measures when dual
functions are provided in usual form. This expression is valid for any
cost function.
Uses Eq. 5 from :cite:`makkuva:20` when given in `corr` form, direct
estimation by integrating dual function against points when using dual form.
When potentials are given in correlation form, as specified by the flag
``corr``, the dual potentials solve the dual problem corresponding to the
minimization of the primal OT problem where the ground cost is
:math:`-2\langle x,y\rangle`. To recover the (squared) 2-Wasserstein
distance, terms are re-arranged and contributions from squared norms are
taken into account.
Args:
src: Samples from the source distribution, array of shape ``[n, d]``.
tgt: Samples from the target distribution, array of shape ``[m, d]``.
Returns:
Wasserstein distance.
Wasserstein distance using specified cost function.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)
f = jax.vmap(self.f)

if self._corr:
grad_g_y = self._grad_g(tgt)
term1 = -jnp.mean(f(src))
term2 = -jnp.mean(jnp.sum(tgt * grad_g_y, axis=-1) - f(grad_g_y))

C = jnp.mean(jnp.sum(src ** 2, axis=-1))
C += jnp.mean(jnp.sum(tgt ** 2, axis=-1))
return 2. * (term1 + term2) + C

g = jax.vmap(self.g)
return jnp.mean(f(src)) + jnp.mean(g(tgt))
out = jnp.mean(f(src)) + jnp.mean(g(tgt))
if self._corr:
out = -2.0 * out + jnp.mean(jnp.sum(src ** 2, axis=-1))
out += jnp.mean(jnp.sum(tgt ** 2, axis=-1))
return out

@property
def f(self) -> Potential_t:
Expand Down
6 changes: 0 additions & 6 deletions src/ott/solvers/linear/continuous_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ def solve_linear_ot(
out.errors if store_errors else None
)

if bar_prob.debiased:
raise NotImplementedError(
"Debiased version of continuous Wasserstein barycenter "
"not yet implemented."
)

reg_ot_costs, convergeds, matrices, errors = solve_linear_ot(
self.a, self.x, seg_b, seg_y
)
Expand Down
20 changes: 17 additions & 3 deletions tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,31 @@ def test_entropic_potentials_dist(
x = g1.sample(rng1, n1)
y = g2.sample(rng2, n2)

geom = pointcloud.PointCloud(x, y, epsilon=eps)
g1.sample(rng3, n1)
g2.sample(rng4, n2)

geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=costs.SqEuclidean())
prob = linear_problem.LinearProblem(geom)
out = sinkhorn.Sinkhorn()(prob)
assert out.converged
potentials = out.to_dual_potentials()
dual_potentials = out.to_dual_potentials()

expected_dist = jnp.sum(out.matrix * geom.cost_matrix)
actual_dist = potentials.distance(x, y)
actual_dist = dual_potentials.distance(x, y)
rel_error = jnp.abs(expected_dist - actual_dist) / expected_dist
assert rel_error < 2 * eps

# Try with potentials in correlation form
f_cor = lambda x: 0.5 * jnp.sum(x ** 2) - 0.5 * dual_potentials.f(x)
g_cor = lambda x: 0.5 * jnp.sum(x ** 2) - 0.5 * dual_potentials.g(x)
dual_potentials_corr = potentials.DualPotentials(
f=f_cor, g=g_cor, cost_fn=dual_potentials.cost_fn, corr=True
)
actual_dist_cor = dual_potentials_corr.distance(x, y)
rel_error = jnp.abs(expected_dist - actual_dist_cor) / expected_dist
assert rel_error < 2 * eps
assert jnp.abs(actual_dist_cor - actual_dist) < 1e-5

@pytest.mark.fast.with_args(forward=[False, True], only_fast=0)
def test_entropic_potentials_displacement(
self, rng: jax.random.PRNGKeyArray, forward: bool
Expand Down
4 changes: 3 additions & 1 deletion tests/solvers/linear/discrete_barycenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def test_discrete_barycenter_grid(
a2 /= jnp.sum(a2)
threshold = 1e-2

fixed_bp = bp.FixedBarycenterProblem(geom=grid_3d, a=jnp.stack((a1, a2)))
fixed_bp = bp.FixedBarycenterProblem(
geom=grid_3d, a=jnp.stack((a1, a2)), weights=jnp.array([0.5, 0.5])
)
solver = db.FixedBarycenter(
threshold=threshold, lse_mode=lse_mode, debiased=debiased
)
Expand Down
7 changes: 7 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Optional

import jax.numpy as jnp
import pytest
from ott import utils

Expand All @@ -36,3 +37,9 @@ def func() -> int:
with pytest.warns(DeprecationWarning, match=expected_msg):
res = func()
assert res == 42


def test_is_jax_array():
x = jnp.array([0.0])
assert utils.is_jax_array(x)
assert not utils.is_jax_array(0)

0 comments on commit 172ba77

Please sign in to comment.