From 36bbc68fdc5a27ebfb0f5f8ed3058e20d89ed655 Mon Sep 17 00:00:00 2001 From: gileshd Date: Sun, 6 Oct 2024 19:54:33 +0100 Subject: [PATCH] Add further type annotations to categorical glm hmm --- .../models/categorical_glm_hmm.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/dynamax/hidden_markov_model/models/categorical_glm_hmm.py b/dynamax/hidden_markov_model/models/categorical_glm_hmm.py index ab30c094..c1127a89 100644 --- a/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +++ b/dynamax/hidden_markov_model/models/categorical_glm_hmm.py @@ -5,7 +5,7 @@ from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar import optax from typing import NamedTuple, Optional, Tuple, Union @@ -24,9 +24,9 @@ class ParamsCategoricalRegressionHMM(NamedTuple): class CategoricalRegressionHMMEmissions(HMMEmissions): def __init__(self, - num_states, - num_classes, - input_dim, + num_states: int, + num_classes: int, + input_dim: int, m_step_optimizer=optax.adam(1e-2), m_step_num_iters=50): """_summary_ @@ -50,7 +50,13 @@ def inputs_shape(self): def log_prior(self, params): return 0.0 - def initialize(self, key=jr.PRNGKey(0), method="prior", emission_weights=None, emission_biases=None): + def initialize( + self, + key: Array=jr.PRNGKey(0), + method: str="prior", + emission_weights: Optional[Float[Array, "num_states num_classes input_dim"]]=None, + emission_biases: Optional[Float[Array, "num_states num_classes"]]=None, + ): """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have @@ -88,7 +94,11 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_weights=None, e biases=ParameterProperties()) return params, props - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsCategoricalRegressionHMMEmissions, + state: IntScalar, + inputs: Float[Array, " input_dim"]): logits = params.weights[state] @ inputs + params.biases[state] return tfd.Categorical(logits=logits)