-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for HMMs with num_states=1 #380
Comments
Hi @umeshksingla! This is an interesting scenario. We do make the assumption that there are at least 2 hidden states (I suppose a model with only one hidden state is not really a HMM). In practice, some behaviours seem to work fine with one hidden state. However, as you have found, we run into an error whenever we try to interact with I am not totally sure that we want to add complexity to handle this, somewhat niche and potentially out of scope, use case however perhaps it is a good idea and if not, at the very least, we should indicate that For your present purposes you can avoid the call to from jax import numpy as jnp
from jax import random as jr
from dynamax.hidden_markov_model import GaussianHMM
hmm = GaussianHMM(num_states=1, emission_dim=1)
initial_probs = jnp.array([1.0])
transition_matrix = jnp.array([[1.0]])
params, props = hmm.initialize(initial_probs=initial_probs, transition_matrix=transition_matrix)
z, x = hmm.sample(params, key = jr.PRNGKey(0), num_timesteps=10)
# z is Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32) You shoud be able to use this model for sampling as normal. However parameter learning (e.g. hmm = GaussianHMM(num_states=1, emission_dim=1)
initial_probs = jnp.array([1.0])
transition_matrix = jnp.array([[1.0]])
params, props = hmm.initialize(initial_probs=initial_probs, transition_matrix=transition_matrix)
z, x = hmm.sample(params, key=jr.PRNGKey(0), num_timesteps=100)
params_inf, props = hmm.initialize(key=jr.PRNGKey(10) , initial_probs=initial_probs, transition_matrix=transition_matrix)
props.initial.probs.trainable = False
props.transitions.transition_matrix.trainable = False
try:
hmm.fit_em(params_inf, props, emissions=x)
except ValueError as e:
print(f"Error: {e}") One work-around for this is to make a model with num_states=2 but specify the initial state distribution and transition matrix so that the model will behave as if it has only one state. Here is an example: from jax import numpy as jnp
from jax import random as jr
from dynamax.hidden_markov_model import GaussianHMM
hmm = GaussianHMM(num_states=2, emission_dim=1)
initial_probs = jnp.array([1.0, 0.])
transition_matrix = jnp.array([[1.0, 0.], [1.0, 0.0]])
params, props = hmm.initialize(key=jr.PRNGKey(0), initial_probs=initial_probs, transition_matrix=transition_matrix)
params_inf, props = hmm.initialize(key=jr.PRNGKey(100), initial_probs=initial_probs, transition_matrix=transition_matrix)
props.initial.probs.trainable = False
props.transitions.transition_matrix.trainable = False
hmm.fit_em(params_inf, props, emissions=x) This might get okay parameter results however the logprob calculations aren't fond of this setup and you may get |
I am trying to fit various HMM classes (LinearRegressionHMM, or GaussianHMM) to my data but it does not let me pass
num_states=1
. Fornum_states > 2
, everything works as expected. I wanted to know whether no support for num_states=1 is the intended behavior.It's easy enough to write code for simple linear regression outside dynamax, however, it still makes the comparison with num_states>2 cases error-prone (as one might be using different constants in log-likelihood calculations, etc.).
If it helps, the error occurs while trying to initialize the Dirichlet distribution.
The text was updated successfully, but these errors were encountered: