Skip to content

Commit

Permalink
Tram docs and a unit test (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaaikeG authored Jan 12, 2022
1 parent 86eb11b commit 65fc482
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 9 deletions.
14 changes: 7 additions & 7 deletions deeptime/markov/msm/tram/_tram.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class TRAM(_MSMBaseEstimator):
callback_interval : int, optional, default=0
Every callback_interval iteration steps, the callback function is calles and error increments are stored. If
track_log_likelihoods=true, the log-likelihood are also stored. If 0, no call to the callback function is done.
progress_bar : object
progress : object
Progress bar object that TRAM will call to indicate progress to the user.
Tested for a tqdm progress bar. Should implement update() and close() and have .total and .desc properties.
Expand All @@ -96,8 +96,8 @@ class TRAM(_MSMBaseEstimator):
def __init__(
self, lagtime=1, count_mode='sliding',
maxiter=10000, maxerr: float = 1e-8,
track_log_likelihoods=False, callback_interval=0,
progress_bar=None):
track_log_likelihoods=False, callback_interval=1,
progress=None):

super(TRAM, self).__init__()

Expand All @@ -108,7 +108,7 @@ def __init__(
self.maxerr = maxerr
self.track_log_likelihoods = track_log_likelihoods
self.callback_interval = callback_interval
self.progress_bar = progress_bar
self.progress = progress
self._largest_connected_set = None
self.log_likelihoods = []
self.increments = []
Expand Down Expand Up @@ -203,7 +203,7 @@ def fit(self, data, model=None, *args, **kw):
def _run_estimation(self, tram_input):
""" Estimate the free energies using self-consistent iteration as described in the TRAM paper.
"""
with TRAMCallback(self.progress_bar, self.maxiter, self.log_likelihoods, self.increments,
with TRAMCallback(self.progress, self.maxiter, self.log_likelihoods, self.increments,
self.callback_interval > 0) as callback:
self._tram_estimator.estimate(tram_input, self.maxiter, self.maxerr,
track_log_likelihoods=self.track_log_likelihoods,
Expand All @@ -229,8 +229,8 @@ class TRAMCallback(callbacks.Callback):
If True, log_likelihoods and increments are appended to their respective lists each time callback.__call__() is
called. If false, no values are appended, only the last increment is stored.
"""
def __init__(self, progress_bar, n_iter, log_likelihoods_list=None, increments=None, store_convergence_info=False):
super().__init__(progress_bar, n_iter, "Running TRAM estimate")
def __init__(self, progress, n_iter, log_likelihoods_list=None, increments=None, store_convergence_info=False):
super().__init__(progress, n_iter, "Running TRAM estimate")
self.log_likelihoods = log_likelihoods_list
self.increments = increments
self.store_convergence_info = store_convergence_info
Expand Down
2 changes: 1 addition & 1 deletion deeptime/markov/msm/tram/_tram_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _find_largest_connected_set(self, connectivity, connectivity_factor, progres
(i_s, j_s) = connectivity_fn(self.ttrajs, self.dtrajs, self.bias_matrices, all_state_counts,
self.n_therm_states, self.n_markov_states, connectivity_factor,
callback)
print((i_s, j_s))

# add transitions that occurred within each thermodynamic state. These are simply the connected sets:
for k in range(self.n_therm_states):
for cset in estimator.fit_fetch(self.dtrajs[k]).connected_sets():
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks
Submodule notebooks updated 1 files
+92 −105 tram.ipynb
27 changes: 27 additions & 0 deletions tests/markov/msm/test_tram.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,30 @@ def test_callback_called(track_log_likelihoods):
np.testing.assert_((np.asarray(tram.log_likelihoods) < 0).all())
else:
np.testing.assert_((np.asarray(tram.log_likelihoods) == 0).all())


def test_progress_bar_update_called():
class ProgressMock:
def __init__(self, _tracking_ints):
self.total = 1
self.desc = 0

self.tracking_ints = tracking_ints

def update(self, _):
self.tracking_ints[0] += 1

def close(self):
self.tracking_ints[1] += 1

# workaround to track function calls because the progress bar is copied internally
tracking_ints = [0, 0]

progress_mock = ProgressMock(tracking_ints)
tram = TRAM(callback_interval=2, maxiter=10, progress=progress_mock)
tram.fit(make_random_input_data(5, 5))

# update() should be called 5 times
np.testing.assert_equal(tracking_ints[0], 5)
# and close() one time
np.testing.assert_equal(tracking_ints[1], 1)
2 changes: 2 additions & 0 deletions tests/markov/test_cktest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
import deeptime as dt
from flaky import flaky
from numpy.testing import assert_allclose, assert_equal, assert_raises


Expand All @@ -12,6 +13,7 @@ def test_invalid_mlags():
est.chapman_kolmogorov_validator(2, mlags=[0, 1, -10])


@flaky(max_runs=3, min_passes=1)
@pytest.mark.parametrize("n_jobs", [1, 2], ids=lambda x: f"n_jobs={x}")
@pytest.mark.parametrize("mlags", [2, [0, 1, 10]], ids=lambda x: f"mlags={x}")
@pytest.mark.parametrize("estimator_type", ["MLMSM", "BMSM", "HMM", "BHMM"])
Expand Down

0 comments on commit 65fc482

Please sign in to comment.