Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some models for generic testing of MCMC #2049

Merged
merged 14 commits into from
Oct 3, 2019
Merged

Conversation

neerajprad
Copy link
Member

Some of the changes needed to enable this in NumPyro are not merged - pyro-ppl/numpyro#374.

Simple models can use generic dispatch to run MCMC. This has the following changes:

  • Moves generic.py into a separate generic module.
  • Creates generic.distributions module which basically imports distributions from the Pyro namespace along with torch constraints and transforms. We can also create separate dispatch modules for either of these, but it seemed a little excess.
  • Exposes handlers in generic. Creates a seed contextmanager in Pyro which simply calls pyro.set_rng_seed in Pyro but calls the seed handler in NumPyro.
  • A simple testing module is added, but it can be modified to add separate model, guide pairs to test SVI. This is important for funsor, and might at one point be useful to have for NumPyro.

I have not formalized the generic API. Doing so will simply involve creating abstract base classes in generic, and ensuring that the dispatched classes inherit from this generic API. This doesn't preclude such a formalization in the future.

@neerajprad neerajprad requested review from fehiepsi and fritzo October 1, 2019 00:55
pyro/generic/testing.py Outdated Show resolved Hide resolved
pyro/generic/testing.py Outdated Show resolved Hide resolved
pyro/generic/generic.py Outdated Show resolved Hide resolved
pyro/generic/generic.py Outdated Show resolved Hide resolved
@neerajprad
Copy link
Member Author

Thanks for reviewing, @fritzo! I will address the remaining comments and update when this PR is ready for a second review.

fehiepsi
fehiepsi previously approved these changes Oct 1, 2019
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design for modeling, inference, and tests LGTM! I just have a minor comment regarding handlers.seed but I would like to defer it to a separate issue, where we can discuss if it is worth to do.

@neerajprad
Copy link
Member Author

I think this is good for a second round of reviewing.

@neerajprad
Copy link
Member Author

Some changes based on prior comments:

  • I have exposed transforms and constraints in the generic module itself that dispatch to torch constraints/transforms. This is because we have a local pyro.distributions.transforms module, and given how important these are (for transformed distributions, etc), it makes sense to have them directly in generic.
  • Also added the generic module itself. This was needed so that we don't have a mock handlers.seed in pyro and can use the generic.seed decorator / contextmanager instead. This however will only have any effect when under the pyro_backend context manager which requires this module be available for generic dispatching. An alternative would be to build a handlers.seed mock handler in other backends.

fehiepsi
fehiepsi previously approved these changes Oct 2, 2019
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall! This is much much cleaner than I have expected and have thought about this generic module previously.

pyro/infer/__init__.py Show resolved Hide resolved
pyro/generic/generic.py Outdated Show resolved Hide resolved
tests/test_generic.py Show resolved Hide resolved
docs/source/pyro.generic.rst Outdated Show resolved Hide resolved
pyro/generic/generic.py Outdated Show resolved Hide resolved
pyro/generic/generic.py Outdated Show resolved Hide resolved
pyro/generic/generic.py Outdated Show resolved Hide resolved
pyro/generic/generic.py Outdated Show resolved Hide resolved
pyro/generic/generic.py Outdated Show resolved Hide resolved
pyro/distributions/__init__.py Outdated Show resolved Hide resolved
pyro/distributions/__init__.py Outdated Show resolved Hide resolved
@neerajprad
Copy link
Member Author

Thanks a lot for a thorough review, @fritzo, @fehiepsi. I think I have addressed all existing comments, so this should be ready for another round of review!

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after a couple minor comments.

After this merges I'll add some SVI examples to pyro.generic.testing and start using that in funsor.

pyro/generic/dispatch.py Outdated Show resolved Hide resolved


@pytest.mark.parametrize('model', MODELS)
@pytest.mark.parametrize('backend', ['pyro'])
Copy link
Member

@fritzo fritzo Oct 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the minipyro backend here, or does that not support MCMC?

EDIT I'd like to add the minipyro backend once we have some SVI examples, but that can wait until a follow-up PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minipyro doesn't support MCMC. Maybe we can add a vanilla HMC to minipyro that uses the integrator in pyro.ops, but I'm not sure if it will be as concise as SVI. It could be a nice little hackathon project.

with pyro_backend(backend), handlers.seed(rng_seed=20):
f = MODELS[model]()
model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
nuts_kernel = infer.NUTS(model=model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well we should later factor this out into check_model() so as to minimize the effort of adding tests to each backend, but we can do that in a follow-up PR when we decide how to add SVI tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well we should later factor this out into check_model() so as to minimize the effort of adding tests to each backend

Could you elaborate on that? There is a very basic testing.check_model that runs the model forward and traces it. I have tried to follow a convention that args represent model inputs and kwargs are used to represent observed data (like Stan's data block). The idea was that each backend will just import MODELS from testing and use it as they see fit. We can continue to enrich the metadata from each model, currently its just the model callable, inputs and observed data (maybe we can add ML estimate for the latent posteriors etc).

Copy link
Member

@fritzo fritzo Oct 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I'd like to implement a very narrow testing interface with two methods that do not mention pytest and do not require testing code to mention pyro.generic. Their use is equivalent to the following:

# in my_pyro_implementation/test/test_generic.py
import pytest
from pyro.generic.testing import get_test_cases, check_test_case
# n.b. do not import any pyro stuff, including pyro.generic

@pytest.mark.parametrize('test_case', get_test_cases())
def test_generic(test_case):
    try:
        check_test_case(test_case)
    except NotImplementedError as e:
        pytest.xfail(f"Not implemented:\n{e}")

I don't care what these are named or whether they are lists or functions or whatever, but I do think we should separate concerns such that:

  1. all pyro.generic interface is hidden inside these test helpers, and
  2. all testing infra (pytest) is applied outside these test helpers.

I see these helpers as kind of an autoconf mechanism for pyro implementations: running them automatically shows what features are available in a given backend.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved this discussion to a new issue #2053

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks for taking the time to work out a good extensible API!

@fritzo fritzo merged commit 8a2b87f into pyro-ppl:dev Oct 3, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants