From 27c643d31e7966253bc6a8eacb358c1ed2e8ccb9 Mon Sep 17 00:00:00 2001 From: Dominik Straub Date: Wed, 6 Nov 2024 11:26:55 +0100 Subject: [PATCH] Fix bug in marginal log likelihood of parallel KF with inputs --- dynamax/linear_gaussian_ssm/parallel_inference.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index ac8a6e05..516387e1 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -179,7 +179,7 @@ def _first_message(params, y, u): eta = jnp.zeros_like(b) J = jnp.eye(len(b)) - logZ = _marginal_loglik_elem(P, H, R, y - D @ u - d) + logZ = _marginal_loglik_elem(P, H, R, y - H @ m - D @ u - d) return A, b, C, J, eta, logZ @@ -190,14 +190,15 @@ def _generic_message(params, y, u, t): S_inv = _emissions_scale(Q, H, R) K = Q @ H.T @ S_inv - eta = F.T @ H.T @ S_inv @ (y - H @ b - D @ u - d) + innov = (y - H @ b - D @ u - d) + eta = F.T @ H.T @ S_inv @ innov J = symmetrize(F.T @ H.T @ S_inv @ H @ F) A = F - K @ H @ F - b = b + B @ u + K @ (y - H @ b - D @ u - d) + b = b + B @ u + K @ innov C = symmetrize(Q - K @ H @ Q) - logZ = _marginal_loglik_elem(Q, H, R, y - D @ u - d) + logZ = _marginal_loglik_elem(Q, H, R, innov) return A, b, C, J, eta, logZ A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0])