Skip to content

Commit

Permalink
Fix/debiased entropic map (#238)
Browse files Browse the repository at this point in the history
* [ci skip] Add `environment.yml` for `binder`

* Fix debiasing in `EntropicMap`

* Test for collapse in non-debiased case

* [ci skip] Add assertion
  • Loading branch information
michalk8 authored Jan 28, 2023
1 parent fe882a7 commit 08491f0
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 43 deletions.
97 changes: 67 additions & 30 deletions src/ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Sequence, Tuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Literal,
Optional,
Sequence,
Tuple,
)

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -154,70 +163,98 @@ class EntropicPotentials(DualPotentials):
"""Dual potential functions from finite samples :cite:`pooladian:21`.
Args:
f: The first dual potential vector of shape ``[n,]``.
g: The second dual potential vector of shape ``[m,]``.
f_xy: The first dual potential vector of shape ``[n,]``.
g_xy: The second dual potential vector of shape ``[m,]``.
prob: Linear problem with :class:`~ott.geometry.pointcloud.PointCloud`
geometry that was used to compute the dual potentials using, e.g.,
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
f_xx: The first dual potential vector of shape ``[n,]`` used for debiasing
:cite:`pooladian:22`.
g_yy: The second dual potential vector of shape ``[m,]`` used for debiasing.
"""

def __init__(
self,
f: jnp.ndarray,
g: jnp.ndarray,
f_xy: jnp.ndarray,
g_xy: jnp.ndarray,
prob: linear_problem.LinearProblem,
f_xx: Optional[jnp.ndarray] = None,
g_yy: Optional[jnp.ndarray] = None,
):
# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g, cost_fn=prob.geom.cost_fn, corr=False)
super().__init__(f_xy, g_xy, cost_fn=prob.geom.cost_fn, corr=False)
self._prob = prob
self._f_xx = f_xx
self._g_yy = g_yy

@property
def f(self) -> Potential_t:
return self._create_potential_function(kind="f")
return self._potential_fn(kind="f")

@property
def g(self) -> Potential_t:
return self._create_potential_function(kind="g")
return self._potential_fn(kind="g")

def _create_potential_function(
self, *, kind: Literal["f", "g"]
) -> Potential_t:
def _potential_fn(self, *, kind: Literal["f", "g"]) -> Potential_t:
from ott.geometry import pointcloud

def callback(x: jnp.ndarray) -> float:
cost = pointcloud.PointCloud(
jnp.atleast_2d(x),
y,
cost_fn=self.cost_fn,
).cost_matrix
def callback(
x: jnp.ndarray,
*,
potential: jnp.ndarray,
y: jnp.ndarray,
weights: jnp.ndarray,
epsilon: float,
) -> float:
x = jnp.atleast_2d(x)
assert x.shape[-1] == y.shape[-1], (x.shape, y.shape)
geom = pointcloud.PointCloud(x, y, cost_fn=self.cost_fn)
cost = geom.cost_matrix
z = (potential - cost) / epsilon
lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1)
lse = -epsilon * jsp.special.logsumexp(z, b=weights, axis=-1)
return jnp.squeeze(lse)

assert isinstance(
self._prob.geom, pointcloud.PointCloud
), f"Expected point cloud geometry, found `{type(self._prob.geom)}`."
epsilon = self.epsilon
x, y = self._prob.geom.x, self._prob.geom.y
a, b = self._prob.a, self._prob.b

if kind == "g":
# When seeking to evaluate 2nd potential function, 1st set of potential
# values and support should be used,
if kind == "f":
# When seeking to evaluate 1st potential function,
# the 2nd set of potential values and support should be used,
# see proof of Prop. 2 in https://arxiv.org/pdf/2109.12004.pdf
potential = self._f
y = self._prob.geom.x
prob_weights = self._prob.a
potential, arr, weights = self._g, y, b
else:
potential = self._g
y = self._prob.geom.y
prob_weights = self._prob.b
potential, arr, weights = self._f, x, a

potential_xy = jax.tree_util.Partial(
callback,
potential=potential,
y=arr,
weights=weights,
epsilon=self.epsilon,
)
if not self.is_debiased:
return potential_xy

ep = EntropicPotentials(self._f_xx, self._g_yy, prob=self._prob)
# switch the order because for `kind='f'` we require `f/x/a` in `other`
# which is accessed when `kind='g'`
potential_other = ep._potential_fn(kind="g" if kind == "f" else "f")

return callback
return lambda x: (potential_xy(x) - potential_other(x))

@property
def is_debiased(self) -> bool:
"""Whether the entropic map is debiased."""
return self._f_xx is not None and self._g_yy is not None

@property
def epsilon(self) -> float:
"""Entropy regularizer."""
return self._prob.geom.epsilon

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self._prob], {}
return [self._f, self._g, self._prob, self._f_xx, self._g_yy], {}
12 changes: 5 additions & 7 deletions src/ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ class SinkhornDivergenceOutput(NamedTuple):
def to_dual_potentials(self) -> "potentials.EntropicPotentials":
"""Return dual estimators :cite:`pooladian:22`, eq. 8."""
geom_xy, *_ = self.geoms
prob = linear_problem.LinearProblem(geom_xy, a=self.a, b=self.b)

(f_xy, g_xy), (f_x, g_x), (f_y, g_y) = self.potentials
f = f_xy - f_x
g = g_xy if g_y is None else (g_xy - g_y) # case when `static_b=True`

return potentials.EntropicPotentials(f, g, prob)
prob_xy = linear_problem.LinearProblem(geom_xy, a=self.a, b=self.b)
(f_xy, g_xy), (f_x, _), (_, g_y) = self.potentials
return potentials.EntropicPotentials(
f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y
)


def sinkhorn_divergence(
Expand Down
21 changes: 15 additions & 6 deletions tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,11 @@ def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool):
actual = 2. * jnp.vdot(v_x, dx)
np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("static_b", [False, True])
def test_potentials_sinkhorn_divergence(
self, rng: jnp.ndarray, static_b: bool
):
@pytest.mark.parametrize("eps", [None, 1e-1, 1e1, 1e2, 1e3])
def test_potentials_sinkhorn_divergence(self, rng: jnp.ndarray, eps: float):
key1, key2, key3 = jax.random.split(rng, 3)
n, m, d = 32, 36, 4
eps, fwd = 1., True
fwd = True
mu0, mu1 = -5., 5.

x = jax.random.normal(key1, (n, d)) + mu0
Expand All @@ -218,6 +216,9 @@ def test_potentials_sinkhorn_divergence(
type(geom), x, y, epsilon=eps
).to_dual_potentials()

assert not sink_pots.is_debiased
assert div_pots.is_debiased

sink_dist = sink_pots.distance(x, y)
div_dist = div_pots.distance(x, y)
assert div_dist < sink_dist
Expand All @@ -227,4 +228,12 @@ def test_potentials_sinkhorn_divergence(

with pytest.raises(AssertionError):
np.testing.assert_allclose(sink_points, div_points)
np.testing.assert_allclose(sink_points, div_points, rtol=0.08, atol=0.31)

# test collapse for high epsilon
if eps is not None and eps >= 1e2:
sink_ref = jnp.repeat(sink_points[:1], n, axis=0)
div_ref = jnp.repeat(div_points[:1], n, axis=0)

np.testing.assert_allclose(sink_ref, sink_points, rtol=1e-1, atol=1e-1)
with pytest.raises(AssertionError):
np.testing.assert_allclose(div_ref, div_points, rtol=1e-1, atol=1e-1)

0 comments on commit 08491f0

Please sign in to comment.