Skip to content

Switching linear dynamical systems (SLDS) models in JAX

Notifications You must be signed in to change notification settings

jonny-so/slds-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

slds-jax

Inference routines for switching linear dynamical system (SLDS) models in JAX.

Example usage

model = Slds(A, a0, B, b, b0, C, c, Q, Q0, R)

# structured mean-field VI
inference_state = slds_smf_inference_init(model, observations)
for _ in range(10):
    inference_state = slds_smf_inference_update(inference_state)
posterior = slds_smf_inference_posterior(inference_state)

# get posterior marginals
discrete_marginals = dds_marginals(posterior.discrete_natparams)
continuous_marginals = lds_marginals(posterior.continuous_natparams)

# get posterior marginal summary statistics
discrete_probabilities = vmap(discrete_meanparams)(discrete_marginals[0])
continuous_mean = vmap(gaussian_mean)(continuous_marginals[0])

See this notebook for more complete examples.

Results

The plot below shows the inferred posteriors using VI and EP on a toy problem, overlaid with the ground truth states and observations. The ground truth model has 3 discrete states governing the continuous dynamics. VI tends to be overconfident in its inference even when it is wrong, whereas EP typically has better calibrated uncertainties. The flipside of this is that EP can be very unstable, often requiring significant damping, and even then may sometimes fail. VI on the other hand is guaranteed to converge, and typically does so in a small number of iterations.

SLDS inference example

What's here

  • Structured mean-field variational inference (VI) [1]
  • Expectation propagation (EP) [2,3]

What's missing

  1. parameter learning
  2. double-loop EP [3]
  3. expectation correction [4]
  4. block gibbs sampling
  5. Pólya-gamma augmented sampling [5]
  6. ...

References

[1] Zoubin Ghahramani and Geoffrey Hinton. Variational Learning for Switching State-Space Models. Neural computation (2000)
[2] Tom Minka. A family of algorithms for approximate Bayesian inference. PhD thesis, MIT (2001)
[3] Tom Heskes and Onno Zoeter. Expectation propagation for approximate inference in dynamic bayesian networks. UAI (2002)
[4] David Barber. Expectation Correction for Smoothed Inference in Switching Linear Dynamical Systems. Journal of Machine Learning Research (2006)
[5] Scott Linderman et al. Bayesian Learning and Inference in Recurrent Switching Linear Dynamical Systems. AISTATS (2017)

About

Switching linear dynamical systems (SLDS) models in JAX

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages