Skip to content

Commit

Permalink
Remove matplotlib dependency (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 authored Feb 20, 2023
1 parent 8e3fc27 commit 4ce187e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 29 deletions.
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ channels:
- conda-forge
dependencies:
- python>=3.8
- ott-jax>=0.3.1
- ott-jax>=0.4
- matplotlib-base>=3.0.0
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ dependencies = [
"jaxopt>=0.5.5",
# https://github.com/google/jax/discussions/9951#discussioncomment-3017784
"numpy>=1.18.4, !=1.23.0",
"matplotlib>=3.0.0",
"flax>=0.5.2",
"optax>=0.1.1",
"scipy>=1.7.0",
Expand Down Expand Up @@ -83,6 +82,7 @@ docs = [
"sphinxcontrib-bibtex>=2.5.0",
"sphinxcontrib-spelling>=7.7.0",
"myst-nb>=0.17.1",
"matplotlib>=3.0.0",
]

[tool.setuptools]
Expand All @@ -101,8 +101,9 @@ profile = "black"
include_trailing_comma = true
multi_line_output = 3
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"]
known_numeric = ["numpy", "scipy", "pandas", "sklearn", "jax", "flax", "optax", "torch"]
known_plotting = ["matplotlib", "mpl_toolkits", "seaborn"]
# also contains what we import in notebooks
known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "torch", "ot", "torchvision", "pandas", "sklearn"]
known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]

[tool.pytest.ini_options]
minversion = "6.0"
Expand Down Expand Up @@ -190,6 +191,7 @@ legacy_tox_ini = """
[testenv:lint-docs]
description = Lint the documentation.
deps =
extras = docs
allowlist_externals =
rm
Expand All @@ -207,6 +209,7 @@ legacy_tox_ini = """
[testenv:build-docs]
description = Build the documentation.
use_develop = true
deps =
extras = docs
allowlist_externals = sphinx-build
commands =
Expand All @@ -216,6 +219,7 @@ legacy_tox_ini = """
[testenv:clean-docs]
description = Remove the documentation.
deps =
skip_install = true
changedir = {tox_root}/docs
allowlist_externals = make
Expand Down
25 changes: 17 additions & 8 deletions src/ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from ott.problems.linear import linear_problem

if TYPE_CHECKING:
from ott.geometry import costs

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
except ImportError:
mpl = plt = None

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]

Expand Down Expand Up @@ -178,10 +182,10 @@ def plot_ot_map(
source: jnp.ndarray,
target: jnp.ndarray,
forward: bool = True,
ax: Optional[matplotlib.axes.Axes] = None,
ax: Optional["plt.Axes"] = None,
legend_kwargs: Optional[Dict[str, Any]] = None,
scatter_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
) -> Tuple["plt.Figure", "plt.Axes"]:
"""Plot data and learned optimal transport map.
Args:
Expand All @@ -190,12 +194,17 @@ def plot_ot_map(
forward: use the forward map from the potentials
if ``True``, otherwise use the inverse map
ax: axis to add the plot to
scatter_kwargs: additional kwargs passed into :meth:`~matplotlib.axes.Axes.scatter`
legend_kwargs: additional kwargs passed into :meth:`~matplotlib.axes.Axes.legend`
scatter_kwargs: additional kwargs passed into
:meth:`~matplotlib.axes.Axes.scatter`
legend_kwargs: additional kwargs passed into
:meth:`~matplotlib.axes.Axes.legend`
Returns:
matplotlib figure and axis with the plots
"""
if mpl is None:
raise RuntimeError("Please install `matplotlib` first.")

if scatter_kwargs is None:
scatter_kwargs = {'alpha': 0.5}
if legend_kwargs is None:
Expand Down Expand Up @@ -263,12 +272,12 @@ def plot_potential(
self,
forward: bool = True,
quantile: float = 0.05,
ax: Optional[matplotlib.axes.Axes] = None,
ax: Optional["mpl.axes.Axes"] = None,
x_bounds: Tuple[float, float] = (-6, 6),
y_bounds: Tuple[float, float] = (-6, 6),
num_grid: int = 50,
contourf_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
) -> Tuple["mpl.figure.Figure", "mpl.axes.Axes"]:
"""Plot the potential.
Args:
Expand Down
44 changes: 27 additions & 17 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,28 @@
from typing import List, Optional, Sequence, Tuple, Union

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import scipy
from matplotlib import animation

from ott import utils
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein

try:
import matplotlib.pyplot as plt
from matplotlib import animation
except ImportError:
plt = animation = None

# TODO(michalk8): make sure all outputs conform to a unified transport interface
Transport = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput,
gromov_wasserstein.GWOutput]


def bidimensional(x: jnp.ndarray,
y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Apply PCA to reduce to bimensional data."""
"""Apply PCA to reduce to bi-dimensional data."""
if x.shape[1] < 3:
return x, y

Expand All @@ -44,25 +48,31 @@ def bidimensional(x: jnp.ndarray,


class Plot:
"""Plot an optimal transport map between two point clouds.
"""Plot an optimal transport map between two \
:class:`PointClouds <ott.geometry.pointcloud.PointCloud>`.
It enables to either plot or update a plot in a single object, offering the
possibilities to create animations as matplotlib.animation.FuncAnimation,
which can in turned be saved to disk at will. There are two design principles
here: 1) we do not rely on saving to/loading from disk to create animations
2) we try as much as possible to disentangle the transport problem(s) from
its visualization(s).
possibilities to create animations as a
:class:`~matplotlib.animation.FuncAnimation`, which can in turned be saved to
disk at will. There are two design principles here:
#. we do not rely on saving to/loading from disk to create animations
#. we try as much as possible to disentangle the transport problem from
its visualization.
"""

def __init__(
self,
fig: Optional[plt.Figure] = None,
ax: Optional[plt.Axes] = None,
fig: Optional["plt.Figure"] = None,
ax: Optional["plt.Axes"] = None,
cost_threshold: float = -1.0, # should be negative for animations.
scale: int = 200,
show_lines: bool = True,
cmap: str = 'cool'
):
if plt is None:
raise RuntimeError("Please install `matplotlib` first.")

if ax is None and fig is None:
fig, ax = plt.subplots()
elif fig is None:
Expand Down Expand Up @@ -102,7 +112,7 @@ def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray):
result.append((xy[i, [0, 2]], xy[i, [1, 3]], strength))
return result

def __call__(self, ot: Transport) -> List[plt.Artist]:
def __call__(self, ot: Transport) -> List["plt.Artist"]:
"""Plot 2-D couplings. Projects via PCA if data is higher dimensional."""
x, y, sx, sy = self._scatter(ot)
self._points_x = self.ax.scatter(
Expand Down Expand Up @@ -130,7 +140,7 @@ def __call__(self, ot: Transport) -> List[plt.Artist]:
self._lines.append(line)
return [self._points_x, self._points_y] + self._lines

def update(self, ot: Transport) -> List[plt.Artist]:
def update(self, ot: Transport) -> List["plt.Artist"]:
"""Update a plot with a transport instance."""
x, y, _, _ = self._scatter(ot)
self._points_x.set_offsets(x)
Expand Down Expand Up @@ -168,7 +178,7 @@ def animate(
self,
transports: Sequence[Transport],
frame_rate: float = 10.0
) -> animation.FuncAnimation:
) -> "animation.FuncAnimation":
"""Make an animation from several transports."""
_ = self(transports[0])
return animation.FuncAnimation(
Expand All @@ -182,13 +192,13 @@ def animate(


def _barycenters(
ax: plt.Axes,
ax: "plt.Axes",
y: jnp.ndarray,
a: jnp.ndarray,
b: jnp.ndarray,
matrix: jnp.ndarray,
scale: int = 200
):
) -> None:
"""Plot 2-D sinkhorn barycenters."""
sa, sb = jnp.min(a) / scale, jnp.min(b) / scale
ax.scatter(*y.T, s=b / sb, edgecolors='k', marker='X', label='y')
Expand All @@ -202,7 +212,7 @@ def barycentric_projections(
a: jnp.ndarray = None,
b: jnp.ndarray = None,
matrix: jnp.ndarray = None,
ax: Optional[plt.Axes] = None,
ax: Optional["plt.Axes"] = None,
**kwargs
):
"""Plot the barycenters, from the Transport object or from arguments."""
Expand Down

0 comments on commit 4ce187e

Please sign in to comment.