Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interactive plots using plotly #82

Merged
merged 7 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Python package
on:
push:
branches:
- main
- main
pull_request:
branches:
- main
Expand Down Expand Up @@ -53,7 +53,7 @@ jobs:
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev,datasets]'
pip install '.[dev,datasets,integrations]'

- name: Run the formatter
run: |
Expand Down
1 change: 1 addition & 0 deletions PKGBUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ optdepends=(
python-matplotlib
python-h5py
python-argparse
python-plotly
)
license=('custom')
arch=('any')
Expand Down
12 changes: 4 additions & 8 deletions cebra/integrations/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
#
"""Matplotlib interface to CEBRA."""

import abc
from collections.abc import Iterable
from typing import List, Literal, Optional, Tuple, Union

import matplotlib
import matplotlib.axes
import matplotlib.cm
import matplotlib.colors
Expand Down Expand Up @@ -483,8 +481,8 @@ def plot(self, **kwargs) -> matplotlib.axes.Axes:
self.ax = self._plot_3d(**kwargs)
else:
self.ax = self._plot_2d(**kwargs)

self.ax.set_title(self.title)
if isinstance(self.ax, matplotlib.axes._axes.Axes):
self.ax.set_title(self.title)

return self.ax

Expand Down Expand Up @@ -751,10 +749,8 @@ def plot_overview(
figsize: tuple = (15, 4),
dpi: int = 100,
**kwargs,
) -> Tuple[
matplotlib.figure.Figure,
Tuple[matplotlib.axes.Axes, matplotlib.axes.Axes, matplotlib.axes.Axes],
]:
) -> Tuple[matplotlib.figure.Figure, Tuple[
matplotlib.axes.Axes, matplotlib.axes.Axes, matplotlib.axes.Axes]]:
"""Plot an overview of a trained CEBRA model.

Args:
Expand Down
191 changes: 191 additions & 0 deletions cebra/integrations/plotly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#
# (c) All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE,
# Switzerland, Laboratory of Prof. Mackenzie W. Mathis (UPMWMATHIS) and
# original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023.
#
# Source code:
# https://github.com/AdaptiveMotorControlLab/CEBRA
#
# Please see LICENSE.md for the full license document:
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
#
"""Plotly interface to CEBRA."""
from typing import Optional, Tuple, Union

import matplotlib.colors
import numpy as np
import numpy.typing as npt
import plotly.graph_objects
import torch

from cebra.integrations.matplotlib import _EmbeddingPlot


def _convert_cmap2colorscale(cmap: str, pl_entries: int = 11, rdigits: int = 2):
"""Convert matplotlib colormap to plotly colorscale.

Args:
cmap: A registered colormap name from matplotlib.
pl_entries: Number of colors to use in the plotly colorscale.
rdigits: Number of digits to round the colorscale to.

Returns:
pl_colorscale: List of scaled colors to plot the embeddings
"""
scale = np.linspace(0, 1, pl_entries)
colors = (cmap(scale)[:, :3] * 255).astype(np.uint8)
pl_colorscale = [[round(s, rdigits), f"rgb{tuple(color)}"]
for s, color in zip(scale, colors)]
return pl_colorscale


class _EmbeddingInteractivePlot(_EmbeddingPlot):

def __init__(self, **kwargs):
self.figsize = kwargs.get("figsize")
super().__init__(**kwargs)
self.colorscale = self._define_colorscale(self.cmap)

def _define_ax(self, axis: Optional[plotly.graph_objects.Figure]):
"""Define the axis of the plot.

Args:
axis: Optional axis to create the plot on.

Returns:
axis: The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot.
"""

if axis is None:
self.axis = plotly.graph_objects.Figure(
layout=plotly.graph_objects.Layout(height=100 * self.figsize[0],
width=100 * self.figsize[1]))

else:
self.axis = axis

def _define_colorscale(self, cmap: str):
"""Specify the cmap for plotting the latent space.

Args:
cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A).


Returns:
colorscale: List of scaled colors to plot the embeddings
"""
colorscale = _convert_cmap2colorscale(matplotlib.cm.get_cmap(cmap))

return colorscale

def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure:
"""Plot the embedding in 3d.

Returns:
The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot.
"""

idx1, idx2, idx3 = self.idx_order
data = [
plotly.graph_objects.Scatter3d(
x=self.embedding[:, idx1],
y=self.embedding[:, idx2],
z=self.embedding[:, idx3],
mode="markers",
marker=dict(
size=self.markersize,
opacity=self.alpha,
color=self.embedding_labels,
colorscale=self.colorscale,
),
)
]
col = kwargs.get("col", None)
row = kwargs.get("row", None)

if col is None or row is None:
self.axis.add_trace(data[0])
else:
self.axis.add_trace(data[0], row=row, col=col)

self.axis.update_layout(
template="plotly_white",
showlegend=False,
title=self.title,
)

return self.axis


def plot_embedding_interactive(
embedding: Union[npt.NDArray, torch.Tensor],
embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey",
axis: Optional[plotly.graph_objects.Figure] = None,
markersize: float = 1,
idx_order: Optional[Tuple[int]] = None,
alpha: float = 0.4,
cmap: str = "cool",
title: str = "Embedding",
figsize: Tuple[int] = (5, 5),
dpi: int = 100,
**kwargs,
) -> plotly.graph_objects.Figure:
"""Plot embedding in a 3D dimensional space.

This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of
dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1).

The function makes use of :py:func:`plotly.graph_objs._scatter.Scatter` and parameters from that function can be provided
as part of ``kwargs``.


Args:
embedding: A matrix containing the feature representation computed with CEBRA.
embedding_labels: The labels used to map the data to color. It can be:

* A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous.
* A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color.
axis: Optional axis to create the plot on.
idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D
embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those
dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation
(e.g., (2, 4, 5)).
markersize: The marker size.
alpha: The marker blending, between 0 (transparent) and 1 (opaque).
cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A).
title: The title on top of the embedding.
figsize: Figure width and height in inches.
dpi: Figure resolution.
kwargs: Optional arguments to customize the plots. See :py:func:`plotly.graph_objs._scatter.Scatter` documentation for more
details on which arguments to use.

Returns:
The plotly figure.


Example:

>>> import cebra
>>> import numpy as np
>>> X = np.random.uniform(0, 1, (100, 50))
>>> y = np.random.uniform(0, 10, (100, 5))
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model.fit(X, y)
CEBRA(max_iterations=10)
>>> embedding = cebra_model.transform(X)
>>> cebra_time = np.arange(X.shape[0])
>>> fig = cebra.integrations.plotly.plot_embedding_interactive(embedding, embedding_labels=cebra_time)

"""
return _EmbeddingInteractivePlot(
embedding=embedding,
embedding_labels=embedding_labels,
axis=axis,
idx_order=idx_order,
markersize=markersize,
alpha=alpha,
cmap=cmap,
title=title,
figsize=figsize,
dpi=dpi,
).plot(**kwargs)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ datasets =
integrations =
jupyter
pandas
plotly
docs =
sphinx==5.3
sphinx-gallery==0.10.1
Expand Down
65 changes: 65 additions & 0 deletions tests/test_plotly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import matplotlib
import numpy as np
import plotly.graph_objects as go
import pytest
from plotly.subplots import make_subplots

import cebra.integrations.plotly as cebra_plotly
import cebra.integrations.sklearn.cebra as cebra_sklearn_cebra


@pytest.mark.parametrize("cmap", ["viridis", "plasma", "inferno", "magma"])
def test_colorscale(cmap):
cmap = matplotlib.cm.get_cmap(cmap)
colorscale = cebra_plotly._convert_cmap2colorscale(cmap)
assert isinstance(colorscale, list)


@pytest.mark.parametrize("output_dimension, idx_order", [(8, (2, 3, 4)),
(3, (0, 1, 2))])
def test_plot_embedding(output_dimension, idx_order):
# example dataset
X = np.random.uniform(0, 1, (1000, 50))
y = np.random.uniform(0, 1, (1000,))

# integration tests
model = cebra_sklearn_cebra.CEBRA(max_iterations=10,
batch_size=512,
output_dimension=output_dimension)

model.fit(X)
embedding = model.transform(X)

fig = cebra_plotly.plot_embedding_interactive(embedding=embedding,
embedding_labels=y)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 1

fig.layout = {}
fig.data = []

fig_subplots = make_subplots(
rows=2,
cols=2,
specs=[
[{
"type": "scatter3d"
}, {
"type": "scatter3d"
}],
[{
"type": "scatter3d"
}, {
"type": "scatter3d"
}],
],
)

fig_subplots = cebra_plotly.plot_embedding_interactive(axis=fig_subplots,
embedding=embedding,
embedding_labels=y,
row=1,
col=1)

fig_subplots.data = []
fig_subplots.layout = {}
Loading