Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of gangbo-mccann map estimators using twist operator #500

Merged
merged 5 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ @article{vayer:20

@article{demetci:22,
author = {Demetci, Pinar and Santorella, Rebecca and Sandstede, Björn and Noble, William Stafford and Singh, Ritambhara},
doi = {10.1089/cmb.2021.0446},
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
journal = {Journal of Computational Biology},
note = {PMID: 35050714},
number = {1},
Expand Down
3 changes: 3 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Datasets
Dykstra
Fenchel
Frobenius
Gangbo
Gangbo-McCann
Gaussians
Gromov
Hessians
Expand All @@ -19,6 +21,7 @@ Kantorovich
Kullback
Leibler
Mahalanobis
McCann
Monge
Moreau
SGD
Expand Down
29 changes: 29 additions & 0 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,27 @@ def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""
return jax.vmap(lambda x_: jax.vmap(lambda y_: self.pairwise(x_, y_))(y))(x)

def twist_operator(
self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
) -> jnp.ndarray:
r"""Twist inverse operator of the cost function.

Given a cost function :math:`c`, the twist operator returns
:math:`\nabla_{1}c(x, \cdot)^{-1}(z)` if ``variable`` is ``False``,
and :math:`\nabla_{2}c(\cdot, y)^{-1}(z)` if ``variable`` is ``True``, for
:math:`x=y` equal to ``vec`` and :math:`z` equal to ``dual_vec``.

Args:
vec: ``[p,]`` point at which the twist inverse operator is evaluated.
dual_vec: ``[q,]`` point to invert by the operator.
variable: apply twist inverse operator on first (``False``) or
second (``True``) variable.

Returns:
A vector.
"""
raise NotImplementedError("Twist operator is not implemented.")

def tree_flatten(self): # noqa: D102
return (), None

Expand Down Expand Up @@ -182,6 +203,14 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute cost as evaluation of :func:`h` on :math:`x-y`."""
return self.h(x - y)

def twist_operator(
self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool
) -> jnp.ndarray:
# Note: when `h` is pair, i.e. h(z) = h(-z), the expressions below coincide
if variable:
return vec + jax.grad(self.h_legendre)(-dual_vec)
return vec - jax.grad(self.h_legendre)(dual_vec)


@jax.tree_util.register_pytree_node_class
class SqPNorm(TICost):
Expand Down
37 changes: 16 additions & 21 deletions src/ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,27 @@ def __init__(
self._corr = corr

def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
r"""Transport ``vec`` according to Brenier formula :cite:`brenier:91`.
r"""Transport ``vec`` according to Gangbo-McCann Brenier :cite:`brenier:91`.

Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when
given the Legendre transform of the dual potentials.

That OT map can be recovered as :math:`x- (\nabla h^*)\circ \nabla f(x)`,
Uses Proposition 1.15 from :cite:`santambrogio:15` to compute an OT map when
applying the inverse gradient of cost. When the cost is translation
invariant, :math:`c(x,y)=h(x-y)`, this translates to the application of the
convex conjugate of :math:`h` to the gradient of the dual potentials,
namely :math:`x- (\nabla h^*)\circ \nabla f(x)` for the forward map,
where :math:`h^*` is the Legendre transform of :math:`h`. For instance,
in the case :math:`h(\cdot) = \|\cdot\|^2, \nabla h(\cdot) = 2 \cdot\,`,
one has :math:`h^*(\cdot) = \|.\|^2 / 4`, and therefore
:math:`\nabla h^*(\cdot) = 0.5 \cdot\,`.

When the dual potentials are solved in correlation form (only in the Sq.
Euclidean distance case), the maps are :math:`\nabla g` for forward,
:math:`\nabla f` for backward.
Note:
When the dual potentials are solved in correlation form (this formulation
is only relevant in the (important) particular case when the cost is
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
the squared-Euclidean distance), the maps are :math:`\nabla g` for
forward, :math:`\nabla f` for backward map.

Args:
vec: Points to transport, array of shape ``[n, d]``.
forward: Whether to transport the points from source to the target
forward: Whether to transport the points from source to the target
distribution or vice-versa.

Returns:
Expand All @@ -99,11 +102,13 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
from ott.geometry import costs

vec = jnp.atleast_2d(vec)

if self._corr and isinstance(self.cost_fn, costs.SqEuclidean):
return self._grad_f(vec) if forward else self._grad_g(vec)
twist_op = jax.vmap(self.cost_fn.twist_operator, in_axes=[0, 0, None])
if forward:
return vec - self._grad_h_inv(self._grad_f(vec))
return vec - self._grad_h_inv(self._grad_g(vec))
return twist_op(vec, self._grad_f(vec), 0)
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
return twist_op(vec, self._grad_g(vec), 1)

def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
r"""Evaluate Wasserstein distance between samples using dual potentials.
Expand Down Expand Up @@ -155,16 +160,6 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Vectorized gradient of the potential function :attr:`g`."""
return jax.vmap(jax.grad(self.g, argnums=0))

@property
def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
from ott.geometry import costs

assert isinstance(self.cost_fn, costs.TICost), (
"Cost must be a `TICost` and "
"provide access to Legendre transform of `h`."
)
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [], {
"f": self._f,
Expand Down
Loading