Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update mm-sink tutorial and plot #572

Merged
merged 10 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -1016,3 +1016,13 @@ @misc{kassraie:24
title = {Progressive Entropic Optimal Transport Solvers},
year = {2024},
}

@article{lin:22,
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
title={On the complexity of approximating multimarginal optimal transport},
author={Lin, Tianyi and Ho, Nhat and Cuturi, Marco and Jordan, Michael I},
journal={Journal of Machine Learning Research},
volume={23},
number={65},
pages={1--43},
year={2022}
}
259,322 changes: 259,196 additions & 126 deletions docs/tutorials/linear/600_mmsink.ipynb

Large diffs are not rendered by default.

175 changes: 173 additions & 2 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,37 @@
# limitations under the License.
from typing import List, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np
import scipy

from ott.experimental import mmsinkhorn
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein

try:
import matplotlib
import matplotlib.patches as ptc
import matplotlib.pyplot as plt
from matplotlib import animation
matplotlib.rcParams["animation.embed_limit"] = 2 ** 128
zoepiran marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
plt = animation = None

# TODO(michalk8): make sure all outputs conform to a unified transport interface
Transport = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput,
gromov_wasserstein.GWOutput]
gromov_wasserstein.GWOutput, mmsinkhorn.MMSinkhornOutput]
michalk8 marked this conversation as resolved.
Show resolved Hide resolved


def ccworder(A: jnp.ndarray) -> jnp.ndarray:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"""Helper fucntion to plot good looking polygons.

https://stackoverflow.com/questions/5040412/how-to-draw-the-largest-polygon-from-a-set-of-points
"""
A = A - jnp.mean(A, 0)[None]
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
return jnp.argsort(jnp.arctan2(A[:, 1], A[:, 0]))


def bidimensional(x: jnp.ndarray,
Expand Down Expand Up @@ -228,7 +242,8 @@ def animate(
self,
transports: Sequence[Transport],
titles: Optional[Sequence[str]] = None,
frame_rate: float = 10.0
frame_rate: float = 10.0,
**kwargs
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
) -> "animation.FuncAnimation":
"""Make an animation from several transports."""
_ = self(transports[0])
Expand All @@ -245,3 +260,159 @@ def animate(
interval=1000 / frame_rate,
blit=True
)


# TODO(zoepiran): add support for data of d > 2 (PCA on all k's)
class PlotMM(Plot):
"""Plot an optimal transport map for MM-Sinkhorn.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

It enables to either plot or update a plot in a single object, offering the
possibilities to create animations as a
:class:`~matplotlib.animation.FuncAnimation`, which can in turned be saved to
disk at will. There are two design principles here:

#. we do not rely on saving to/loading from disk to create animations
#. we try as much as possible to disentangle the transport problem from
its visualization.

Args:
fig: Specify figure object. Created by default
ax: Specify axes objects. Created by default
threshold: value below which links in transportation matrix won't be
plotted. This value should be negative when using animations.
cmap: color map used to plot line colors.
scale_alpha_by_coupling: use or not the coupling's value as proxy for alpha
alpha: default alpha value for lines.
title: title of the plot.
"""

def __init__(
self,
fig: Optional["plt.Figure"] = None,
ax: Optional["plt.Axes"] = None,
cmap: str = "cividis_r",
scale_alpha_by_coupling: bool = False,
alpha: float = 0.7,
title: Optional[str] = None
):

super().__init__(
fig=fig,
ax=ax,
cmap=cmap,
scale_alpha_by_coupling=scale_alpha_by_coupling,
alpha=alpha,
title=title
)

self._patches = []
self._points = []
self._n, self._k, self._top_k = None, None, None
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, ot: mmsinkhorn.MMSinkhornOutput) -> List["plt.Artist"]:
"""Plot 2-D couplings. does not support higher dimensional."""
n_s = [len(ot.x_s[i]) for i in range(len(ot.x_s))]
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
assert self._n < jnp.prod(
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
jnp.array(n_s)
), "Intended number of tuples too large."

# Extract top_k largest entries in the tensor, and their indices.
val, idx = jax.lax.top_k(ot.tensor.ravel(), self._top_k)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
indices = jnp.unravel_index(idx, n_s)

# Setttings for plot
markers = "svopxdh"
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

for j in range(self._top_k):
points = [ot.x_s[i][indices[i][j], :] for i in range(self._k)]
points = [points[i] for i in ccworder(jnp.array(points))]
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
alphas = np.linspace(0.6, 0.2, self._top_k - self._n)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
alpha = 0.6 if j < self._n else alphas[j - self._n]
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
points = ptc.Polygon(
points,
fill=True,
linewidth=2,
color=self._cmap[j > self._n],
alpha=alpha,
zorder=-j,
)
self._patches.append(self.ax.add_patch(points))

for i in range(self._k):
for j, val in enumerate(ot.x_s[i]):
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
self._points.append(
self.ax.scatter(
val[0],
val[1],
s=200 * ot.a_s[i][j] * len(ot.a_s[i]),
marker=markers[i],
c="black",
linewidth=0.0,
edgecolor=None,
label=str(i)
)
)

if self._title is not None:
self.ax.set_title(self._title)

return self._points + self._patches

def update(
self,
ot: mmsinkhorn.MMSinkhornOutput,
title: Optional[str] = None
) -> List["plt.Artist"]:
"""Update a plot with a transport instance."""
n_s = [len(ot.x_s[i]) for i in range(len(ot.x_s))]
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
assert self._n < jnp.prod(jnp.array(n_s)), \
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"Intended number of tuples too large."
# Extract top_k largest entries in the tensor, and their indices.
val, idx = jax.lax.top_k(ot.tensor.ravel(), self._top_k)
indices = jnp.unravel_index(idx, n_s)

alphas = np.linspace(0.6, 0.2, self._top_k - self._n)
for j in range(self._top_k):
points = [ot.x_s[i][indices[i][j], :] for i in range(self._k)]
# reorder to ensure polygons have maximal area
points = [points[i] for i in ccworder(jnp.array(points))]
alpha = 0.6 if j < self._n else alphas[j - self._n]
self._patches[j].set_xy(points)
self._patches[j].set_color(self._cmap[j > self._n])
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
self._patches[j].set_alpha(alpha)

for i in range(self._k):
for j, val in enumerate(ot.x_s[i]):
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
idx = np.ravel_multi_index((i, j), (self._k, self._n))
self._points[idx].set_offsets(val)

if title is not None:
self.ax.set_title(title)

self.ax.set_ylim(-2.5e-2, 1 + 2.5e-2)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
self.ax.set_xlim(-2.5e-2, 1 + 2.5e-2)
return self._points + self._patches

def animate(
self,
transports: Sequence[Transport],
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
titles: Optional[Sequence[str]] = None,
frame_rate: float = 10.0,
**kwargs
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
) -> "animation.FuncAnimation":
"""Make an animation from several transports."""
self._k = len(transports[0].tensor.shape)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
self._n = transports[0].tensor.shape[0]
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

self._top_k = kwargs.pop("top_k", self._n)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
_ = self(ot=transports[0])

titles = titles if titles is not None else [""] * len(transports)
return animation.FuncAnimation(
self.fig,
lambda i: self.update(ot=transports[i], title=titles[i]),
np.arange(0, len(transports)),
init_func=lambda: self.update(ot=transports[0], title=titles[0]),
interval=1000 / frame_rate,
blit=True,
)
Loading