-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some models for generic testing of MCMC #2049
Conversation
Thanks for reviewing, @fritzo! I will address the remaining comments and update when this PR is ready for a second review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design for modeling, inference, and tests LGTM! I just have a minor comment regarding handlers.seed
but I would like to defer it to a separate issue, where we can discuss if it is worth to do.
I think this is good for a second round of reviewing. |
Some changes based on prior comments:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall! This is much much cleaner than I have expected and have thought about this generic
module previously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after a couple minor comments.
After this merges I'll add some SVI examples to pyro.generic.testing and start using that in funsor.
|
||
|
||
@pytest.mark.parametrize('model', MODELS) | ||
@pytest.mark.parametrize('backend', ['pyro']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add the minipyro
backend here, or does that not support MCMC?
EDIT I'd like to add the minipyro
backend once we have some SVI examples, but that can wait until a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minipyro
doesn't support MCMC. Maybe we can add a vanilla HMC to minipyro that uses the integrator in pyro.ops
, but I'm not sure if it will be as concise as SVI. It could be a nice little hackathon project.
with pyro_backend(backend), handlers.seed(rng_seed=20): | ||
f = MODELS[model]() | ||
model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) | ||
nuts_kernel = infer.NUTS(model=model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well we should later factor this out into check_model()
so as to minimize the effort of adding tests to each backend, but we can do that in a follow-up PR when we decide how to add SVI tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well we should later factor this out into check_model() so as to minimize the effort of adding tests to each backend
Could you elaborate on that? There is a very basic testing.check_model
that runs the model forward and traces it. I have tried to follow a convention that args represent model inputs and kwargs are used to represent observed data (like Stan's data block). The idea was that each backend will just import MODELS
from testing and use it as they see fit. We can continue to enrich the metadata from each model, currently its just the model callable, inputs and observed data (maybe we can add ML estimate for the latent posteriors etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally I'd like to implement a very narrow testing interface with two methods that do not mention pytest and do not require testing code to mention pyro.generic. Their use is equivalent to the following:
# in my_pyro_implementation/test/test_generic.py
import pytest
from pyro.generic.testing import get_test_cases, check_test_case
# n.b. do not import any pyro stuff, including pyro.generic
@pytest.mark.parametrize('test_case', get_test_cases())
def test_generic(test_case):
try:
check_test_case(test_case)
except NotImplementedError as e:
pytest.xfail(f"Not implemented:\n{e}")
I don't care what these are named or whether they are lists or functions or whatever, but I do think we should separate concerns such that:
- all pyro.generic interface is hidden inside these test helpers, and
- all testing infra (pytest) is applied outside these test helpers.
I see these helpers as kind of an autoconf mechanism for pyro implementations: running them automatically shows what features are available in a given backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved this discussion to a new issue #2053
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks for taking the time to work out a good extensible API!
Some of the changes needed to enable this in NumPyro are not merged - pyro-ppl/numpyro#374.
Simple models can use generic dispatch to run MCMC. This has the following changes:
generic.py
into a separategeneric
module.generic.distributions
module which basically imports distributions from the Pyro namespace along with torch constraints and transforms. We can also create separate dispatch modules for either of these, but it seemed a little excess.handlers
in generic. Creates aseed
contextmanager in Pyro which simply callspyro.set_rng_seed
in Pyro but calls theseed
handler in NumPyro.I have not formalized the generic API. Doing so will simply involve creating abstract base classes in generic, and ensuring that the dispatched classes inherit from this generic API. This doesn't preclude such a formalization in the future.