Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel hmm posterior sample #342

Merged
merged 2 commits into from
Jun 18, 2024

Conversation

calebweinreb
Copy link
Contributor

@calebweinreb calebweinreb commented Sep 5, 2023

This PR implements a parallel version of HMM posterior sampling using associative scan (see #341). The scan elements $E_{ij}$ are vectors specifying a sample

z_j ~ p(z_j \mid z_i)

for each possible value of $z_i$. They can be thought of as functions $E : [1,...,n] \to [1,...,n]$ where the associative operator is function composition. This implementation passes the test written for serial sampling (which is commented out for some reason). It starts performing better than serial sampling when the sequence length exceeds a few thousand (I'm a little mystified as to why it takes so long for the crossover to happen).

from dynamax.hidden_markov_model.inference_test import random_hmm_args
from dynamax.hidden_markov_model import hmm_posterior_sample, parallel_hmm_posterior_sample
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import time

num_states = 2
num_iters = 5
timesteps = np.logspace(0,6,10).astype(int)
serial_times, parallel_times = [], []
for num_timesteps in timesteps:
    print(num_timesteps)
    serial_time, parallel_time = 0, 0
    for itr in range(num_iters+1):
        args = random_hmm_args(jr.PRNGKey(itr), num_timesteps, 5)
        
        t = time.time()
        hmm_posterior_sample(jr.PRNGKey(itr), *args)
        print('s', time.time()-t)
        if itr > 0: serial_time += time.time()-t
            
        t = time.time()
        parallel_hmm_posterior_sample(jr.PRNGKey(itr), *args)
        print('p', time.time()-t)
        if itr > 0: parallel_time += time.time()-t
            
    serial_times.append(serial_time/num_iters)
    parallel_times.append(parallel_time/num_iters)
    
plt.plot(timesteps, serial_times, label='serial')
plt.plot(timesteps, parallel_times, label='parallel')
plt.legend(loc='upper left')
plt.xscale('log')
plt.yscale('log')
plt.ylabel('Runtime (s)')
plt.xlabel('Sequence length')
plt.gcf().set_size_inches((3,2))
Screenshot 2023-09-05 at 11 22 51 AM

@slinderman
Copy link
Collaborator

Thanks @calebweinreb! To clarify, I would say that the associative operator takes in two sets of samples,
$$z_s \sim p(z_s \mid x_{1:s}, z_{s+1}) $$
and
$$z_t \sim p(z_t \mid x_{1:t}, z_{t+1})$$
for all values of $z_{s+1} \in [K]$ and $z_{t+1} \in [K]$.

Then, assuming $t > s$, the associative operator returns a sample
$$z_s \sim p(z_s \mid x_{1:t}, z_{t+1})$$
for all $z_{t+1} \in [K]$.

The final message is a sample $z_T \sim p(z_T \mid x_{1:T})$, replicated $K$ times so that it is the same shape as the preceding messages.

The output of associative scan thus yields samples of $z_{1:T} \sim p(z_{1:T} \mid x_{1:T})$. The output shape is (T,K), but all columns are identical since they all started with the same final state. Thus, it suffices to take the first column of the output matrix.

@@ -285,3 +286,32 @@ def test_parallel_smoother(key=0, num_timesteps=100, num_states=3):
posterior = core.hmm_smoother(initial_probs, transition_matrix, log_likelihoods)
posterior2 = parallel.hmm_smoother(initial_probs, transition_matrix, log_likelihoods)
assert jnp.allclose(posterior.smoothed_probs, posterior2.smoothed_probs, atol=1e-1)


def test_parallel_posterior_sample(
Copy link
Collaborator

@slinderman slinderman Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we commented this test out for a few reasons:

  1. it's expensive to draw so many samples
  2. it's a randomized test, so it can fail with some probability
  3. we could do a little math to figure out the failure probability as a function of eps and num_samples, but we never took the time to do so.

I think this is exactly the right test to run to check for correctness though, and I'd love to have it in our rotation as long as it reliably passes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a practical level, I assume it would be reliable because the seed is fixed?

@gileshd
Copy link
Collaborator

gileshd commented Sep 6, 2023

This looks really neat @calebweinreb!

One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien).

@calebweinreb
Copy link
Contributor Author

Hi Scott, thanks for clarifying! I think we landed on a good way of articulating the algorithm over slack. I'll repost here in case others are interested:

  • Let's assume an HMM with $K$ hidden states and $T$ time-steps.
  • The initial messages $E_{t,t+1}$ are samples from $p(z_t \mid x_{1:t}, z_{t+1})$ for all possible values of $z_{t+1}$
  • The initial final message $E_{T}$ is a sample from the last filtering dist, $p(z_T \mid x_{1:T})$, repeated K times so that it's the same shape as the other messages.
  • In the first iteration, the associative operator gives you samples from $p(z_t \mid x_{1:t+1}, z_{t+2})$ for all values of $z_{t+2}$ . It does so by sampling $z_{t+1} \sim p(z_{t+1} \mid x_{1:t+1}, z_{t+2})$ then sampling $z_t$ conditioned on $z_{t+1}$.
  • This step is repeated recursively in the associative scan. At any intermediate point, the message $E_{i,j}$ stores samples $z_i \sim p(z_i \mid x_{t:j-1}, z_j)$ for each possible value of $z_j$.
  • The final output is an array of shape (T,K) where the columns (which are all the same because they share the same final state) each contain the final sampled sequence $z_{1:T}$.

@calebweinreb
Copy link
Contributor Author

This looks really neat @calebweinreb!

One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien).

I ran the test on a GPU. I assume on a CPU, parallel would always do worse?

@slinderman slinderman merged commit a6b85ba into probml:main Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants