diff --git a/dynamax/slds/mixture_kalman_filter_demo.py b/dynamax/slds/mixture_kalman_filter_demo.py index 9302b913..18c0d7aa 100644 --- a/dynamax/slds/mixture_kalman_filter_demo.py +++ b/dynamax/slds/mixture_kalman_filter_demo.py @@ -3,11 +3,12 @@ # Author: Gerardo Durán-Martín (@gerdm) +from dataclasses import dataclass import jax -import jax.numpy as jnp from jax import random +import jax.numpy as jnp from jax.scipy.special import logit -from dataclasses import dataclass +from jaxtyping import Array, Float @dataclass @@ -24,12 +25,12 @@ class RBPFParamsDiscrete: noise1_next ~ N(0, Q) noise2_next ~ N(0, R) """ - A: jnp.array - B: jnp.array - C: jnp.array - Q: jnp.array - R: jnp.array - transition_matrix: jnp.array + A: Float[Array, "dim_hidden dim_hidden"] + B: Float[Array, "dim_hidden dim_control"] + C: Float[Array, "dim_emission dim_hidden"] + Q: Float[Array, "dim_hidden dim_hidden"] + R: Float[Array, "dim_emission dim_emission"] + transition_matrix: Float[Array, "dim_control dim_control"] def draw_state(val, key, params): @@ -42,7 +43,7 @@ def draw_state(val, key, params): ---------- val: tuple (int, jnp.array) (latent value of system, state value of system). - params: PRBPFParamsDiscrete + params: RBPFParamsDiscrete key: PRNGKey """ latent_old, state_old = val @@ -158,4 +159,4 @@ def rbpf_optimal(current_config, xt, params, nparticles=100): weights_t = jnp.ones(nparticles) / nparticles - return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp) \ No newline at end of file + return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp)