diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index 8d0c968cf..21588b4d4 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -1,4 +1,13 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Sequence, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Literal, + Optional, + Sequence, + Tuple, +) import jax import jax.numpy as jnp @@ -154,65 +163,93 @@ class EntropicPotentials(DualPotentials): """Dual potential functions from finite samples :cite:`pooladian:21`. Args: - f: The first dual potential vector of shape ``[n,]``. - g: The second dual potential vector of shape ``[m,]``. + f_xy: The first dual potential vector of shape ``[n,]``. + g_xy: The second dual potential vector of shape ``[m,]``. prob: Linear problem with :class:`~ott.geometry.pointcloud.PointCloud` geometry that was used to compute the dual potentials using, e.g., :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. + f_xx: The first dual potential vector of shape ``[n,]`` used for debiasing + :cite:`pooladian:22`. + g_yy: The second dual potential vector of shape ``[m,]`` used for debiasing. """ def __init__( self, - f: jnp.ndarray, - g: jnp.ndarray, + f_xy: jnp.ndarray, + g_xy: jnp.ndarray, prob: linear_problem.LinearProblem, + f_xx: Optional[jnp.ndarray] = None, + g_yy: Optional[jnp.ndarray] = None, ): # we pass directly the arrays and override the properties # since only the properties need to be callable - super().__init__(f, g, cost_fn=prob.geom.cost_fn, corr=False) + super().__init__(f_xy, g_xy, cost_fn=prob.geom.cost_fn, corr=False) self._prob = prob + self._f_xx = f_xx + self._g_yy = g_yy @property def f(self) -> Potential_t: - return self._create_potential_function(kind="f") + return self._potential_fn(kind="f") @property def g(self) -> Potential_t: - return self._create_potential_function(kind="g") + return self._potential_fn(kind="g") - def _create_potential_function( - self, *, kind: Literal["f", "g"] - ) -> Potential_t: + def _potential_fn(self, *, kind: Literal["f", "g"]) -> Potential_t: from ott.geometry import pointcloud - def callback(x: jnp.ndarray) -> float: - cost = pointcloud.PointCloud( - jnp.atleast_2d(x), - y, - cost_fn=self.cost_fn, - ).cost_matrix + def callback( + x: jnp.ndarray, + *, + potential: jnp.ndarray, + y: jnp.ndarray, + weights: jnp.ndarray, + epsilon: float, + ) -> float: + x = jnp.atleast_2d(x) + assert x.shape[-1] == y.shape[-1], (x.shape, y.shape) + geom = pointcloud.PointCloud(x, y, cost_fn=self.cost_fn) + cost = geom.cost_matrix z = (potential - cost) / epsilon - lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1) + lse = -epsilon * jsp.special.logsumexp(z, b=weights, axis=-1) return jnp.squeeze(lse) assert isinstance( self._prob.geom, pointcloud.PointCloud ), f"Expected point cloud geometry, found `{type(self._prob.geom)}`." - epsilon = self.epsilon + x, y = self._prob.geom.x, self._prob.geom.y + a, b = self._prob.a, self._prob.b - if kind == "g": - # When seeking to evaluate 2nd potential function, 1st set of potential - # values and support should be used, + if kind == "f": + # When seeking to evaluate 1st potential function, + # the 2nd set of potential values and support should be used, # see proof of Prop. 2 in https://arxiv.org/pdf/2109.12004.pdf - potential = self._f - y = self._prob.geom.x - prob_weights = self._prob.a + potential, arr, weights = self._g, y, b else: - potential = self._g - y = self._prob.geom.y - prob_weights = self._prob.b + potential, arr, weights = self._f, x, a + + potential_xy = jax.tree_util.Partial( + callback, + potential=potential, + y=arr, + weights=weights, + epsilon=self.epsilon, + ) + if not self.is_debiased: + return potential_xy + + ep = EntropicPotentials(self._f_xx, self._g_yy, prob=self._prob) + # switch the order because for `kind='f'` we require `f/x/a` in `other` + # which is accessed when `kind='g'` + potential_other = ep._potential_fn(kind="g" if kind == "f" else "f") - return callback + return lambda x: (potential_xy(x) - potential_other(x)) + + @property + def is_debiased(self) -> bool: + """Whether the entropic map is debiased.""" + return self._f_xx is not None and self._g_yy is not None @property def epsilon(self) -> float: @@ -220,4 +257,4 @@ def epsilon(self) -> float: return self._prob.geom.epsilon def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: - return [self._f, self._g, self._prob], {} + return [self._f, self._g, self._prob, self._f_xx, self._g_yy], {} diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index a362937e4..9d4da7f93 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -40,13 +40,11 @@ class SinkhornDivergenceOutput(NamedTuple): def to_dual_potentials(self) -> "potentials.EntropicPotentials": """Return dual estimators :cite:`pooladian:22`, eq. 8.""" geom_xy, *_ = self.geoms - prob = linear_problem.LinearProblem(geom_xy, a=self.a, b=self.b) - - (f_xy, g_xy), (f_x, g_x), (f_y, g_y) = self.potentials - f = f_xy - f_x - g = g_xy if g_y is None else (g_xy - g_y) # case when `static_b=True` - - return potentials.EntropicPotentials(f, g, prob) + prob_xy = linear_problem.LinearProblem(geom_xy, a=self.a, b=self.b) + (f_xy, g_xy), (f_x, _), (_, g_y) = self.potentials + return potentials.EntropicPotentials( + f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y + ) def sinkhorn_divergence( diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index 02866a130..b5535eb33 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -198,13 +198,11 @@ def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool): actual = 2. * jnp.vdot(v_x, dx) np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) - @pytest.mark.parametrize("static_b", [False, True]) - def test_potentials_sinkhorn_divergence( - self, rng: jnp.ndarray, static_b: bool - ): + @pytest.mark.parametrize("eps", [None, 1e-1, 1e1, 1e2, 1e3]) + def test_potentials_sinkhorn_divergence(self, rng: jnp.ndarray, eps: float): key1, key2, key3 = jax.random.split(rng, 3) n, m, d = 32, 36, 4 - eps, fwd = 1., True + fwd = True mu0, mu1 = -5., 5. x = jax.random.normal(key1, (n, d)) + mu0 @@ -218,6 +216,9 @@ def test_potentials_sinkhorn_divergence( type(geom), x, y, epsilon=eps ).to_dual_potentials() + assert not sink_pots.is_debiased + assert div_pots.is_debiased + sink_dist = sink_pots.distance(x, y) div_dist = div_pots.distance(x, y) assert div_dist < sink_dist @@ -227,4 +228,12 @@ def test_potentials_sinkhorn_divergence( with pytest.raises(AssertionError): np.testing.assert_allclose(sink_points, div_points) - np.testing.assert_allclose(sink_points, div_points, rtol=0.08, atol=0.31) + + # test collapse for high epsilon + if eps is not None and eps >= 1e2: + sink_ref = jnp.repeat(sink_points[:1], n, axis=0) + div_ref = jnp.repeat(div_points[:1], n, axis=0) + + np.testing.assert_allclose(sink_ref, sink_points, rtol=1e-1, atol=1e-1) + with pytest.raises(AssertionError): + np.testing.assert_allclose(div_ref, div_points, rtol=1e-1, atol=1e-1)