Skip to content

Commit

Permalink
Tram performance improvements (#217)
Browse files Browse the repository at this point in the history
MaaikeG authored Mar 15, 2022
1 parent fe8b289 commit 38b0158
Showing 9 changed files with 83 additions and 30 deletions.
12 changes: 6 additions & 6 deletions deeptime/markov/msm/tram/_tram.py
Original file line number Diff line number Diff line change
@@ -228,7 +228,7 @@ def _make_tram_estimator(self, model, dataset):
else:
if self.init_strategy == "MBAR":
# initialize free energies using MBAR.
with callbacks.ProgressCallback(self.progress, "Initializing free energies using MBAR",
with callbacks.IterationErrorProgressCallback(self.progress, "Initializing free energies using MBAR",
self.init_maxiter) as callback:
free_energies = tram.initialize_free_energies_mbar(np.concatenate(dataset.bias_matrices),
dataset.state_counts.sum(axis=1),
@@ -258,7 +258,7 @@ def _run_estimation(self, tram_input):
f"Last increment: {callback.last_increment}", ConvergenceWarning)


class TRAMCallback(callbacks.ProgressCallback):
class TRAMCallback(callbacks.IterationErrorProgressCallback):
"""Callback for the TRAM estimate process. Increments a progress bar and saves iteration increments in the free
energies and log-likelihoods to a list.
@@ -278,7 +278,7 @@ def __init__(self, progress, total, log_likelihoods_list=None, increments=None):
self.increments = increments
self.last_increment = 0

def __call__(self, n_iterations, increment, log_likelihood=0):
def __call__(self, inc, error, log_likelihood=0):
"""Call the callback. Increment a progress bar (if available) and store convergence information.
Parameters
@@ -290,12 +290,12 @@ def __call__(self, n_iterations, increment, log_likelihood=0):
log_likelihood : float
The current log-likelihood, or 0. when the tram estimator is not configured to calculate log-likelihoods.
"""
super().__call__(n_iterations)
super().__call__(inc, error=error)

if self.log_likelihoods is not None:
self.log_likelihoods.append(log_likelihood)

if self.increments is not None:
self.increments.append(increment)
self.increments.append(error)

self.last_increment = increment
self.last_increment = error
35 changes: 21 additions & 14 deletions deeptime/src/include/deeptime/common.h
Original file line number Diff line number Diff line change
@@ -71,13 +71,8 @@ class Index {
std::copy(shapeBegin, shapeEnd, begin(dims));
auto n_elems = std::accumulate(begin(dims), end(dims), static_cast<value_type>(1), std::multiplies<value_type>());

GridDims strides {};
if (n_elems > 0) {
strides[0] = n_elems / dims[0];
for (std::size_t d = 0; d < Dims - 1; ++d) {
strides[d + 1] = strides[d] / dims[d + 1];
}
}
GridDims strides = computeStrides(dims, n_elems);

return Index<Dims, GridDims>{dims, strides, n_elems};
}

@@ -96,13 +91,7 @@ class Index {
: _size(), n_elems(std::accumulate(begin(size), end(size), 1u, std::multiplies<value_type>())) {
std::copy(begin(size), end(size), begin(_size));

GridDims strides {};
if (n_elems > 0) {
strides[0] = n_elems / size[0];
for (std::size_t d = 0; d < Dims - 1; ++d) {
strides[d + 1] = strides[d] / size[d + 1];
}
}
GridDims strides = computeStrides(size, n_elems);
_cum_size = std::move(strides);
}

@@ -185,6 +174,24 @@ class Index {
return result;
}

/**
* compute strides for each dim
* @param size size of each dimension
* @param n_elems total number of elements
* @return the strides
*/
template<typename Shape>
static GridDims computeStrides(const Shape &size, value_type n_elems) {
GridDims strides {};
if (n_elems > 0) {
strides[0] = n_elems / size[0];
for (std::size_t d = 0; d < Dims - 1; ++d) {
strides[d + 1] = strides[d] / size[d + 1];
}
}
return strides;
}

private:
GridDims _size;
GridDims _cum_size;
6 changes: 3 additions & 3 deletions deeptime/src/include/deeptime/markov/msm/tram/mbar.h
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@

#include "common.h"


using namespace pybind11::literals;
namespace deeptime::markov::tram {


@@ -60,7 +60,7 @@ void selfConsistentUpdate(ExchangeableArray<dtype, 1> &thermStateEnergies,
template<typename dtype>
np_array <dtype>
initialize_MBAR(BiasMatrix <dtype> biasMatrix, CountsMatrix stateCounts, std::size_t maxIter = 1000,
double maxErr = 1e-6, std::size_t callbackInterval = 1, const py::object *callback = nullptr) {
double maxErr = 1e-6, std::size_t callbackInterval = 1, const py::function *callback = nullptr) {
// get dimensions...
auto nThermStates = stateCounts.shape(0);
auto nSamples = biasMatrix.shape(0);
@@ -93,7 +93,7 @@ initialize_MBAR(BiasMatrix <dtype> biasMatrix, CountsMatrix stateCounts, std::si
// keep the python user up to date on the progress by a callback
if (callback && callbackInterval > 0 && iterationCount % callbackInterval == 0) {
py::gil_scoped_acquire guard;
(*callback)(callbackInterval, iterationError);
(*callback)("inc"_a=callbackInterval, "error"_a=iterationError);
}

if (iterationError < maxErr) {
8 changes: 4 additions & 4 deletions deeptime/src/include/deeptime/markov/msm/tram/tram.h
Original file line number Diff line number Diff line change
@@ -36,8 +36,8 @@ static const dtype computeSampleLikelihood(const TRAMInput<dtype> &input,
auto modifiedStateCountsLogPtr = &modifiedStateCountsLogBuf;

auto inputPtr = &input;
#pragma omp parallel for default(none) firstprivate(nThermStates, inputPtr, biasMatrixPtr, sampleWeights, \
modifiedStateCountsLogPtr, cumNSamples)
#pragma omp parallel for default(none) firstprivate(nThermStates, inputPtr, biasMatrixPtr, \
modifiedStateCountsLogPtr, cumNSamples) shared(sampleWeights)
for (auto i = 0; i < inputPtr->nMarkovStates(); ++i) {
std::vector<dtype> scratch(nThermStates);
for (auto x = 0; x < inputPtr->nSamples(i); ++x) {
@@ -47,8 +47,8 @@ static const dtype computeSampleLikelihood(const TRAMInput<dtype> &input,
scratch[o++] = (*modifiedStateCountsLogPtr)(l, i) - biasMatrixPtr[i](x, l);
}
}
auto log_divisor = numeric::kahan::logsumexp_sort_kahan_inplace(scratch.begin(), o);
sampleWeights[cumNSamples[i] + x] = -log_divisor;
auto logDivisor = numeric::kahan::logsumexp_sort_kahan_inplace(scratch.begin(), o);
sampleWeights[cumNSamples[i] + x] = -logDivisor;
}
}
return numeric::kahan::logsumexp_sort_kahan_inplace(sampleWeights.begin(), sampleWeights.end());
40 changes: 38 additions & 2 deletions deeptime/util/callbacks.py
Original file line number Diff line number Diff line change
@@ -49,12 +49,12 @@ class ProgressCallback:
def __init__(self, progress, description=None, total=None):
self.progress_bar = handle_progress_bar(progress)(total=total)
self.total = total
self.set_description(description)

assert supports_progress_interface(self.progress_bar), \
f"Progress bar did not satisfy interface! It should at least have " \
f"the method(s) {supports_progress_interface.required_methods} and " \
f"the attribute(s) {supports_progress_interface.required_attributes}."
if description is not None:
self.progress_bar.set_description(description)

def __call__(self, inc=1, *args, **kw):
self.progress_bar.update(inc)
@@ -66,3 +66,39 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.progress_bar.total = self.progress_bar.n # force finish
self.progress_bar.close()

def set_description(self, value):
self.progress_bar.set_description(value)


class IterationErrorProgressCallback(ProgressCallback):
r"""Callback function for the c++ bindings to indicate progress by incrementing a progress bar and showing the
iteration error on each iteration.
Parameters
----------
progress : object
Tested for a tqdm progress bar. Should implement `update()`, `set_description()`, and `close()`. Should
also possess a `total` constructor keyword argument.
total : int
Number of iterations to completion.
description : string
text to display in front of the progress bar.
Notes
-----
To display the iteration error, the error needs to be passed to `__call__()` as keyword argument `error`.
See Also
--------
supports_progress_interface, ProgressCallback
"""

def __init__(self, progress, description=None, total=None):
super().__init__(progress, description, total)
self.description = description

def __call__(self, inc=1, *args, **kw):
super().__call__(inc)
if 'error' in kw:
super().set_description("{} - [inc: {:.1e}]".format(self.description, kw.get('error')))
2 changes: 2 additions & 0 deletions tests/markov/hmm/test_integration.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_almost_equal
from flaky import flaky

import deeptime
from deeptime.data import prinz_potential
@@ -96,6 +97,7 @@ def test_observation_probabilities(hmm_scenario):
assert_almost_equal(minerr, 0, decimal=2)


@flaky(max_runs=3)
def test_stationary_distribution(hmm_scenario):
model = hmm_scenario.hmm
minerr = 1e6
2 changes: 2 additions & 0 deletions tests/markov/msm/test_mlmsm.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
"""
import collections
import unittest
from flaky import flaky

from numpy.testing import assert_equal, assert_raises, assert_, assert_array_almost_equal, assert_array_equal

@@ -143,6 +144,7 @@ def test_strongly_connected_count_matrix():
assert_equal(msm.count_model.state_symbols, [4, 6])


@flaky(max_runs=3)
@pytest.mark.parametrize("sparse", [False, True], ids=["dense", "sparse"])
def test_birth_death_chain(fixed_seed, sparse):
"""Meta-stable birth-death chain"""
4 changes: 4 additions & 0 deletions tests/markov/msm/test_tram.py
Original file line number Diff line number Diff line change
@@ -248,6 +248,8 @@ def __new__(cls, *args, **kwargs): return progress

# update() should be called 5 times
np.testing.assert_equal(progress.n_update_calls, 5)
# description should be set once initially and on each update call
np.testing.assert_equal(progress.n_description_updates, progress.n_update_calls + 1)
np.testing.assert_equal(progress.n, 10)
# and close() one time
np.testing.assert_equal(progress.n_close_calls, 1)
@@ -264,6 +266,8 @@ def __new__(cls, *args, **kwargs): return progress

# update() should be called 10 times
np.testing.assert_equal(progress.n_update_calls, 10)
# description should be set once initially for both MBAR and TRAM (=2) plus once on each update call
np.testing.assert_equal(progress.n_description_updates, progress.n_update_calls + 2)
np.testing.assert_equal(progress.n, 20)
# and close() one time
np.testing.assert_equal(progress.n_close_calls, 2)
4 changes: 3 additions & 1 deletion tests/testing_utilities.py
Original file line number Diff line number Diff line change
@@ -10,8 +10,10 @@ def __init__(self):
self.n = 0
self.n_close_calls = 0
self.n_update_calls = 0
self.n_description_updates = 0

def set_description(self, *_): ...
def set_description(self, *_):
self.n_description_updates += 1

def update(self, n=1):
self.n += n

0 comments on commit 38b0158

Please sign in to comment.