Skip to content

Commit

Permalink
Add further type annotations to hmm initial base class
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Sep 25, 2024
1 parent ee3da88 commit 45e3e7f
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions dynamax/hidden_markov_model/models/initial.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dynamax.hidden_markov_model.models.abstractions import HMMInitialState
from dynamax.parameters import ParameterProperties
from typing import Any, cast, NamedTuple, Optional, Tuple, Union
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Float, Array
import tensorflow_probability.substrates.jax.distributions as tfd
import tensorflow_probability.substrates.jax.bijectors as tfb
from typing import cast, NamedTuple, Optional, Union
from dynamax.hidden_markov_model.inference import HMMPosterior
from dynamax.hidden_markov_model.models.abstractions import HMMInitialState
from dynamax.parameters import ParameterProperties
from dynamax.types import Scalar


class ParamsStandardHMMInitialState(NamedTuple):
Expand All @@ -17,18 +19,23 @@ class StandardHMMInitialState(HMMInitialState):
"""
def __init__(self,
num_states: int,
initial_probs_concentration: Union[float, Float[Array, " num_states"]]=1.1):
initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1):
"""
Args:
initial_probabilities[k]: prob(hidden(1)=k)
"""
self.num_states = num_states
self.initial_probs_concentration = initial_probs_concentration * jnp.ones(num_states)

def distribution(self, params, inputs=None):
def distribution(self, params: ParamsStandardHMMInitialState, inputs=None) -> tfd.Distribution:
return tfd.Categorical(probs=params.probs)

def initialize(self, key: Optional[Array]=None, method="prior", initial_probs: Optional[Float[Array, " num_states"]]=None):
def initialize(
self,
key: Optional[Array]=None,
method="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None
) -> Tuple[ParamsStandardHMMInitialState, ParamsStandardHMMInitialState]:
"""Initialize the model parameters and their corresponding properties.
Args:
Expand All @@ -53,19 +60,27 @@ def initialize(self, key: Optional[Array]=None, method="prior", initial_probs: O
props = ParamsStandardHMMInitialState(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
return params, props

def log_prior(self, params):
def log_prior(self, params: ParamsStandardHMMInitialState) -> Scalar:
return tfd.Dirichlet(self.initial_probs_concentration).log_prob(params.probs)

def _compute_initial_probs(self, params, inputs=None):
def _compute_initial_probs(
self, params: ParamsStandardHMMInitialState, inputs=None
) -> Float[Array, " num_states"]:
return params.probs

def collect_suff_stats(self, params, posterior, inputs=None):
def collect_suff_stats(self, params, posterior: HMMPosterior, inputs=None) -> Float[Array, " num_states"]:
return posterior.smoothed_probs[0]

def initialize_m_step_state(self, params, props):
def initialize_m_step_state(self, params, props) -> None:
return None

def m_step(self, params, props, batch_stats, m_step_state):
def m_step(
self,
params: ParamsStandardHMMInitialState,
props: ParamsStandardHMMInitialState,
batch_stats: Float[Array, "batch num_states"],
m_step_state: Any
) -> Tuple[ParamsStandardHMMInitialState, Any]:
if props.probs.trainable:
if self.num_states == 1:
probs = jnp.array([1.0])
Expand Down

0 comments on commit 45e3e7f

Please sign in to comment.