Skip to content

Commit

Permalink
Merge pull request #365 from edeno/main
Browse files Browse the repository at this point in the history
Use reverse=True keyword argument in lax.scan for smoothers
  • Loading branch information
slinderman authored Jun 26, 2024
2 parents a6b85ba + 84c024b commit 9b3fb2f
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 109 deletions.
19 changes: 11 additions & 8 deletions dynamax/generalized_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
_jacfwd_2d = lambda f, x: jnp.atleast_2d(jacfwd(f)(x))



class EKFIntegrals(NamedTuple):
""" Lightweight container for EKF Gaussian integrals."""
gaussian_expectation: Callable = lambda f, m, P: jnp.atleast_1d(f(m))
Expand Down Expand Up @@ -85,7 +84,7 @@ def compute_weights_and_sigmas(self, m, P):

def _predict(m, P, f, Q, u, g_ev, g_cov):
"""Predict next mean and covariance under an additive-noise Gaussian filter
p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred)
where
mu_pred = gev(f, m, P)
Expand Down Expand Up @@ -337,13 +336,17 @@ def _step(carry, args):
return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov)

# Run the smoother
init_carry = (filtered_means[-1], filtered_covs[-1])
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1])
_, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args)
_, (smoothed_means, smoothed_covs) = lax.scan(
_step,
(filtered_means[-1], filtered_covs[-1]),
(jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]),
reverse=True
)

# Concatenate the last smoothed mean and covariance
smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...]))

# Reverse the arrays and return
smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
return PosteriorGSSMSmoothed(
marginal_loglik=ll,
filtered_means=filtered_means,
Expand Down
37 changes: 20 additions & 17 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def _step(carry, t):
return post



@partial(jit, static_argnames=["transition_fn"])
def hmm_backward_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Expand Down Expand Up @@ -184,9 +183,9 @@ def _step(carry, t):
next_backward_pred_probs = _predict(backward_filt_probs, A.T)
return (log_normalizer, next_backward_pred_probs), backward_pred_probs

carry = (0.0, jnp.ones(num_states))
(log_normalizer, _), rev_backward_pred_probs = lax.scan(_step, carry, jnp.arange(num_timesteps)[::-1])
backward_pred_probs = rev_backward_pred_probs[::-1]
(log_normalizer, _), backward_pred_probs = lax.scan(
_step, (0.0, jnp.ones(num_states)), jnp.arange(num_timesteps), reverse=True
)
return log_normalizer, backward_pred_probs


Expand Down Expand Up @@ -273,7 +272,7 @@ def hmm_smoother(
posterior distribution
"""
num_timesteps, num_states = log_likelihoods.shape
num_timesteps = log_likelihoods.shape[0]

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn)
Expand All @@ -298,12 +297,15 @@ def _step(carry, args):
return smoothed_probs, smoothed_probs

# Run the HMM smoother
carry = filtered_probs[-1]
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_probs[:-1][::-1], predicted_probs[1:][::-1])
_, rev_smoothed_probs = lax.scan(_step, carry, args)
_, smoothed_probs = lax.scan(
_step,
filtered_probs[-1],
(jnp.arange(num_timesteps - 1), filtered_probs[:-1], predicted_probs[1:]),
reverse=True,
)

# Reverse the arrays and return
smoothed_probs = jnp.vstack([rev_smoothed_probs[::-1], filtered_probs[-1]])
# Concatenate the arrays and return
smoothed_probs = jnp.vstack([smoothed_probs, filtered_probs[-1]])

# Package into a posterior
posterior = HMMPosterior(
Expand Down Expand Up @@ -467,10 +469,9 @@ def _backward_pass(best_next_score, t):
return best_next_score, best_next_state

num_states = log_likelihoods.shape[1]
best_second_score, rev_best_next_states = lax.scan(
_backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 2, -1, -1)
best_second_score, best_next_states = lax.scan(
_backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 1), reverse=True
)
best_next_states = rev_best_next_states[::-1]

# Run the forward pass
def _forward_pass(state, best_next_state):
Expand Down Expand Up @@ -530,11 +531,13 @@ def _step(carry, args):
# Run the HMM smoother
rngs = jr.split(rng, num_timesteps)
last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1])
args = (jnp.arange(num_timesteps - 1, 0, -1), rngs[:-1][::-1], filtered_probs[:-1][::-1])
_, rev_states = lax.scan(_step, last_state, args)
_, states = lax.scan(
_step, last_state, (jnp.arange(1, num_timesteps), rngs[:-1], filtered_probs[:-1]),
reverse=True
)

# Reverse the arrays and return
states = jnp.concatenate([rev_states[::-1], jnp.array([last_state])])
# Add the last state
states = jnp.concatenate([states, jnp.array([last_state])])
return log_normalizer, states

def _compute_sum_transition_probs(
Expand Down
101 changes: 53 additions & 48 deletions dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,21 @@ class ParamsLGSSMDynamics(NamedTuple):
:param cov: dynamics covariance $Q$
"""
weights: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
weights: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"]]

bias: Union[ParameterProperties,
Float[Array, "state_dim"],
Float[Array, "state_dim"],
Float[Array, "ntime state_dim"]]

input_weights: Union[ParameterProperties,
Float[Array, "state_dim input_dim"],
Float[Array, "state_dim input_dim"],
Float[Array, "ntime state_dim input_dim"]]
cov: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"],

cov: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"],
Float[Array, "state_dim_triu"]]


Expand All @@ -77,22 +77,22 @@ class ParamsLGSSMEmissions(NamedTuple):
"""
weights: Union[ParameterProperties,
Float[Array, "emission_dim state_dim"],
Float[Array, "emission_dim state_dim"],
Float[Array, "ntime emission_dim state_dim"]]

bias: Union[ParameterProperties,
Float[Array, "emission_dim"],
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"]]

input_weights: Union[ParameterProperties,
Float[Array, "emission_dim input_dim"],
Float[Array, "emission_dim input_dim"],
Float[Array, "ntime emission_dim input_dim"]]

cov: Union[ParameterProperties,
Float[Array, "emission_dim emission_dim"],
Float[Array, "ntime emission_dim emission_dim"],
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"],
Float[Array, "emission_dim emission_dim"],
Float[Array, "ntime emission_dim emission_dim"],
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"],
Float[Array, "emission_dim_triu"]]


Expand Down Expand Up @@ -166,9 +166,9 @@ def _get_params(params, num_timesteps, t):
D = _get_one_param(params.emissions.input_weights, 2, t)
d = _get_one_param(params.emissions.bias, 1, t)

if len(params.emissions.cov.shape) == 1:
if len(params.emissions.cov.shape) == 1:
R = _get_one_param(params.emissions.cov, 1, t)
elif len(params.emissions.cov.shape) > 2:
elif len(params.emissions.cov.shape) > 2:
R = _get_one_param(params.emissions.cov, 2, t)
elif params.emissions.cov.shape[0] != num_timesteps:
R = _get_one_param(params.emissions.cov, 2, t)
Expand Down Expand Up @@ -278,20 +278,20 @@ def _condition_on(m, P, H, D, d, R, u, y):
if R.ndim == 2:
S = R + H @ P @ H.T
K = psd_solve(S, H @ P).T
else:
else:
# Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I
# (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity)
I = jnp.eye(P.shape[0])
U = H @ jnp.linalg.cholesky(P)
X = U / R[:, None]
S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T)
S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T)
"""
# Could alternatively use U=H and C=P
R_inv = jnp.diag(1.0 / R)
P_inv = psd_solve(P, jnp.eye(P.shape[0]))
S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv)
"""
K = P @ H.T @ S_inv
K = P @ H.T @ S_inv
S = jnp.diag(R) + H @ P @ H.T

Sigma_cond = P - K @ S @ K.T
Expand Down Expand Up @@ -361,8 +361,6 @@ def wrapper(*args, **kwargs):
return wrapper




def lgssm_joint_sample(
params: ParamsLGSSM,
key: PRNGKey,
Expand All @@ -371,7 +369,7 @@ def lgssm_joint_sample(
)-> Tuple[Float[Array, "num_timesteps state_dim"],
Float[Array, "num_timesteps emission_dim"]]:
r"""Sample from the joint distribution to produce state and emission trajectories.
Args:
params: model parameters
inputs: optional array of inputs.
Expand All @@ -390,7 +388,7 @@ def _sample_emission(key, H, D, d, R, x, u):
mean = H @ x + D @ u + d
R = jnp.diag(R) if R.ndim==1 else R
return MVN(mean, R).sample(seed=key)

def _sample_initial(key, params, inputs):
key1, key2 = jr.split(key)

Expand All @@ -417,7 +415,7 @@ def _step(prev_state, args):

# Sample the initial state
key1, key2 = jr.split(key)

initial_state, initial_emission = _sample_initial(key1, params, inputs)

# Sample the remaining emissions and states
Expand Down Expand Up @@ -462,7 +460,7 @@ def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y):
else:
L = H @ jnp.linalg.cholesky(pred_cov)
return MVNLowRank(m, R, L).log_prob(y)


def _step(carry, t):
ll, pred_mean, pred_cov = carry
Expand Down Expand Up @@ -539,14 +537,17 @@ def _step(carry, args):
return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov, smoothed_cross)

# Run the Kalman smoother
init_carry = (filtered_means[-1], filtered_covs[-1])
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1])
_, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan(_step, init_carry, args)

# Reverse the arrays and return
smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
smoothed_cross = smoothed_cross[::-1]
_, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan(
_step,
(filtered_means[-1], filtered_covs[-1]),
(jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]),
reverse=True,
)

# Concatenate the arrays and return
smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...]))

return PosteriorGSSMSmoothed(
marginal_loglik=ll,
filtered_means=filtered_means,
Expand All @@ -563,7 +564,7 @@ def lgssm_posterior_sample(
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None,
jitter: Optional[Scalar]=0

) -> Float[Array, "ntime state_dim"]:
r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$.
Expand Down Expand Up @@ -603,12 +604,16 @@ def _step(carry, args):
key, this_key = jr.split(key, 2)
last_state = MVN(filtered_means[-1], filtered_covs[-1]).sample(seed=this_key)

args = (
jr.split(key, num_timesteps - 1),
filtered_means[:-1][::-1],
filtered_covs[:-1][::-1],
jnp.arange(num_timesteps - 2, -1, -1),
_, states = lax.scan(
_step,
last_state,
(
jr.split(key, num_timesteps - 1),
filtered_means[:-1],
filtered_covs[:-1],
jnp.arange(num_timesteps - 1),
),
reverse=True,
)
_, reversed_states = lax.scan(_step, last_state, args)
states = jnp.vstack([reversed_states[::-1], last_state])
return states

return jnp.vstack([states, last_state])
16 changes: 10 additions & 6 deletions dynamax/linear_gaussian_ssm/info_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,17 @@ def _smooth_step(carry, args):
return (smoothed_eta, smoothed_prec), (smoothed_eta, smoothed_prec)

# Run the Kalman smoother
init_carry = (filtered_etas[-1], filtered_precisions[-1])
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_etas[:-1][::-1], filtered_precisions[:-1][::-1])
_, (smoothed_etas, smoothed_precisions) = lax.scan(_smooth_step, init_carry, args)
_, (smoothed_etas, smoothed_precisions) = lax.scan(
_smooth_step,
(filtered_etas[-1], filtered_precisions[-1]),
(jnp.arange(num_timesteps - 1), filtered_etas[:-1], filtered_precisions[:-1]),
reverse=True
)

# Concatenate the arrays and return
smoothed_etas = jnp.vstack((smoothed_etas, filtered_etas[-1][None, ...]))
smoothed_precisions = jnp.vstack((smoothed_precisions, filtered_precisions[-1][None, ...]))

# Reverse the arrays and return
smoothed_etas = jnp.vstack((smoothed_etas[::-1], filtered_etas[-1][None, ...]))
smoothed_precisions = jnp.vstack((smoothed_precisions[::-1], filtered_precisions[-1][None, ...]))
return PosteriorGSSMInfoSmoothed(
marginal_loglik=ll,
filtered_etas=filtered_etas,
Expand Down
Loading

0 comments on commit 9b3fb2f

Please sign in to comment.