Skip to content

Commit

Permalink
Merge pull request #98 from fabian-paul/fix_ratematrix_assertions
Browse files Browse the repository at this point in the history
Fix ratematrix assertions
  • Loading branch information
franknoe authored Jan 3, 2017
2 parents f0feb28 + 770d1b5 commit dc134b3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 40 deletions.
92 changes: 53 additions & 39 deletions msmtools/estimation/dense/ratematrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,41 @@ def sum1(M):
return x


def raise_or_warn(msg, on_error, warning=UserWarning, exception=RuntimeError):
if on_error == 'raise':
raise exception(msg)
elif on_error == 'warn':
warnings.warn(msg, warning)
else:
raise ValueError('Unsupported value of on_error (%s). Should be "raise" or "warn".'%on_error)


class _RateMatrixEstimator(object):
# base class: includes parametrization of K matrix and the basic class interface
def __init__(self, C, dt=1.0, sparsity=None, t_agg=None, pi=None, tol=1.0E7, maxiter=100000, on_error='raise'):
assert np.all(C >= 0)
assert C.shape[0] == C.shape[1]
if not np.all(C >= 0):
raise ValueError('Elements of C matrix should be non-negative.')
if not C.shape[0] == C.shape[1]:
raise ValueError('C matrix should be square.')
self.zero_C = np.where(C == 0)
self.nonzero_C = np.where(C != 0)
assert dt > 0.0
if not dt > 0.0:
raise ValueError('dt should be positive.')
if sparsity is not None:
assert np.all(sparsity >= 0)
assert sparsity.shape[0] == sparsity.shape[1] == C.shape[0]
if not np.all(sparsity >= 0):
raise ValueError('Elements of sparsity matrix should be non-negative.')
if not sparsity.shape[0] == sparsity.shape[1] == C.shape[0]:
raise ValueError('sparsity matrix should be square and of the same dimension as the C matrix.')
if pi is not None:
assert np.all(pi > 0.0)
assert pi.shape[0] == C.shape[0]
assert np.allclose(np.sum(pi), 1.0)
if not np.all(pi > 0.0):
raise ValueError('pi should be positive.')
if not pi.shape[0] == C.shape[0]:
raise ValueError('pi and the C matrix should have compatible dimensions.')
if not np.allclose(np.sum(pi), 1.0):
raise ValueError('pi should sum to one.')
if t_agg is not None:
assert t_agg > 0
if not t_agg > 0:
raise ValueError('t_agg should be positive.')
self.t_agg = t_agg
else:
self.t_agg = dt*C.sum()
Expand Down Expand Up @@ -153,7 +171,7 @@ def __init__(self, C, pi, dt=1.0, sparsity=None, t_agg=None, tol=1.0E7, maxiter=
self.lower_bounds = np.zeros(len(self.I))
for i, j, n in zip(self.I, self.J, range(len(self.I))):
self.lower_bounds[n] = 1.0 / (self.t_agg * (1.0 / self.pi[i] + 1.0 / self.pi[j]))
assert self.lower_bounds[n] > 0.0
assert self.lower_bounds[n] > 0.0 # self-consitency test
self.bounds[n] = (self.lower_bounds[n], None)

# for matrix derivatives
Expand Down Expand Up @@ -202,10 +220,7 @@ def run(self):
logging.info('l_bfgs_b says: '+str(d))
logging.info('objective function value reached: %f' % f)
if d['warnflag'] != 0:
if self.on_error == 'raise':
raise NotConvergedError(str(d))
else:
warnings.warn(str(d), NotConvergedWarning)
raise_or_warn(str(d), on_error=self.on_error, warning=NotConvergedWarning, exception=NotConvergedError)

K = np.zeros((self.N, self.N))
K[self.I, self.J] = theta / self.pi[self.I]
Expand Down Expand Up @@ -246,6 +261,9 @@ def run(self, maxiter=100000, on_error='raise'):
self.T = transition_matrix(self.C, reversible=True, mu=self.pi)

self.K = np.maximum(np.array(sp.linalg.logm(np.dot(self.T, self.T))/(2.0*self.dt)), 0)
np.fill_diagonal(self.K, 0)
np.fill_diagonal(self.K, -sum1(self.K))

return self.K


Expand Down Expand Up @@ -311,12 +329,17 @@ def __init__(self, T, K0, pi, dt=1.0, sparsity=None, t_agg=None, tol=1.0E7, maxi

super(CrommelinVandenEijndenEstimator, self).__init__(T, pi, dt=dt, sparsity=sparsity, t_agg=t_agg, tol=tol, maxiter=maxiter, on_error=on_error)

assert K0.shape[0] == K0.shape[1] == self.N
assert is_transition_matrix(T)
if not K0.shape[0] == K0.shape[1] == self.N:
raise ValueError('Shapes of K0 matrix (initial guess) and count matrix do not match.')
if not is_transition_matrix(T):
raise_or_warn('T is not a valid transition matrix.', self.on_error)

evals, self.U, self.Uinv = eigen_decomposition(T, self.pi)
assert np.all(np.abs(evals) > 0.0) # don't allow eigenvalue==exactly zero
assert np.allclose(self.Uinv.dot(T).dot(self.U), np.diag(evals)) # debug
if not np.all(np.abs(evals) > 0.0): # don't allow eigenvalue==exactly zero
raise ValueError('T has eigenvalues that are exactly zero, can\'t proceed with rate matrix estimation. '
'If the CVE method is only used to intitialize the KL method, you might try to call the KL '
'method with an initial guess of the rate matrix (K0) instead of intializing with CVE.')
assert np.allclose(self.Uinv.dot(T).dot(self.U), np.diag(evals)) # self-consistency test

self.c = np.abs(evals)
self.L = np.diag(np.log(np.abs(evals)) / self.dt)
Expand All @@ -325,10 +348,8 @@ def __init__(self, T, K0, pi, dt=1.0, sparsity=None, t_agg=None, tol=1.0E7, maxi
self.initial = np.maximum(theta, self.lower_bounds)

def function(self, x):
if self.sparsity is None:
assert np.all(x >= 0)
else:
assert np.all(x > 0)
if not np.all(x>=self.lower_bounds):
raise_or_warn('Optimizer violated the lower bounds for rate matrix elements.', self.on_error)

# compute K
K = np.zeros((self.N, self.N))
Expand All @@ -341,11 +362,8 @@ def function(self, x):
return f

def function_and_gradient(self, x):
assert np.all(x>=self.lower_bounds) # debug
if self.sparsity is None:
assert np.all(x >= 0)
else:
assert np.all(x > 0)
if not np.all(x>=self.lower_bounds):
raise_or_warn('Optimizer violated the lower bounds for rate matrix elements.', self.on_error)

# compute K
K = np.zeros((self.N, self.N))
Expand Down Expand Up @@ -428,7 +446,8 @@ class KalbfleischLawlessEstimator(_ReversibleRateMatrixEstimator):
def __init__(self, C, K0, pi, dt=1.0, sparsity=None, t_agg=None, tol=1.0E7, maxiter=100000, on_error='raise'):
super(KalbfleischLawlessEstimator, self).__init__(C, pi, dt=dt, sparsity=sparsity, t_agg=t_agg, tol=tol, maxiter=maxiter, on_error=on_error)

assert K0.shape[0] == K0.shape[1] == self.N
if not K0.shape[0] == K0.shape[1] == self.N:
raise ValueError('Shapes of K0 matrix (initial guess) and count matrix do not match.')

# specific variables for KL estimator
self.sqrt_pi = np.sqrt(pi)
Expand All @@ -438,10 +457,8 @@ def __init__(self, C, K0, pi, dt=1.0, sparsity=None, t_agg=None, tol=1.0E7, maxi

def function(self, x):
self.count += 1
if self.sparsity is None:
assert np.all(x >= 0)
else:
assert np.all(x > 0)
if not np.all(x>=self.lower_bounds):
raise_or_warn('Optimizer violated the lower bounds for rate matrix elements.', self.on_error)

# compute function
K = np.zeros((self.N, self.N))
Expand All @@ -452,7 +469,7 @@ def function(self, x):
T[self.zero] = 1.0 # set unused elements to dummy to avoid division by 0
# check T!=0 for C!=0
nonzero_C = np.where(self.C != 0)
if np.any(np.abs(T[nonzero_C]) <= 1E-15):
if np.any(np.abs(T[nonzero_C]) <= 1E-20):
warnings.warn('Warning: during iteration T_ij became very small while C(tau)_ij > 0.', NotConnectedWarning)
f = ksum(self.C * np.log(T))

Expand All @@ -461,11 +478,8 @@ def function(self, x):
return -f

def function_and_gradient(self, x):
assert np.all(x>=self.lower_bounds) # debug
if self.sparsity is None:
assert np.all(x >= 0)
else:
assert np.all(x > 0)
if not np.all(x>=self.lower_bounds):
raise_or_warn('Optimizer violated the lower bounds for rate matrix elements.', self.on_error)

# compute function
K = np.zeros((self.N, self.N))
Expand Down Expand Up @@ -610,7 +624,7 @@ def estimate_rate_matrix(C, dt=1.0, method='KL', sparsity=None,
15:1474, 2010.
"""
if method not in ['pseudo', 'truncated_log', 'CVE', 'KL']:
raise Exception("method must be one of 'KL', 'CVE', 'pseudo' or 'truncated_log'")
raise ValueError("method must be one of 'KL', 'CVE', 'pseudo' or 'truncated_log'")

# special case: truncated matrix logarithm
if method == 'truncated_log':
Expand Down
13 changes: 12 additions & 1 deletion msmtools/estimation/tests/test_ratematrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import scipy as sp
from msmtools.util import kahandot
import unittest
import sys
import msmtools
import warnings

class TestLowlevelNumerics(unittest.TestCase):
def test_kdot(self):
Expand Down Expand Up @@ -73,6 +73,17 @@ def test_api_without_connectivity_without_pi_with_guess(self):
K_est = msmtools.estimation.rate_matrix(self.C, dt=self.tau, tol=100.0, K0=self.K)
assert np.allclose(self.K, K_est, rtol=5.0E-3, atol=1.0E-3)

def test_raise(self):
with self.assertRaises(msmtools.estimation.dense.ratematrix.NotConvergedError):
msmtools.estimation.rate_matrix(self.C, dt=self.tau, method='CVE', maxiter=1, on_error='raise')

def test_warn(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
msmtools.estimation.rate_matrix(self.C, dt=self.tau, method='CVE', maxiter=1, on_error='warn')
assert len(w) == 1
assert issubclass(w[-1].category, msmtools.estimation.dense.ratematrix.NotConvergedWarning)


if __name__ == '__main__':
unittest.main()

0 comments on commit dc134b3

Please sign in to comment.