Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor sinkhorn_divergence #577

Merged
merged 2 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions docs/tutorials/linear/000_One_Sinkhorn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"# Focus on Sinkhorn\n",
"\n",
"We provide in this example a detailed walk-through some of the functionalities of the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm, including the computation of {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`."
"We provide in this example a detailed walk-through some of the functionalities of the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm and its convenience API packaged in {func}`~ott.solvers.linear.solve`"
]
},
{
Expand Down Expand Up @@ -43,8 +43,7 @@
"from ott import problems\n",
"from ott.geometry import geometry, pointcloud\n",
"from ott.solvers import linear\n",
"from ott.solvers.linear import acceleration, sinkhorn\n",
"from ott.tools.sinkhorn_divergence import sinkhorn_divergence"
"from ott.solvers.linear import acceleration, sinkhorn"
]
},
{
Expand Down Expand Up @@ -117,7 +116,7 @@
"id": "piPqXHXFN3vy"
},
"source": [
"## Pairwise Sinkhorn divergences\n",
"## Pairwise Sinkhorn\n",
"\n",
"Before setting a value for `epsilon`, let's get a feel of what the {class}`~ott.geometry.pointcloud.PointCloud` of embeddings looks like in terms of distances."
]
Expand Down Expand Up @@ -308,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"sink_div_2vmap = jax.jit(\n",
"sink_2vmap = jax.jit(\n",
" jax.vmap(jax.vmap(sink, [0] + [None] * 5, 0), [None, 0] + [None] * 4, 1),\n",
" static_argnums=[4, 5],\n",
")"
Expand All @@ -320,7 +319,7 @@
"id": "_wFVc5qWEAJc"
},
"source": [
"Compute now a pairwise $44 \\times 37$ matrix of Sinkhorn divergences (about 1000 divergences in total). We pick 30 different texts twice."
"Compute now a pairwise $44 \\times 37$ matrix of Sinkhorn distances (about 1000 distances in total). We pick 30 different texts twice."
]
},
{
Expand Down Expand Up @@ -365,7 +364,7 @@
},
"outputs": [],
"source": [
"DIV = sink_div_2vmap(HIST_a, HIST_b, cost, 1, 0, 100)"
"DIV = sink_2vmap(HIST_a, HIST_b, cost, 1, 0, 100)"
]
},
{
Expand All @@ -374,7 +373,7 @@
"id": "DTwxN8_IExxD"
},
"source": [
"We now carry out divergence computations and plot their matrix for various `epsilon`."
"We now carry out distance computations and plot their matrix for various `epsilon`."
]
},
{
Expand All @@ -385,14 +384,12 @@
},
"outputs": [],
"source": [
"DIV, ran_in = [], []\n",
"DIS, ran_in = [], []\n",
"epsilons = [None, 1e-2, 1e-1]\n",
"for epsilon in epsilons:\n",
" tic = time.perf_counter()\n",
" DIV.append(\n",
" sink_div_2vmap(\n",
" HIST_a, HIST_b, cost, epsilon, 0, 100\n",
" ).block_until_ready()\n",
" DIS.append(\n",
" sink_2vmap(HIST_a, HIST_b, cost, epsilon, 0, 100).block_until_ready()\n",
" )\n",
" toc = time.perf_counter()\n",
" ran_in.append(toc - tic)"
Expand Down Expand Up @@ -434,11 +431,11 @@
"fig, axes = plt.subplots(1, 3, figsize=(12, 6))\n",
"fig.tight_layout()\n",
"axes = [axes[0], axes[1], axes[2]]\n",
"vmin = min([jnp.min(div) for div in DIV])\n",
"vmax = max([jnp.max(div) for div in DIV])\n",
"vmin = min([jnp.min(dis) for dis in DIS])\n",
"vmax = max([jnp.max(dis) for dis in DIS])\n",
"\n",
"for epsilon, DIV_, ran_in_, ax_ in zip(epsilons, DIV, ran_in, axes):\n",
" im = ax_.imshow(DIV_, vmin=vmin, vmax=vmax)\n",
"for epsilon, dis, ran_in_, ax_ in zip(epsilons, DIS, ran_in, axes):\n",
" im = ax_.imshow(dis, vmin=vmin, vmax=vmax)\n",
" eps = f\" ({geom.epsilon:.4f})\" if epsilon is None else \"\"\n",
" ax_.set_title(\n",
" r\"$\\varepsilon$ = \" + str(epsilon) + eps + f\"\\n {ran_in_:.2f} s\"\n",
Expand Down Expand Up @@ -480,7 +477,7 @@
"source": [
"epsilon = 1e-2\n",
"# Naive Vmapping\n",
"%time out_1 = DIV = sink_div_2vmap(HIST_a, HIST_b, cost, epsilon, 0, 100).block_until_ready()"
"%time out_1 = sink_2vmap(HIST_a, HIST_b, cost, epsilon, 0, 100).block_until_ready()"
]
},
{
Expand All @@ -506,7 +503,7 @@
],
"source": [
"# Vmapping while forcing the number of iterations to be fixed.\n",
"%time out_2 = DIV = sink_div_2vmap(HIST_a, HIST_b, cost, epsilon, 100, 100).block_until_ready()"
"%time out_2 = sink_2vmap(HIST_a, HIST_b, cost, epsilon, 100, 100).block_until_ready()"
]
},
{
Expand Down
29,353 changes: 14,673 additions & 14,680 deletions docs/tutorials/linear/200_sinkhorn_divergence_gradient_flow.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions docs/tutorials/linear/400_Hessians.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"outputs": [],
"source": [
"def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True) -> float:\n",
" return sinkhorn_divergence.sinkhorn_divergence(\n",
" div, _ = sinkhorn_divergence.sinkhorn_divergence(\n",
" pointcloud.PointCloud,\n",
" x,\n",
" y, # this part defines geometry\n",
Expand All @@ -114,7 +114,8 @@
" \"implicit_diff\": imp_diff.ImplicitDiff() if implicit else None,\n",
" \"use_danskin\": False,\n",
" },\n",
" ).divergence"
" )\n",
" return div"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/neural/000_neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,10 @@
" a = jnp.ones(len(x)) / len(x)\n",
" b = jnp.ones(len(y)) / len(y)\n",
"\n",
" sdiv = sinkhorn_divergence.sinkhorn_divergence(\n",
" sdiv, _ = sinkhorn_divergence.sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b\n",
" )\n",
" return sdiv.divergence"
" return sdiv"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/neural/200_Monge_Gap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,10 @@
"\n",
" @jax.jit\n",
" def fitting_loss(x, y):\n",
" out = sinkhorn_divergence.sinkhorn_divergence(\n",
" div, out = sinkhorn_divergence.sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon_fitting, static_b=True\n",
" )\n",
" return out.divergence, out.n_iters\n",
" return div, out.n_iters\n",
"\n",
" if cost_fn is None:\n",
" regularizer = None\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/neural/300_ENOT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@
" a = jnp.ones(len(x)) / len(x)\n",
" b = jnp.ones(len(y)) / len(y)\n",
"\n",
" sdiv = sinkhorn_divergence.sinkhorn_divergence(\n",
" sdiv, _ = sinkhorn_divergence.sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b\n",
" )\n",
" return sdiv.divergence"
" return sdiv"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion src/ott/tools/progot.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _sinkhorn_divergence(
eps: Optional[float],
**kwargs: Any,
) -> sd.SinkhornDivergenceOutput:
return sd.sinkhorn_divergence(
_, out = sd.sinkhorn_divergence(
pointcloud.PointCloud,
x,
y,
Expand All @@ -423,6 +423,7 @@ def _sinkhorn_divergence(
share_epsilon=False,
sinkhorn_kwargs=kwargs,
)
return out


def _interpolate(
Expand Down
16 changes: 9 additions & 7 deletions src/ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def sinkhorn_divergence(
share_epsilon: bool = True,
symmetric_sinkhorn: bool = True,
**kwargs: Any,
) -> SinkhornDivergenceOutput:
) -> Tuple[jnp.ndarray, SinkhornDivergenceOutput]:
"""Compute Sinkhorn divergence defined by a geometry, weights, parameters.

Args:
Expand Down Expand Up @@ -115,7 +115,7 @@ def sinkhorn_divergence(
geometry.

Returns:
Sinkhorn divergence value, three pairs of potentials, three costs.
The Sinkhorn divergence value, and output object detailing computations.
"""
geoms = geom.prepare_divergences(*args, static_b=static_b, **kwargs)
geom_xy, geom_x, geom_y, *_ = geoms + (None,) * 3
Expand All @@ -129,7 +129,7 @@ def sinkhorn_divergence(

a = jnp.ones(num_a) / num_a if a is None else a
b = jnp.ones(num_b) / num_b if b is None else b
return _sinkhorn_divergence(
out = _sinkhorn_divergence(
geom_xy,
geom_x,
geom_y,
Expand All @@ -138,6 +138,7 @@ def sinkhorn_divergence(
symmetric_sinkhorn=symmetric_sinkhorn,
**sinkhorn_kwargs
)
return out.divergence, out


def _sinkhorn_divergence(
Expand Down Expand Up @@ -172,7 +173,7 @@ def _sinkhorn_divergence(
kwargs: Keyword arguments to :func:`~ott.solvers.linear.solve`.

Returns:
SinkhornDivergenceOutput named tuple.
The output object
"""
kwargs_symmetric = kwargs.copy()
is_low_rank = kwargs.get("rank", -1) > 0
Expand Down Expand Up @@ -317,7 +318,7 @@ def segment_sinkhorn_divergence(
instance entropy regularization float, scheduler or normalization.

Returns:
An array of Sinkhorn divergences for each segment.
An array of Sinkhorn divergence values for each segment.
"""
# instantiate padding vector
dim = x.shape[1]
Expand All @@ -335,7 +336,7 @@ def eval_fn(
) -> float:
mask_x = padded_weight_x > 0.0
mask_y = padded_weight_y > 0.0
return sinkhorn_divergence(
div, _ = sinkhorn_divergence(
pointcloud.PointCloud,
padded_x,
padded_y,
Expand All @@ -349,7 +350,8 @@ def eval_fn(
src_mask=mask_x,
tgt_mask=mask_y,
**kwargs
).divergence
)
return div

return segment._segment_interface(
x,
Expand Down
4 changes: 2 additions & 2 deletions tests/neural/methods/monge_gap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ def fitting_loss(
mapped_samples: jnp.ndarray,
) -> Optional[float]:
r"""Sinkhorn divergence fitting loss."""
div = sinkhorn_divergence.sinkhorn_divergence(
div, _ = sinkhorn_divergence.sinkhorn_divergence(
pointcloud.PointCloud,
x=samples,
y=mapped_samples,
).divergence
)
return div, None

def regularizer(x, y):
Expand Down
17 changes: 9 additions & 8 deletions tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ def test_entropic_potentials_sqpnorm(

if forward:
z = potentials.transport(x_test, forward=forward)
div = sdiv(z, y).divergence
div, _ = sdiv(z, y)
else:
z = potentials.transport(y_test, forward=forward)
div = sdiv(x, z).divergence
div, _ = sdiv(x, z)

div_0 = sdiv(x, y).divergence
div_0, _ = sdiv(x, y)
mult = 0.1 if p > 1.0 else 0.25
# check we have moved points much closer to target
assert div < mult * div_0
Expand Down Expand Up @@ -210,12 +210,12 @@ def test_entropic_potentials_pnorm(
else:
if forward:
z = potentials.transport(x_test, forward=forward)
div = sdiv(z, y).divergence
div, _ = sdiv(z, y)
else:
z = potentials.transport(y_test, forward=forward)
div = sdiv(x, z).divergence
div, _ = sdiv(x, z)

div_0 = sdiv(x, y).divergence
div_0, _ = sdiv(x, y)
# check we have moved points much closer to target
assert div < 0.1 * div_0

Expand Down Expand Up @@ -255,9 +255,10 @@ def test_potentials_sinkhorn_divergence(self, rng: jax.Array, eps: float):
prob = linear_problem.LinearProblem(geom)

sink_pots = sinkhorn.Sinkhorn()(prob).to_dual_potentials()
div_pots = sinkhorn_divergence.sinkhorn_divergence(
_, out = sinkhorn_divergence.sinkhorn_divergence(
type(geom), x, y, epsilon=eps
).to_dual_potentials()
)
div_pots = out.to_dual_potentials()

assert not sink_pots.is_debiased
assert div_pots.is_debiased
Expand Down
Loading