From 6cdd0359e0bca988fa1016f3e7156fae5ef05c81 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 26 Apr 2022 13:28:17 +0200 Subject: [PATCH 01/18] network plots --- deeptime/plots/network.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 deeptime/plots/network.py diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py new file mode 100644 index 000000000..248867b09 --- /dev/null +++ b/deeptime/plots/network.py @@ -0,0 +1,18 @@ +from scipy.sparse import issparse + +from deeptime.util.decorators import plotting_function + + +def _explicit_layout(graph, positions): + return dict(zip(graph, positions)) + + +@plotting_function(requires_networkx=True) +def plot_adjacency(adjacency_matrix, positions=None): + import networkx as nx + if issparse(adjacency_matrix): + graph = nx.from_scipy_sparse_matrix(adjacency_matrix) + else: + graph = nx.from_numpy_matrix(adjacency_matrix) + if positions is not None: + layout = ... From 345754311f0ec1b04f03d7e80c92f3880968cb0c Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Wed, 27 Apr 2022 09:56:43 +0200 Subject: [PATCH 02/18] lift import --- deeptime/decomposition/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deeptime/decomposition/__init__.py b/deeptime/decomposition/__init__.py index f3c37e14f..704742618 100644 --- a/deeptime/decomposition/__init__.py +++ b/deeptime/decomposition/__init__.py @@ -66,6 +66,7 @@ vamp_score_cv cvsplit_trajs + blocksplit_trajs deep.TVAEEncoder deep.koopman_matrix @@ -76,7 +77,7 @@ deep.kvad_score """ -from ._score import vamp_score, vamp_score_data, vamp_score_cv, cvsplit_trajs +from ._score import vamp_score, vamp_score_data, vamp_score_cv, cvsplit_trajs, blocksplit_trajs from ._tica import TICA from ._vamp import VAMP from ._koopman import TransferOperatorModel, CovarianceKoopmanModel From 1a55018f5e0820bee14dbc1b8b3b247037d787e1 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Wed, 27 Apr 2022 10:13:29 +0200 Subject: [PATCH 03/18] lower version bounds for dependencies --- pyproject.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 757c2fb98..af60350b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,12 @@ authors = [ {name = 'Frank NoƩ'} ] requires-python = ">= 3.7" -dependencies = ['numpy', 'scipy', 'scikit-learn', 'threadpoolctl'] +dependencies = [ + 'numpy>=1.17.3', + 'scipy>=1.1.0', + 'scikit-learn>=0.14.1', + 'threadpoolctl>=2.0.0' +] dynamic = ['version'] [project.urls] From 603ebea1cfd85f684a5ebdbc7047e615da0c001b Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Wed, 27 Apr 2022 10:14:45 +0200 Subject: [PATCH 04/18] units? --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index af60350b9..b01e7a7f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ authors = [ ] requires-python = ">= 3.7" dependencies = [ - 'numpy>=1.17.3', + 'numpy>=1.19.5', 'scipy>=1.1.0', 'scikit-learn>=0.14.1', 'threadpoolctl>=2.0.0' @@ -38,6 +38,7 @@ download = "https://pypi.org/project/deeptime/#files" [project.optional-dependencies] deep-learning = ['pytorch'] plotting = ['matplotlib', 'networkx'] +units = ['pint>=0.19.1'] [build-system] requires = [ From 2a15b8fe26bce4c3822bfa1c94f72d7a1761d42a Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Thu, 28 Apr 2022 11:08:44 +0200 Subject: [PATCH 05/18] minor --- deeptime/covariance/util/_moments.py | 4 ++-- deeptime/plots/energy.py | 14 +++++++++++++- examples/methods/plot_energy_surface.py | 4 ++-- tests/plots/test_energy_surface.py | 6 +++++- tests/requirements.txt | 1 + 5 files changed, 23 insertions(+), 6 deletions(-) diff --git a/deeptime/covariance/util/_moments.py b/deeptime/covariance/util/_moments.py index 21558f4fa..8b2805a14 100644 --- a/deeptime/covariance/util/_moments.py +++ b/deeptime/covariance/util/_moments.py @@ -50,7 +50,7 @@ using the dot product on sliced arrays (option 2). Exceptions are when the data is extremely sparse, such that only a few columns are selected. * Copying and subselecting columns (option 3) is only faster than the full - dot product (option 1), if 50% or less columns are selected. This observation + dot product (option 1), if 50% or fewer columns are selected. This observation is roughly independent of N. * The observations above are valid for matrices (T x N) that are sufficiently large. We assume that "sufficiently large" means that they don't fully fit @@ -66,7 +66,7 @@ number of constant column candidates drops below the minimum number, to avoid wasting time on the decision. 3. Subselect the desired columns and copy the data to a new array X0 (Y0). - 4. Run operation on the new array X0 (Y0), including in-place substraction + 4. Run operation on the new array X0 (Y0), including in-place subtraction of the mean if needed. """ diff --git a/deeptime/plots/energy.py b/deeptime/plots/energy.py index be21fa8b0..d944f1cb6 100644 --- a/deeptime/plots/energy.py +++ b/deeptime/plots/energy.py @@ -3,7 +3,11 @@ class Energy2dPlot: - r""" The result of a :meth:`plot_energy2d` call. + r""" The result of a :meth:`plot_energy2d` call. Instances of this class can be unpacked like a tuple: + + >>> import numpy as np + >>> from deeptime.util import energy2d + >>> ax, contour, cbar = plot_energy2d(energy2d(*np.random.uniform(size=(100 ,2)).T)) Parameters ---------- @@ -25,6 +29,14 @@ def __init__(self, ax, contour, colorbar=None): self.contour = contour self.colorbar = colorbar + # makes object unpackable + def __len__(self): + return 3 + + # makes object unpackable + def __iter__(self): + return iter((self.ax, self.contour, self.colorbar)) + @plotting_function() def plot_energy2d(energies: EnergyLandscape2d, ax=None, levels=100, contourf_kws=None, cbar=True, diff --git a/examples/methods/plot_energy_surface.py b/examples/methods/plot_energy_surface.py index b7c0cd913..873e8c8c1 100644 --- a/examples/methods/plot_energy_surface.py +++ b/examples/methods/plot_energy_surface.py @@ -21,5 +21,5 @@ weights = msm.compute_trajectory_weights(np.concatenate(dtrajs))[0] energies = energy2d(*traj_concat.T, bins=(80, 20), kbt=1, weights=weights, shift_energy=True) -plot = plot_energy2d(energies, contourf_kws=dict(cmap='nipy_spectral')) -plot.colorbar.set_label('energy / kT') +ax, contour, cbar = plot_energy2d(energies, contourf_kws=dict(cmap='nipy_spectral')) +cbar.set_label('energy / kT') diff --git a/tests/plots/test_energy_surface.py b/tests/plots/test_energy_surface.py index 2b9baa059..65322a21b 100644 --- a/tests/plots/test_energy_surface.py +++ b/tests/plots/test_energy_surface.py @@ -1,5 +1,6 @@ import matplotlib import pytest +from numpy.testing import assert_ from deeptime.data import ellipsoids from deeptime.plots import plot_energy2d @@ -13,4 +14,7 @@ def test_energy2d(shift_energy, cbar): traj = ellipsoids().observations(20000) data = energy2d(*traj.T, bins=100, shift_energy=shift_energy) - plot_energy2d(data, cbar=cbar) + ax, contourf, cbar = plot_energy2d(data, cbar=cbar) + assert_(ax is not None) + assert_(contourf is not None) + assert_(cbar is not None if cbar else cbar is None) diff --git a/tests/requirements.txt b/tests/requirements.txt index 62091cee8..d639a61cb 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,6 +3,7 @@ typing_extensions; python_version<"3.8" pybind11==2.9.1 torch>=1.10.0; python_version<"3.10" matplotlib +pint pytest==7.1.1 pytest-cov==3.0.0 From 5e44df3f37f10365409adb97582df5227e6e9bd7 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Thu, 28 Apr 2022 16:29:01 +0200 Subject: [PATCH 06/18] shortcut methods --- deeptime/plots/chapman_kolmogorov.py | 1 + deeptime/util/stats.py | 9 +++++++++ deeptime/util/validation.py | 9 ++++++++- examples/methods/plot_energy_surface.py | 3 +-- tests/plots/test_ck_test.py | 2 +- tests/plots/test_energy_surface.py | 4 +--- 6 files changed, 21 insertions(+), 7 deletions(-) diff --git a/deeptime/plots/chapman_kolmogorov.py b/deeptime/plots/chapman_kolmogorov.py index b343fe150..1b88eb7a7 100644 --- a/deeptime/plots/chapman_kolmogorov.py +++ b/deeptime/plots/chapman_kolmogorov.py @@ -93,6 +93,7 @@ def n_tests(self): return len(self._tests) +@plotting_function() def plot_ck_test(data: ChapmanKolmogorovTest, height=2.5, aspect=1., conf: float = 0.95, color=None, grid: CKTestGrid = None, legend=True, xlabel='lagtime (steps)', ylabel='probability', y01=True, sharey=True, **plot_kwargs): diff --git a/deeptime/util/stats.py b/deeptime/util/stats.py index 523649b8d..22b247c77 100644 --- a/deeptime/util/stats.py +++ b/deeptime/util/stats.py @@ -3,6 +3,7 @@ import numpy as np +from deeptime.util.decorators import plotting_function from deeptime.util.types import ensure_array @@ -384,6 +385,14 @@ def __init__(self, x_meshgrid, y_meshgrid, energies, kbt): self.energies = energies self.kbt = kbt + @plotting_function() + def plot(self, ax=None, levels=100, contourf_kws=None, cbar=True, + cbar_kws=None, cbar_ax=None): + r""" Plot estimated energy landscape directly. See :func:`deeptime.plots.plot_energy2d`. """ + from deeptime.plots import plot_energy2d + return plot_energy2d(self, ax=ax, levels=levels, contourf_kws=contourf_kws, cbar=cbar, + cbar_kws=cbar_kws, cbar_ax=cbar_ax) + def energy2d(x: np.ndarray, y: np.ndarray, bins=100, kbt: float = 1., weights=None, shift_energy=True): r""" Compute a two-dimensional energy landscape based on data arrays `x` and `y`. diff --git a/deeptime/util/validation.py b/deeptime/util/validation.py index 36b4576ec..231e676c6 100644 --- a/deeptime/util/validation.py +++ b/deeptime/util/validation.py @@ -190,7 +190,7 @@ def n_samples(self, lagtime_index: int, process_index: int) -> int: def plot(self, *args, **kw): r""" Dispatches to :meth:`plot_implied_timescales`. """ from deeptime.plots import plot_implied_timescales - plot_implied_timescales(self, *args, **kw) + return plot_implied_timescales(self, *args, **kw) def ck_test(models, observable: Observable, test_model=None, include_lag0=True, err_est=False, progress=None): @@ -358,6 +358,13 @@ def err_est(self): r""" Whether the estimated models contain samples """ return self.estimates_samples is not None and len(self.estimates_samples) > 0 + def plot(self, height=2.5, aspect=1., conf: float = 0.95, color=None, grid = None, legend=True, + xlabel='lagtime (steps)', ylabel='probability', y01=True, sharey=True, **plot_kwargs): + r""" Shortcut to :func:`deeptime.plots.plot_ck_test`. """ + from deeptime.plots import plot_ck_test + return plot_ck_test(self, height=height, aspect=aspect, conf=conf, color=color, grid=grid, legend=legend, + xlabel=xlabel, ylabel=ylabel, y01=y01, sharey=sharey, **plot_kwargs) + class DeprecatedCKValidator(Estimator): diff --git a/examples/methods/plot_energy_surface.py b/examples/methods/plot_energy_surface.py index 873e8c8c1..a46fc5d62 100644 --- a/examples/methods/plot_energy_surface.py +++ b/examples/methods/plot_energy_surface.py @@ -11,7 +11,6 @@ from deeptime.clustering import KMeans from deeptime.markov.msm import MaximumLikelihoodMSM from deeptime.util import energy2d -from deeptime.plots import plot_energy2d trajs = triple_well_2d(h=1e-3, n_steps=100).trajectory(x0=[[-1, 0], [1, 0], [0, 0]], length=5000) traj_concat = np.concatenate(trajs, axis=0) @@ -21,5 +20,5 @@ weights = msm.compute_trajectory_weights(np.concatenate(dtrajs))[0] energies = energy2d(*traj_concat.T, bins=(80, 20), kbt=1, weights=weights, shift_energy=True) -ax, contour, cbar = plot_energy2d(energies, contourf_kws=dict(cmap='nipy_spectral')) +ax, contour, cbar = energies.plot(contourf_kws=dict(cmap='nipy_spectral')) cbar.set_label('energy / kT') diff --git a/tests/plots/test_ck_test.py b/tests/plots/test_ck_test.py index ddc121f61..9d68f4560 100644 --- a/tests/plots/test_ck_test.py +++ b/tests/plots/test_ck_test.py @@ -27,7 +27,7 @@ def test_sanity_msm(hidden, bayesian): cktest = test_model.ck_test(models, n_metastable_sets=2) else: cktest = test_model.ck_test(models) - plot_ck_test(cktest, conf=1) + cktest.plot(conf=1) @pytest.mark.parametrize("fractional", [False, True]) diff --git a/tests/plots/test_energy_surface.py b/tests/plots/test_energy_surface.py index 65322a21b..fb4673f78 100644 --- a/tests/plots/test_energy_surface.py +++ b/tests/plots/test_energy_surface.py @@ -3,7 +3,6 @@ from numpy.testing import assert_ from deeptime.data import ellipsoids -from deeptime.plots import plot_energy2d from deeptime.util import energy2d matplotlib.use('Agg') @@ -13,8 +12,7 @@ @pytest.mark.parametrize('cbar', [True, False], ids=lambda x: f"cbar={x}") def test_energy2d(shift_energy, cbar): traj = ellipsoids().observations(20000) - data = energy2d(*traj.T, bins=100, shift_energy=shift_energy) - ax, contourf, cbar = plot_energy2d(data, cbar=cbar) + ax, contourf, cbar = energy2d(*traj.T, bins=100, shift_energy=shift_energy).plot(cbar=cbar) assert_(ax is not None) assert_(contourf is not None) assert_(cbar is not None if cbar else cbar is None) From 4742511dc2614794dbfe7a49959017c8f36589f7 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Thu, 28 Apr 2022 17:27:55 +0200 Subject: [PATCH 07/18] f --- deeptime/plots/network.py | 77 +++++++++++++++++++++++++++++++++++-- tests/plots/plot_network.py | 15 ++++++++ 2 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 tests/plots/plot_network.py diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index 248867b09..aea99350d 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -1,18 +1,87 @@ +import math +from typing import Optional + +import numpy as np from scipy.sparse import issparse from deeptime.util.decorators import plotting_function -def _explicit_layout(graph, positions): - return dict(zip(graph, positions)) +def _draw_arrow( + ax, x1, y1, x2, y2, Dx, Dy, label="", width=1.0, arrow_curvature=1.0, color="grey", + patchA=None, patchB=None, shrinkA=0, shrinkB=0, arrow_label_size=None): + """ + Draws a slightly curved arrow from (x1,y1) to (x2,y2). + Will allow the given patches at start end end. + + """ + # set arrow properties + dist = math.sqrt(((x2 - x1) / float(Dx)) ** 2 + ((y2 - y1) / float(Dy)) ** 2) + arrow_curvature *= 0.075 # standard scale + rad = arrow_curvature / dist + tail_width = width + head_width = max(0.5, 2 * width) + head_length = head_width + ax.annotate( + "", xy=(x2, y2), xycoords='data', xytext=(x1, y1), textcoords='data', + arrowprops=dict( + arrowstyle='simple,head_length=%f,head_width=%f,tail_width=%f' % ( + head_length, head_width, tail_width), + color=color, shrinkA=shrinkA, shrinkB=shrinkB, patchA=patchA, patchB=patchB, + connectionstyle="arc3,rad=%f" % -rad), + zorder=0) + # weighted center position + center = np.array([0.55 * x1 + 0.45 * x2, 0.55 * y1 + 0.45 * y2]) + v = np.array([x2 - x1, y2 - y1]) # 1->2 vector + vabs = np.abs(v) + vnorm = np.array([v[1], -v[0]]) # orthogonal vector + vnorm = np.divide(vnorm, np.linalg.norm(vnorm)) # normalize + # cross product to determine the direction into which vnorm points + z = np.cross(v, vnorm) + if z < 0: + vnorm *= -1 + offset = 0.5 * arrow_curvature * \ + ((vabs[0] / (vabs[0] + vabs[1])) + * Dx + (vabs[1] / (vabs[0] + vabs[1])) * Dy) + ptext = center + offset * vnorm + ax.text( + ptext[0], ptext[1], label, size=arrow_label_size, + horizontalalignment='center', verticalalignment='center', zorder=1) @plotting_function(requires_networkx=True) -def plot_adjacency(adjacency_matrix, positions=None): +def plot_adjacency(adjacency_matrix, positions: Optional[np.ndarray] = None, layout=None, ax=None, node_size=None, + self_loops=False, curved=True): import networkx as nx + if ax is None: + import matplotlib.pyplot as plt + ax = plt.gca() + if positions is not None: + if positions.ndim != 2 or positions.shape[0] != adjacency_matrix.shape[0] or positions.shape[1] != 2: + raise ValueError(f"Unsupported positions array. Has to be ({adjacency_matrix.shape[0]}, 2)-shaped but " + f"was of shape {positions.shape}.") + if issparse(adjacency_matrix): graph = nx.from_scipy_sparse_matrix(adjacency_matrix) else: graph = nx.from_numpy_matrix(adjacency_matrix) if positions is not None: - layout = ... + def layout(g): + return dict(zip(g.nodes(), positions)) + else: + layout = nx.spring_layout if layout is None else layout + assert layout is not None + pos = layout(graph) + + if not self_loops: + graph.remove_edges_from(nx.selfloop_edges(graph)) + nx.draw_networkx_nodes(graph, pos, node_size=node_size, ax=ax) + if curved: + Dx = max(x[0] for x in pos.values()) - min(x[0] for x in pos.values()) + Dy = max(x[1] for x in pos.values()) - min(x[1] for x in pos.values()) + edges = graph.edges() + for (e1, e2) in edges: + _draw_arrow(ax, *pos[e1].T, *pos[e2].T, Dx, Dy) + else: + nx.draw_networkx_edges(graph, pos, ax=ax) + return ax, graph diff --git a/tests/plots/plot_network.py b/tests/plots/plot_network.py new file mode 100644 index 000000000..c72910711 --- /dev/null +++ b/tests/plots/plot_network.py @@ -0,0 +1,15 @@ +import numpy as np + +import matplotlib.pyplot as plt + +from deeptime.markov.msm import MarkovStateModel +from deeptime.plots.network import plot_adjacency + + +def test_sanity(): + X = np.random.uniform(size=(5, 5)) + X /= X.sum(1)[:, None] + msm = MarkovStateModel(X) + positions = np.array([[-1, -1], [0, 0], [1.5, 3], [3., 1.5], [-1., 4.]]) + plot_adjacency(msm.transition_matrix, positions=positions) + plt.show() From 79a4af8c05eacd06a1cac5cb873e546abf843106 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Mon, 2 May 2022 16:28:24 +0200 Subject: [PATCH 08/18] network plots wip --- deeptime/plots/network.py | 279 +++++++++++++++++++++++++++++------- deeptime/plots/util.py | 14 ++ tests/plots/plot_network.py | 32 +++-- 3 files changed, 263 insertions(+), 62 deletions(-) diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index aea99350d..2200eb212 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -1,57 +1,239 @@ -import math -from typing import Optional +from typing import Optional, Union, Dict, Tuple import numpy as np from scipy.sparse import issparse +from deeptime.plots.util import default_image_cmap, default_line_width from deeptime.util.decorators import plotting_function +from deeptime.util.types import ensure_number_array -def _draw_arrow( - ax, x1, y1, x2, y2, Dx, Dy, label="", width=1.0, arrow_curvature=1.0, color="grey", - patchA=None, patchB=None, shrinkA=0, shrinkB=0, arrow_label_size=None): - """ - Draws a slightly curved arrow from (x1,y1) to (x2,y2). - Will allow the given patches at start end end. +class NetworkPlot: + r"""Plot of a network with nodes and arcs. + + Parameters + ---------- + adjacency_matrix : ndarray + weight matrix or adjacency matrix of the network to visualize + pos : ndarray or dict[int, ndarray] + user-defined positions as (n,2) array + + Examples + -------- + We define first define a reactive flux by taking the following transition + matrix and computing TPT from state 2 to 3. + + >>> import numpy as np + >>> P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], + ... [0.1, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.1, 0.8, 0.0, 0.05], + ... [0.0, 0.2, 0.0, 0.8, 0.0], + ... [0.0, 0.02, 0.02, 0.0, 0.96]]) + >>> from deeptime.markov.msm import MarkovStateModel + >>> flux = MarkovStateModel(P).reactive_flux([2], [3]) + now plot the gross flux + >>> import networkx as nx + >>> positions = nx.spring_layout(nx.from_numpy_array(flux.gross_flux)) + >>> NetworkPlot(flux.gross_flux, positions).plot_network() # doctest: +ELLIPSIS + <...Figure... """ - # set arrow properties - dist = math.sqrt(((x2 - x1) / float(Dx)) ** 2 + ((y2 - y1) / float(Dy)) ** 2) - arrow_curvature *= 0.075 # standard scale - rad = arrow_curvature / dist - tail_width = width - head_width = max(0.5, 2 * width) - head_length = head_width - ax.annotate( - "", xy=(x2, y2), xycoords='data', xytext=(x1, y1), textcoords='data', - arrowprops=dict( - arrowstyle='simple,head_length=%f,head_width=%f,tail_width=%f' % ( - head_length, head_width, tail_width), - color=color, shrinkA=shrinkA, shrinkB=shrinkB, patchA=patchA, patchB=patchB, - connectionstyle="arc3,rad=%f" % -rad), - zorder=0) - # weighted center position - center = np.array([0.55 * x1 + 0.45 * x2, 0.55 * y1 + 0.45 * y2]) - v = np.array([x2 - x1, y2 - y1]) # 1->2 vector - vabs = np.abs(v) - vnorm = np.array([v[1], -v[0]]) # orthogonal vector - vnorm = np.divide(vnorm, np.linalg.norm(vnorm)) # normalize - # cross product to determine the direction into which vnorm points - z = np.cross(v, vnorm) - if z < 0: - vnorm *= -1 - offset = 0.5 * arrow_curvature * \ - ((vabs[0] / (vabs[0] + vabs[1])) - * Dx + (vabs[1] / (vabs[0] + vabs[1])) * Dy) - ptext = center + offset * vnorm - ax.text( - ptext[0], ptext[1], label, size=arrow_label_size, - horizontalalignment='center', verticalalignment='center', zorder=1) + + def __init__(self, adjacency_matrix, pos): + self.adjacency_matrix = adjacency_matrix + self.pos = pos + + @property + def pos(self) -> np.ndarray: + return self._pos + + @pos.setter + def pos(self, value: Union[np.ndarray, Dict[int, np.ndarray]]): + if len(value) < self.n_nodes: + raise ValueError(f'Given less positions ({len(value)}) than states ({self.n_nodes})') + if isinstance(value, dict): + value = np.stack((value[i] for i in range(len(value)))) + self._pos = value + + @property + def n_nodes(self): + return self.adjacency_matrix.shape[0] + + @property + def bounds(self) -> Tuple[Tuple[float, float], Tuple[float, float]]: + return ((np.min(self.pos[:, 0]), np.max(self.pos[:, 0])), + (np.min(self.pos[:, 1]), np.max(self.pos[:, 1]))) + + @property + def d_x(self): + return self.bounds[0][1] - self.bounds[0][0] + + @property + def d_y(self): + return self.bounds[1][1] - self.bounds[1][0] + + @staticmethod + def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, color="grey", + patchA=None, patchB=None, shrinkA=2, shrinkB=1, arrow_label_size=None, arrow_label_location=.55): + r""" + Draws a slightly curved arrow from (x1,y1) to (x2,y2). + Will allow the given patches at start and end. + """ + from matplotlib import patches + + # set arrow properties + dist = np.linalg.norm(pos_2 - pos_1) + arrow_curvature *= 0.075 # standard scale + rad = arrow_curvature / dist + tail_width = width + head_width = max(2., 2 * width) + # head_length = max(1.5, head_width) + + arrow_style = patches.ArrowStyle.Simple(head_length=head_width, head_width=head_width, tail_width=tail_width) + connection_style = patches.ConnectionStyle.Arc3(rad=-rad) + arr = patches.FancyArrowPatch(posA=pos_1, posB=pos_2, arrowstyle=arrow_style, + connectionstyle=connection_style, color=color, + shrinkA=shrinkA, shrinkB=shrinkB, patchA=patchA, patchB=patchB, + zorder=0, transform=ax.transData) + ax.add_patch(arr) + + # Bezier control point + control_vertex = np.array(arr.get_connectionstyle().connect(pos_1, pos_2).vertices[1]) + # quadratic Bezier at slightly shifted midpoint t = arrow_label_location + t = arrow_label_location # shorthand + ptext = (1 - t)**2 * pos_1 + 2 * (1 - t) * t * control_vertex + t**2 * pos_2 + + ax.text( + ptext[0], ptext[1], label, size=arrow_label_size, + horizontalalignment='center', verticalalignment='center', zorder=1, + transform=ax.transData) + + def plot_network(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500', state_labels='auto', + arrow_scale=1.0, arrow_curvature=1.0, arrow_labels='weights', arrow_label_format='{:.1e}', + arrow_label_location=0.55, cmap=None, **textkwargs): + r"""Draws a network using discs and curved arrows. + + The thicknesses and labels of the arrows are taken from the off-diagonal matrix elements + in A. + + Parameters + ---------- + ax : matplotlib.axes.Axes, optional, default=None + The axes to plot on. + state_sizes : ndarray, optional, default=None + List of state sizes to plot, must be of length `n_states`. It effectively evaluates as + + .. math:: + \frac{\mathrm{state\_scale} \min (d_x, d_y)^2}{2 n_\mathrm{nodes}} + \frac{\mathrm{state_sizes}}{\| \mathrm{state_sizes} \|_{\max}} + + with :math:`\mathrm{state\_sizes}` interpreted as `1` in case of `None`. In particular this means that + the states scale their size with respect to the volume. I.e., `state_scale=[1, 2]` leads to the second + state drawn as a circle with twice the volume of the first. + state_scale : float, default=1. + Uniform scaling factor for `state_sizes`. + state_colors : str or list of float, default='#ff5500' + The color to use for drawing states. If given as list of float, uses the colormap (`cmap` argument) to + determine the color. + state_labels : None or 'auto' or list of str, default='auto' + The state labels. If 'auto', just enumerates the states. In case of `None` no state labels are depicted, + otherwise assigns each state its label based on the list. + """ + # Set the default values for the text dictionary + from matplotlib import pyplot as plt + from matplotlib import colors + from matplotlib import patches + if ax is None: + ax = plt.gca() + if cmap is None: + cmap = default_image_cmap() + textkwargs.setdefault('size', None) + textkwargs.setdefault('horizontalalignment', 'center') + textkwargs.setdefault('verticalalignment', 'center') + textkwargs.setdefault('color', 'black') + # remove the temporary key 'arrow_label_size' as it cannot be parsed by plt.text! + arrow_label_size = textkwargs.pop('arrow_label_size', textkwargs['size']) + # sizes of nodes + if state_sizes is None: + state_sizes = 0.5 * state_scale * min(self.d_x, self.d_y) ** 2 * np.ones(self.n_nodes) / float(self.n_nodes) + else: + state_sizes = 0.5 * state_scale * min(self.d_x, self.d_y) ** 2 * state_sizes \ + / (np.max(state_sizes) * float(self.n_nodes)) + # automatic arrow rescaling + diag_mask = np.logical_not(np.eye(self.adjacency_matrix.shape[0], dtype=bool)) + default_lw = default_line_width() + arrow_scale *= 2 * default_lw / np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)]) + + # set node labels + if state_labels is None: + pass + elif isinstance(state_labels, str) and state_labels == 'auto': + state_labels = [str(i) for i in np.arange(self.n_nodes)] + else: + if len(state_labels) != self.n_nodes: + raise ValueError(f"length of state_labels({len(state_labels)}) has to match " + f"length of states({self.n_nodes}).") + # set node colors + if isinstance(state_colors, str): + state_colors = np.array([colors.to_rgb(state_colors)] * self.n_nodes) + else: + state_colors = ensure_number_array(state_colors, ndim=1) + if not isinstance(cmap, colors.Colormap): + cmap = plt.get_cmap(cmap) + assert isinstance(cmap, colors.Colormap) + state_colors = np.array([cmap(x) for x in state_colors]) + if len(state_colors) != self.n_nodes: + raise ValueError(f"Mismatch between n_states and #state_colors ({self.n_nodes} vs {len(state_colors)}).") + + # set arrow labels + if isinstance(arrow_labels, np.ndarray): + L = arrow_labels + if isinstance(arrow_labels[0, 0], str): + arrow_label_format = '{}' + elif isinstance(arrow_labels, str) and arrow_labels.lower() == 'weights': + L = np.copy(self.adjacency_matrix) + elif arrow_labels is None: + L = np.full(self.adjacency_matrix.shape, fill_value='', dtype=object) + arrow_label_format = '{}' + else: + raise ValueError('invalid arrow labels') + + # draw circles + circles = [] + for i in range(self.n_nodes): + circles.append( + patches.Circle(self.pos[i], radius=np.sqrt(0.5 * state_sizes[i]) / 2.0, + color=state_colors[i], zorder=2) + ) + ax.add_patch(circles[-1]) + + # add annotation + if state_labels is not None: + ax.text(self.pos[i][0], self.pos[i][1], state_labels[i], zorder=3, **textkwargs) + + assert len(circles) == self.n_nodes, f"{len(circles)} != {self.n_nodes}" + + # draw arrows + for i, j in zip(*np.triu_indices(self.n_nodes, k=1)): # upper triangular indices with 0 <= i < j < n_nodes + if abs(self.adjacency_matrix[i, j]) > 0: + self._draw_arrow( + ax, self.pos[i], self.pos[j], + label=arrow_label_format.format(L[i, j]), width=arrow_scale * self.adjacency_matrix[i, j], + arrow_curvature=arrow_curvature, patchA=circles[i], patchB=circles[j], + shrinkA=3, shrinkB=0, arrow_label_size=arrow_label_size, arrow_label_location=arrow_label_location) + if abs(self.adjacency_matrix[j, i]) > 0: + self._draw_arrow( + ax, self.pos[j], self.pos[i], + label=arrow_label_format.format(L[j, i]), width=arrow_scale * self.adjacency_matrix[j, i], + arrow_curvature=arrow_curvature, patchA=circles[j], patchB=circles[i], + shrinkA=3, shrinkB=0, arrow_label_size=arrow_label_size, arrow_label_location=arrow_label_location) + + ax.autoscale_view() + return ax @plotting_function(requires_networkx=True) -def plot_adjacency(adjacency_matrix, positions: Optional[np.ndarray] = None, layout=None, ax=None, node_size=None, - self_loops=False, curved=True): +def plot_adjacency(adjacency_matrix, positions: Optional[np.ndarray] = None, layout=None, ax=None, node_size=None): import networkx as nx if ax is None: import matplotlib.pyplot as plt @@ -73,15 +255,6 @@ def layout(g): assert layout is not None pos = layout(graph) - if not self_loops: - graph.remove_edges_from(nx.selfloop_edges(graph)) - nx.draw_networkx_nodes(graph, pos, node_size=node_size, ax=ax) - if curved: - Dx = max(x[0] for x in pos.values()) - min(x[0] for x in pos.values()) - Dy = max(x[1] for x in pos.values()) - min(x[1] for x in pos.values()) - edges = graph.edges() - for (e1, e2) in edges: - _draw_arrow(ax, *pos[e1].T, *pos[e2].T, Dx, Dy) - else: - nx.draw_networkx_edges(graph, pos, ax=ax) + plot = NetworkPlot(adjacency_matrix, pos=pos) + ax = plot.plot_network(ax=ax, state_sizes=node_size) return ax, graph diff --git a/deeptime/plots/util.py b/deeptime/plots/util.py index 2f38001fa..403c719a6 100644 --- a/deeptime/plots/util.py +++ b/deeptime/plots/util.py @@ -6,3 +6,17 @@ def default_colors(): r""" Yields matplotlib default color cycle as per rc param 'axes.prop_cycle'. """ from matplotlib import rcParams return rcParams['axes.prop_cycle'].by_key()['color'] + + +@plotting_function() +def default_image_cmap() -> "matplotlib.colors.Colormap": + r""" Yields the default image color map. """ + import matplotlib.pyplot as plt + from matplotlib import rcParams + return plt.get_cmap(rcParams["image.cmap"]) + + +@plotting_function() +def default_line_width() -> float: + from matplotlib import rcParams + return rcParams['lines.linewidth'] diff --git a/tests/plots/plot_network.py b/tests/plots/plot_network.py index c72910711..931a0464c 100644 --- a/tests/plots/plot_network.py +++ b/tests/plots/plot_network.py @@ -1,15 +1,29 @@ import numpy as np - import matplotlib.pyplot as plt - +import networkx as nx +from deeptime.plots.network import NetworkPlot from deeptime.markov.msm import MarkovStateModel -from deeptime.plots.network import plot_adjacency - def test_sanity(): - X = np.random.uniform(size=(5, 5)) - X /= X.sum(1)[:, None] - msm = MarkovStateModel(X) - positions = np.array([[-1, -1], [0, 0], [1.5, 3], [3., 1.5], [-1., 4.]]) - plot_adjacency(msm.transition_matrix, positions=positions) + P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], + [0.1, 0.75, 0.05, 0.05, 0.05], + [0.05, 0.1, 0.8, 0.0, 0.05], + [0.0, 0.2, 0.0, 0.8, 0.0], + [0.0, 0.02, 0.02, 0.0, 0.96]]) + + flux = MarkovStateModel(P).reactive_flux([2], [3]) + + positions = nx.planar_layout(nx.from_numpy_array(flux.gross_flux)) + pl = NetworkPlot(flux.gross_flux, positions) + + f, (ax1, ax2) = plt.subplots(1, 2) + + ax = pl.plot_network(state_colors=np.linspace(0, 1, num=flux.n_states), + ax=ax1) + ax.set_aspect('equal') + + from pyemma.plots import plot_network + + plot_network(flux.gross_flux, pos=np.array([positions[i] for i in range(flux.n_states)]), ax=ax2) plt.show() + From 03abc66b45fc37cffc6c843752e9f1050b3a5b56 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Mon, 2 May 2022 17:10:18 +0200 Subject: [PATCH 09/18] network plots wip --- deeptime/plots/__init__.py | 2 + deeptime/plots/network.py | 95 ++++++++++++++++++++++++------------- tests/plots/plot_network.py | 7 ++- 3 files changed, 67 insertions(+), 37 deletions(-) diff --git a/deeptime/plots/__init__.py b/deeptime/plots/__init__.py index 1e36a9f40..051e7f0a8 100644 --- a/deeptime/plots/__init__.py +++ b/deeptime/plots/__init__.py @@ -9,8 +9,10 @@ plot_ck_test plot_energy2d Energy2dPlot + Network """ from .implied_timescales import plot_implied_timescales from .chapman_kolmogorov import plot_ck_test from .energy import plot_energy2d, Energy2dPlot +from .network import Network diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index 2200eb212..b17a26106 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -8,7 +8,7 @@ from deeptime.util.types import ensure_number_array -class NetworkPlot: +class Network: r"""Plot of a network with nodes and arcs. Parameters @@ -32,10 +32,11 @@ class NetworkPlot: >>> from deeptime.markov.msm import MarkovStateModel >>> flux = MarkovStateModel(P).reactive_flux([2], [3]) - now plot the gross flux + Now plot the gross flux using networkx spring layout. + >>> import networkx as nx >>> positions = nx.spring_layout(nx.from_numpy_array(flux.gross_flux)) - >>> NetworkPlot(flux.gross_flux, positions).plot_network() # doctest: +ELLIPSIS + >>> Network(flux.gross_flux, positions).plot() # doctest: +ELLIPSIS <...Figure... """ @@ -45,6 +46,13 @@ def __init__(self, adjacency_matrix, pos): @property def pos(self) -> np.ndarray: + r""" Position array. If the object was constructed with a dict-style layout (as generated by networkx), + the positions are converted back into an array format. + + :getter: Yields the positions. + :setter: Sets the positions, can also be provided as dict. + :type: ndarray + """ return self._pos @pos.setter @@ -57,28 +65,29 @@ def pos(self, value: Union[np.ndarray, Dict[int, np.ndarray]]): @property def n_nodes(self): + r""" Number of nodes in the network. """ return self.adjacency_matrix.shape[0] @property def bounds(self) -> Tuple[Tuple[float, float], Tuple[float, float]]: + r""" The bounds of node positions. Yields `((xmin, xmax), (ymin, ymax))`. """ return ((np.min(self.pos[:, 0]), np.max(self.pos[:, 0])), (np.min(self.pos[:, 1]), np.max(self.pos[:, 1]))) @property def d_x(self): + r""" Width of the network. """ return self.bounds[0][1] - self.bounds[0][0] @property def d_y(self): + r""" Height of the network. """ return self.bounds[1][1] - self.bounds[1][0] @staticmethod def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, color="grey", patchA=None, patchB=None, shrinkA=2, shrinkB=1, arrow_label_size=None, arrow_label_location=.55): - r""" - Draws a slightly curved arrow from (x1,y1) to (x2,y2). - Will allow the given patches at start and end. - """ + r""" Draws a slightly curved arrow from (x1,y1) to (x2,y2). Will allow the given patches at start and end. """ from matplotlib import patches # set arrow properties @@ -101,16 +110,14 @@ def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, colo control_vertex = np.array(arr.get_connectionstyle().connect(pos_1, pos_2).vertices[1]) # quadratic Bezier at slightly shifted midpoint t = arrow_label_location t = arrow_label_location # shorthand - ptext = (1 - t)**2 * pos_1 + 2 * (1 - t) * t * control_vertex + t**2 * pos_2 + ptext = (1 - t) ** 2 * pos_1 + 2 * (1 - t) * t * control_vertex + t ** 2 * pos_2 - ax.text( - ptext[0], ptext[1], label, size=arrow_label_size, - horizontalalignment='center', verticalalignment='center', zorder=1, - transform=ax.transData) + ax.text(*ptext, label, size=arrow_label_size, horizontalalignment='center', verticalalignment='center', + zorder=3, transform=ax.transData) - def plot_network(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500', state_labels='auto', - arrow_scale=1.0, arrow_curvature=1.0, arrow_labels='weights', arrow_label_format='{:.1e}', - arrow_label_location=0.55, cmap=None, **textkwargs): + def plot(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500', state_labels='auto', + arrow_scale=1.0, arrow_curvature=1.0, arrow_labels='weights', arrow_label_format='{:.1e}', + arrow_label_location=0.55, cmap=None, **textkwargs): r"""Draws a network using discs and curved arrows. The thicknesses and labels of the arrows are taken from the off-diagonal matrix elements @@ -124,8 +131,8 @@ def plot_network(self, ax=None, state_sizes=None, state_scale=1.0, state_colors= List of state sizes to plot, must be of length `n_states`. It effectively evaluates as .. math:: - \frac{\mathrm{state\_scale} \min (d_x, d_y)^2}{2 n_\mathrm{nodes}} - \frac{\mathrm{state_sizes}}{\| \mathrm{state_sizes} \|_{\max}} + \frac{\mathrm{state\_scale} \cdot (\min (d_x, d_y))^2}{2 n_\mathrm{nodes}} + \frac{\mathrm{state\_sizes}}{\| \mathrm{state\_sizes} \|_{\max}} with :math:`\mathrm{state\_sizes}` interpreted as `1` in case of `None`. In particular this means that the states scale their size with respect to the volume. I.e., `state_scale=[1, 2]` leads to the second @@ -138,6 +145,24 @@ def plot_network(self, ax=None, state_sizes=None, state_scale=1.0, state_colors= state_labels : None or 'auto' or list of str, default='auto' The state labels. If 'auto', just enumerates the states. In case of `None` no state labels are depicted, otherwise assigns each state its label based on the list. + arrow_scale : float, optional, default=1. + Linear scaling coefficient for all arrow widths. Takes the default line width `rcParams['lines.linewidth']` + into account. + arrow_curvature : float, optional, default=1. + Linear scaling coefficient for arrow curvature. Setting it to `0` produces straight arrows. + arrow_labels : 'weights' or ndarray or None, default='weights' + If 'weights', arrows obtain labels according to the weights in the adjacency matrix. If ndarray the dtype + is expected to be object and the argument should be a `(n, n)` matrix with labels. If None, no labels are + printed. + arrow_label_format : str, default='{:.1e}' + Format string for arrow labels. Only has an effect if arrow_labels is set to `weights`. + arrow_label_location : float, default=0.55 + Location of the arrow labels on the curve. Should be between 0 (meaning on the source state) and 1 (meaning + on the target state). Defaults to 0.55, i.e., slightly shifted toward the target state from the midpoint. + cmap : matplotlib.colors.Colormap or str, default=None + The colormap for `state_color`s. + **textkwargs + Optional arguments for state labels. """ # Set the default values for the text dictionary from matplotlib import pyplot as plt @@ -187,16 +212,18 @@ def plot_network(self, ax=None, state_sizes=None, state_scale=1.0, state_colors= # set arrow labels if isinstance(arrow_labels, np.ndarray): - L = arrow_labels + label_matrix = arrow_labels + assert label_matrix.shape == self.adjacency_matrix.shape, \ + f"Arrow labels matrix has shape {label_matrix.shape} =/= {self.adjacency_matrix.shape}" if isinstance(arrow_labels[0, 0], str): arrow_label_format = '{}' elif isinstance(arrow_labels, str) and arrow_labels.lower() == 'weights': - L = np.copy(self.adjacency_matrix) + label_matrix = np.copy(self.adjacency_matrix) elif arrow_labels is None: - L = np.full(self.adjacency_matrix.shape, fill_value='', dtype=object) + label_matrix = np.full(self.adjacency_matrix.shape, fill_value='', dtype=object) arrow_label_format = '{}' else: - raise ValueError('invalid arrow labels') + raise ValueError("Invalid arrow labels, should be 'weights', ndarray of strings or None.") # draw circles circles = [] @@ -216,17 +243,19 @@ def plot_network(self, ax=None, state_sizes=None, state_scale=1.0, state_colors= # draw arrows for i, j in zip(*np.triu_indices(self.n_nodes, k=1)): # upper triangular indices with 0 <= i < j < n_nodes if abs(self.adjacency_matrix[i, j]) > 0: - self._draw_arrow( - ax, self.pos[i], self.pos[j], - label=arrow_label_format.format(L[i, j]), width=arrow_scale * self.adjacency_matrix[i, j], - arrow_curvature=arrow_curvature, patchA=circles[i], patchB=circles[j], - shrinkA=3, shrinkB=0, arrow_label_size=arrow_label_size, arrow_label_location=arrow_label_location) + label = arrow_label_format.format(label_matrix[i, j]) + width = arrow_scale * self.adjacency_matrix[i, j] + self._draw_arrow(ax, self.pos[i], self.pos[j], label=label, width=width, + arrow_curvature=arrow_curvature, patchA=circles[i], patchB=circles[j], + shrinkA=3, shrinkB=0, + arrow_label_size=arrow_label_size, arrow_label_location=arrow_label_location) if abs(self.adjacency_matrix[j, i]) > 0: - self._draw_arrow( - ax, self.pos[j], self.pos[i], - label=arrow_label_format.format(L[j, i]), width=arrow_scale * self.adjacency_matrix[j, i], - arrow_curvature=arrow_curvature, patchA=circles[j], patchB=circles[i], - shrinkA=3, shrinkB=0, arrow_label_size=arrow_label_size, arrow_label_location=arrow_label_location) + label = arrow_label_format.format(label_matrix[j, i]) + width = arrow_scale * self.adjacency_matrix[j, i] + self._draw_arrow(ax, self.pos[j], self.pos[i], label=label, width=width, + arrow_curvature=arrow_curvature, patchA=circles[j], patchB=circles[i], + shrinkA=3, shrinkB=0, + arrow_label_size=arrow_label_size, arrow_label_location=arrow_label_location) ax.autoscale_view() return ax @@ -255,6 +284,6 @@ def layout(g): assert layout is not None pos = layout(graph) - plot = NetworkPlot(adjacency_matrix, pos=pos) - ax = plot.plot_network(ax=ax, state_sizes=node_size) + plot = Network(adjacency_matrix, pos=pos) + ax = plot.plot(ax=ax, state_sizes=node_size) return ax, graph diff --git a/tests/plots/plot_network.py b/tests/plots/plot_network.py index 931a0464c..5cc38e546 100644 --- a/tests/plots/plot_network.py +++ b/tests/plots/plot_network.py @@ -1,7 +1,7 @@ import numpy as np import matplotlib.pyplot as plt import networkx as nx -from deeptime.plots.network import NetworkPlot +from deeptime.plots.network import Network from deeptime.markov.msm import MarkovStateModel def test_sanity(): @@ -14,12 +14,11 @@ def test_sanity(): flux = MarkovStateModel(P).reactive_flux([2], [3]) positions = nx.planar_layout(nx.from_numpy_array(flux.gross_flux)) - pl = NetworkPlot(flux.gross_flux, positions) + pl = Network(flux.gross_flux, positions) f, (ax1, ax2) = plt.subplots(1, 2) - ax = pl.plot_network(state_colors=np.linspace(0, 1, num=flux.n_states), - ax=ax1) + ax = pl.plot(state_colors=np.linspace(0, 1, num=flux.n_states), ax=ax1, arrow_curvature=2.) ax.set_aspect('equal') from pyemma.plots import plot_network From 37bfe19fa700e114cde13df89c8a317a370612a2 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 3 May 2022 09:50:37 +0200 Subject: [PATCH 10/18] refactor --- deeptime/markov/msm/_markov_state_model.py | 2 +- deeptime/plots/network.py | 66 ++++++++++++------- .../methods/plot_sindy_rossler_attractor.py | 18 +---- tests/plots/plot_network.py | 2 +- 4 files changed, 45 insertions(+), 43 deletions(-) diff --git a/deeptime/markov/msm/_markov_state_model.py b/deeptime/markov/msm/_markov_state_model.py index 92b54aeb3..85bceea36 100644 --- a/deeptime/markov/msm/_markov_state_model.py +++ b/deeptime/markov/msm/_markov_state_model.py @@ -771,7 +771,7 @@ def correlation(self, a, b=None, maxtime=None, k=None, ncv=None): >>> times, acf = M.correlation(a) >>> >>> import matplotlib.pylab as plt # doctest: +SKIP - >>> plt.plot(times, acf) # doctest: +SKIP + >>> plt.plot(times,acf) # doctest: +SKIP """ # input checking is done in low-level API # compute number of tau steps diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index b17a26106..112e234dc 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -116,8 +116,8 @@ def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, colo zorder=3, transform=ax.transData) def plot(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500', state_labels='auto', - arrow_scale=1.0, arrow_curvature=1.0, arrow_labels='weights', arrow_label_format='{:.1e}', - arrow_label_location=0.55, cmap=None, **textkwargs): + edge_scale=1.0, edge_curvature=1.0, edge_labels='weights', edge_label_format='{:.1e}', + edge_label_location=0.55, cmap=None, **textkwargs): r"""Draws a network using discs and curved arrows. The thicknesses and labels of the arrows are taken from the off-diagonal matrix elements @@ -145,18 +145,18 @@ def plot(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500 state_labels : None or 'auto' or list of str, default='auto' The state labels. If 'auto', just enumerates the states. In case of `None` no state labels are depicted, otherwise assigns each state its label based on the list. - arrow_scale : float, optional, default=1. + edge_scale : float, optional, default=1. Linear scaling coefficient for all arrow widths. Takes the default line width `rcParams['lines.linewidth']` into account. - arrow_curvature : float, optional, default=1. + edge_curvature : float, optional, default=1. Linear scaling coefficient for arrow curvature. Setting it to `0` produces straight arrows. - arrow_labels : 'weights' or ndarray or None, default='weights' + edge_labels : 'weights' or ndarray or None, default='weights' If 'weights', arrows obtain labels according to the weights in the adjacency matrix. If ndarray the dtype is expected to be object and the argument should be a `(n, n)` matrix with labels. If None, no labels are printed. - arrow_label_format : str, default='{:.1e}' + edge_label_format : str, default='{:.1e}' Format string for arrow labels. Only has an effect if arrow_labels is set to `weights`. - arrow_label_location : float, default=0.55 + edge_label_location : float, default=0.55 Location of the arrow labels on the curve. Should be between 0 (meaning on the source state) and 1 (meaning on the target state). Defaults to 0.55, i.e., slightly shifted toward the target state from the midpoint. cmap : matplotlib.colors.Colormap or str, default=None @@ -187,7 +187,7 @@ def plot(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500 # automatic arrow rescaling diag_mask = np.logical_not(np.eye(self.adjacency_matrix.shape[0], dtype=bool)) default_lw = default_line_width() - arrow_scale *= 2 * default_lw / np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)]) + edge_scale *= 2 * default_lw / np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)]) # set node labels if state_labels is None: @@ -211,17 +211,17 @@ def plot(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500 raise ValueError(f"Mismatch between n_states and #state_colors ({self.n_nodes} vs {len(state_colors)}).") # set arrow labels - if isinstance(arrow_labels, np.ndarray): - label_matrix = arrow_labels + if isinstance(edge_labels, np.ndarray): + label_matrix = edge_labels assert label_matrix.shape == self.adjacency_matrix.shape, \ f"Arrow labels matrix has shape {label_matrix.shape} =/= {self.adjacency_matrix.shape}" - if isinstance(arrow_labels[0, 0], str): - arrow_label_format = '{}' - elif isinstance(arrow_labels, str) and arrow_labels.lower() == 'weights': + if isinstance(edge_labels[0, 0], str): + edge_label_format = '{}' + elif isinstance(edge_labels, str) and edge_labels.lower() == 'weights': label_matrix = np.copy(self.adjacency_matrix) - elif arrow_labels is None: + elif edge_labels is None: label_matrix = np.full(self.adjacency_matrix.shape, fill_value='', dtype=object) - arrow_label_format = '{}' + edge_label_format = '{}' else: raise ValueError("Invalid arrow labels, should be 'weights', ndarray of strings or None.") @@ -243,26 +243,42 @@ def plot(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500 # draw arrows for i, j in zip(*np.triu_indices(self.n_nodes, k=1)): # upper triangular indices with 0 <= i < j < n_nodes if abs(self.adjacency_matrix[i, j]) > 0: - label = arrow_label_format.format(label_matrix[i, j]) - width = arrow_scale * self.adjacency_matrix[i, j] + label = edge_label_format.format(label_matrix[i, j]) + width = edge_scale * self.adjacency_matrix[i, j] self._draw_arrow(ax, self.pos[i], self.pos[j], label=label, width=width, - arrow_curvature=arrow_curvature, patchA=circles[i], patchB=circles[j], + arrow_curvature=edge_curvature, patchA=circles[i], patchB=circles[j], shrinkA=3, shrinkB=0, - arrow_label_size=arrow_label_size, arrow_label_location=arrow_label_location) + arrow_label_size=arrow_label_size, arrow_label_location=edge_label_location) if abs(self.adjacency_matrix[j, i]) > 0: - label = arrow_label_format.format(label_matrix[j, i]) - width = arrow_scale * self.adjacency_matrix[j, i] + label = edge_label_format.format(label_matrix[j, i]) + width = edge_scale * self.adjacency_matrix[j, i] self._draw_arrow(ax, self.pos[j], self.pos[i], label=label, width=width, - arrow_curvature=arrow_curvature, patchA=circles[j], patchB=circles[i], + arrow_curvature=edge_curvature, patchA=circles[j], patchB=circles[i], shrinkA=3, shrinkB=0, - arrow_label_size=arrow_label_size, arrow_label_location=arrow_label_location) + arrow_label_size=arrow_label_size, arrow_label_location=edge_label_location) ax.autoscale_view() return ax @plotting_function(requires_networkx=True) -def plot_adjacency(adjacency_matrix, positions: Optional[np.ndarray] = None, layout=None, ax=None, node_size=None): +def plot_adjacency(adjacency_matrix, positions: Optional[np.ndarray] = None, layout=None, ax=None, scale_nodes=True, + scale_edges=True): + r"""Plot an adjacency matrix. The edges are + + Parameters + ---------- + adjacency_matrix + positions + layout + ax + scale_nodes + scale_edges + + Returns + ------- + + """ import networkx as nx if ax is None: import matplotlib.pyplot as plt @@ -286,4 +302,4 @@ def layout(g): plot = Network(adjacency_matrix, pos=pos) ax = plot.plot(ax=ax, state_sizes=node_size) - return ax, graph + return ax diff --git a/examples/methods/plot_sindy_rossler_attractor.py b/examples/methods/plot_sindy_rossler_attractor.py index 017888392..94b6d9a7c 100644 --- a/examples/methods/plot_sindy_rossler_attractor.py +++ b/examples/methods/plot_sindy_rossler_attractor.py @@ -59,24 +59,10 @@ def rossler(z, t): # Plot test data fig = plt.figure(figsize=(8, 6)) ax = fig.add_subplot(111, projection="3d") -ax.plot( - x_test[:, 0], - x_test[:, 1], - x_test[:, 2], - label="True solution", - color="firebrick", - alpha=0.7, -) +ax.plot(x_test[:, 0], x_test[:, 1], x_test[:, 2], label="True solution", color="firebrick", alpha=0.7) ax.set(xlabel="x", ylabel="y", zlabel="z", title="Testing data (Rossler system)") # Simulate data with SINDy model and plot x_sim = model.simulate(x0_test, t_test) -ax.plot( - x_sim[:, 0], - x_sim[:, 1], - x_sim[:, 2], - label="Model simulation", - color="royalblue", - linestyle="dashed", -) +ax.plot(x_sim[:, 0], x_sim[:, 1], x_sim[:, 2], label="Model simulation", color="royalblue", linestyle="dashed") ax.legend() diff --git a/tests/plots/plot_network.py b/tests/plots/plot_network.py index 5cc38e546..d6db8ab19 100644 --- a/tests/plots/plot_network.py +++ b/tests/plots/plot_network.py @@ -18,7 +18,7 @@ def test_sanity(): f, (ax1, ax2) = plt.subplots(1, 2) - ax = pl.plot(state_colors=np.linspace(0, 1, num=flux.n_states), ax=ax1, arrow_curvature=2.) + ax = pl.plot(ax=ax1, state_colors=np.linspace(0, 1, num=flux.n_states), edge_curvature=2.) ax.set_aspect('equal') from pyemma.plots import plot_network From 59ca64a721e50ab703f81fe1e6fa9a5fdc6bcded Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 3 May 2022 12:06:32 +0200 Subject: [PATCH 11/18] plot_msm --- deeptime/plots/__init__.py | 5 +- deeptime/plots/network.py | 442 +++++++++++++++++++++++++++--------- tests/plots/plot_network.py | 28 --- tests/plots/test_network.py | 54 +++++ 4 files changed, 388 insertions(+), 141 deletions(-) delete mode 100644 tests/plots/plot_network.py create mode 100644 tests/plots/test_network.py diff --git a/deeptime/plots/__init__.py b/deeptime/plots/__init__.py index 051e7f0a8..0f67e003d 100644 --- a/deeptime/plots/__init__.py +++ b/deeptime/plots/__init__.py @@ -9,10 +9,13 @@ plot_ck_test plot_energy2d Energy2dPlot + + plot_adjacency + plot_markov_model Network """ from .implied_timescales import plot_implied_timescales from .chapman_kolmogorov import plot_ck_test from .energy import plot_energy2d, Energy2dPlot -from .network import Network +from .network import Network, plot_adjacency, plot_markov_model diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index 112e234dc..e84a4d5e4 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -1,8 +1,10 @@ -from typing import Optional, Union, Dict, Tuple +from typing import Optional, Union, Dict, Tuple, List import numpy as np +import scipy from scipy.sparse import issparse +from deeptime.markov.msm import MarkovStateModel from deeptime.plots.util import default_image_cmap, default_line_width from deeptime.util.decorators import plotting_function from deeptime.util.types import ensure_number_array @@ -17,6 +19,40 @@ class Network: weight matrix or adjacency matrix of the network to visualize pos : ndarray or dict[int, ndarray] user-defined positions as (n,2) array + cmap : matplotlib.colors.Colormap or str, default=None + The colormap for `state_color`s. + state_sizes : ndarray, optional, default=None + List of state sizes to plot, must be of length `n_states`. It effectively evaluates as + + .. math:: + \frac{\mathrm{state\_scale} \cdot (\min (d_x, d_y))^2}{2 n_\mathrm{nodes}} + \frac{\mathrm{state\_sizes}}{\| \mathrm{state\_sizes} \|_{\max}} + + with :math:`\mathrm{state\_sizes}` interpreted as `1` in case of `None`. In particular this means that + the states scale their size with respect to the volume. I.e., `state_scale=[1, 2]` leads to the second + state drawn as a circle with twice the volume of the first. + state_scale : float, default=1. + Uniform scaling factor for `state_sizes`. + state_colors : str or list of float, default='#ff5500' + The color to use for drawing states. If given as list of float, uses the colormap (`cmap` argument) to + determine the color. + state_labels : None or 'auto' or list of str, default='auto' + The state labels. If 'auto', just enumerates the states. In case of `None` no state labels are depicted, + otherwise assigns each state its label based on the list. + edge_scale : float, optional, default=1. + Linear scaling coefficient for all arrow widths. Takes the default line width `rcParams['lines.linewidth']` + into account. + edge_curvature : float, optional, default=1. + Linear scaling coefficient for arrow curvature. Setting it to `0` produces straight arrows. + edge_labels : 'weights' or ndarray or None, default='weights' + If 'weights', arrows obtain labels according to the weights in the adjacency matrix. If ndarray the dtype + is expected to be object and the argument should be a `(n, n)` matrix with labels. If None, no labels are + printed. + edge_label_format : str, default='{:.1e}' + Format string for arrow labels. Only has an effect if arrow_labels is set to `weights`. + edge_label_location : float, default=0.55 + Location of the arrow labels on the curve. Should be between 0 (meaning on the source state) and 1 (meaning + on the target state). Defaults to 0.55, i.e., slightly shifted toward the target state from the midpoint. Examples -------- @@ -37,12 +73,38 @@ class Network: >>> import networkx as nx >>> positions = nx.spring_layout(nx.from_numpy_array(flux.gross_flux)) >>> Network(flux.gross_flux, positions).plot() # doctest: +ELLIPSIS - <...Figure... + <...Axes... """ - def __init__(self, adjacency_matrix, pos): + def __init__(self, adjacency_matrix, pos, cmap=None, + state_sizes=None, state_scale=1.0, state_colors='#ff5500', state_labels='auto', + edge_scale: float = 1., edge_curvature: float = 1.0, + edge_labels: Optional[Union[str, np.ndarray]] = 'weights', edge_label_format: str = '{:.1e}', + edge_label_location: float = 0.55): self.adjacency_matrix = adjacency_matrix self.pos = pos + self.edge_scale = edge_scale + self.edge_curvature = edge_curvature + self.edge_labels = edge_labels + self.edge_label_format = edge_label_format + self.edge_label_location = edge_label_location + self.cmap = cmap + self.state_sizes = state_sizes + self.state_scale = state_scale + self.state_colors = state_colors + self.state_labels = state_labels + + @property + def adjacency_matrix(self): + r""" The adjacency matrix. Can be sparse. """ + return self._adjacency_matrix + + @adjacency_matrix.setter + def adjacency_matrix(self, value): + if issparse(value): + self._adjacency_matrix = scipy.sparse.csr_matrix(value) + else: + self._adjacency_matrix = value @property def pos(self) -> np.ndarray: @@ -84,6 +146,25 @@ def d_y(self): r""" Height of the network. """ return self.bounds[1][1] - self.bounds[1][0] + @property + def cmap(self): + r""" The colormap to use for states if colors are given as one-dimensional array of floats. + + :type: matplotlib.colors.Colormap + """ + return self._cmap + + @cmap.setter + def cmap(self, value): + from matplotlib import colors + from matplotlib import colormaps + if value is None: + value = default_image_cmap() + if not isinstance(value, colors.Colormap): + value = colormaps.get(value) + assert isinstance(value, colors.Colormap) + self._cmap = value + @staticmethod def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, color="grey", patchA=None, patchB=None, shrinkA=2, shrinkB=1, arrow_label_size=None, arrow_label_location=.55): @@ -115,9 +196,116 @@ def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, colo ax.text(*ptext, label, size=arrow_label_size, horizontalalignment='center', verticalalignment='center', zorder=3, transform=ax.transData) - def plot(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500', state_labels='auto', - edge_scale=1.0, edge_curvature=1.0, edge_labels='weights', edge_label_format='{:.1e}', - edge_label_location=0.55, cmap=None, **textkwargs): + @property + def edge_base_scale(self): + r""" Base scale for edges depending on the matplotlib default line width and the maximum off-diagonal + element in the adjacency matrix """ + if issparse(self.adjacency_matrix): + mat = self.adjacency_matrix.tocoo() + max_off_diag = -np.inf + for i, j, v in zip(mat.row, mat.col, mat.data): + if i != j: + max_off_diag = max(max_off_diag, v) + else: + diag_mask = np.logical_not(np.eye(self.adjacency_matrix.shape[0], dtype=bool)) + max_off_diag = np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)]) + default_lw = default_line_width() + return 2 * default_lw / max_off_diag + + @property + def edge_labels(self) -> Optional[np.ndarray]: + r""" Edge labels. Can be left None for no labels, otherwise must be matrix of same shape as adjacency matrix. + containing numerical values (in conjunction with :attr:`edge_label_format`) or strings.""" + return self._edge_labels + + @edge_labels.setter + def edge_labels(self, value): + if isinstance(value, np.ndarray): + if value.shape != self.adjacency_matrix.shape: + raise ValueError(f"Arrow labels matrix has shape {value.shape} =/= {self.adjacency_matrix.shape}") + self._edge_labels = value + elif isinstance(value, str) and value.lower() == 'weights': + self._edge_labels = self.adjacency_matrix + elif value is None: + self._edge_labels = None + else: + raise ValueError("Invalid edge labels, should be 'weights', ndarray of strings or None.") + + @property + def state_sizes(self) -> np.ndarray: + r""" Sizes of states. Can be left `None` which amounts to state sizes of 1. """ + return self._state_sizes + + @state_sizes.setter + def state_sizes(self, value: Optional[np.ndarray]): + self._state_sizes = value if value is not None else np.ones(self.n_nodes) + + @property + def node_sizes(self): + r""" The effective node sizes. Rescales to account for size of plot. """ + return 0.5 * self.state_scale * min(self.d_x, self.d_y) ** 2 * self.state_sizes \ + / (np.max(self.state_sizes) * float(self.n_nodes)) + + @property + def state_labels(self) -> List[str]: + r""" State labels. """ + return self._state_labels + + @state_labels.setter + def state_labels(self, value): + if isinstance(value, str) and value == 'auto': + value = [str(i) for i in np.arange(self.n_nodes)] + else: + if len(value) != self.n_nodes: + raise ValueError(f"length of state_labels({len(value)}) has to match " + f"length of states({self.n_nodes}).") + self._state_labels = value + + @property + def state_colors(self) -> np.ndarray: + """ The state colors in (N, rgb(a))-shaped array. """ + return self._state_colors + + @state_colors.setter + def state_colors(self, value): + # set node colors + from matplotlib import colors + if isinstance(value, str): + state_colors = np.array([colors.to_rgba(value)] * self.n_nodes) + else: + state_colors = ensure_number_array(value) + if state_colors.ndim == 1: + state_colors = np.array([self.cmap(x) for x in state_colors]) + elif state_colors.ndim in (3, 4): + pass # ok + else: + raise ValueError() + if len(state_colors) != self.n_nodes: + raise ValueError(f"Mismatch between n_states and #state_colors ({self.n_nodes} vs {len(state_colors)}).") + self._state_colors = state_colors + + def edge_label(self, i, j) -> str: + r""" Yields the formatted edge label for edge i->j. + + Parameters + ---------- + i : int + edge i + j : int + edge j + + Returns + ------- + label : str + The edge label. + """ + if self.edge_labels is None: + return "" + else: + fmt = self.edge_label_format if self.edge_labels.dtype.type is np.string_ else "{}" + return fmt.format(self.edge_labels[i, j]) + + def plot(self, ax=None, **textkwargs): r"""Draws a network using discs and curved arrows. The thicknesses and labels of the arrows are taken from the off-diagonal matrix elements @@ -127,162 +315,91 @@ def plot(self, ax=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500 ---------- ax : matplotlib.axes.Axes, optional, default=None The axes to plot on. - state_sizes : ndarray, optional, default=None - List of state sizes to plot, must be of length `n_states`. It effectively evaluates as - - .. math:: - \frac{\mathrm{state\_scale} \cdot (\min (d_x, d_y))^2}{2 n_\mathrm{nodes}} - \frac{\mathrm{state\_sizes}}{\| \mathrm{state\_sizes} \|_{\max}} - - with :math:`\mathrm{state\_sizes}` interpreted as `1` in case of `None`. In particular this means that - the states scale their size with respect to the volume. I.e., `state_scale=[1, 2]` leads to the second - state drawn as a circle with twice the volume of the first. - state_scale : float, default=1. - Uniform scaling factor for `state_sizes`. - state_colors : str or list of float, default='#ff5500' - The color to use for drawing states. If given as list of float, uses the colormap (`cmap` argument) to - determine the color. - state_labels : None or 'auto' or list of str, default='auto' - The state labels. If 'auto', just enumerates the states. In case of `None` no state labels are depicted, - otherwise assigns each state its label based on the list. - edge_scale : float, optional, default=1. - Linear scaling coefficient for all arrow widths. Takes the default line width `rcParams['lines.linewidth']` - into account. - edge_curvature : float, optional, default=1. - Linear scaling coefficient for arrow curvature. Setting it to `0` produces straight arrows. - edge_labels : 'weights' or ndarray or None, default='weights' - If 'weights', arrows obtain labels according to the weights in the adjacency matrix. If ndarray the dtype - is expected to be object and the argument should be a `(n, n)` matrix with labels. If None, no labels are - printed. - edge_label_format : str, default='{:.1e}' - Format string for arrow labels. Only has an effect if arrow_labels is set to `weights`. - edge_label_location : float, default=0.55 - Location of the arrow labels on the curve. Should be between 0 (meaning on the source state) and 1 (meaning - on the target state). Defaults to 0.55, i.e., slightly shifted toward the target state from the midpoint. - cmap : matplotlib.colors.Colormap or str, default=None - The colormap for `state_color`s. **textkwargs Optional arguments for state labels. """ # Set the default values for the text dictionary from matplotlib import pyplot as plt - from matplotlib import colors from matplotlib import patches if ax is None: ax = plt.gca() - if cmap is None: - cmap = default_image_cmap() textkwargs.setdefault('size', None) textkwargs.setdefault('horizontalalignment', 'center') textkwargs.setdefault('verticalalignment', 'center') textkwargs.setdefault('color', 'black') # remove the temporary key 'arrow_label_size' as it cannot be parsed by plt.text! arrow_label_size = textkwargs.pop('arrow_label_size', textkwargs['size']) - # sizes of nodes - if state_sizes is None: - state_sizes = 0.5 * state_scale * min(self.d_x, self.d_y) ** 2 * np.ones(self.n_nodes) / float(self.n_nodes) - else: - state_sizes = 0.5 * state_scale * min(self.d_x, self.d_y) ** 2 * state_sizes \ - / (np.max(state_sizes) * float(self.n_nodes)) # automatic arrow rescaling - diag_mask = np.logical_not(np.eye(self.adjacency_matrix.shape[0], dtype=bool)) - default_lw = default_line_width() - edge_scale *= 2 * default_lw / np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)]) - - # set node labels - if state_labels is None: - pass - elif isinstance(state_labels, str) and state_labels == 'auto': - state_labels = [str(i) for i in np.arange(self.n_nodes)] - else: - if len(state_labels) != self.n_nodes: - raise ValueError(f"length of state_labels({len(state_labels)}) has to match " - f"length of states({self.n_nodes}).") - # set node colors - if isinstance(state_colors, str): - state_colors = np.array([colors.to_rgb(state_colors)] * self.n_nodes) - else: - state_colors = ensure_number_array(state_colors, ndim=1) - if not isinstance(cmap, colors.Colormap): - cmap = plt.get_cmap(cmap) - assert isinstance(cmap, colors.Colormap) - state_colors = np.array([cmap(x) for x in state_colors]) - if len(state_colors) != self.n_nodes: - raise ValueError(f"Mismatch between n_states and #state_colors ({self.n_nodes} vs {len(state_colors)}).") - - # set arrow labels - if isinstance(edge_labels, np.ndarray): - label_matrix = edge_labels - assert label_matrix.shape == self.adjacency_matrix.shape, \ - f"Arrow labels matrix has shape {label_matrix.shape} =/= {self.adjacency_matrix.shape}" - if isinstance(edge_labels[0, 0], str): - edge_label_format = '{}' - elif isinstance(edge_labels, str) and edge_labels.lower() == 'weights': - label_matrix = np.copy(self.adjacency_matrix) - elif edge_labels is None: - label_matrix = np.full(self.adjacency_matrix.shape, fill_value='', dtype=object) - edge_label_format = '{}' - else: - raise ValueError("Invalid arrow labels, should be 'weights', ndarray of strings or None.") + edge_scale = self.edge_base_scale * self.edge_scale # draw circles + node_sizes = self.node_sizes circles = [] for i in range(self.n_nodes): circles.append( - patches.Circle(self.pos[i], radius=np.sqrt(0.5 * state_sizes[i]) / 2.0, - color=state_colors[i], zorder=2) + patches.Circle(self.pos[i], radius=np.sqrt(0.5 * node_sizes[i]) / 2.0, + color=self.state_colors[i], zorder=2) ) ax.add_patch(circles[-1]) # add annotation - if state_labels is not None: - ax.text(self.pos[i][0], self.pos[i][1], state_labels[i], zorder=3, **textkwargs) + if self.state_labels is not None: + ax.text(self.pos[i][0], self.pos[i][1], self.state_labels[i], zorder=3, **textkwargs) assert len(circles) == self.n_nodes, f"{len(circles)} != {self.n_nodes}" # draw arrows for i, j in zip(*np.triu_indices(self.n_nodes, k=1)): # upper triangular indices with 0 <= i < j < n_nodes if abs(self.adjacency_matrix[i, j]) > 0: - label = edge_label_format.format(label_matrix[i, j]) + label = self.edge_label(i, j) width = edge_scale * self.adjacency_matrix[i, j] self._draw_arrow(ax, self.pos[i], self.pos[j], label=label, width=width, - arrow_curvature=edge_curvature, patchA=circles[i], patchB=circles[j], + arrow_curvature=self.edge_curvature, patchA=circles[i], patchB=circles[j], shrinkA=3, shrinkB=0, - arrow_label_size=arrow_label_size, arrow_label_location=edge_label_location) + arrow_label_size=arrow_label_size, arrow_label_location=self.edge_label_location) if abs(self.adjacency_matrix[j, i]) > 0: - label = edge_label_format.format(label_matrix[j, i]) + label = self.edge_label(j, i) width = edge_scale * self.adjacency_matrix[j, i] self._draw_arrow(ax, self.pos[j], self.pos[i], label=label, width=width, - arrow_curvature=edge_curvature, patchA=circles[j], patchB=circles[i], + arrow_curvature=self.edge_curvature, patchA=circles[j], patchB=circles[i], shrinkA=3, shrinkB=0, - arrow_label_size=arrow_label_size, arrow_label_location=edge_label_location) + arrow_label_size=arrow_label_size, arrow_label_location=self.edge_label_location) ax.autoscale_view() return ax @plotting_function(requires_networkx=True) -def plot_adjacency(adjacency_matrix, positions: Optional[np.ndarray] = None, layout=None, ax=None, scale_nodes=True, - scale_edges=True): - r"""Plot an adjacency matrix. The edges are +def plot_adjacency(adjacency_matrix, positions: Optional[np.ndarray] = None, layout=None, ax=None, scale_states=True): + r"""Plot an adjacency matrix. The edges are scaled according to the respective values. For more fine-grained + control use :class:`Network`. Parameters ---------- - adjacency_matrix - positions - layout - ax - scale_nodes - scale_edges + adjacency_matrix : ndarray or scipy sparse matrix + Adjacency matrix to plot. Could be, e.g., a MSM transition matrix. + positions : ndarray, optional, default=None + A (N, 2)-shaped ndarray containing positions for the nodes of the adjacency matrix. If left as `None`, the + layout is algorithmically determined (based on the algorithm specified in the `layout` parameter). + layout : callable, optional, default=None + The automatic layout to use. Only has an effect, if `positions` is `None`. In that case, it defaults to + the `networkx.spring_layout`. + Can be any callable which takes a networkx graph as first argument and yields a position array or dict. + ax : matplotlib.axes.Axes, optional, default=None + The axes to plot on. Otherwise, uses the current axes (via `plt.gca()`). + scale_states : bool, default=True + Whether to scale nodes according to the value on the diagonal of the adjacency matrix. Returns ------- + ax : matplotlib.axes.Axes + The axes that were plotted on. + See Also + -------- + Network """ import networkx as nx - if ax is None: - import matplotlib.pyplot as plt - ax = plt.gca() if positions is not None: if positions.ndim != 2 or positions.shape[0] != adjacency_matrix.shape[0] or positions.shape[1] != 2: raise ValueError(f"Unsupported positions array. Has to be ({adjacency_matrix.shape[0]}, 2)-shaped but " @@ -300,6 +417,107 @@ def layout(g): assert layout is not None pos = layout(graph) - plot = Network(adjacency_matrix, pos=pos) - ax = plot.plot(ax=ax, state_sizes=node_size) + if scale_states: + state_sizes = adjacency_matrix.diagonal() if issparse(adjacency_matrix) else np.diag(adjacency_matrix) + else: + state_sizes = None + plot = Network(adjacency_matrix, pos=pos, state_sizes=state_sizes) + ax = plot.plot(ax=ax) return ax + + +@plotting_function(requires_networkx=True) +def plot_markov_model(msm: Union[MarkovStateModel, np.ndarray], pos=None, state_sizes=None, state_scale=1.0, + state_colors='#ff5500', state_labels='auto', + minflux=1e-6, edge_scale=1.0, edge_curvature=1.0, edge_labels='weights', + edge_label_format='{2.e}', ax=None, **textkwargs): + r"""Network representation of MSM transition matrix. + + This visualization is not optimized for large matrices. It is meant to be used for the visualization of small + models with up to 10-20 states, e.g., obtained by a HMM coarse-graining. If used with large network, the automatic + node positioning will be very slow and may still look ugly. + + Parameters + ---------- + msm : MarkovStateModel or ndarray + The Markov state model to plot. Can also be the transition matrix. + pos : ndarray(n,2) or dict, optional, default=None + User-defined positions to draw the states on. If not given, will try to place them automatically. + The output of networkx layouts can be used for this argument. + state_sizes : ndarray(n), optional, default=None + User-defined areas of the discs drawn for each state. If not given, the + stationary probability of P will be used. + state_colors : string, ndarray(n), or list, optional, default='#ff5500' (orange) + string : + a Hex code for a single color used for all states + array : + n values in [0,1] which will result in a grayscale plot + list : + of len = nstates, with a color for each state. The list can mix strings, RGB values and + hex codes, e.g. :py:obj:`state_colors` = ['g', 'red', [.23, .34, .35], '#ff5500'] is + possible. + state_labels : list of strings, optional, default is 'auto' + A list with a label for each state, to be displayed at the center + of each node/state. If left to 'auto', the labels are automatically set to the state + indices. + minflux : float, optional, default=1e-6 + The minimal flux (p_i * p_ij) for a transition to be drawn + edge_scale : float, optional, default=1.0 + Relative arrow scale. Set to a value different from 1 to increase or decrease the arrow width. + edge_curvature : float, optional, default=1.0 + Relative arrow curvature. Set to a value different from 1 to make arrows more or less curved. + edge_labels : 'weights', None or a ndarray(n,n) with label strings. Optional, default='weights' + Strings to be placed upon arrows. If None, no labels will be used. + If 'weights', the elements of P will be used. If a matrix of strings is given by the user these will be used. + edge_label_format : str, optional, default='{2.e}' + The numeric format to print the arrow labels. + ax : matplotlib Axes object, optional, default=None + The axes to plot to. When set to None a new Axes (and Figure) object will be used. + textkwargs : optional argument for the text of the state and arrow labels. + See https://matplotlib.org/stable/api/text_api.html for more info. The + parameter 'size' refers to the size of the state and arrow labels and overwrites the + matplotlib default. The parameter 'arrow_label_size' is only used for the arrow labels; + please note that 'arrow_label_size' is not part of matplotlib.text.Text's set of parameters + and will raise an exception when passed to matplotlib.text.Text directly. + + Returns + ------- + ax, pos : matplotlib.axes.Axes, ndarray(n,2) + An axes object containing the plot and the positions of states. + Can be used later to plot a different network representation (e.g. the flux) + + Examples + -------- + >>> import numpy as np + >>> P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], + ... [0.1, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.1, 0.8, 0.0, 0.05], + ... [0.0, 0.2, 0.0, 0.8, 0.0], + ... [0.0, 0.02, 0.02, 0.0, 0.96]]) + >>> plot_markov_model(P) # doctest:+ELLIPSIS + (<...Axes..., array...) + """ + if not isinstance(msm, MarkovStateModel): + msm = MarkovStateModel(msm) + P = msm.transition_matrix.copy() + if state_sizes is None: + state_sizes = msm.stationary_distribution + if minflux > 0: + if msm.sparse: + sddiag = scipy.sparse.diags(msm.stationary_distribution) + else: + sddiag = np.diag(msm.stationary_distribution) + flux = sddiag.dot(msm.transition_matrix) + if msm.sparse: + P = P.multiply(P >= minflux) + P.eliminate_zeros() + else: + P[flux < minflux] = 0.0 + if pos is None: + import networkx as nx + graph = nx.from_scipy_sparse_matrix(P) if msm.sparse else nx.from_numpy_matrix(P) + pos = nx.spring_layout(graph) + network = Network(P, pos=pos, state_scale=state_scale, state_colors=state_colors, state_labels=state_labels, + state_sizes=state_sizes, edge_scale=edge_scale, edge_curvature=edge_curvature, + edge_labels=edge_labels, edge_label_format=edge_label_format) + return network.plot(ax=ax, **textkwargs), pos diff --git a/tests/plots/plot_network.py b/tests/plots/plot_network.py deleted file mode 100644 index d6db8ab19..000000000 --- a/tests/plots/plot_network.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -import networkx as nx -from deeptime.plots.network import Network -from deeptime.markov.msm import MarkovStateModel - -def test_sanity(): - P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], - [0.1, 0.75, 0.05, 0.05, 0.05], - [0.05, 0.1, 0.8, 0.0, 0.05], - [0.0, 0.2, 0.0, 0.8, 0.0], - [0.0, 0.02, 0.02, 0.0, 0.96]]) - - flux = MarkovStateModel(P).reactive_flux([2], [3]) - - positions = nx.planar_layout(nx.from_numpy_array(flux.gross_flux)) - pl = Network(flux.gross_flux, positions) - - f, (ax1, ax2) = plt.subplots(1, 2) - - ax = pl.plot(ax=ax1, state_colors=np.linspace(0, 1, num=flux.n_states), edge_curvature=2.) - ax.set_aspect('equal') - - from pyemma.plots import plot_network - - plot_network(flux.gross_flux, pos=np.array([positions[i] for i in range(flux.n_states)]), ax=ax2) - plt.show() - diff --git a/tests/plots/test_network.py b/tests/plots/test_network.py new file mode 100644 index 000000000..b1035f215 --- /dev/null +++ b/tests/plots/test_network.py @@ -0,0 +1,54 @@ +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +from deeptime.plots.network import Network, plot_markov_model +from deeptime.markov.msm import MarkovStateModel + + +def test_network(): + P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], + [0.1, 0.75, 0.05, 0.05, 0.05], + [0.05, 0.1, 0.8, 0.0, 0.05], + [0.0, 0.2, 0.0, 0.8, 0.0], + [0.0, 0.02, 0.02, 0.0, 0.96]]) + from scipy import sparse + Psparse = sparse.csr_matrix(P) + + flux = MarkovStateModel(Psparse).reactive_flux([2], [3]) + + positions = nx.planar_layout(nx.from_scipy_sparse_matrix(flux.gross_flux)) + labels = np.array([["juhu"] * flux.n_states] * flux.n_states) + pl = Network(flux.gross_flux, positions, edge_curvature=2., edge_labels=labels, + state_colors=np.linspace(0, 1, num=flux.n_states)) + + f, (ax1, ax2) = plt.subplots(1, 2) + + ax = pl.plot(ax=ax1) + ax.set_aspect('equal') + + from pyemma.plots import plot_network + + flux = MarkovStateModel(P).reactive_flux([2], [3]) + plot_network(flux.gross_flux, pos=np.array([positions[i] for i in range(flux.n_states)]), ax=ax2) + plt.show() + + # test with sparse matrix + # test against seaborn cmap + + +def test_msm(): + P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], + [0.1, 0.75, 0.05, 0.05, 0.05], + [0.05, 0.1, 0.8, 0.0, 0.05], + [0.0, 0.2, 0.0, 0.8, 0.0], + [1e-7, 0.02 - 1e-7, 0.02, 0.0, 0.96]]) + from scipy import sparse + Psparse = sparse.csr_matrix(P) + + f, (ax1, ax2) = plt.subplots(1, 2) + ax1.set_aspect('equal') + ax2.set_aspect('equal') + + plot_markov_model(Psparse, ax=ax1) + plot_markov_model(P, ax=ax2) + plt.show() From a3b494f612cb3817804d9bc89090709460cd4d64 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 3 May 2022 13:31:34 +0200 Subject: [PATCH 12/18] plot network example --- deeptime/plots/network.py | 6 ++-- examples/methods/plot_network.py | 47 ++++++++++++++++++++++++++++++++ tests/plots/test_network.py | 6 ++-- 3 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 examples/methods/plot_network.py diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index e84a4d5e4..c3f5fcb92 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -122,7 +122,7 @@ def pos(self, value: Union[np.ndarray, Dict[int, np.ndarray]]): if len(value) < self.n_nodes: raise ValueError(f'Given less positions ({len(value)}) than states ({self.n_nodes})') if isinstance(value, dict): - value = np.stack((value[i] for i in range(len(value)))) + value = np.stack([value[i] for i in range(len(value))]) self._pos = value @property @@ -302,7 +302,7 @@ def edge_label(self, i, j) -> str: if self.edge_labels is None: return "" else: - fmt = self.edge_label_format if self.edge_labels.dtype.type is np.string_ else "{}" + fmt = self.edge_label_format if self.edge_labels.dtype.type is not np.string_ else "{}" return fmt.format(self.edge_labels[i, j]) def plot(self, ax=None, **textkwargs): @@ -430,7 +430,7 @@ def layout(g): def plot_markov_model(msm: Union[MarkovStateModel, np.ndarray], pos=None, state_sizes=None, state_scale=1.0, state_colors='#ff5500', state_labels='auto', minflux=1e-6, edge_scale=1.0, edge_curvature=1.0, edge_labels='weights', - edge_label_format='{2.e}', ax=None, **textkwargs): + edge_label_format='{:.2e}', ax=None, **textkwargs): r"""Network representation of MSM transition matrix. This visualization is not optimized for large matrices. It is meant to be used for the visualization of small diff --git a/examples/methods/plot_network.py b/examples/methods/plot_network.py new file mode 100644 index 000000000..22db36ae1 --- /dev/null +++ b/examples/methods/plot_network.py @@ -0,0 +1,47 @@ +""" +Network plots +============= + +We demonstrate different kinds of network plots based on :meth:`plots.Network `. +In particular: + + * plotting a Markov state model where the state sizes depend on the stationary distribution and edges are scaled + according to jump probabilities (:meth:`deeptime.plots.plot_markov_model`) + * plotting the gross flux, in accordance to edge widths and colored according to the forward committor + (:meth:`deeptime.plots.Network`). +""" +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt +import networkx as nx + +from mpl_toolkits.axes_grid1 import make_axes_locatable + +from deeptime.markov.msm import MarkovStateModel +from deeptime.plots import plot_markov_model, Network + +P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], + [0.1, 0.75, 0.05, 0.05, 0.05], + [0.05, 0.1, 0.8, 0.0, 0.05], + [0.0, 0.2, 0.0, 0.8, 0.0], + [1e-7, 0.02 - 1e-7, 0.02, 0.0, 0.96]]) + +f, axes = plt.subplots(1, 2) +for ax in axes.flatten(): + ax.set_aspect('equal') + +axes[0].set_title('Plotting the Markov model') +plot_markov_model(P, ax=axes[0]) + +axes[1].set_title('Plotting the gross flux') +flux = MarkovStateModel(P).reactive_flux(source_states=[2], target_states=[3]) +positions = nx.planar_layout(nx.from_numpy_array(flux.gross_flux)) +cmap = mpl.cm.get_cmap('coolwarm') +network = Network(flux.gross_flux, positions, edge_curvature=2., + state_colors=flux.forward_committor, cmap=cmap) +network.plot(ax=axes[1]) +norm = mpl.colors.Normalize(vmin=np.min(flux.forward_committor), vmax=np.max(flux.forward_committor)) +divider = make_axes_locatable(axes[1]) +cax = divider.append_axes("right", size="5%", pad=0.05) +f.colorbar(mpl.cm.ScalarMappable(norm, cmap), cax=cax) +plt.show() diff --git a/tests/plots/test_network.py b/tests/plots/test_network.py index b1035f215..3038dbe88 100644 --- a/tests/plots/test_network.py +++ b/tests/plots/test_network.py @@ -1,9 +1,12 @@ +import matplotlib import numpy as np import matplotlib.pyplot as plt import networkx as nx from deeptime.plots.network import Network, plot_markov_model from deeptime.markov.msm import MarkovStateModel +matplotlib.use('Agg') + def test_network(): P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], @@ -30,9 +33,7 @@ def test_network(): flux = MarkovStateModel(P).reactive_flux([2], [3]) plot_network(flux.gross_flux, pos=np.array([positions[i] for i in range(flux.n_states)]), ax=ax2) - plt.show() - # test with sparse matrix # test against seaborn cmap @@ -51,4 +52,3 @@ def test_msm(): plot_markov_model(Psparse, ax=ax1) plot_markov_model(P, ax=ax2) - plt.show() From a829b659266a4527859190f8273cccf42df217e7 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 3 May 2022 13:36:25 +0200 Subject: [PATCH 13/18] figsize --- examples/methods/plot_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/methods/plot_network.py b/examples/methods/plot_network.py index 22db36ae1..d6889dd8d 100644 --- a/examples/methods/plot_network.py +++ b/examples/methods/plot_network.py @@ -26,7 +26,7 @@ [0.0, 0.2, 0.0, 0.8, 0.0], [1e-7, 0.02 - 1e-7, 0.02, 0.0, 0.96]]) -f, axes = plt.subplots(1, 2) +f, axes = plt.subplots(1, 2, figsize=(16, 12)) for ax in axes.flatten(): ax.set_aspect('equal') From f8af3c7d96e00f344001271dced1d7048940ac51 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 3 May 2022 13:59:48 +0200 Subject: [PATCH 14/18] nx required for testing --- tests/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/requirements.txt b/tests/requirements.txt index d639a61cb..2aa8a8557 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -4,6 +4,7 @@ pybind11==2.9.1 torch>=1.10.0; python_version<"3.10" matplotlib pint +networkx pytest==7.1.1 pytest-cov==3.0.0 From 9c2109afe8bd288c2c992fd25d5969914dedc35d Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 3 May 2022 14:12:02 +0200 Subject: [PATCH 15/18] fix test --- deeptime/plots/network.py | 2 +- examples/methods/plot_network.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index c3f5fcb92..d7cc3fe7a 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -302,7 +302,7 @@ def edge_label(self, i, j) -> str: if self.edge_labels is None: return "" else: - fmt = self.edge_label_format if self.edge_labels.dtype.type is not np.string_ else "{}" + fmt = self.edge_label_format if np.issubdtype(self.edge_labels.dtype, np.number) else "{}" return fmt.format(self.edge_labels[i, j]) def plot(self, ax=None, **textkwargs): diff --git a/examples/methods/plot_network.py b/examples/methods/plot_network.py index d6889dd8d..aea5196c1 100644 --- a/examples/methods/plot_network.py +++ b/examples/methods/plot_network.py @@ -44,4 +44,3 @@ divider = make_axes_locatable(axes[1]) cax = divider.append_axes("right", size="5%", pad=0.05) f.colorbar(mpl.cm.ScalarMappable(norm, cmap), cax=cax) -plt.show() From 3e11aba1523f5f08cba1cc7b11ebd6fd8c1a957b Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 3 May 2022 14:19:55 +0200 Subject: [PATCH 16/18] ... --- tests/plots/test_network.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/plots/test_network.py b/tests/plots/test_network.py index 3038dbe88..632e4918e 100644 --- a/tests/plots/test_network.py +++ b/tests/plots/test_network.py @@ -29,13 +29,6 @@ def test_network(): ax = pl.plot(ax=ax1) ax.set_aspect('equal') - from pyemma.plots import plot_network - - flux = MarkovStateModel(P).reactive_flux([2], [3]) - plot_network(flux.gross_flux, pos=np.array([positions[i] for i in range(flux.n_states)]), ax=ax2) - - # test against seaborn cmap - def test_msm(): P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], From d911892f4205c87b14331eab7046f4809f898d29 Mon Sep 17 00:00:00 2001 From: clonker Date: Tue, 3 May 2022 21:24:42 +0200 Subject: [PATCH 17/18] incorporate feedback --- deeptime/plots/network.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index d7cc3fe7a..f4e7a22b8 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -29,7 +29,7 @@ class Network: \frac{\mathrm{state\_sizes}}{\| \mathrm{state\_sizes} \|_{\max}} with :math:`\mathrm{state\_sizes}` interpreted as `1` in case of `None`. In particular this means that - the states scale their size with respect to the volume. I.e., `state_scale=[1, 2]` leads to the second + the states scale their size with respect to the volume. I.e., `state_sizes=[1, 2]` leads to the second state drawn as a circle with twice the volume of the first. state_scale : float, default=1. Uniform scaling factor for `state_sizes`. @@ -157,12 +157,11 @@ def cmap(self): @cmap.setter def cmap(self, value): from matplotlib import colors - from matplotlib import colormaps + import matplotlib.pyplot as plt if value is None: value = default_image_cmap() if not isinstance(value, colors.Colormap): - value = colormaps.get(value) - assert isinstance(value, colors.Colormap) + value = plt.get_cmap(value) self._cmap = value @staticmethod @@ -177,7 +176,6 @@ def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, colo rad = arrow_curvature / dist tail_width = width head_width = max(2., 2 * width) - # head_length = max(1.5, head_width) arrow_style = patches.ArrowStyle.Simple(head_length=head_width, head_width=head_width, tail_width=tail_width) connection_style = patches.ConnectionStyle.Arc3(rad=-rad) @@ -202,13 +200,13 @@ def edge_base_scale(self): element in the adjacency matrix """ if issparse(self.adjacency_matrix): mat = self.adjacency_matrix.tocoo() - max_off_diag = -np.inf + max_off_diag = 0 for i, j, v in zip(mat.row, mat.col, mat.data): if i != j: - max_off_diag = max(max_off_diag, v) + max_off_diag = max(max_off_diag, abs(v)) else: diag_mask = np.logical_not(np.eye(self.adjacency_matrix.shape[0], dtype=bool)) - max_off_diag = np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)]) + max_off_diag = np.max(np.abs(self.adjacency_matrix[diag_mask])) default_lw = default_line_width() return 2 * default_lw / max_off_diag @@ -238,6 +236,9 @@ def state_sizes(self) -> np.ndarray: @state_sizes.setter def state_sizes(self, value: Optional[np.ndarray]): + if value is not None and len(value) != self.n_nodes: + raise ValueError(f"State sizes must correspond to states (# = {self.n_nodes}) but was of " + f"length {len(value)}") self._state_sizes = value if value is not None else np.ones(self.n_nodes) @property @@ -276,10 +277,11 @@ def state_colors(self, value): state_colors = ensure_number_array(value) if state_colors.ndim == 1: state_colors = np.array([self.cmap(x) for x in state_colors]) - elif state_colors.ndim in (3, 4): - pass # ok + elif state_colors.ndim == 2 and state_colors.shape[1] in (3, 4): + pass # ok: rgb(a) values else: - raise ValueError() + raise ValueError(f"state color(s) can only be individual color or float range or rgb(a) values " + f"but was {state_colors}") if len(state_colors) != self.n_nodes: raise ValueError(f"Mismatch between n_states and #state_colors ({self.n_nodes} vs {len(state_colors)}).") self._state_colors = state_colors @@ -346,18 +348,16 @@ def plot(self, ax=None, **textkwargs): if self.state_labels is not None: ax.text(self.pos[i][0], self.pos[i][1], self.state_labels[i], zorder=3, **textkwargs) - assert len(circles) == self.n_nodes, f"{len(circles)} != {self.n_nodes}" - # draw arrows for i, j in zip(*np.triu_indices(self.n_nodes, k=1)): # upper triangular indices with 0 <= i < j < n_nodes - if abs(self.adjacency_matrix[i, j]) > 0: + if self.adjacency_matrix[i, j] != 0: label = self.edge_label(i, j) width = edge_scale * self.adjacency_matrix[i, j] self._draw_arrow(ax, self.pos[i], self.pos[j], label=label, width=width, arrow_curvature=self.edge_curvature, patchA=circles[i], patchB=circles[j], shrinkA=3, shrinkB=0, arrow_label_size=arrow_label_size, arrow_label_location=self.edge_label_location) - if abs(self.adjacency_matrix[j, i]) > 0: + if self.adjacency_matrix[j, i] != 0: label = self.edge_label(j, i) width = edge_scale * self.adjacency_matrix[j, i] self._draw_arrow(ax, self.pos[j], self.pos[i], label=label, width=width, From bac36d2a0d153043718484ab1fccd51aa31490b3 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Wed, 4 May 2022 09:55:04 +0200 Subject: [PATCH 18/18] some more tests --- deeptime/plots/network.py | 2 +- tests/plots/test_energy_surface.py | 6 +++-- tests/plots/test_network.py | 37 +++++++++++++++++++++++++----- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index f4e7a22b8..0eabff06e 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -239,7 +239,7 @@ def state_sizes(self, value: Optional[np.ndarray]): if value is not None and len(value) != self.n_nodes: raise ValueError(f"State sizes must correspond to states (# = {self.n_nodes}) but was of " f"length {len(value)}") - self._state_sizes = value if value is not None else np.ones(self.n_nodes) + self._state_sizes = np.asfarray(value) if value is not None else np.ones(self.n_nodes) @property def node_sizes(self): diff --git a/tests/plots/test_energy_surface.py b/tests/plots/test_energy_surface.py index fb4673f78..15b3f8004 100644 --- a/tests/plots/test_energy_surface.py +++ b/tests/plots/test_energy_surface.py @@ -1,6 +1,6 @@ import matplotlib import pytest -from numpy.testing import assert_ +from numpy.testing import assert_, assert_equal from deeptime.data import ellipsoids from deeptime.util import energy2d @@ -12,7 +12,9 @@ @pytest.mark.parametrize('cbar', [True, False], ids=lambda x: f"cbar={x}") def test_energy2d(shift_energy, cbar): traj = ellipsoids().observations(20000) - ax, contourf, cbar = energy2d(*traj.T, bins=100, shift_energy=shift_energy).plot(cbar=cbar) + plot_object = energy2d(*traj.T, bins=100, shift_energy=shift_energy).plot(cbar=cbar) + assert_equal(len(plot_object), 3) + ax, contourf, cbar = plot_object assert_(ax is not None) assert_(contourf is not None) assert_(cbar is not None if cbar else cbar is None) diff --git a/tests/plots/test_network.py b/tests/plots/test_network.py index 632e4918e..ff6880cf9 100644 --- a/tests/plots/test_network.py +++ b/tests/plots/test_network.py @@ -2,13 +2,18 @@ import numpy as np import matplotlib.pyplot as plt import networkx as nx +import pytest +from numpy.testing import assert_raises + from deeptime.plots.network import Network, plot_markov_model from deeptime.markov.msm import MarkovStateModel matplotlib.use('Agg') -def test_network(): +@pytest.mark.parametrize("cmap", ["nipy_spectral", plt.cm.nipy_spectral, None]) +@pytest.mark.parametrize("labels", [None, np.array([["juhu"] * 5] * 5), 'weights'], ids=['None', 'strs', 'weights']) +def test_network(labels, cmap): P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0], [0.1, 0.75, 0.05, 0.05, 0.05], [0.05, 0.1, 0.8, 0.0, 0.05], @@ -20,14 +25,34 @@ def test_network(): flux = MarkovStateModel(Psparse).reactive_flux([2], [3]) positions = nx.planar_layout(nx.from_scipy_sparse_matrix(flux.gross_flux)) - labels = np.array([["juhu"] * flux.n_states] * flux.n_states) pl = Network(flux.gross_flux, positions, edge_curvature=2., edge_labels=labels, - state_colors=np.linspace(0, 1, num=flux.n_states)) + state_colors=np.linspace(0, 1, num=flux.n_states), cmap=cmap) + ax = pl.plot() + ax.set_aspect('equal') - f, (ax1, ax2) = plt.subplots(1, 2) - ax = pl.plot(ax=ax1) - ax.set_aspect('equal') +def test_network_invalid_args(): + msm = MarkovStateModel(np.eye(3)) + with assert_raises(ValueError): + Network(msm.transition_matrix, pos=np.zeros((2, 2))) # not enough positions + network = Network(msm.transition_matrix, pos=np.zeros((3, 2))) + with assert_raises(ValueError): + network.edge_labels = np.array([["hi"]*2]*2) # not enough labels + with assert_raises(ValueError): + network.edge_labels = 'bogus' # may be 'weights' + network.state_sizes = [1., 2., 3.] + with assert_raises(ValueError): + network.state_sizes = [1., 2.] + network.state_labels = ["1", "2", "3"] + with assert_raises(ValueError): + network.state_labels = ["1", "2"] + network.state_colors = np.random.uniform(size=(3, 3)) + network.state_colors = np.random.uniform(size=(3, 4)) + with assert_raises(ValueError): + network.state_colors = np.random.uniform(size=(3, 5)) + with assert_raises(ValueError): + network.state_colors = np.random.uniform(size=(2, 4)) + network.plot() def test_msm():