Skip to content

Commit

Permalink
incorporate feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
clonker committed May 3, 2022
1 parent 3e11aba commit d911892
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions deeptime/plots/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Network:
\frac{\mathrm{state\_sizes}}{\| \mathrm{state\_sizes} \|_{\max}}
with :math:`\mathrm{state\_sizes}` interpreted as `1` in case of `None`. In particular this means that
the states scale their size with respect to the volume. I.e., `state_scale=[1, 2]` leads to the second
the states scale their size with respect to the volume. I.e., `state_sizes=[1, 2]` leads to the second
state drawn as a circle with twice the volume of the first.
state_scale : float, default=1.
Uniform scaling factor for `state_sizes`.
Expand Down Expand Up @@ -157,12 +157,11 @@ def cmap(self):
@cmap.setter
def cmap(self, value):
from matplotlib import colors
from matplotlib import colormaps
import matplotlib.pyplot as plt
if value is None:
value = default_image_cmap()
if not isinstance(value, colors.Colormap):
value = colormaps.get(value)
assert isinstance(value, colors.Colormap)
value = plt.get_cmap(value)
self._cmap = value

@staticmethod
Expand All @@ -177,7 +176,6 @@ def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, colo
rad = arrow_curvature / dist
tail_width = width
head_width = max(2., 2 * width)
# head_length = max(1.5, head_width)

arrow_style = patches.ArrowStyle.Simple(head_length=head_width, head_width=head_width, tail_width=tail_width)
connection_style = patches.ConnectionStyle.Arc3(rad=-rad)
Expand All @@ -202,13 +200,13 @@ def edge_base_scale(self):
element in the adjacency matrix """
if issparse(self.adjacency_matrix):
mat = self.adjacency_matrix.tocoo()
max_off_diag = -np.inf
max_off_diag = 0
for i, j, v in zip(mat.row, mat.col, mat.data):
if i != j:
max_off_diag = max(max_off_diag, v)
max_off_diag = max(max_off_diag, abs(v))
else:
diag_mask = np.logical_not(np.eye(self.adjacency_matrix.shape[0], dtype=bool))
max_off_diag = np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)])
max_off_diag = np.max(np.abs(self.adjacency_matrix[diag_mask]))
default_lw = default_line_width()
return 2 * default_lw / max_off_diag

Expand Down Expand Up @@ -238,6 +236,9 @@ def state_sizes(self) -> np.ndarray:

@state_sizes.setter
def state_sizes(self, value: Optional[np.ndarray]):
if value is not None and len(value) != self.n_nodes:
raise ValueError(f"State sizes must correspond to states (# = {self.n_nodes}) but was of "
f"length {len(value)}")
self._state_sizes = value if value is not None else np.ones(self.n_nodes)

@property
Expand Down Expand Up @@ -276,10 +277,11 @@ def state_colors(self, value):
state_colors = ensure_number_array(value)
if state_colors.ndim == 1:
state_colors = np.array([self.cmap(x) for x in state_colors])
elif state_colors.ndim in (3, 4):
pass # ok
elif state_colors.ndim == 2 and state_colors.shape[1] in (3, 4):
pass # ok: rgb(a) values
else:
raise ValueError()
raise ValueError(f"state color(s) can only be individual color or float range or rgb(a) values "
f"but was {state_colors}")
if len(state_colors) != self.n_nodes:
raise ValueError(f"Mismatch between n_states and #state_colors ({self.n_nodes} vs {len(state_colors)}).")
self._state_colors = state_colors
Expand Down Expand Up @@ -346,18 +348,16 @@ def plot(self, ax=None, **textkwargs):
if self.state_labels is not None:
ax.text(self.pos[i][0], self.pos[i][1], self.state_labels[i], zorder=3, **textkwargs)

assert len(circles) == self.n_nodes, f"{len(circles)} != {self.n_nodes}"

# draw arrows
for i, j in zip(*np.triu_indices(self.n_nodes, k=1)): # upper triangular indices with 0 <= i < j < n_nodes
if abs(self.adjacency_matrix[i, j]) > 0:
if self.adjacency_matrix[i, j] != 0:
label = self.edge_label(i, j)
width = edge_scale * self.adjacency_matrix[i, j]
self._draw_arrow(ax, self.pos[i], self.pos[j], label=label, width=width,
arrow_curvature=self.edge_curvature, patchA=circles[i], patchB=circles[j],
shrinkA=3, shrinkB=0,
arrow_label_size=arrow_label_size, arrow_label_location=self.edge_label_location)
if abs(self.adjacency_matrix[j, i]) > 0:
if self.adjacency_matrix[j, i] != 0:
label = self.edge_label(j, i)
width = edge_scale * self.adjacency_matrix[j, i]
self._draw_arrow(ax, self.pos[j], self.pos[i], label=label, width=width,
Expand Down

0 comments on commit d911892

Please sign in to comment.