forked from probml/JSL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
kalman_sampler_test.py
69 lines (54 loc) · 2.12 KB
/
kalman_sampler_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
from jax import random
from jax import numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from kalman_filter import LDS, kalman_filter
from kalman_sampler import smooth_sampler
def tfp_filter(timesteps, A, transition_noise_scale, C, observation_noise_scale, mu0, x_hist):
""" Perform filtering using tensorflow probability """
state_size, _ = A.shape
observation_size, _ = C.shape
transition_noise = tfd.MultivariateNormalDiag(
scale_diag=jnp.ones(state_size) * transition_noise_scale
)
obs_noise = tfd.MultivariateNormalDiag(
scale_diag=jnp.ones(observation_size) * observation_noise_scale
)
prior = tfd.MultivariateNormalDiag(mu0, tf.ones([state_size]))
LGSSM = tfd.LinearGaussianStateSpaceModel(
timesteps, A, transition_noise, C, obs_noise, prior
)
smps = LGSSM.posterior_sample(x_hist, sample_shape=x_hist.shape[0])
return smps[:,:,0]
def test_kalman_filter():
key = random.PRNGKey(314)
timesteps = 15
delta = 1.0
### LDS Parameters ###
state_size = 2
observation_size = 2
A = jnp.eye(state_size)
C = jnp.eye(state_size)
transition_noise_scale = 1.0
observation_noise_scale = 1.0
Q = jnp.eye(state_size) * transition_noise_scale
R = jnp.eye(observation_size) * observation_noise_scale
### Prior distribution params ###
mu0 = jnp.array([8, 10]).astype(float)
Sigma0 = jnp.eye(state_size) * 1.0
### Sample data ###
lds_instance = LDS(A, C, Q, R, mu0, Sigma0)
z_hist, x_hist = lds_instance.sample(key, timesteps)
JSL_z_filt, JSL_Sigma_filt, _, _ = kalman_filter(lds_instance, x_hist)
s_jax = smooth_sampler(lds_instance, key, JSL_z_filt, JSL_Sigma_filt, n_samples=x_hist.shape[0])[:,:,0]
s_tfp = tfp_filter(
timesteps, A, transition_noise_scale, C, observation_noise_scale, mu0, x_hist
)
mean_jax = jnp.mean(s_jax, axis=0)
mean_tfp = jnp.mean(jnp.array(s_tfp), axis=0)
assert np.allclose(mean_jax, mean_tfp, atol=2e-1)