-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel hmm posterior sample #342
Parallel hmm posterior sample #342
Conversation
Thanks @calebweinreb! To clarify, I would say that the associative operator takes in two sets of samples, Then, assuming The final message is a sample The output of associative scan thus yields samples of |
@@ -285,3 +286,32 @@ def test_parallel_smoother(key=0, num_timesteps=100, num_states=3): | |||
posterior = core.hmm_smoother(initial_probs, transition_matrix, log_likelihoods) | |||
posterior2 = parallel.hmm_smoother(initial_probs, transition_matrix, log_likelihoods) | |||
assert jnp.allclose(posterior.smoothed_probs, posterior2.smoothed_probs, atol=1e-1) | |||
|
|||
|
|||
def test_parallel_posterior_sample( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we commented this test out for a few reasons:
- it's expensive to draw so many samples
- it's a randomized test, so it can fail with some probability
- we could do a little math to figure out the failure probability as a function of
eps
andnum_samples
, but we never took the time to do so.
I think this is exactly the right test to run to check for correctness though, and I'd love to have it in our rotation as long as it reliably passes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a practical level, I assume it would be reliable because the seed is fixed?
This looks really neat @calebweinreb! One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien). |
Hi Scott, thanks for clarifying! I think we landed on a good way of articulating the algorithm over slack. I'll repost here in case others are interested:
|
I ran the test on a GPU. I assume on a CPU, parallel would always do worse? |
This PR implements a parallel version of HMM posterior sampling using associative scan (see #341). The scan elements$E_{ij}$ are vectors specifying a sample
for each possible value of$z_i$ . They can be thought of as functions $E : [1,...,n] \to [1,...,n]$ where the associative operator is function composition. This implementation passes the test written for serial sampling (which is commented out for some reason). It starts performing better than serial sampling when the sequence length exceeds a few thousand (I'm a little mystified as to why it takes so long for the crossover to happen).