From ee3da88b9f3f430b755e545ae15c4258e19fb234 Mon Sep 17 00:00:00 2001 From: gileshd Date: Wed, 25 Sep 2024 17:13:20 +0100 Subject: [PATCH] Add further type annotations to hmm transitions class --- .../hidden_markov_model/models/transitions.py | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/dynamax/hidden_markov_model/models/transitions.py b/dynamax/hidden_markov_model/models/transitions.py index 5fab5c96..cdb8a6ea 100644 --- a/dynamax/hidden_markov_model/models/transitions.py +++ b/dynamax/hidden_markov_model/models/transitions.py @@ -1,12 +1,13 @@ +from typing import Any, cast, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp -import jax.random as jr +from jaxtyping import Float, Array import tensorflow_probability.substrates.jax.distributions as tfd import tensorflow_probability.substrates.jax.bijectors as tfb + from dynamax.hidden_markov_model.models.abstractions import HMMTransitions +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.parameters import ParameterProperties -from dynamax.types import Scalar -from jaxtyping import Float, Array -from typing import cast, NamedTuple, Optional, Union +from dynamax.types import IntScalar, Scalar class ParamsStandardHMMTransitions(NamedTuple): @@ -30,7 +31,12 @@ class StandardHMMTransitions(HMMTransitions): """ - def __init__(self, num_states: int, concentration: Scalar=1.1, stickiness: Scalar=0.0): + def __init__( + self, + num_states: int, + concentration: Union[Scalar, Float[Array, "num_states num_states"]]=1.1, + stickiness: Union[Scalar, Float[Array, " num_states"]]=0.0 + ): """ Args: transition_matrix[j,k]: prob(hidden(t) = k | hidden(t-1)j) @@ -40,10 +46,15 @@ def __init__(self, num_states: int, concentration: Scalar=1.1, stickiness: Scala concentration * jnp.ones((num_states, num_states)) + \ stickiness * jnp.eye(num_states) - def distribution(self, params, state, inputs=None): + def distribution(self, params: ParamsStandardHMMTransitions, state: IntScalar, inputs=None): return tfd.Categorical(probs=params.transition_matrix[state]) - def initialize(self, key: Optional[Array] =None, method="prior", transition_matrix: Optional[Float[Array, "num_states num_states"]]=None): + def initialize( + self, + key: Optional[Array]=None, + method="prior", + transition_matrix: Optional[Float[Array, "num_states num_states"]]=None + ) -> Tuple[ParamsStandardHMMTransitions, ParamsStandardHMMTransitions]: """Initialize the model parameters and their corresponding properties. Args: @@ -58,8 +69,7 @@ def initialize(self, key: Optional[Array] =None, method="prior", transition_matr if key is None: raise ValueError("key must be provided if transition_matrix is not provided.") else: - this_key, key = jr.split(key) - transition_matrix_sample = tfd.Dirichlet(self.concentration).sample(seed=this_key) + transition_matrix_sample = tfd.Dirichlet(self.concentration).sample(seed=key) transition_matrix = cast(Float[Array, "num_states num_states"], transition_matrix_sample) # Package the results into dictionaries @@ -67,19 +77,33 @@ def initialize(self, key: Optional[Array] =None, method="prior", transition_matr props = ParamsStandardHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered())) return params, props - def log_prior(self, params): + def log_prior(self, params: ParamsStandardHMMTransitions) -> Scalar: return tfd.Dirichlet(self.concentration).log_prob(params.transition_matrix).sum() - def _compute_transition_matrices(self, params, inputs=None): + def _compute_transition_matrices( + self, params: ParamsStandardHMMTransitions, inputs=None + ) -> Float[Array, "num_states num_states"]: return params.transition_matrix - def collect_suff_stats(self, params, posterior, inputs=None): + def collect_suff_stats( + self, + params, + posterior: HMMPosterior, + inputs=None + ) -> Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]: return posterior.trans_probs def initialize_m_step_state(self, params, props): return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsStandardHMMTransitions, + props: ParamsStandardHMMTransitions, + batch_stats: Float[Array, "batch num_states num_states"], + m_step_state: Any + ) -> Tuple[ParamsStandardHMMTransitions, Any]: if props.transition_matrix.trainable: if self.num_states == 1: transition_matrix = jnp.array([[1.0]])