diff --git a/deeptime/plots/network.py b/deeptime/plots/network.py
index d7cc3fe7a..f4e7a22b8 100644
--- a/deeptime/plots/network.py
+++ b/deeptime/plots/network.py
@@ -29,7 +29,7 @@ class Network:
               \frac{\mathrm{state\_sizes}}{\| \mathrm{state\_sizes} \|_{\max}}
 
         with :math:`\mathrm{state\_sizes}` interpreted as `1` in case of `None`. In particular this means that
-        the states scale their size with respect to the volume. I.e., `state_scale=[1, 2]` leads to the second
+        the states scale their size with respect to the volume. I.e., `state_sizes=[1, 2]` leads to the second
         state drawn as a circle with twice the volume of the first.
     state_scale : float, default=1.
         Uniform scaling factor for `state_sizes`.
@@ -157,12 +157,11 @@ def cmap(self):
     @cmap.setter
     def cmap(self, value):
         from matplotlib import colors
-        from matplotlib import colormaps
+        import matplotlib.pyplot as plt
         if value is None:
             value = default_image_cmap()
         if not isinstance(value, colors.Colormap):
-            value = colormaps.get(value)
-        assert isinstance(value, colors.Colormap)
+            value = plt.get_cmap(value)
         self._cmap = value
 
     @staticmethod
@@ -177,7 +176,6 @@ def _draw_arrow(ax, pos_1, pos_2, label="", width=1.0, arrow_curvature=1.0, colo
         rad = arrow_curvature / dist
         tail_width = width
         head_width = max(2., 2 * width)
-        # head_length = max(1.5, head_width)
 
         arrow_style = patches.ArrowStyle.Simple(head_length=head_width, head_width=head_width, tail_width=tail_width)
         connection_style = patches.ConnectionStyle.Arc3(rad=-rad)
@@ -202,13 +200,13 @@ def edge_base_scale(self):
         element in the adjacency matrix """
         if issparse(self.adjacency_matrix):
             mat = self.adjacency_matrix.tocoo()
-            max_off_diag = -np.inf
+            max_off_diag = 0
             for i, j, v in zip(mat.row, mat.col, mat.data):
                 if i != j:
-                    max_off_diag = max(max_off_diag, v)
+                    max_off_diag = max(max_off_diag, abs(v))
         else:
             diag_mask = np.logical_not(np.eye(self.adjacency_matrix.shape[0], dtype=bool))
-            max_off_diag = np.max(self.adjacency_matrix[diag_mask & (self.adjacency_matrix > 0)])
+            max_off_diag = np.max(np.abs(self.adjacency_matrix[diag_mask]))
         default_lw = default_line_width()
         return 2 * default_lw / max_off_diag
 
@@ -238,6 +236,9 @@ def state_sizes(self) -> np.ndarray:
 
     @state_sizes.setter
     def state_sizes(self, value: Optional[np.ndarray]):
+        if value is not None and len(value) != self.n_nodes:
+            raise ValueError(f"State sizes must correspond to states (# = {self.n_nodes}) but was of "
+                             f"length {len(value)}")
         self._state_sizes = value if value is not None else np.ones(self.n_nodes)
 
     @property
@@ -276,10 +277,11 @@ def state_colors(self, value):
             state_colors = ensure_number_array(value)
             if state_colors.ndim == 1:
                 state_colors = np.array([self.cmap(x) for x in state_colors])
-            elif state_colors.ndim in (3, 4):
-                pass  # ok
+            elif state_colors.ndim == 2 and state_colors.shape[1] in (3, 4):
+                pass  # ok: rgb(a) values
             else:
-                raise ValueError()
+                raise ValueError(f"state color(s) can only be individual color or float range or rgb(a) values "
+                                 f"but was {state_colors}")
         if len(state_colors) != self.n_nodes:
             raise ValueError(f"Mismatch between n_states and #state_colors ({self.n_nodes} vs {len(state_colors)}).")
         self._state_colors = state_colors
@@ -346,18 +348,16 @@ def plot(self, ax=None, **textkwargs):
             if self.state_labels is not None:
                 ax.text(self.pos[i][0], self.pos[i][1], self.state_labels[i], zorder=3, **textkwargs)
 
-        assert len(circles) == self.n_nodes, f"{len(circles)} != {self.n_nodes}"
-
         # draw arrows
         for i, j in zip(*np.triu_indices(self.n_nodes, k=1)):  # upper triangular indices with 0 <= i < j < n_nodes
-            if abs(self.adjacency_matrix[i, j]) > 0:
+            if self.adjacency_matrix[i, j] != 0:
                 label = self.edge_label(i, j)
                 width = edge_scale * self.adjacency_matrix[i, j]
                 self._draw_arrow(ax, self.pos[i], self.pos[j], label=label, width=width,
                                  arrow_curvature=self.edge_curvature, patchA=circles[i], patchB=circles[j],
                                  shrinkA=3, shrinkB=0,
                                  arrow_label_size=arrow_label_size, arrow_label_location=self.edge_label_location)
-            if abs(self.adjacency_matrix[j, i]) > 0:
+            if self.adjacency_matrix[j, i] != 0:
                 label = self.edge_label(j, i)
                 width = edge_scale * self.adjacency_matrix[j, i]
                 self._draw_arrow(ax, self.pos[j], self.pos[i], label=label, width=width,