Skip to content

Commit

Permalink
Merge pull request #5 from Joshuaalbert/develop
Browse files Browse the repository at this point in the history
1.0.1
  • Loading branch information
Joshuaalbert authored Aug 14, 2024
2 parents b8c9bb7 + 27e4f61 commit 585c403
Show file tree
Hide file tree
Showing 10 changed files with 1,054 additions and 191 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ from essm_jax.essm import ExtendedStateSpaceModel
tfpd = tfp.distributions


def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down Expand Up @@ -67,8 +67,7 @@ print(smooth_result)
forward_samples = essm.forward_simulate(
key=jax.random.PRNGKey(0),
num_time=25,
observations=samples.observation,
mask=mask
filter_result=filter_result
)

import pylab as plt
Expand All @@ -86,9 +85,14 @@ plt.legend()
plt.show()
```

## Online Filtering

Take a look at [examples](./docs/examples) to learn how to do online filtering, for interactive application.

# Change Log

13 August 2024: Initial release 1.0.0.
14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.

## Star History

Expand Down
317 changes: 317 additions & 0 deletions docs/examples/excitable_damped_harmonic_oscillator.ipynb

Large diffs are not rendered by default.

196 changes: 196 additions & 0 deletions docs/examples/online_filtering.ipynb

Large diffs are not rendered by default.

426 changes: 258 additions & 168 deletions essm_jax/essm.py

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions essm_jax/pytee_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Tuple, Callable, TypeVar

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax

PT = TypeVar('PT')


def pytree_unravel(example_tree: PT) -> Tuple[Callable[[PT], jax.Array], Callable[[jax.Array], PT]]:
"""
Returns functions to ravel and unravel a pytree.
Args:
example_tree: a pytree to be unravelled
Returns:
ravel_fun: a function to ravel the pytree
unravel_fun: a function to unravel
"""
leaf_list, tree_def = jax.tree.flatten(example_tree)

sizes = [np.size(leaf) for leaf in leaf_list]
shapes = [np.shape(leaf) for leaf in leaf_list]
dtypes = [leaf.dtype for leaf in leaf_list]

def ravel_fun(pytree: PT) -> jax.Array:
leaf_list, tree_def = jax.tree.flatten(pytree)
# promote types to common one
common_dtype = jnp.result_type(*dtypes)
leaf_list = [leaf.astype(common_dtype) for leaf in leaf_list]
return jnp.concatenate([lax.reshape(leaf, (size,)) for leaf, size in zip(leaf_list, sizes)])

def unravel_fun(flat_array: jax.Array) -> PT:
leaf_list = []
start = 0
for size, shape, dtype in zip(sizes, shapes, dtypes):
leaf_list.append(lax.reshape(flat_array[start:start + size], shape).astype(dtype))
start += size
return jax.tree.unflatten(tree_def, leaf_list)

return ravel_fun, unravel_fun
68 changes: 68 additions & 0 deletions essm_jax/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Tuple, NamedTuple

import jax
import jax.numpy as jnp
import numpy as np


class SparseRepresentation(NamedTuple):
shape: Tuple[int, ...]
rows: jax.Array
cols: jax.Array
vals: jax.Array


def create_sparse_rep(m: np.ndarray) -> SparseRepresentation:
"""
Creates a sparse rep from matrix m. Use in linear models with materialise_jacobian=False for 2x speed up.
Args:
m: [N,M] matrix
Returns:
sparse rep
"""
rows, cols = np.where(m)
sort_indices = np.lexsort((cols, rows))
rows = rows[sort_indices]
cols = cols[sort_indices]
return SparseRepresentation(
shape=np.shape(m),
rows=jnp.asarray(rows),
cols=jnp.asarray(cols),
vals=jnp.asarray(m[rows, cols])
)


def to_dense(m: SparseRepresentation, out: jax.Array | None = None) -> jax.Array:
"""
Form dense matrix.
Args:
m: sparse rep
out: output buffer
Returns:
out + M
"""
if out is None:
out = jnp.zeros(m.shape, m.vals.dtype)

return out.at[m.rows, m.cols].add(m.vals, unique_indices=True, indices_are_sorted=True)


def matvec_sparse(m: SparseRepresentation, v: jax.Array, out: jax.Array | None = None) -> jax.Array:
"""
Compute matmul for sparse rep. Speeds up large sparse linear models by about 2x.
Args:
m: sparse rep
v: vec
out: output buffer to add to.
Returns:
out + M @ v
"""
if out is None:
out = jnp.zeros(m.shape[0])
return out.at[m.rows].add(m.vals * v[m.cols], unique_indices=True, indices_are_sorted=True)
143 changes: 127 additions & 16 deletions essm_jax/tests/test_essm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import time

import jax
import pytest

jax.config.update('jax_enable_x64', True)
from essm_jax.sparse import create_sparse_rep, matvec_sparse

import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp
Expand All @@ -13,7 +18,7 @@
def test_extended_state_space_model():
num_time = 10

def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = 2 * z
cov = jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down Expand Up @@ -132,7 +137,7 @@ def _compare(key):


def test_jvp_essm():
def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = jnp.sin(2 * z)
cov = jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down Expand Up @@ -188,7 +193,7 @@ def observation_fn(z, t):


def test_speed_test_jvp_essm():
def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = jnp.sin(2 * z + t)
cov = jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down Expand Up @@ -217,19 +222,21 @@ def observation_fn(z, t):
)

sample = essm.sample(jax.random.PRNGKey(0), 1000)
filter_fn = jax.jit(lambda: essm.forward_filter(sample.observation)).lower().compile()
filter_jvp_fn = jax.jit(lambda: essm_jvp.forward_filter(sample.observation)).lower().compile()
filter_fn = jax.jit(
lambda: essm.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile()
filter_jvp_fn = jax.jit(
lambda: essm_jvp.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile()

t0 = time.time()
filter_results = filter_fn()
filter_results.t.block_until_ready()
filter_results.block_until_ready()
t1 = time.time()
dt1 = t1 - t0
print(f"Time for essm: {t1 - t0}")

t0 = time.time()
filter_results_jvp = filter_jvp_fn()
filter_results_jvp.t.block_until_ready()
filter_results_jvp.block_until_ready()
t1 = time.time()
dt2 = t1 - t0
print(f"Time for essm_jvp: {t1 - t0}")
Expand All @@ -238,14 +245,14 @@ def observation_fn(z, t):


def test_essm_forward_simulation():
def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))

def observation_fn(z, t):
mean = z
cov = t * 0.01 * jnp.eye(np.size(z))
cov = 0.01 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))

n = 1
Expand All @@ -266,24 +273,25 @@ def observation_fn(z, t):
# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0

# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)

# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)
assert np.all(np.isfinite(filter_result.log_cumulative_marginal_likelihood))
assert np.all(np.isfinite(filter_result.filtered_mean))

# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
assert log_prob == filter_result.log_cumulative_marginal_likelihood[-1]

# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)
assert np.all(np.isfinite(smooth_result.smoothed_mean))

# Forward simulate the model
forward_samples = essm.forward_simulate(
key=jax.random.PRNGKey(0),
num_time=25,
observations=samples.observation,
mask=mask
filter_result=filter_result
)

try:
Expand Down Expand Up @@ -313,3 +321,106 @@ def test__efficienct_add_scalar_diag():
A = jnp.eye(100)
c = 1.
assert jnp.all(_efficient_add_scalar_diag(A, c) == A + c * jnp.eye(100))


def test_incremental_filtering():
def transition_fn(z, t, t_next):
mean = z + z * jnp.sin(2 * jnp.pi * t / 10)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))

def observation_fn(z, t):
mean = z
cov = t * 0.01 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))

n = 1

initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))

essm = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=False, # Fast
more_data_than_params=False # if observation is bigger than latent we can speed it up.
)
samples = essm.sample(jax.random.PRNGKey(0), 100)

filter_result = essm.forward_filter(samples.observation)

filter_state = essm.create_initial_filter_state()

for i in range(100):
filter_state = essm.incremental_predict(filter_state)
filter_state, _ = essm.incremental_update(filter_state, samples.observation[i])
assert filter_state.t == filter_result.t[i]
np.testing.assert_allclose(filter_state.log_cumulative_marginal_likelihood,
filter_result.log_cumulative_marginal_likelihood[i], atol=1e-5)
np.testing.assert_allclose(filter_state.filtered_mean, filter_result.filtered_mean[i], atol=1e-5)
np.testing.assert_allclose(filter_state.filtered_cov, filter_result.filtered_cov[i], atol=1e-5)


@pytest.mark.parametrize('use_sparse', [False, True])
def test_performance_sparse(use_sparse: bool):
# Show that using sparse rep speeds up linear system
n = 128
k = 10
m = np.zeros((n, n))
rows = np.random.randint(n, size=k)
cols = np.random.randint(n, size=k)
m[rows, cols] += 1.

if use_sparse:
m = create_sparse_rep(m)
else:
m = jnp.asarray(m)

def transition_fn(z, t, t_next):
if use_sparse:
mean = matvec_sparse(m, z)
else:
mean = m @ z
scale = jnp.ones(np.size(z))
return tfpd.MultivariateNormalDiag(mean, scale)

def observation_fn(z, t):
mean = z
scale = jnp.ones(np.size(z))
return tfpd.MultivariateNormalDiag(mean, scale)

initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))

essm = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=True
)

essm_jvp = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=False
)

sample = essm.sample(jax.random.PRNGKey(0), 512)
filter_fn = jax.jit(
lambda: essm.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile()
filter_jvp_fn = jax.jit(
lambda: essm_jvp.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile()

t0 = time.time()
filter_results = filter_fn()
filter_results.block_until_ready()
t1 = time.time()
dt1 = t1 - t0
print(f"Time for essm(use_sparse={use_sparse}): {t1 - t0}")

t0 = time.time()
filter_results_jvp = filter_jvp_fn()
filter_results_jvp.block_until_ready()
t1 = time.time()
dt2 = t1 - t0
print(f"Time for essm_jvp(use_sparse={use_sparse}): {t1 - t0}")
9 changes: 6 additions & 3 deletions essm_jax/tests/test_jvp_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)
import numpy as np
import pytest

Expand All @@ -13,7 +16,7 @@ def test_jvp_linear_op():
def fn(x):
return jnp.asarray([jnp.sum(jnp.sin(x) ** i) for i in range(m)])

x = jnp.arange(n).astype(jnp.float32)
x = jnp.arange(n).astype(float)

jvp_op = JVPLinearOp(fn)
jvp_op = jvp_op(x)
Expand Down Expand Up @@ -70,8 +73,8 @@ def test_multiple_primals(init_primals: bool):
def fn(x, y):
return jnp.stack([x * y, y, -y], axis=-1) # [n, 3]

x = jnp.arange(n).astype(jnp.float32)
y = jnp.arange(n).astype(jnp.float32)
x = jnp.arange(n).astype(float)
y = jnp.arange(n).astype(float)
if init_primals:
jvp_op = JVPLinearOp(fn, primals=(x, y))
else:
Expand Down
Loading

0 comments on commit 585c403

Please sign in to comment.