Skip to content

Commit

Permalink
Addition of gangbo-mccann map estimators using twist operator (#500)
Browse files Browse the repository at this point in the history
* addition of gangbo-mccann map estimators using twist operator

* docs

* docs

* link

* changes after review
  • Loading branch information
marcocuturi authored Mar 12, 2024
1 parent cbff1c7 commit b0e809a
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
# linkcheck
linkcheck_ignore = [
# 403 Client Error
"https://doi.org/10.1089/cmb.2021.0446"
"https://www.jstor.org/stable/3647580",
"https://doi.org/10.1137/19M1301047",
"https://doi.org/10.1137/17M1140431",
Expand Down
1 change: 0 additions & 1 deletion docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ @article{vayer:20

@article{demetci:22,
author = {Demetci, Pinar and Santorella, Rebecca and Sandstede, Björn and Noble, William Stafford and Singh, Ritambhara},
doi = {10.1089/cmb.2021.0446},

This comment has been minimized.

Copy link
@michalk8

michalk8 Mar 12, 2024

Collaborator

@marcocuturi this should be undone.

journal = {Journal of Computational Biology},
note = {PMID: 35050714},
number = {1},
Expand Down
3 changes: 3 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Datasets
Dykstra
Fenchel
Frobenius
Gangbo
Gangbo-McCann
Gaussians
Gromov
Hessians
Expand All @@ -19,6 +21,7 @@ Kantorovich
Kullback
Leibler
Mahalanobis
McCann
Monge
Moreau
SGD
Expand Down
29 changes: 29 additions & 0 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,27 @@ def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""
return jax.vmap(lambda x_: jax.vmap(lambda y_: self.pairwise(x_, y_))(y))(x)

def twist_operator(
self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool
) -> jnp.ndarray:
r"""Twist inverse operator of the cost function.
Given a cost function :math:`c`, the twist operator returns
:math:`\nabla_{1}c(x, \cdot)^{-1}(z)` if ``variable`` is ``0``,
and :math:`\nabla_{2}c(\cdot, y)^{-1}(z)` if ``variable`` is ``1``, for
:math:`x=y` equal to ``vec`` and :math:`z` equal to ``dual_vec``.
Args:
vec: ``[p,]`` point at which the twist inverse operator is evaluated.
dual_vec: ``[q,]`` point to invert by the operator.
variable: apply twist inverse operator on first (i.e. value set to ``0``

This comment has been minimized.

Copy link
@michalk8

michalk8 Mar 12, 2024

Collaborator

Since variable is bool, let's drop mentioning 0/1 in the docstring. Also I still maintain that the name variable is too generic.

or equivalently ``False``) or second (``1`` or ``True``) variable.
Returns:
A vector.
"""
raise NotImplementedError("Twist operator is not implemented.")

def tree_flatten(self): # noqa: D102
return (), None

Expand Down Expand Up @@ -182,6 +203,14 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute cost as evaluation of :func:`h` on :math:`x-y`."""
return self.h(x - y)

def twist_operator(
self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool
) -> jnp.ndarray:
# Note: when `h` is pair, i.e. h(z) = h(-z), the expressions below coincide
if variable:
return vec + jax.grad(self.h_legendre)(-dual_vec)
return vec - jax.grad(self.h_legendre)(dual_vec)


@jax.tree_util.register_pytree_node_class
class SqPNorm(TICost):
Expand Down
43 changes: 23 additions & 20 deletions src/ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,35 @@ def __init__(
self._corr = corr

def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
r"""Transport ``vec`` according to Brenier formula :cite:`brenier:91`.
r"""Transport ``vec`` according to Gangbo-McCann Brenier :cite:`brenier:91`.
Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when
given the Legendre transform of the dual potentials.
Uses Proposition 1.15 from :cite:`santambrogio:15` to compute an OT map when
applying the inverse gradient of cost.
That OT map can be recovered as :math:`x- (\nabla h^*)\circ \nabla f(x)`,
When the cost is a general cost, the operator uses the
:meth:`~ott.geometry.costs.CostFn.twist_operator` associated of the
corresponding :class:`~ott.geometry.costs.CostFn`.
When the cost is a translation invariant :class:`~ott.geometry.costs.TICost`
cost, :math:`c(x,y)=h(x-y)`, and the twist operator translates to the
application of the convex conjugate of :math:`h` to the
gradient of the dual potentials, namely
:math:`x- (\nabla h^*)\circ \nabla f(x)` for the forward map,
where :math:`h^*` is the Legendre transform of :math:`h`. For instance,
in the case :math:`h(\cdot) = \|\cdot\|^2, \nabla h(\cdot) = 2 \cdot\,`,
one has :math:`h^*(\cdot) = \|.\|^2 / 4`, and therefore
:math:`\nabla h^*(\cdot) = 0.5 \cdot\,`.
When the dual potentials are solved in correlation form (only in the Sq.
Euclidean distance case), the maps are :math:`\nabla g` for forward,
:math:`\nabla f` for backward.
Note:
When the dual potentials are solved in correlation form, and marked
accordingly by setting ``corr`` to ``True``, the maps are
:math:`\nabla g` for forward, :math:`\nabla f` for backward map. This can
only make sense when using the squared-Euclidean
:class:`~ott.geometry.costs.SqEuclidean` cost.
Args:
vec: Points to transport, array of shape ``[n, d]``.
forward: Whether to transport the points from source to the target
forward: Whether to transport the points from source to the target
distribution or vice-versa.
Returns:
Expand All @@ -99,11 +110,13 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
from ott.geometry import costs

vec = jnp.atleast_2d(vec)

if self._corr and isinstance(self.cost_fn, costs.SqEuclidean):
return self._grad_f(vec) if forward else self._grad_g(vec)
twist_op = jax.vmap(self.cost_fn.twist_operator, in_axes=[0, 0, None])
if forward:
return vec - self._grad_h_inv(self._grad_f(vec))
return vec - self._grad_h_inv(self._grad_g(vec))
return twist_op(vec, self._grad_f(vec), 0)

This comment has been minimized.

Copy link
@michalk8

michalk8 Mar 12, 2024

Collaborator

Unless jax breaks, why not pass False/True.

return twist_op(vec, self._grad_g(vec), 1)

def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
r"""Evaluate Wasserstein distance between samples using dual potentials.
Expand Down Expand Up @@ -155,16 +168,6 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Vectorized gradient of the potential function :attr:`g`."""
return jax.vmap(jax.grad(self.g, argnums=0))

@property
def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
from ott.geometry import costs

assert isinstance(self.cost_fn, costs.TICost), (
"Cost must be a `TICost` and "
"provide access to Legendre transform of `h`."
)
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [], {
"f": self._f,
Expand Down

0 comments on commit b0e809a

Please sign in to comment.