Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some models for generic testing of MCMC #2049

Merged
merged 14 commits into from
Oct 3, 2019
2 changes: 2 additions & 0 deletions docs/source/poutine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Handlers

.. autofunction:: pyro.poutine.scale

.. autofunction:: pyro.poutine.seed

.. autofunction:: pyro.poutine.trace

.. autofunction:: pyro.infer.enum.config_enumerate
Expand Down
2 changes: 2 additions & 0 deletions docs/source/primitives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ Primitives
.. autofunction:: pyro.enable_validation

.. autofunction:: pyro.ops.jit.trace

.. autofunction:: pyro.set_rng_seed
2 changes: 1 addition & 1 deletion docs/source/pyro.generic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Generic Interface
The ``pyro.generic`` module provides an interface to dynamically dispatch Pyro code
to custom backends.

.. automodule:: pyro.generic
.. automodule:: pyro.generic.dispatch
:members:
:undoc-members:
:show-inheritance:
Expand Down
11 changes: 11 additions & 0 deletions pyro/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pyro.generic.dispatch import distributions, handlers, infer, ops, optim, pyro, pyro_backend

__all__ = [
'distributions',
'handlers',
'infer',
'ops',
'optim',
'pyro',
'pyro_backend',
]
25 changes: 15 additions & 10 deletions pyro/generic.py → pyro/generic/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import importlib

from contextlib import contextmanager


Expand Down Expand Up @@ -54,37 +53,43 @@ def pyro_backend(*aliases, **new_backends):

_ALIASES = {
'pyro': {
'pyro': 'pyro',
'distributions': 'pyro.distributions',
'handlers': 'pyro.poutine',
'infer': 'pyro.infer',
'optim': 'pyro.optim',
'ops': 'torch',
'optim': 'pyro.optim',
'pyro': 'pyro',
},
'minipyro': {
'pyro': 'pyro.contrib.minipyro',
'distributions': 'pyro.distributions',
'handlers': 'pyro.poutine',
'infer': 'pyro.contrib.minipyro',
'optim': 'pyro.contrib.minipyro',
'ops': 'torch',
'optim': 'pyro.contrib.minipyro',
'pyro': 'pyro.contrib.minipyro',
},
'funsor': {
'pyro': 'funsor.minipyro',
'infer': 'funsor.minipyro',
'optim': 'funsor.minipyro',
'distributions': 'funsor.distributions',
'handlers': 'pyro.poutine',
neerajprad marked this conversation as resolved.
Show resolved Hide resolved
'infer': 'funsor.minipyro',
'ops': 'funsor.ops',
'optim': 'funsor.minipyro',
'pyro': 'funsor.minipyro',
},
'numpy': {
'pyro': 'numpyro.compat.pyro',
'distributions': 'numpyro.compat.distributions',
'handlers': 'numpyro.compat.handlers',
'infer': 'numpyro.compat.infer',
'ops': 'numpyro.compat.ops',
'optim': 'numpyro.compat.optim',
'ops': 'numpy',
'pyro': 'numpyro.compat.pyro',
},
}

# These modules can be overridden.
pyro = GenericModule('pyro', 'pyro')
distributions = GenericModule('distributions', 'pyro.distributions')
handlers = GenericModule('handlers', 'pyro.poutine')
infer = GenericModule('infer', 'pyro.infer')
optim = GenericModule('optim', 'pyro.optim')
ops = GenericModule('ops', 'torch')
111 changes: 111 additions & 0 deletions pyro/generic/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Models for testing the generic interface.
For specifying the arguments to model functions, the convention followed is
that positional arguments are inputs to the model and keyword arguments denote
observed data.
"""

import argparse
from collections import OrderedDict

from pyro.generic import distributions as dist, handlers, ops, pyro, pyro_backend

MODELS = OrderedDict()


def register(rng_seed=None):
def _register_fn(fn):
MODELS[fn.__name__] = handlers.seed(fn, rng_seed)

return _register_fn


@register(rng_seed=1)
def logistic_regression():
N, dim = 3000, 3
# generic way to sample from distributions
data = pyro.sample('data', dist.Normal(0., 1.), sample_shape=(N, dim))
true_coefs = ops.arange(1., dim + 1.)
logits = ops.sum(true_coefs * data, axis=-1)
labels = pyro.sample('labels', dist.Bernoulli(logits=logits))

def model(x, y=None):
coefs = pyro.sample('coefs', dist.Normal(ops.zeros(dim), ops.ones(dim)))
intercept = pyro.sample('intercept', dist.Normal(0., 1.))
logits = ops.sum(coefs * x, axis=-1) + intercept
return pyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

return {'model': model, 'model_args': (data,), 'model_kwargs': {'y': labels}}


@register(rng_seed=1)
def neals_funnel():
def model(dim):
y = pyro.sample('y', dist.Normal(0, 3))
pyro.sample('x', dist.TransformedDistribution(
dist.Normal(ops.zeros(dim - 1), 1), dist.transforms.AffineTransform(0, ops.exp(y / 2))))

return {'model': model, 'model_args': (10,)}


@register(rng_seed=1)
def eight_schools():
J = 8
y = ops.tensor([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = ops.tensor([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

def model(J, sigma, y=None):
mu = pyro.sample('mu', dist.Normal(0, 5))
tau = pyro.sample('tau', dist.HalfCauchy(5))
with pyro.plate('J', J):
theta = pyro.sample('theta', dist.Normal(mu, tau))
pyro.sample('obs', dist.Normal(theta, sigma), obs=y)

return {'model': model, 'model_args': (J, sigma), 'model_kwargs': {'y': y}}


@register(rng_seed=1)
def beta_binomial():
N, D1, D2 = 10, 2, 2
true_probs = ops.tensor([[0.7, 0.4], [0.6, 0.4]])
total_count = ops.tensor([[1000, 600], [400, 800]])

data = pyro.sample('data', dist.Binomial(total_count=total_count, probs=true_probs),
sample_shape=(N,))

def model(N, D1, D2, data=None):
with pyro.plate("plate_0", D1):
alpha = pyro.sample("alpha", dist.HalfCauchy(1.))
beta = pyro.sample("beta", dist.HalfCauchy(1.))
with pyro.plate("plate_1", D2):
probs = pyro.sample("probs", dist.Beta(alpha, beta))
with pyro.plate("data", N):
pyro.sample("binomial", dist.Binomial(probs=probs, total_count=total_count), obs=data)

return {'model': model, 'model_args': (N, D1, D2), 'model_kwargs': {'data': data}}


def check_model(backend, name):
get_model = MODELS[name]
print('Running model "{}" on backend "{}".'.format(name, args.backend))
with pyro_backend(backend), handlers.seed(rng_seed=2):
f = get_model()
model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
print('Sample from prior...')
model(*model_args)
print('Trace model...')
handlers.trace(model).get_trace(*model_args, **model_kwargs)


def main(args):
for name in MODELS:
check_model(args.backend, name)


if __name__ == '__main__':
assert pyro.__version__.startswith('0.4.1')
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-b", "--backend", default="pyro")
args = parser.parse_args()
main(args)
6 changes: 6 additions & 0 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from pyro.infer.elbo import ELBO
from pyro.infer.enum import config_enumerate
from pyro.infer.importance import Importance
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.renyi_elbo import RenyiELBO
from pyro.infer.smcfilter import SMCFilter
from pyro.infer.svi import SVI
Expand All @@ -23,13 +26,16 @@
"is_validation_enabled",
"ELBO",
"EmpiricalMarginal",
"HMC",
"Importance",
"IMQSteinKernel",
"infer_discrete",
"JitTraceEnum_ELBO",
"JitTraceGraph_ELBO",
"JitTraceMeanField_ELBO",
"JitTrace_ELBO",
"MCMC",
"NUTS",
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
"RBFSteinKernel",
"RenyiELBO",
"SMCFilter",
Expand Down
3 changes: 2 additions & 1 deletion pyro/poutine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .handlers import (block, broadcast, condition, do, enum, escape, infer_config, lift, markov, mask, queue, replay,
scale, trace, uncondition)
scale, seed, trace, uncondition)
from .runtime import NonlocalExit
from .trace_struct import Trace
from .util import enable_validation, is_validation_enabled
Expand All @@ -21,6 +21,7 @@
"replay",
"queue",
"scale",
"seed",
"trace",
"Trace",
"uncondition",
Expand Down
30 changes: 29 additions & 1 deletion pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
import functools

from pyro.poutine import util

from pyro.poutine.messenger import Messenger
from pyro.util import set_rng_seed
from .block_messenger import BlockMessenger
from .broadcast_messenger import BroadcastMessenger
from .condition_messenger import ConditionMessenger
Expand All @@ -66,6 +67,7 @@
from .trace_messenger import TraceMessenger
from .uncondition_messenger import UnconditionMessenger


############################################
# Begin primitive operations
############################################
Expand Down Expand Up @@ -497,3 +499,29 @@ def markov(fn=None, history=1, keep=False):
return MarkovMessenger(history=history, keep=keep).generator(iterable=fn)
# Used as a decorator with bound args
return MarkovMessenger(history=history, keep=keep)(fn)


class _SeedMessenger(Messenger):
def __init__(self, rng_seed):
assert isinstance(rng_seed, int)
self.rng_seed = rng_seed
super(_SeedMessenger, self).__init__()

def __enter__(self):
set_rng_seed(self.rng_seed)
super(_SeedMessenger, self).__enter__()


def seed(fn=None, rng_seed=None):
"""
Handler to set the random number generator to a pre-defined state by setting its
seed. This is the same as calling :func:`pyro.set_rng_seed` before the
call to `fn`. This handler has no additional effect on primitive statements on the
standard Pyro backend, but it might intercept ``pyro.sample`` calls in other
backends. e.g. the NumPy backend.
:param fn: a stochastic function (callable containing Pyro primitive calls).
:param int rng_seed: rng seed.
"""
msngr = _SeedMessenger(rng_seed)
return msngr(fn) if fn is not None else msngr
3 changes: 2 additions & 1 deletion pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

def set_rng_seed(rng_seed):
"""
Sets seeds of torch and torch.cuda (if available).
Sets seeds of `torch` and `torch.cuda` (if available).
:param int rng_seed: The seed value.
"""
torch.manual_seed(rng_seed)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest

from pyro.generic import infer, pyro_backend, handlers
from pyro.generic.testing import MODELS


pytestmark = pytest.mark.stage('unit')


@pytest.mark.parametrize('model', MODELS)
@pytest.mark.parametrize('backend', ['pyro'])
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

@fritzo fritzo Oct 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the minipyro backend here, or does that not support MCMC?

EDIT I'd like to add the minipyro backend once we have some SVI examples, but that can wait until a follow-up PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minipyro doesn't support MCMC. Maybe we can add a vanilla HMC to minipyro that uses the integrator in pyro.ops, but I'm not sure if it will be as concise as SVI. It could be a nice little hackathon project.

def test_mcmc_interface(model, backend):
with pyro_backend(backend), handlers.seed(rng_seed=20):
f = MODELS[model]()
model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
nuts_kernel = infer.NUTS(model=model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well we should later factor this out into check_model() so as to minimize the effort of adding tests to each backend, but we can do that in a follow-up PR when we decide how to add SVI tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well we should later factor this out into check_model() so as to minimize the effort of adding tests to each backend

Could you elaborate on that? There is a very basic testing.check_model that runs the model forward and traces it. I have tried to follow a convention that args represent model inputs and kwargs are used to represent observed data (like Stan's data block). The idea was that each backend will just import MODELS from testing and use it as they see fit. We can continue to enrich the metadata from each model, currently its just the model callable, inputs and observed data (maybe we can add ML estimate for the latent posteriors etc).

Copy link
Member

@fritzo fritzo Oct 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I'd like to implement a very narrow testing interface with two methods that do not mention pytest and do not require testing code to mention pyro.generic. Their use is equivalent to the following:

# in my_pyro_implementation/test/test_generic.py
import pytest
from pyro.generic.testing import get_test_cases, check_test_case
# n.b. do not import any pyro stuff, including pyro.generic

@pytest.mark.parametrize('test_case', get_test_cases())
def test_generic(test_case):
    try:
        check_test_case(test_case)
    except NotImplementedError as e:
        pytest.xfail(f"Not implemented:\n{e}")

I don't care what these are named or whether they are lists or functions or whatever, but I do think we should separate concerns such that:

  1. all pyro.generic interface is hidden inside these test helpers, and
  2. all testing infra (pytest) is applied outside these test helpers.

I see these helpers as kind of an autoconf mechanism for pyro implementations: running them automatically shows what features are available in a given backend.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved this discussion to a new issue #2053

mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10)
mcmc.run(*args, **kwargs)
mcmc.summary()