From 5886495eebe536f0442260aa9bdfebee3a9f87bf Mon Sep 17 00:00:00 2001 From: nastya236 Date: Sun, 17 Sep 2023 01:12:28 +0200 Subject: [PATCH 1/5] Add interactive plots using plotly - Added interactive plots - Add plotly integration - Fix during testing, add header - Add plotly to integrations - Delete plotly import from matplotlib - Add plotly to optdepends - Add tests and update Makefile - Checking whether pass test with required plotly - Final check - Delete redundent code - Delete plotly from essential dependencies - Final changes and test to interactive plot - Fix docs - apply pre-commit - Adapt docstring --- PKGBUILD | 1 + cebra/integrations/matplotlib.py | 12 +- cebra/integrations/plotly.py | 191 +++++++++++++++++++++++++++++++ setup.cfg | 1 + tests/test_plotly.py | 65 +++++++++++ 5 files changed, 262 insertions(+), 8 deletions(-) create mode 100644 cebra/integrations/plotly.py create mode 100644 tests/test_plotly.py diff --git a/PKGBUILD b/PKGBUILD index 8bf9a603..5acf7d6f 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -24,6 +24,7 @@ optdepends=( python-matplotlib python-h5py python-argparse + python-plotly ) license=('custom') arch=('any') diff --git a/cebra/integrations/matplotlib.py b/cebra/integrations/matplotlib.py index 4ef386d6..839cf42a 100644 --- a/cebra/integrations/matplotlib.py +++ b/cebra/integrations/matplotlib.py @@ -10,12 +10,10 @@ # https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md # """Matplotlib interface to CEBRA.""" - import abc from collections.abc import Iterable from typing import List, Literal, Optional, Tuple, Union -import matplotlib import matplotlib.axes import matplotlib.cm import matplotlib.colors @@ -483,8 +481,8 @@ def plot(self, **kwargs) -> matplotlib.axes.Axes: self.ax = self._plot_3d(**kwargs) else: self.ax = self._plot_2d(**kwargs) - - self.ax.set_title(self.title) + if isinstance(self.ax, matplotlib.axes._axes.Axes): + self.ax.set_title(self.title) return self.ax @@ -736,10 +734,8 @@ def plot_overview( figsize: tuple = (15, 4), dpi: int = 100, **kwargs, -) -> Tuple[ - matplotlib.figure.Figure, - Tuple[matplotlib.axes.Axes, matplotlib.axes.Axes, matplotlib.axes.Axes], -]: +) -> Tuple[matplotlib.figure.Figure, Tuple[ + matplotlib.axes.Axes, matplotlib.axes.Axes, matplotlib.axes.Axes]]: """Plot an overview of a trained CEBRA model. Args: diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py new file mode 100644 index 00000000..167a340a --- /dev/null +++ b/cebra/integrations/plotly.py @@ -0,0 +1,191 @@ +# +# (c) All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE, +# Switzerland, Laboratory of Prof. Mackenzie W. Mathis (UPMWMATHIS) and +# original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023. +# +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md +# +"""Plotly interface to CEBRA.""" +from typing import Optional, Tuple, Union + +import matplotlib.colors +import numpy as np +import numpy.typing as npt +import plotly.graph_objects +import torch + +from cebra.integrations.matplotlib import _EmbeddingPlot + + +def _convert_cmap2colorscale(cmap: str, pl_entries: int = 11, rdigits: int = 2): + """Convert matplotlib colormap to plotly colorscale. + + Args: + cmap: A registered colormap name from matplotlib. + pl_entries: Number of colors to use in the plotly colorscale. + rdigits: Number of digits to round the colorscale to. + + Returns: + pl_colorscale: List of scaled colors to plot the embeddings + """ + scale = np.linspace(0, 1, pl_entries) + colors = (cmap(scale)[:, :3] * 255).astype(np.uint8) + pl_colorscale = [[round(s, rdigits), f"rgb{tuple(color)}"] + for s, color in zip(scale, colors)] + return pl_colorscale + + +class _EmbeddingInteractivePlot(_EmbeddingPlot): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.colorscale = self._define_colorscale(self.cmap) + + def _define_ax(self, axis: Optional[plotly.graph_objects.Figure]): + """Define the axis of the plot. + + Args: + axis: Optional axis to create the plot on. + + Returns: + axis: The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot. + """ + + if axis is None: + print(self.figsize[0]) + self.axis = plotly.graph_objects.Figure( + layout=plotly.graph_objects.Layout(height=100 * self.figsize[0], + width=100 * self.figsize[1])) + + else: + self.axis = axis + + def _define_colorscale(self, cmap: str): + """Specify the cmap for plotting the latent space. + + Args: + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). + + + Returns: + colorscale: List of scaled colors to plot the embeddings + """ + colorscale = _convert_cmap2colorscale(matplotlib.cm.get_cmap(cmap)) + + return colorscale + + def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure: + """Plot the embedding in 3d. + + Returns: + The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot. + """ + + idx1, idx2, idx3 = self.idx_order + data = [ + plotly.graph_objects.Scatter3d( + x=self.embedding[:, idx1], + y=self.embedding[:, idx2], + z=self.embedding[:, idx3], + mode="markers", + marker=dict( + size=self.markersize, + opacity=self.alpha, + color=self.embedding_labels, + colorscale=self.colorscale, + ), + ) + ] + col = kwargs.get("col", None) + row = kwargs.get("row", None) + + if col is None or row is None: + self.axis.add_trace(data[0]) + else: + self.axis.add_trace(data[0], row=row, col=col) + + self.axis.update_layout( + template="plotly_white", + showlegend=False, + title=self.title, + ) + + return self.axis + + +def plot_embedding_interactive( + embedding: Union[npt.NDArray, torch.Tensor], + embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey", + axis: Optional[plotly.graph_objects.Figure] = None, + markersize: float = 0.05, + idx_order: Optional[Tuple[int]] = None, + alpha: float = 0.4, + cmap: str = "cool", + title: str = "Embedding", + figsize: Tuple[int] = (5, 5), + dpi: int = 100, + **kwargs, +) -> plotly.graph_objects.Figure: + """Plot embedding in a 3D dimensional space. + + This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of + dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1). + + The function makes use of :py:func:`plotly.graph_objs._scatter.Scatter` and parameters from that function can be provided + as part of ``kwargs``. + + + Args: + embedding: A matrix containing the feature representation computed with CEBRA. + embedding_labels: The labels used to map the data to color. It can be: + + * A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous. + * A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color. + axis: Optional axis to create the plot on. + idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D + embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those + dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation + (e.g., (2, 4, 5)). + markersize: The marker size. + alpha: The marker blending, between 0 (transparent) and 1 (opaque). + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). + title: The title on top of the embedding. + figsize: Figure width and height in inches. + dpi: Figure resolution. + kwargs: Optional arguments to customize the plots. See :py:func:`plotly.graph_objs._scatter.Scatter` documentation for more + details on which arguments to use. + + Returns: + The plotly figure. + + + Example: + + >>> import cebra + >>> import numpy as np + >>> X = np.random.uniform(0, 1, (100, 50)) + >>> y = np.random.uniform(0, 10, (100, 5)) + >>> cebra_model = cebra.CEBRA(max_iterations=10) + >>> cebra_model.fit(X, y) + CEBRA(max_iterations=10) + >>> embedding = cebra_model.transform(X) + >>> cebra_time = np.arange(X.shape[0]) + >>> fig = cebra.integrations.plotly.plot_embedding_interactive(embedding, embedding_labels=cebra_time) + + """ + return _EmbeddingInteractivePlot( + embedding=embedding, + embedding_labels=embedding_labels, + axis=axis, + idx_order=idx_order, + markersize=markersize, + alpha=alpha, + cmap=cmap, + title=title, + figsize=figsize, + dpi=dpi, + ).plot(**kwargs) diff --git a/setup.cfg b/setup.cfg index 4283daea..4bdf8b77 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,6 +48,7 @@ datasets = integrations = jupyter pandas + plotly docs = sphinx==5.3 sphinx-gallery==0.10.1 diff --git a/tests/test_plotly.py b/tests/test_plotly.py new file mode 100644 index 00000000..b999cf39 --- /dev/null +++ b/tests/test_plotly.py @@ -0,0 +1,65 @@ +import matplotlib +import numpy as np +import plotly.graph_objects as go +import pytest +from plotly.subplots import make_subplots + +import cebra.integrations.plotly as cebra_plotly +import cebra.integrations.sklearn.cebra as cebra_sklearn_cebra + + +@pytest.mark.parametrize("cmap", ["viridis", "plasma", "inferno", "magma"]) +def test_colorscale(cmap): + cmap = matplotlib.cm.get_cmap(cmap) + colorscale = cebra_plotly._convert_cmap2colorscale(cmap) + assert isinstance(colorscale, list) + + +@pytest.mark.parametrize("output_dimension, idx_order", [(8, (2, 3, 4)), + (3, (0, 1, 2))]) +def test_plot_embedding(output_dimension, idx_order): + # example dataset + X = np.random.uniform(0, 1, (1000, 50)) + y = np.random.uniform(0, 1, (1000,)) + + # integration tests + model = cebra_sklearn_cebra.CEBRA(max_iterations=10, + batch_size=512, + output_dimension=output_dimension) + + model.fit(X) + embedding = model.transform(X) + + fig = cebra_plotly.plot_embedding_interactive(embedding=embedding, + embedding_labels=y) + assert isinstance(fig, go.Figure) + assert len(fig.data) == 1 + + fig.layout = {} + fig.data = [] + + fig_subplots = make_subplots( + rows=2, + cols=2, + specs=[ + [{ + "type": "scatter3d" + }, { + "type": "scatter3d" + }], + [{ + "type": "scatter3d" + }, { + "type": "scatter3d" + }], + ], + ) + + fig_subplots = cebra_plotly.plot_embedding_interactive(axis=fig_subplots, + embedding=embedding, + embedding_labels=y, + row=1, + col=1) + + fig_subplots.data = [] + fig_subplots.layout = {} From bbce14caa7342f06f303d5275cbadb55ede828e7 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Mon, 2 Oct 2023 23:14:28 +0200 Subject: [PATCH 2/5] Add integrations to test dependencies --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fe28edae..251767fa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,7 +3,7 @@ name: Python package on: push: branches: - - main + - main pull_request: branches: - main @@ -53,7 +53,7 @@ jobs: run: | python -m pip install --upgrade pip setuptools wheel python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu - pip install '.[dev,datasets]' + pip install '.[dev,datasets,integrations]' - name: Run the formatter run: | From 7df0e98f4faba21ba2c57b7de82cb1a6a1c345e9 Mon Sep 17 00:00:00 2001 From: nastya236 Date: Tue, 3 Oct 2023 00:34:28 +0200 Subject: [PATCH 3/5] fixed an issue with figsize --- cebra/integrations/plotly.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index 167a340a..bee6f65b 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -42,6 +42,7 @@ def _convert_cmap2colorscale(cmap: str, pl_entries: int = 11, rdigits: int = 2): class _EmbeddingInteractivePlot(_EmbeddingPlot): def __init__(self, **kwargs): + self.figsize = kwargs.get("figsize", (5, 5)) super().__init__(**kwargs) self.colorscale = self._define_colorscale(self.cmap) @@ -56,7 +57,6 @@ def _define_ax(self, axis: Optional[plotly.graph_objects.Figure]): """ if axis is None: - print(self.figsize[0]) self.axis = plotly.graph_objects.Figure( layout=plotly.graph_objects.Layout(height=100 * self.figsize[0], width=100 * self.figsize[1])) From 76ba40b7e08c950100e4a971c806ab48979cda6f Mon Sep 17 00:00:00 2001 From: nastya236 Date: Tue, 3 Oct 2023 01:15:01 +0200 Subject: [PATCH 4/5] Change default params, delete typo --- cebra/integrations/plotly.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index bee6f65b..4c6556a5 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -34,15 +34,15 @@ def _convert_cmap2colorscale(cmap: str, pl_entries: int = 11, rdigits: int = 2): """ scale = np.linspace(0, 1, pl_entries) colors = (cmap(scale)[:, :3] * 255).astype(np.uint8) - pl_colorscale = [[round(s, rdigits), f"rgb{tuple(color)}"] - for s, color in zip(scale, colors)] + pl_colorscale = [ + [round(s, rdigits), f"rgb{tuple(color)}"] for s, color in zip(scale, colors) + ] return pl_colorscale class _EmbeddingInteractivePlot(_EmbeddingPlot): - def __init__(self, **kwargs): - self.figsize = kwargs.get("figsize", (5, 5)) + self.figsize = kwargs.get("figsize") super().__init__(**kwargs) self.colorscale = self._define_colorscale(self.cmap) @@ -58,8 +58,10 @@ def _define_ax(self, axis: Optional[plotly.graph_objects.Figure]): if axis is None: self.axis = plotly.graph_objects.Figure( - layout=plotly.graph_objects.Layout(height=100 * self.figsize[0], - width=100 * self.figsize[1])) + layout=plotly.graph_objects.Layout( + height=100 * self.figsize[0], width=100 * self.figsize[1] + ) + ) else: self.axis = axis @@ -121,7 +123,7 @@ def plot_embedding_interactive( embedding: Union[npt.NDArray, torch.Tensor], embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey", axis: Optional[plotly.graph_objects.Figure] = None, - markersize: float = 0.05, + markersize: float = 1, idx_order: Optional[Tuple[int]] = None, alpha: float = 0.4, cmap: str = "cool", @@ -171,7 +173,6 @@ def plot_embedding_interactive( >>> y = np.random.uniform(0, 10, (100, 5)) >>> cebra_model = cebra.CEBRA(max_iterations=10) >>> cebra_model.fit(X, y) - CEBRA(max_iterations=10) >>> embedding = cebra_model.transform(X) >>> cebra_time = np.arange(X.shape[0]) >>> fig = cebra.integrations.plotly.plot_embedding_interactive(embedding, embedding_labels=cebra_time) From 75d27ba24943d127e16a9ee2e16f80312303bc8e Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 3 Oct 2023 13:33:30 +0200 Subject: [PATCH 5/5] Fix example in plotly --- cebra/integrations/plotly.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index 4c6556a5..d3edf509 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -34,13 +34,13 @@ def _convert_cmap2colorscale(cmap: str, pl_entries: int = 11, rdigits: int = 2): """ scale = np.linspace(0, 1, pl_entries) colors = (cmap(scale)[:, :3] * 255).astype(np.uint8) - pl_colorscale = [ - [round(s, rdigits), f"rgb{tuple(color)}"] for s, color in zip(scale, colors) - ] + pl_colorscale = [[round(s, rdigits), f"rgb{tuple(color)}"] + for s, color in zip(scale, colors)] return pl_colorscale class _EmbeddingInteractivePlot(_EmbeddingPlot): + def __init__(self, **kwargs): self.figsize = kwargs.get("figsize") super().__init__(**kwargs) @@ -58,10 +58,8 @@ def _define_ax(self, axis: Optional[plotly.graph_objects.Figure]): if axis is None: self.axis = plotly.graph_objects.Figure( - layout=plotly.graph_objects.Layout( - height=100 * self.figsize[0], width=100 * self.figsize[1] - ) - ) + layout=plotly.graph_objects.Layout(height=100 * self.figsize[0], + width=100 * self.figsize[1])) else: self.axis = axis @@ -173,6 +171,7 @@ def plot_embedding_interactive( >>> y = np.random.uniform(0, 10, (100, 5)) >>> cebra_model = cebra.CEBRA(max_iterations=10) >>> cebra_model.fit(X, y) + CEBRA(max_iterations=10) >>> embedding = cebra_model.transform(X) >>> cebra_time = np.arange(X.shape[0]) >>> fig = cebra.integrations.plotly.plot_embedding_interactive(embedding, embedding_labels=cebra_time)