diff --git a/pymc3_hmm/utils.py b/pymc3_hmm/utils.py index 694f76b..c194bae 100644 --- a/pymc3_hmm/utils.py +++ b/pymc3_hmm/utils.py @@ -4,6 +4,10 @@ from scipy.special import logsumexp +from pymc3_hmm.distributions import HMMStateSeq + +import pymc3 as pm + vsearchsorted = np.vectorize(np.searchsorted, otypes=[np.int], signature="(n),()->()") @@ -208,7 +212,7 @@ def plot_split_timeseries( drawstyle="steps-pre", linewidth=0.5, plot_fn=None, - **plot_kwds + **plot_kwds, ): # pragma: no cover """Plot long timeseries by splitting them across multiple rows using a given time frequency. @@ -324,3 +328,34 @@ def plot_fn(ax, data, **kwargs): plt.tight_layout() return return_axes_data + + +def auto_set_hmm_seq(N_states, model, states): + """ + Initiate a HMMStateSeq based on the length of the mixture component. + + This function require pymc3 and HMMStateSeq. + + Parameters + ---------- + N_states : int + Number of states in the mixture + model : pymc3.model.Model + Model object that we trained on + states : ndarray + Vector sequence of states to set the `test_value` for `HMMStateSeq` + + Returns + ------- + locals(), a dict of local variables for reference in sampling steps. + """ + with model: + pp = [pm.Dirichlet(f"p_{i}", np.ones(N_states)) for i in range(N_states)] + P_tt = tt.stack(pp) + P_rv = pm.Deterministic("Gamma", tt.shape_padleft(P_tt)) + pi_0_tt = compute_steady_state(P_rv) + + S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0]) + S_rv.tag.test_value = states + + return locals()