Skip to content

Commit

Permalink
Add further type annotations to hmm transitions class
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Sep 25, 2024
1 parent 47fe33f commit ee3da88
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions dynamax/hidden_markov_model/models/transitions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, cast, NamedTuple, Optional, Tuple, Union
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Float, Array
import tensorflow_probability.substrates.jax.distributions as tfd
import tensorflow_probability.substrates.jax.bijectors as tfb

from dynamax.hidden_markov_model.models.abstractions import HMMTransitions
from dynamax.hidden_markov_model.inference import HMMPosterior
from dynamax.parameters import ParameterProperties
from dynamax.types import Scalar
from jaxtyping import Float, Array
from typing import cast, NamedTuple, Optional, Union
from dynamax.types import IntScalar, Scalar


class ParamsStandardHMMTransitions(NamedTuple):
Expand All @@ -30,7 +31,12 @@ class StandardHMMTransitions(HMMTransitions):
"""
def __init__(self, num_states: int, concentration: Scalar=1.1, stickiness: Scalar=0.0):
def __init__(
self,
num_states: int,
concentration: Union[Scalar, Float[Array, "num_states num_states"]]=1.1,
stickiness: Union[Scalar, Float[Array, " num_states"]]=0.0
):
"""
Args:
transition_matrix[j,k]: prob(hidden(t) = k | hidden(t-1)j)
Expand All @@ -40,10 +46,15 @@ def __init__(self, num_states: int, concentration: Scalar=1.1, stickiness: Scala
concentration * jnp.ones((num_states, num_states)) + \
stickiness * jnp.eye(num_states)

def distribution(self, params, state, inputs=None):
def distribution(self, params: ParamsStandardHMMTransitions, state: IntScalar, inputs=None):
return tfd.Categorical(probs=params.transition_matrix[state])

def initialize(self, key: Optional[Array] =None, method="prior", transition_matrix: Optional[Float[Array, "num_states num_states"]]=None):
def initialize(
self,
key: Optional[Array]=None,
method="prior",
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None
) -> Tuple[ParamsStandardHMMTransitions, ParamsStandardHMMTransitions]:
"""Initialize the model parameters and their corresponding properties.
Args:
Expand All @@ -58,28 +69,41 @@ def initialize(self, key: Optional[Array] =None, method="prior", transition_matr
if key is None:
raise ValueError("key must be provided if transition_matrix is not provided.")
else:
this_key, key = jr.split(key)
transition_matrix_sample = tfd.Dirichlet(self.concentration).sample(seed=this_key)
transition_matrix_sample = tfd.Dirichlet(self.concentration).sample(seed=key)
transition_matrix = cast(Float[Array, "num_states num_states"], transition_matrix_sample)

# Package the results into dictionaries
params = ParamsStandardHMMTransitions(transition_matrix=transition_matrix)
props = ParamsStandardHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
return params, props

def log_prior(self, params):
def log_prior(self, params: ParamsStandardHMMTransitions) -> Scalar:
return tfd.Dirichlet(self.concentration).log_prob(params.transition_matrix).sum()

def _compute_transition_matrices(self, params, inputs=None):
def _compute_transition_matrices(
self, params: ParamsStandardHMMTransitions, inputs=None
) -> Float[Array, "num_states num_states"]:
return params.transition_matrix

def collect_suff_stats(self, params, posterior, inputs=None):
def collect_suff_stats(
self,
params,
posterior: HMMPosterior,
inputs=None
) -> Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]:
return posterior.trans_probs

def initialize_m_step_state(self, params, props):
return None

def m_step(self, params, props, batch_stats, m_step_state):
def m_step(
self,
params: ParamsStandardHMMTransitions,
props: ParamsStandardHMMTransitions,
batch_stats: Float[Array, "batch num_states num_states"],
m_step_state: Any
) -> Tuple[ParamsStandardHMMTransitions, Any]:
if props.transition_matrix.trainable:
if self.num_states == 1:
transition_matrix = jnp.array([[1.0]])
Expand Down

0 comments on commit ee3da88

Please sign in to comment.