Skip to content

Commit

Permalink
Fix bug in marginal log likelihood of parallel KF with inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikstrb committed Nov 6, 2024
1 parent f61c71b commit 27c643d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions dynamax/linear_gaussian_ssm/parallel_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _first_message(params, y, u):
eta = jnp.zeros_like(b)
J = jnp.eye(len(b))

logZ = _marginal_loglik_elem(P, H, R, y - D @ u - d)
logZ = _marginal_loglik_elem(P, H, R, y - H @ m - D @ u - d)
return A, b, C, J, eta, logZ


Expand All @@ -190,14 +190,15 @@ def _generic_message(params, y, u, t):
S_inv = _emissions_scale(Q, H, R)
K = Q @ H.T @ S_inv

eta = F.T @ H.T @ S_inv @ (y - H @ b - D @ u - d)
innov = (y - H @ b - D @ u - d)
eta = F.T @ H.T @ S_inv @ innov
J = symmetrize(F.T @ H.T @ S_inv @ H @ F)

A = F - K @ H @ F
b = b + B @ u + K @ (y - H @ b - D @ u - d)
b = b + B @ u + K @ innov
C = symmetrize(Q - K @ H @ Q)

logZ = _marginal_loglik_elem(Q, H, R, y - D @ u - d)
logZ = _marginal_loglik_elem(Q, H, R, innov)
return A, b, C, J, eta, logZ

A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0])
Expand Down

0 comments on commit 27c643d

Please sign in to comment.