Skip to content

Commit

Permalink
Add further type annotations to multinomialhmm
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Sep 23, 2024
1 parent 6197671 commit b0877ed
Showing 1 changed file with 42 additions and 19 deletions.
61 changes: 42 additions & 19 deletions dynamax/hidden_markov_model/models/multinomial_hmm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Array, Float
from jaxtyping import Array, Float, Int

from dynamax.hidden_markov_model.inference import HMMPosterior
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions
from dynamax.hidden_markov_model.models.initial import ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState
Expand All @@ -23,11 +24,11 @@ class ParamsMultinomialHMMEmissions(NamedTuple):
class MultinomialHMMEmissions(HMMEmissions):

def __init__(self,
num_states,
emission_dim,
num_classes,
num_trials,
emission_prior_concentration=1.1):
num_states: int,
emission_dim: int,
num_classes: int,
num_trials: int,
emission_prior_concentration: Union[Scalar, Float[Array, " num_classes"]] = 1.1):
self.num_states = num_states
self.emission_dim = emission_dim
self.num_classes = num_classes
Expand All @@ -38,7 +39,11 @@ def __init__(self,
def emission_shape(self):
return (self.emission_dim, self.num_classes)

def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None):
def initialize(self,
key: Array = jr.PRNGKey(0),
method: str = "prior",
emission_probs: Optional[Float[Array, "num_states emission_dim num_classes"]] = None
) -> Tuple[ParamsMultinomialHMMEmissions, ParamsMultinomialHMMEmissions]:
# Initialize the emission probabilities
if emission_probs is None:
if method.lower() == "prior":
Expand All @@ -58,26 +63,44 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None):
props = ParamsMultinomialHMMEmissions(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
return params, props

def distribution(self, params, state, inputs=None):
def distribution(
self,
params: ParamsMultinomialHMMEmissions,
state: int,
inputs: Optional[Array] = None
) -> tfd.Distribution:
return tfd.Independent(
tfd.Multinomial(self.num_trials, probs=params.probs[state]),
reinterpreted_batch_ndims=1)

def log_prior(self, params):
def log_prior(self, params: ParamsMultinomialHMMEmissions) -> Float[Array, ""]:
return tfd.Dirichlet(self.emission_prior_concentration).log_prob(params.probs).sum()

def collect_suff_stats(self, params, posterior, emissions, inputs=None):
def collect_suff_stats(
self,
params: ParamsMultinomialHMMEmissions,
posterior: HMMPosterior,
emissions: Int[Array, "num_timesteps emission_dim num_classes"],
inputs: Optional[Array] = None
) -> Dict[str, Float[Array, "num_states emission_dim num_classes"]]:
expected_states = posterior.smoothed_probs
return dict(sum_x=jnp.einsum("tk, tdi->kdi", expected_states, emissions))

def initialize_m_step_state(self, params, props):
def initialize_m_step_state(self, params, props) -> None:
return None

def m_step(self, params, props, batch_stats, m_step_state):
def m_step(
self,
params: ParamsMultinomialHMMEmissions,
props: ParamsMultinomialHMMEmissions,
batch_stats: Dict[str, Float[Array, "num_states emission_dim num_classes"]],
m_step_state: Any
) -> Tuple[ParamsMultinomialHMMEmissions, Any]:
if props.probs.trainable:
emission_stats = pytree_sum(batch_stats, axis=0)
probs = tfd.Dirichlet(
self.emission_prior_concentration + emission_stats['sum_x']).mode()
self.emission_prior_concentration + emission_stats['sum_x']
).mode()
params = params._replace(probs=probs)
return params, m_step_state

Expand Down Expand Up @@ -111,14 +134,14 @@ class MultinomialHMM(HMM):
"""
def __init__(self,
num_states,
emission_dim,
num_classes,
num_trials,
num_states: int,
emission_dim: int,
num_classes: int,
num_trials: int,
initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
emission_prior_concentration: Scalar=1.1):
emission_prior_concentration: Union[Scalar, Float[Array, " num_classes"]]=1.1):
self.emission_dim = emission_dim
self.num_classes = num_classes
self.num_trials = num_trials
Expand Down

0 comments on commit b0877ed

Please sign in to comment.