Skip to content

Commit

Permalink
Add further type annotations to Gamma HMM
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Sep 20, 2024
1 parent f4ba0a8 commit ccbad20
Showing 1 changed file with 64 additions and 39 deletions.
103 changes: 64 additions & 39 deletions dynamax/hidden_markov_model/models/gamma_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,33 @@ class ParamsGammaHMMEmissions(NamedTuple):


class GammaHMMEmissions(HMMEmissions):
def __init__(self,
num_states,
m_step_optimizer=optax.adam(1e-2),
m_step_num_iters=50):
def __init__(
self,
num_states: int,
m_step_optimizer: optax.GradientTransformation = optax.adam(1e-2),
m_step_num_iters: int = 50,
):
super().__init__(m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters)
self.num_states = num_states

@property
def emission_shape(self):
def emission_shape(self) -> Tuple:
return ()

def initialize(self,
key=jr.PRNGKey(0),
method="prior",
emission_concentrations=None,
emission_rates=None,
emissions=None):
def initialize(
self,
key: Array = jr.PRNGKey(0),
method="prior",
emission_concentrations: Optional[Float[Array, " num_states"]] = None,
emission_rates: Optional[Float[Array, " num_states"]] = None,
emissions: Optional[Float[Array, " num_timesteps"]] = None,
# ) -> Tuple[ParamsGammaHMMEmissions, ParamsGammaHMMEmissions]:
) -> Tuple[ParamsGammaHMMEmissions, ParamsGammaHMMEmissions]:

if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans

key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, 1))
Expand All @@ -57,18 +63,19 @@ def initialize(self,
default = lambda x, x0: x if x is not None else x0
params = ParamsGammaHMMEmissions(
concentration=default(emission_concentrations, _emission_concentrations),
rate=default(emission_rates, _emission_rates))
rate=default(emission_rates, _emission_rates),
)
props = ParamsGammaHMMEmissions(
concentration=ParameterProperties(constrainer=tfb.Softplus()),
rate=ParameterProperties(constrainer=tfb.Softplus()))
rate=ParameterProperties(constrainer=tfb.Softplus()),
)
return params, props

def log_prior(self, params):
def log_prior(self, params) -> float:
return 0.0

def distribution(self, params, state, inputs=None):
return tfd.Gamma(concentration=params.concentration[state],
rate=params.rate[state])
def distribution(self, params: ParamsGammaHMMEmissions, state, inputs=None) -> tfd.Distribution:
return tfd.Gamma(concentration=params.concentration[state], rate=params.rate[state])


class ParamsGammaHMM(NamedTuple):
Expand Down Expand Up @@ -96,27 +103,35 @@ class GammaHMM(HMM):
:param m_step_num_iters: number of optimizer steps per M-step.
"""
def __init__(self,
num_states: int,
initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
m_step_optimizer: optax.GradientTransformation=optax.adam(1e-2),
m_step_num_iters: int=50):

def __init__(
self,
num_states: int,
initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]] = 1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]] = 1.1,
transition_matrix_stickiness: Scalar = 0.0,
m_step_optimizer: optax.GradientTransformation = optax.adam(1e-2),
m_step_num_iters: int = 50,
):
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = GammaHMMEmissions(num_states, m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters)
transition_component = StandardHMMTransitions(
num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness
)
emission_component = GammaHMMEmissions(
num_states, m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters
)
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self,
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_concentrations: Optional[Float[Array, " num_states"]]=None,
emission_rates: Optional[Float[Array, " num_states"]]=None,
emissions: Optional[Float[Array, " num_timesteps"]]=None,
) -> Tuple[HMMParameterSet, HMMPropertySet]:
def initialize(
self,
key: Array = jr.PRNGKey(0),
method: str = "prior",
initial_probs: Optional[Float[Array, " num_states"]] = None,
transition_matrix: Optional[Float[Array, "num_states num_states"]] = None,
emission_concentrations: Optional[Float[Array, " num_states"]] = None,
emission_rates: Optional[Float[Array, " num_states"]] = None,
emissions: Optional[Float[Array, " num_timesteps"]] = None,
) -> Tuple[HMMParameterSet, HMMPropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
Expand All @@ -136,9 +151,19 @@ def initialize(self,
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
key1, key2, key3 = jr.split(key, 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_concentrations=emission_concentrations, emission_rates=emission_rates, emissions=emissions)
params["initial"], props["initial"] = self.initial_component.initialize(
key1, method=method, initial_probs=initial_probs
)
params["transitions"], props["transitions"] = self.transition_component.initialize(
key2, method=method, transition_matrix=transition_matrix
)
params["emissions"], props["emissions"] = self.emission_component.initialize(
key3,
method=method,
emission_concentrations=emission_concentrations,
emission_rates=emission_rates,
emissions=emissions,
)
return ParamsGammaHMM(**params), ParamsGammaHMM(**props)

0 comments on commit ccbad20

Please sign in to comment.