Skip to content

Commit

Permalink
add in auto_set_hmm_seq for auto gen hmm_seq
Browse files Browse the repository at this point in the history
  • Loading branch information
xjing76 committed Oct 14, 2020
1 parent 1845f6b commit 1a49f79
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion pymc3_hmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

from scipy.special import logsumexp

from pymc3_hmm.distributions import HMMStateSeq

import pymc3 as pm


vsearchsorted = np.vectorize(np.searchsorted, otypes=[np.int], signature="(n),()->()")

Expand Down Expand Up @@ -208,7 +212,7 @@ def plot_split_timeseries(
drawstyle="steps-pre",
linewidth=0.5,
plot_fn=None,
**plot_kwds
**plot_kwds,
): # pragma: no cover
"""Plot long timeseries by splitting them across multiple rows using a given time frequency.
Expand Down Expand Up @@ -324,3 +328,34 @@ def plot_fn(ax, data, **kwargs):
plt.tight_layout()

return return_axes_data


def auto_set_hmm_seq(N_states, model, states):
"""
Initiate a HMMStateSeq based on the length of the mixture component.
This function require pymc3 and HMMStateSeq.
Parameters
----------
N_states : int
Number of states in the mixture
model : pymc3.model.Model
Model object that we trained on
states : ndarray
Vector sequence of states to set the `test_value` for `HMMStateSeq`
Returns
-------
locals(), a dict of local variables for reference in sampling steps.
"""
with model:
pp = [pm.Dirichlet(f"p_{i}", np.ones(N_states)) for i in range(N_states)]
P_tt = tt.stack(pp)
P_rv = pm.Deterministic("Gamma", tt.shape_padleft(P_tt))
pi_0_tt = compute_steady_state(P_rv)

S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0])
S_rv.tag.test_value = states

return locals()

0 comments on commit 1a49f79

Please sign in to comment.