Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel hmm posterior sample #342

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dynamax/hidden_markov_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
from dynamax.hidden_markov_model.inference import compute_transition_probs

from dynamax.hidden_markov_model.parallel_inference import hmm_filter as parallel_hmm_filter
from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother
from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother
from dynamax.hidden_markov_model.parallel_inference import hmm_posterior_sample as parallel_hmm_posterior_sample
30 changes: 30 additions & 0 deletions dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools as it
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
import dynamax.hidden_markov_model.inference as core
import dynamax.hidden_markov_model.parallel_inference as parallel

Expand Down Expand Up @@ -285,3 +286,32 @@ def test_parallel_smoother(key=0, num_timesteps=100, num_states=3):
posterior = core.hmm_smoother(initial_probs, transition_matrix, log_likelihoods)
posterior2 = parallel.hmm_smoother(initial_probs, transition_matrix, log_likelihoods)
assert jnp.allclose(posterior.smoothed_probs, posterior2.smoothed_probs, atol=1e-1)


def test_parallel_posterior_sample(
Copy link
Collaborator

@slinderman slinderman Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we commented this test out for a few reasons:

  1. it's expensive to draw so many samples
  2. it's a randomized test, so it can fail with some probability
  3. we could do a little math to figure out the failure probability as a function of eps and num_samples, but we never took the time to do so.

I think this is exactly the right test to run to check for correctness though, and I'd love to have it in our rotation as long as it reliably passes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a practical level, I assume it would be reliable because the seed is fixed?

key=0, num_timesteps=5, num_states=2, eps=1e-3,
num_samples=1000000, num_iterations=5
):
if isinstance(key, int):
key = jr.PRNGKey(key)

max_unique_size = 1 << num_timesteps

def iterate_test(key_iter):
keys_iter = jr.split(key_iter, num_samples)
args = random_hmm_args(key_iter, num_timesteps, num_states)

# Sample sequences from posterior
state_seqs = vmap(parallel.hmm_posterior_sample, (0, None, None, None), (0, 0))(keys_iter, *args)[1]
unique_seqs, counts = jnp.unique(state_seqs, axis=0, size=max_unique_size, return_counts=True)
blj_sample = counts / counts.sum()

# Compute joint probabilities
blj = jnp.exp(big_log_joint(*args))
blj = jnp.ravel(blj / blj.sum())

# Compare the joint distributions
return jnp.allclose(blj_sample, blj, rtol=0, atol=eps)

keys = jr.split(key, num_iterations)
assert iterate_test(keys[0])
92 changes: 87 additions & 5 deletions dynamax/hidden_markov_model/parallel_inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import jax.numpy as jnp
import jax.random as jr
from jax import lax, vmap, value_and_grad
from jaxtyping import Array, Float
from jaxtyping import Array, Float, Int
from typing import NamedTuple, Union
from functools import partial

from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered

class Message(NamedTuple):
#---------------------------------------------------------------------------#
# Filtering #
#---------------------------------------------------------------------------#

class FilterMessage(NamedTuple):
"""Filtering associative scan elements.

Attributes:
A: $p(z_j \mid z_i)$
log_b: $\log P(y_{i+1}, ..., y_j \mid z_i)$
"""
A: Float[Array, "num_timesteps num_states num_states"]
log_b: Float[Array, "num_timesteps num_states"]

Expand Down Expand Up @@ -43,15 +55,15 @@ def marginalize(m_ij, m_jk):
A_ij_cond, lognorm = _condition_on(m_ij.A, m_jk.log_b)
A_ik = A_ij_cond @ m_jk.A
log_b_ik = m_ij.log_b + lognorm
return Message(A=A_ik, log_b=log_b_ik)
return FilterMessage(A=A_ik, log_b=log_b_ik)


# Initialize the messages
A0, log_b0 = _condition_on(initial_probs, log_likelihoods[0])
A0 *= jnp.ones((K, K))
log_b0 *= jnp.ones(K)
A1T, log_b1T = vmap(_condition_on, in_axes=(None, 0))(transition_matrix, log_likelihoods[1:])
initial_messages = Message(
initial_messages = FilterMessage(
A=jnp.concatenate([A0[None, :, :], A1T]),
log_b=jnp.vstack([log_b0, log_b1T])
)
Expand All @@ -72,6 +84,11 @@ def marginalize(m_ij, m_jk):
predicted_probs=predicted_probs)


#---------------------------------------------------------------------------#
# Smoothing #
#---------------------------------------------------------------------------#


def hmm_smoother(initial_probs: Float[Array, "num_states"],
transition_matrix: Float[Array, "num_states num_states"],
log_likelihoods: Float[Array, "num_timesteps num_states"]
Expand Down Expand Up @@ -109,4 +126,69 @@ def log_normalizer(log_initial_probs, log_transition_matrix, log_likelihoods):
initial_probs=smoothed_probs[0],
smoothed_probs=smoothed_probs,
trans_probs=trans_probs
)
)


#---------------------------------------------------------------------------#
# Sampling #
#---------------------------------------------------------------------------#
"""Associative scan elements $E_ij$ are vectors specifying a sample::

$z_j ~ p(z_j \mid z_i)$

for each possible value of $z_i$.
"""

def _initialize_sampling_messages(rng, transition_matrix, filtered_probs):
"""Preprocess filtering output to construct input for sampling assocative scan."""

T, K = filtered_probs.shape
rngs = jr.split(rng, T)

def _last_message(rng, probs):
state = jr.choice(rng, K, p=probs)
return jnp.repeat(state, K)

@vmap
def _generic_message(rng, probs):
smoothed_probs = probs * transition_matrix.T
smoothed_probs = smoothed_probs / smoothed_probs.sum(1).reshape(K,1)
return vmap(lambda p: jr.choice(rng, K, p=p))(smoothed_probs)

En = _last_message(rngs[-1], filtered_probs[-1])
Et = _generic_message(rngs[:-1], filtered_probs[:-1])
return jnp.concatenate([Et, En[None]])


def hmm_posterior_sample(rng: jr.PRNGKey,
initial_distribution: Float[Array, "num_states"],
transition_matrix: Float[Array, "num_states num_states"],
log_likelihoods: Float[Array, "num_timesteps num_states"]
) -> Int[Array, "num_timesteps"]:
r"""Sample a sequence of hidden states from the posterior.

Args:
rng: random number generator
initial_distribution: $p(z_1 \mid u_1, \theta)$
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.

Returns:
log_normalizer: $\log P(y_{1:T} \mid u_{1:T}, \theta)$
states: sequence of hidden states $z_{1:T}$
"""
T, K = log_likelihoods.shape

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods)
log_normalizer = post.marginal_loglik
filtered_probs = post.filtered_probs

@vmap
def _operator(E_jk, E_ij):
return jnp.take(E_ij, E_jk)

initial_messages = _initialize_sampling_messages(rng, transition_matrix, filtered_probs)
final_messages = lax.associative_scan(_operator, initial_messages, reverse=True)
states = final_messages[:,0]
return log_normalizer, states