diff --git a/docs/conf.py b/docs/conf.py index 964d349f8..69ef540ee 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -120,6 +120,7 @@ # linkcheck linkcheck_ignore = [ # 403 Client Error + "https://doi.org/10.1089/cmb.2021.0446" "https://www.jstor.org/stable/3647580", "https://doi.org/10.1137/19M1301047", "https://doi.org/10.1137/17M1140431", diff --git a/docs/references.bib b/docs/references.bib index 4497a8f43..0c9899bdb 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -169,7 +169,6 @@ @article{vayer:20 @article{demetci:22, author = {Demetci, Pinar and Santorella, Rebecca and Sandstede, Björn and Noble, William Stafford and Singh, Ritambhara}, - doi = {10.1089/cmb.2021.0446}, journal = {Journal of Computational Biology}, note = {PMID: 35050714}, number = {1}, diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index aba1108e2..7c7ba4ae9 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -9,6 +9,8 @@ Datasets Dykstra Fenchel Frobenius +Gangbo +Gangbo-McCann Gaussians Gromov Hessians @@ -19,6 +21,7 @@ Kantorovich Kullback Leibler Mahalanobis +McCann Monge Moreau SGD diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 0125af75c..7e1cfc57c 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -138,6 +138,27 @@ def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """ return jax.vmap(lambda x_: jax.vmap(lambda y_: self.pairwise(x_, y_))(y))(x) + def twist_operator( + self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool + ) -> jnp.ndarray: + r"""Twist inverse operator of the cost function. + + Given a cost function :math:`c`, the twist operator returns + :math:`\nabla_{1}c(x, \cdot)^{-1}(z)` if ``variable`` is ``0``, + and :math:`\nabla_{2}c(\cdot, y)^{-1}(z)` if ``variable`` is ``1``, for + :math:`x=y` equal to ``vec`` and :math:`z` equal to ``dual_vec``. + + Args: + vec: ``[p,]`` point at which the twist inverse operator is evaluated. + dual_vec: ``[q,]`` point to invert by the operator. + variable: apply twist inverse operator on first (i.e. value set to ``0`` + or equivalently ``False``) or second (``1`` or ``True``) variable. + + Returns: + A vector. + """ + raise NotImplementedError("Twist operator is not implemented.") + def tree_flatten(self): # noqa: D102 return (), None @@ -182,6 +203,14 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) + def twist_operator( + self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool + ) -> jnp.ndarray: + # Note: when `h` is pair, i.e. h(z) = h(-z), the expressions below coincide + if variable: + return vec + jax.grad(self.h_legendre)(-dual_vec) + return vec - jax.grad(self.h_legendre)(dual_vec) + @jax.tree_util.register_pytree_node_class class SqPNorm(TICost): diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index 688515277..20db89cc7 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -73,24 +73,35 @@ def __init__( self._corr = corr def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: - r"""Transport ``vec`` according to Brenier formula :cite:`brenier:91`. + r"""Transport ``vec`` according to Gangbo-McCann Brenier :cite:`brenier:91`. - Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when - given the Legendre transform of the dual potentials. + Uses Proposition 1.15 from :cite:`santambrogio:15` to compute an OT map when + applying the inverse gradient of cost. - That OT map can be recovered as :math:`x- (\nabla h^*)\circ \nabla f(x)`, + When the cost is a general cost, the operator uses the + :meth:`~ott.geometry.costs.CostFn.twist_operator` associated of the + corresponding :class:`~ott.geometry.costs.CostFn`. + + When the cost is a translation invariant :class:`~ott.geometry.costs.TICost` + cost, :math:`c(x,y)=h(x-y)`, and the twist operator translates to the + application of the convex conjugate of :math:`h` to the + gradient of the dual potentials, namely + :math:`x- (\nabla h^*)\circ \nabla f(x)` for the forward map, where :math:`h^*` is the Legendre transform of :math:`h`. For instance, in the case :math:`h(\cdot) = \|\cdot\|^2, \nabla h(\cdot) = 2 \cdot\,`, one has :math:`h^*(\cdot) = \|.\|^2 / 4`, and therefore :math:`\nabla h^*(\cdot) = 0.5 \cdot\,`. - When the dual potentials are solved in correlation form (only in the Sq. - Euclidean distance case), the maps are :math:`\nabla g` for forward, - :math:`\nabla f` for backward. + Note: + When the dual potentials are solved in correlation form, and marked + accordingly by setting ``corr`` to ``True``, the maps are + :math:`\nabla g` for forward, :math:`\nabla f` for backward map. This can + only make sense when using the squared-Euclidean + :class:`~ott.geometry.costs.SqEuclidean` cost. Args: vec: Points to transport, array of shape ``[n, d]``. - forward: Whether to transport the points from source to the target + forward: Whether to transport the points from source to the target distribution or vice-versa. Returns: @@ -99,11 +110,13 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: from ott.geometry import costs vec = jnp.atleast_2d(vec) + if self._corr and isinstance(self.cost_fn, costs.SqEuclidean): return self._grad_f(vec) if forward else self._grad_g(vec) + twist_op = jax.vmap(self.cost_fn.twist_operator, in_axes=[0, 0, None]) if forward: - return vec - self._grad_h_inv(self._grad_f(vec)) - return vec - self._grad_h_inv(self._grad_g(vec)) + return twist_op(vec, self._grad_f(vec), 0) + return twist_op(vec, self._grad_g(vec), 1) def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float: r"""Evaluate Wasserstein distance between samples using dual potentials. @@ -155,16 +168,6 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]: """Vectorized gradient of the potential function :attr:`g`.""" return jax.vmap(jax.grad(self.g, argnums=0)) - @property - def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]: - from ott.geometry import costs - - assert isinstance(self.cost_fn, costs.TICost), ( - "Cost must be a `TICost` and " - "provide access to Legendre transform of `h`." - ) - return jax.vmap(jax.grad(self.cost_fn.h_legendre)) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [], { "f": self._f,