Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding log_likelihood, observed_data, and sample_stats to numpyro sampler #5189

Merged

Conversation

zaxtax
Copy link
Contributor

@zaxtax zaxtax commented Nov 16, 2021

This adds more fields to the trace object returned from sample_numpyro_nuts addressing some of the concerns in #5121

@codecov
Copy link

codecov bot commented Nov 16, 2021

Codecov Report

Merging #5189 (6e4fcab) into main (a11eaa2) will decrease coverage by 0.14%.
The diff coverage is 27.08%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5189      +/-   ##
==========================================
- Coverage   78.11%   77.97%   -0.15%     
==========================================
  Files          88       88              
  Lines       14159    14210      +51     
==========================================
+ Hits        11061    11080      +19     
- Misses       3098     3130      +32     
Impacted Files Coverage Δ
pymc/sampling_jax.py 0.00% <0.00%> (ø)
pymc/backends/arviz.py 89.55% <86.66%> (ø)
pymc/bart/pgbart.py 95.14% <0.00%> (-1.10%) ⬇️
pymc/distributions/continuous.py 96.57% <0.00%> (+0.04%) ⬆️
pymc/distributions/distribution.py 94.70% <0.00%> (+0.20%) ⬆️

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just left a comment below to avoid recreating the same code

pymc/sampling_jax.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 added the trace-backend Traces and ArviZ stuff label Nov 16, 2021
@zaxtax zaxtax force-pushed the numpyro_sample_stats_and_observed_data branch from 3ba2643 to 8fa77c3 Compare November 16, 2021 14:48
@zaxtax zaxtax force-pushed the numpyro_sample_stats_and_observed_data branch from 7d742f0 to 661ca8c Compare November 17, 2021 11:40
@zaxtax zaxtax force-pushed the numpyro_sample_stats_and_observed_data branch from 79dd918 to f5aeaf6 Compare November 17, 2021 16:00
pymc/sampling_jax.py Show resolved Hide resolved
@zaxtax zaxtax changed the title Adding observed_data and sample_stats to numpyro sampler Adding log_likelihood, observed_data, and sample_stats to numpyro sampler Nov 18, 2021
@ricardoV94 ricardoV94 merged commit fe2d101 into pymc-devs:main Nov 18, 2021
@ricardoV94
Copy link
Member

Thanks @zaxtax. Awesome work!

logp_v = replace_shared_variables([logpt(v)])
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
jax_fn = jax_funcify(fgraph)
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, would we expect any benefits to jit_compiling this outer vmap?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use a similar approach with Aesara directly?

Here we only loop over observed variables in order to get the pointwise log likelihood. We had some discussion about this in #4489 but ended up keeping the 3 nested loops over variables, chains and draws.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it. Here is a Notebook that documents some things I tried: https://gist.github.com/ricardoV94/6089a8c46a0e19665f01c79ea04e1cb2

It might be faster if using shared variables...

Copy link
Contributor Author

@zaxtax zaxtax Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea. I think the easiest thing to do is just benchmark it. I don't even call optimize_graph on either the graph in this function or the main sample routine.

When I run the model in the unit test with the change

result = jax.vmap(jax.vmap(jax_fn))(*samples)[0] to
result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]

I don't really get a speed-up until there are millions of samples.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't even call optimize_graph on either the graph in this function or the main sample routine

We should definitely call optimize_graph, otherwise the computed logps may not correspond to the ones used during sampling. For instance we have many optimizations that improve numerically stability, so you might get underflows to -inf for some of the posterior samples (which would never have been accepted by NUTS) which could screw up things downstream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it.

Then it's probably not worth it. I was under the impression it would be possible to vectorize/broadcast the operation from the conversations in #4489 and in slack.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It must be possible, since the vmap above works just fine. I just have no idea how they do it xD, or how/if you could do it in Aesara. I also wonder whether the vmap works for more complicated models with multivariate distributions and the like

Copy link
Contributor Author

@zaxtax zaxtax Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. I'm going to make a separate PR for some of this other stuff.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, feel free to tag me if you want me to review, I am not watching PRs. I can already say I won't be able to help with the vectorized log_likelihood thing, I tried and I lost much more time with that than what would have been healthy. I should be able to help with coords and dims though

@OriolAbril
Copy link
Member

Thanks! We should document that while posterior, log_likelihood, sample_stats and observed_data groups will be created, all coords and dims are ignored unlike with the "regular" backend.

Is sample numpyro in the docs already? Should it be?

@ricardoV94
Copy link
Member

Thanks! We should document that while posterior, log_likelihood, sample_stats and observed_data groups will be created, all coords and dims are ignored unlike with the "regular" backend.

I guess we could also retrieve these, no?

Is sample numpyro in the docs already? Should it be?

If it's not yet, it should!

@zaxtax
Copy link
Contributor Author

zaxtax commented Nov 18, 2021

@OriolAbril I think it would take me about the same effort to document the discrepancy as just do the correct thing with coords and dims.

Out of curiosity, there used to be a jax sampler based on TFP. Did that just silently get dropped?

@OriolAbril
Copy link
Member

@OriolAbril I think it would take me about the same effort to document the discrepancy as just do the correct thing with coords and dims.

Thanks! I assume the code is already written actually, either here in the to_inference_data function or in io_numpyro in ArviZ it should be a matter of combining the pieces together.

@ricardoV94
Copy link
Member

Out of curiosity, there used to be a jax sampler based on TFP. Did that just silently get dropped?

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants