-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Addition of gangbo-mccann map estimators using twist operator (#500)
* addition of gangbo-mccann map estimators using twist operator * docs * docs * link * changes after review
- Loading branch information
1 parent
cbff1c7
commit b0e809a
Showing
5 changed files
with
56 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,6 +138,27 @@ def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: | |
""" | ||
return jax.vmap(lambda x_: jax.vmap(lambda y_: self.pairwise(x_, y_))(y))(x) | ||
|
||
def twist_operator( | ||
self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool | ||
) -> jnp.ndarray: | ||
r"""Twist inverse operator of the cost function. | ||
Given a cost function :math:`c`, the twist operator returns | ||
:math:`\nabla_{1}c(x, \cdot)^{-1}(z)` if ``variable`` is ``0``, | ||
and :math:`\nabla_{2}c(\cdot, y)^{-1}(z)` if ``variable`` is ``1``, for | ||
:math:`x=y` equal to ``vec`` and :math:`z` equal to ``dual_vec``. | ||
Args: | ||
vec: ``[p,]`` point at which the twist inverse operator is evaluated. | ||
dual_vec: ``[q,]`` point to invert by the operator. | ||
variable: apply twist inverse operator on first (i.e. value set to ``0`` | ||
This comment has been minimized.
Sorry, something went wrong.
michalk8
Collaborator
|
||
or equivalently ``False``) or second (``1`` or ``True``) variable. | ||
Returns: | ||
A vector. | ||
""" | ||
raise NotImplementedError("Twist operator is not implemented.") | ||
|
||
def tree_flatten(self): # noqa: D102 | ||
return (), None | ||
|
||
|
@@ -182,6 +203,14 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: | |
"""Compute cost as evaluation of :func:`h` on :math:`x-y`.""" | ||
return self.h(x - y) | ||
|
||
def twist_operator( | ||
self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool | ||
) -> jnp.ndarray: | ||
# Note: when `h` is pair, i.e. h(z) = h(-z), the expressions below coincide | ||
if variable: | ||
return vec + jax.grad(self.h_legendre)(-dual_vec) | ||
return vec - jax.grad(self.h_legendre)(dual_vec) | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
class SqPNorm(TICost): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@marcocuturi this should be undone.