From d911892f4205c87b14331eab7046f4809f898d29 Mon Sep 17 00:00:00 2001 From: clonker Date: Tue, 3 May 2022 21:24:42 +0200 Subject: [PATCH] incorporate feedback --- deeptime/plots/network.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py index d7cc3fe7a..f4e7a22b8 100644 --- a/deeptime/plots/network.py +++ b/deeptime/plots/network.py @@ -29,7 +29,7 @@ class Network: \frac{\mathrm{state\_sizes}}{\| \mathrm{state\_sizes} \|_{\max}} with :math:`\mathrm{state\_sizes}` interpreted as `1` in case of `None`. In particular this means that - the states scale their size with respect to the volume. I.e., `state_scale=[1, 2]` leads to the second + the states scale their size with respect to the volume. I.e., `state_sizes=[1, 2]` leads to the second state drawn as a circle with twice the volume of the first. state_scale : float, default=1. Uniform scaling factor for `state_sizes`. @@ -157,12 +157,11 @@ def cmap(self): @cmap.setter def cmap(self, value): from matplotlib import colors - from matplotlib import colormaps + import matplotlib.pyplot as plt if value is None: value = default_image_cmap() if not isinstance(value, colors.Colormap): - value = colormaps.get(value) - assert isinstance(value, colors.Colormap) + value = plt.get_cmap(value) self._cmap = value @staticmethod @@ -177,7 +176,6 @@ def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, colo rad = arrow_curvature / dist tail_width = width head_width = max(2., 2 * width) - # head_length = max(1.5, head_width) arrow_style = patches.ArrowStyle.Simple(head_length=head_width, head_width=head_width, tail_width=tail_width) connection_style = patches.ConnectionStyle.Arc3(rad=-rad) @@ -202,13 +200,13 @@ def edge_base_scale(self): element in the adjacency matrix """ if issparse(self.adjacency_matrix): mat = self.adjacency_matrix.tocoo() - max_off_diag = -np.inf + max_off_diag = 0 for i, j, v in zip(mat.row, mat.col, mat.data): if i != j: - max_off_diag = max(max_off_diag, v) + max_off_diag = max(max_off_diag, abs(v)) else: diag_mask = np.logical_not(np.eye(self.adjacency_matrix.shape[0], dtype=bool)) - max_off_diag = np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)]) + max_off_diag = np.max(np.abs(self.adjacency_matrix[diag_mask])) default_lw = default_line_width() return 2 * default_lw / max_off_diag @@ -238,6 +236,9 @@ def state_sizes(self) -> np.ndarray: @state_sizes.setter def state_sizes(self, value: Optional[np.ndarray]): + if value is not None and len(value) != self.n_nodes: + raise ValueError(f"State sizes must correspond to states (# = {self.n_nodes}) but was of " + f"length {len(value)}") self._state_sizes = value if value is not None else np.ones(self.n_nodes) @property @@ -276,10 +277,11 @@ def state_colors(self, value): state_colors = ensure_number_array(value) if state_colors.ndim == 1: state_colors = np.array([self.cmap(x) for x in state_colors]) - elif state_colors.ndim in (3, 4): - pass # ok + elif state_colors.ndim == 2 and state_colors.shape[1] in (3, 4): + pass # ok: rgb(a) values else: - raise ValueError() + raise ValueError(f"state color(s) can only be individual color or float range or rgb(a) values " + f"but was {state_colors}") if len(state_colors) != self.n_nodes: raise ValueError(f"Mismatch between n_states and #state_colors ({self.n_nodes} vs {len(state_colors)}).") self._state_colors = state_colors @@ -346,18 +348,16 @@ def plot(self, ax=None, **textkwargs): if self.state_labels is not None: ax.text(self.pos[i][0], self.pos[i][1], self.state_labels[i], zorder=3, **textkwargs) - assert len(circles) == self.n_nodes, f"{len(circles)} != {self.n_nodes}" - # draw arrows for i, j in zip(*np.triu_indices(self.n_nodes, k=1)): # upper triangular indices with 0 <= i < j < n_nodes - if abs(self.adjacency_matrix[i, j]) > 0: + if self.adjacency_matrix[i, j] != 0: label = self.edge_label(i, j) width = edge_scale * self.adjacency_matrix[i, j] self._draw_arrow(ax, self.pos[i], self.pos[j], label=label, width=width, arrow_curvature=self.edge_curvature, patchA=circles[i], patchB=circles[j], shrinkA=3, shrinkB=0, arrow_label_size=arrow_label_size, arrow_label_location=self.edge_label_location) - if abs(self.adjacency_matrix[j, i]) > 0: + if self.adjacency_matrix[j, i] != 0: label = self.edge_label(j, i) width = edge_scale * self.adjacency_matrix[j, i] self._draw_arrow(ax, self.pos[j], self.pos[i], label=label, width=width,