diff --git a/docs/tutorials/neural/200_Monge_Gap.ipynb b/docs/tutorials/neural/200_Monge_Gap.ipynb index 5dac8c025..bb88c1725 100644 --- a/docs/tutorials/neural/200_Monge_Gap.ipynb +++ b/docs/tutorials/neural/200_Monge_Gap.ipynb @@ -50,7 +50,7 @@ "\n", "The first requirement (efficiency) can be quantified with the **Monge gap** $\\mathcal{M}_\\mu^c$, a non-negative regularizer defined through $\\mu$ and $c$, and which takes as an argument any map $T : \\mathbb{R}^d \\rightarrow \\mathbb{R}^d$. The value $\\mathcal{M}_\\mu^c(T)$ quantifies how $T$ moves mass efficiently between $\\mu$ and $T \\sharp \\mu$, and only cancels $\\mathcal{M}_\\mu^c(T) = 0$ i.f.f. $T$ is optimal between $\\mu$ and $T \\sharp \\mu$ for the cost $c$.\n", "\n", - "The second requirement (landing on $\\nu$) is then simply handled using a fitting loss $\\Delta$ between $T \\sharp \\mu$ and $\\nu$. This can be measured, e.g., using a {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`. Introducing a regularization strength $\\lambda_\\mathrm{MG} > 0$, looking for a Monge map can be reformulated as finding a $T$ that minimizes:\n", + "The second requirement (landing on $\\nu$) is then simply handled using a fitting loss $\\Delta$ between $T \\sharp \\mu$ and $\\nu$. This can be measured, e.g., using the Sinkhorn divergence, {func}`~ott.tools.sinkhorn_divergence.sinkdiv`. Introducing a regularization strength $\\lambda_\\mathrm{MG} > 0$, looking for a Monge map can be reformulated as finding a $T$ that minimizes:\n", "\n", "$$\n", "\\min_{T:\\mathbb{R}^d \\rightarrow \\mathbb{R}^d} \\Delta(T\\sharp \\mu, \\nu) + \\lambda_\\mathrm{MG} \\mathcal{M}_\\mu^c(T)\n", @@ -324,8 +324,7 @@ "$$\n", "\\min_{T:\\mathbb{R}^d \\rightarrow \\mathbb{R}^d} \\Delta(T\\sharp \\mu, \\nu) + \\lambda_\\mathrm{MG} \\mathcal{M}_\\mu^c(T)\n", "$$\n", - "For all fittings, we use $\\Delta = S_{\\varepsilon, \\ell_2^2}$, the {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` with the {class}`squared Euclidean cost `\n", - "The function considers a ground cost function `cost_fn` (corresponding to $c$), as well as the `epsilon` regularization parameters to compute approximated Wasserstein distances, both for fitting and regularizer." + "For all fittings, we use $\\Delta = S_{\\varepsilon, \\ell_2^2}$, the :term:`Sinkhorn divergence`, {func}`~ott.tools.sinkhorn_divergence.sinkdiv` with the {class}`squared Euclidean cost ` :term:`ground cost` function `cost_fn` (corresponding to $c$), as well as the `epsilon` regularization parameters to compute approximated Wasserstein distances, both for fitting and regularizer." ] }, { @@ -359,8 +358,8 @@ "\n", " @jax.jit\n", " def fitting_loss(x, y):\n", - " div, out = sinkhorn_divergence.sinkhorn_divergence(\n", - " pointcloud.PointCloud, x, y, epsilon=epsilon_fitting, static_b=True\n", + " div, out = sinkhorn_divergence.sinkdiv(\n", + " x, y, epsilon=epsilon_fitting, static_b=True\n", " )\n", " return div, out.n_iters\n", "\n", diff --git a/src/ott/neural/methods/monge_gap.py b/src/ott/neural/methods/monge_gap.py index 71e3446fb..6dbf7c0d9 100644 --- a/src/ott/neural/methods/monge_gap.py +++ b/src/ott/neural/methods/monge_gap.py @@ -184,7 +184,7 @@ class MongeGapEstimator: sets of points. For instance, :math:`\Delta` can be the - :func:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` + :func:`~ott.tools.sinkhorn_divergence.sinkdiv` and :math:`R` the :func:`~ott.neural.methods.monge_gap.monge_gap_from_samples` :cite:`uscidda:23` for a given cost function :math:`c`. In that case, it estimates a :math:`c`-OT map, i.e. a map :math:`T` @@ -260,7 +260,8 @@ def setup( def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: """Regularizer added to the fitting loss. - Can be, e.g. the :func:`~ott.neural.methods.monge_gap.monge_gap_from_samples`. + Can be, e.g. the + :func:`~ott.neural.methods.monge_gap.monge_gap_from_samples`. If no regularizer is passed for solver instantiation, or regularization weight :attr:`regularizer_strength` is 0, return 0 by default along with an empty set of log values. @@ -273,7 +274,7 @@ def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: def fitting_loss(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: """Fitting loss to fit the marginal constraint. - Can be, e.g. :func:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`. + Can be, e.g. :func:`~ott.tools.sinkhorn_divergence.sinkdiv`. If no fitting_loss is passed for solver instantiation, return 0 by default, and no log values. """ diff --git a/tests/neural/methods/monge_gap_test.py b/tests/neural/methods/monge_gap_test.py index 8a546d890..397bcaa19 100644 --- a/tests/neural/methods/monge_gap_test.py +++ b/tests/neural/methods/monge_gap_test.py @@ -20,7 +20,7 @@ import numpy as np from ott import datasets -from ott.geometry import costs, pointcloud, regularizers +from ott.geometry import costs, regularizers from ott.neural.methods import monge_gap from ott.neural.networks import potentials from ott.tools import sinkhorn_divergence @@ -144,8 +144,7 @@ def fitting_loss( mapped_samples: jnp.ndarray, ) -> Optional[float]: r"""Sinkhorn divergence fitting loss.""" - div, _ = sinkhorn_divergence.sinkhorn_divergence( - pointcloud.PointCloud, + div, _ = sinkhorn_divergence.sinkdiv( x=samples, y=mapped_samples, )