Skip to content

Commit

Permalink
some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
clonker committed May 4, 2022
1 parent d911892 commit bac36d2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
2 changes: 1 addition & 1 deletion deeptime/plots/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def state_sizes(self, value: Optional[np.ndarray]):
if value is not None and len(value) != self.n_nodes:
raise ValueError(f"State sizes must correspond to states (# = {self.n_nodes}) but was of "
f"length {len(value)}")
self._state_sizes = value if value is not None else np.ones(self.n_nodes)
self._state_sizes = np.asfarray(value) if value is not None else np.ones(self.n_nodes)

@property
def node_sizes(self):
Expand Down
6 changes: 4 additions & 2 deletions tests/plots/test_energy_surface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import matplotlib
import pytest
from numpy.testing import assert_
from numpy.testing import assert_, assert_equal

from deeptime.data import ellipsoids
from deeptime.util import energy2d
Expand All @@ -12,7 +12,9 @@
@pytest.mark.parametrize('cbar', [True, False], ids=lambda x: f"cbar={x}")
def test_energy2d(shift_energy, cbar):
traj = ellipsoids().observations(20000)
ax, contourf, cbar = energy2d(*traj.T, bins=100, shift_energy=shift_energy).plot(cbar=cbar)
plot_object = energy2d(*traj.T, bins=100, shift_energy=shift_energy).plot(cbar=cbar)
assert_equal(len(plot_object), 3)
ax, contourf, cbar = plot_object
assert_(ax is not None)
assert_(contourf is not None)
assert_(cbar is not None if cbar else cbar is None)
37 changes: 31 additions & 6 deletions tests/plots/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import pytest
from numpy.testing import assert_raises

from deeptime.plots.network import Network, plot_markov_model
from deeptime.markov.msm import MarkovStateModel

matplotlib.use('Agg')


def test_network():
@pytest.mark.parametrize("cmap", ["nipy_spectral", plt.cm.nipy_spectral, None])
@pytest.mark.parametrize("labels", [None, np.array([["juhu"] * 5] * 5), 'weights'], ids=['None', 'strs', 'weights'])
def test_network(labels, cmap):
P = np.array([[0.8, 0.15, 0.05, 0.0, 0.0],
[0.1, 0.75, 0.05, 0.05, 0.05],
[0.05, 0.1, 0.8, 0.0, 0.05],
Expand All @@ -20,14 +25,34 @@ def test_network():
flux = MarkovStateModel(Psparse).reactive_flux([2], [3])

positions = nx.planar_layout(nx.from_scipy_sparse_matrix(flux.gross_flux))
labels = np.array([["juhu"] * flux.n_states] * flux.n_states)
pl = Network(flux.gross_flux, positions, edge_curvature=2., edge_labels=labels,
state_colors=np.linspace(0, 1, num=flux.n_states))
state_colors=np.linspace(0, 1, num=flux.n_states), cmap=cmap)
ax = pl.plot()
ax.set_aspect('equal')

f, (ax1, ax2) = plt.subplots(1, 2)

ax = pl.plot(ax=ax1)
ax.set_aspect('equal')
def test_network_invalid_args():
msm = MarkovStateModel(np.eye(3))
with assert_raises(ValueError):
Network(msm.transition_matrix, pos=np.zeros((2, 2))) # not enough positions
network = Network(msm.transition_matrix, pos=np.zeros((3, 2)))
with assert_raises(ValueError):
network.edge_labels = np.array([["hi"]*2]*2) # not enough labels
with assert_raises(ValueError):
network.edge_labels = 'bogus' # may be 'weights'
network.state_sizes = [1., 2., 3.]
with assert_raises(ValueError):
network.state_sizes = [1., 2.]
network.state_labels = ["1", "2", "3"]
with assert_raises(ValueError):
network.state_labels = ["1", "2"]
network.state_colors = np.random.uniform(size=(3, 3))
network.state_colors = np.random.uniform(size=(3, 4))
with assert_raises(ValueError):
network.state_colors = np.random.uniform(size=(3, 5))
with assert_raises(ValueError):
network.state_colors = np.random.uniform(size=(2, 4))
network.plot()


def test_msm():
Expand Down

0 comments on commit bac36d2

Please sign in to comment.