-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add interactive plots using plotly (#82)
* Added interactive plots * Add plotly integration * Fix during testing, add header * Add plotly to integrations * Delete plotly import from matplotlib * Add plotly to optdepends * Add tests and update Makefile * Checking whether pass test with required plotly * Final check * Delete redundent code * Delete plotly from essential dependencies * Final changes and test to interactive plot * Fix docs * apply pre-commit * Adapt docstring * Add integrations to test dependencies * fixed an issue with figsize * Change default params, delete typo * Fix example in plotly --------- Co-authored-by: nastya236 <[email protected]>
- Loading branch information
Showing
6 changed files
with
264 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# | ||
# (c) All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE, | ||
# Switzerland, Laboratory of Prof. Mackenzie W. Mathis (UPMWMATHIS) and | ||
# original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023. | ||
# | ||
# Source code: | ||
# https://github.com/AdaptiveMotorControlLab/CEBRA | ||
# | ||
# Please see LICENSE.md for the full license document: | ||
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md | ||
# | ||
"""Plotly interface to CEBRA.""" | ||
from typing import Optional, Tuple, Union | ||
|
||
import matplotlib.colors | ||
import numpy as np | ||
import numpy.typing as npt | ||
import plotly.graph_objects | ||
import torch | ||
|
||
from cebra.integrations.matplotlib import _EmbeddingPlot | ||
|
||
|
||
def _convert_cmap2colorscale(cmap: str, pl_entries: int = 11, rdigits: int = 2): | ||
"""Convert matplotlib colormap to plotly colorscale. | ||
Args: | ||
cmap: A registered colormap name from matplotlib. | ||
pl_entries: Number of colors to use in the plotly colorscale. | ||
rdigits: Number of digits to round the colorscale to. | ||
Returns: | ||
pl_colorscale: List of scaled colors to plot the embeddings | ||
""" | ||
scale = np.linspace(0, 1, pl_entries) | ||
colors = (cmap(scale)[:, :3] * 255).astype(np.uint8) | ||
pl_colorscale = [[round(s, rdigits), f"rgb{tuple(color)}"] | ||
for s, color in zip(scale, colors)] | ||
return pl_colorscale | ||
|
||
|
||
class _EmbeddingInteractivePlot(_EmbeddingPlot): | ||
|
||
def __init__(self, **kwargs): | ||
self.figsize = kwargs.get("figsize") | ||
super().__init__(**kwargs) | ||
self.colorscale = self._define_colorscale(self.cmap) | ||
|
||
def _define_ax(self, axis: Optional[plotly.graph_objects.Figure]): | ||
"""Define the axis of the plot. | ||
Args: | ||
axis: Optional axis to create the plot on. | ||
Returns: | ||
axis: The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot. | ||
""" | ||
|
||
if axis is None: | ||
self.axis = plotly.graph_objects.Figure( | ||
layout=plotly.graph_objects.Layout(height=100 * self.figsize[0], | ||
width=100 * self.figsize[1])) | ||
|
||
else: | ||
self.axis = axis | ||
|
||
def _define_colorscale(self, cmap: str): | ||
"""Specify the cmap for plotting the latent space. | ||
Args: | ||
cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). | ||
Returns: | ||
colorscale: List of scaled colors to plot the embeddings | ||
""" | ||
colorscale = _convert_cmap2colorscale(matplotlib.cm.get_cmap(cmap)) | ||
|
||
return colorscale | ||
|
||
def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure: | ||
"""Plot the embedding in 3d. | ||
Returns: | ||
The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot. | ||
""" | ||
|
||
idx1, idx2, idx3 = self.idx_order | ||
data = [ | ||
plotly.graph_objects.Scatter3d( | ||
x=self.embedding[:, idx1], | ||
y=self.embedding[:, idx2], | ||
z=self.embedding[:, idx3], | ||
mode="markers", | ||
marker=dict( | ||
size=self.markersize, | ||
opacity=self.alpha, | ||
color=self.embedding_labels, | ||
colorscale=self.colorscale, | ||
), | ||
) | ||
] | ||
col = kwargs.get("col", None) | ||
row = kwargs.get("row", None) | ||
|
||
if col is None or row is None: | ||
self.axis.add_trace(data[0]) | ||
else: | ||
self.axis.add_trace(data[0], row=row, col=col) | ||
|
||
self.axis.update_layout( | ||
template="plotly_white", | ||
showlegend=False, | ||
title=self.title, | ||
) | ||
|
||
return self.axis | ||
|
||
|
||
def plot_embedding_interactive( | ||
embedding: Union[npt.NDArray, torch.Tensor], | ||
embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey", | ||
axis: Optional[plotly.graph_objects.Figure] = None, | ||
markersize: float = 1, | ||
idx_order: Optional[Tuple[int]] = None, | ||
alpha: float = 0.4, | ||
cmap: str = "cool", | ||
title: str = "Embedding", | ||
figsize: Tuple[int] = (5, 5), | ||
dpi: int = 100, | ||
**kwargs, | ||
) -> plotly.graph_objects.Figure: | ||
"""Plot embedding in a 3D dimensional space. | ||
This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of | ||
dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1). | ||
The function makes use of :py:func:`plotly.graph_objs._scatter.Scatter` and parameters from that function can be provided | ||
as part of ``kwargs``. | ||
Args: | ||
embedding: A matrix containing the feature representation computed with CEBRA. | ||
embedding_labels: The labels used to map the data to color. It can be: | ||
* A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous. | ||
* A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color. | ||
axis: Optional axis to create the plot on. | ||
idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D | ||
embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those | ||
dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation | ||
(e.g., (2, 4, 5)). | ||
markersize: The marker size. | ||
alpha: The marker blending, between 0 (transparent) and 1 (opaque). | ||
cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). | ||
title: The title on top of the embedding. | ||
figsize: Figure width and height in inches. | ||
dpi: Figure resolution. | ||
kwargs: Optional arguments to customize the plots. See :py:func:`plotly.graph_objs._scatter.Scatter` documentation for more | ||
details on which arguments to use. | ||
Returns: | ||
The plotly figure. | ||
Example: | ||
>>> import cebra | ||
>>> import numpy as np | ||
>>> X = np.random.uniform(0, 1, (100, 50)) | ||
>>> y = np.random.uniform(0, 10, (100, 5)) | ||
>>> cebra_model = cebra.CEBRA(max_iterations=10) | ||
>>> cebra_model.fit(X, y) | ||
CEBRA(max_iterations=10) | ||
>>> embedding = cebra_model.transform(X) | ||
>>> cebra_time = np.arange(X.shape[0]) | ||
>>> fig = cebra.integrations.plotly.plot_embedding_interactive(embedding, embedding_labels=cebra_time) | ||
""" | ||
return _EmbeddingInteractivePlot( | ||
embedding=embedding, | ||
embedding_labels=embedding_labels, | ||
axis=axis, | ||
idx_order=idx_order, | ||
markersize=markersize, | ||
alpha=alpha, | ||
cmap=cmap, | ||
title=title, | ||
figsize=figsize, | ||
dpi=dpi, | ||
).plot(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ datasets = | |
integrations = | ||
jupyter | ||
pandas | ||
plotly | ||
docs = | ||
sphinx==5.3 | ||
sphinx-gallery==0.10.1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import matplotlib | ||
import numpy as np | ||
import plotly.graph_objects as go | ||
import pytest | ||
from plotly.subplots import make_subplots | ||
|
||
import cebra.integrations.plotly as cebra_plotly | ||
import cebra.integrations.sklearn.cebra as cebra_sklearn_cebra | ||
|
||
|
||
@pytest.mark.parametrize("cmap", ["viridis", "plasma", "inferno", "magma"]) | ||
def test_colorscale(cmap): | ||
cmap = matplotlib.cm.get_cmap(cmap) | ||
colorscale = cebra_plotly._convert_cmap2colorscale(cmap) | ||
assert isinstance(colorscale, list) | ||
|
||
|
||
@pytest.mark.parametrize("output_dimension, idx_order", [(8, (2, 3, 4)), | ||
(3, (0, 1, 2))]) | ||
def test_plot_embedding(output_dimension, idx_order): | ||
# example dataset | ||
X = np.random.uniform(0, 1, (1000, 50)) | ||
y = np.random.uniform(0, 1, (1000,)) | ||
|
||
# integration tests | ||
model = cebra_sklearn_cebra.CEBRA(max_iterations=10, | ||
batch_size=512, | ||
output_dimension=output_dimension) | ||
|
||
model.fit(X) | ||
embedding = model.transform(X) | ||
|
||
fig = cebra_plotly.plot_embedding_interactive(embedding=embedding, | ||
embedding_labels=y) | ||
assert isinstance(fig, go.Figure) | ||
assert len(fig.data) == 1 | ||
|
||
fig.layout = {} | ||
fig.data = [] | ||
|
||
fig_subplots = make_subplots( | ||
rows=2, | ||
cols=2, | ||
specs=[ | ||
[{ | ||
"type": "scatter3d" | ||
}, { | ||
"type": "scatter3d" | ||
}], | ||
[{ | ||
"type": "scatter3d" | ||
}, { | ||
"type": "scatter3d" | ||
}], | ||
], | ||
) | ||
|
||
fig_subplots = cebra_plotly.plot_embedding_interactive(axis=fig_subplots, | ||
embedding=embedding, | ||
embedding_labels=y, | ||
row=1, | ||
col=1) | ||
|
||
fig_subplots.data = [] | ||
fig_subplots.layout = {} |