From 8a2b87f41f2e3939f0ba9ca2a00b1e9ffa1f51b5 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 3 Oct 2019 10:19:23 -0700 Subject: [PATCH] Add some models for generic testing of MCMC (#2049) --- docs/source/poutine.rst | 2 + docs/source/primitives.rst | 2 + docs/source/pyro.generic.rst | 2 +- pyro/generic/__init__.py | 11 +++ pyro/{generic.py => generic/dispatch.py} | 25 +++-- pyro/generic/testing.py | 111 +++++++++++++++++++++++ pyro/infer/__init__.py | 6 ++ pyro/poutine/__init__.py | 3 +- pyro/poutine/handlers.py | 30 +++++- pyro/util.py | 3 +- tests/test_generic.py | 19 ++++ 11 files changed, 200 insertions(+), 14 deletions(-) create mode 100644 pyro/generic/__init__.py rename pyro/{generic.py => generic/dispatch.py} (90%) create mode 100644 pyro/generic/testing.py create mode 100644 tests/test_generic.py diff --git a/docs/source/poutine.rst b/docs/source/poutine.rst index e572a30c58..49725c01d5 100644 --- a/docs/source/poutine.rst +++ b/docs/source/poutine.rst @@ -36,6 +36,8 @@ Handlers .. autofunction:: pyro.poutine.scale +.. autofunction:: pyro.poutine.seed + .. autofunction:: pyro.poutine.trace .. autofunction:: pyro.infer.enum.config_enumerate diff --git a/docs/source/primitives.rst b/docs/source/primitives.rst index 2f25796d15..864870089a 100644 --- a/docs/source/primitives.rst +++ b/docs/source/primitives.rst @@ -18,3 +18,5 @@ Primitives .. autofunction:: pyro.enable_validation .. autofunction:: pyro.ops.jit.trace + +.. autofunction:: pyro.set_rng_seed diff --git a/docs/source/pyro.generic.rst b/docs/source/pyro.generic.rst index d91f427bcb..4d0f0f924a 100644 --- a/docs/source/pyro.generic.rst +++ b/docs/source/pyro.generic.rst @@ -4,7 +4,7 @@ Generic Interface The ``pyro.generic`` module provides an interface to dynamically dispatch Pyro code to custom backends. -.. automodule:: pyro.generic +.. automodule:: pyro.generic.dispatch :members: :undoc-members: :show-inheritance: diff --git a/pyro/generic/__init__.py b/pyro/generic/__init__.py new file mode 100644 index 0000000000..dec476aab2 --- /dev/null +++ b/pyro/generic/__init__.py @@ -0,0 +1,11 @@ +from pyro.generic.dispatch import distributions, handlers, infer, ops, optim, pyro, pyro_backend + +__all__ = [ + 'distributions', + 'handlers', + 'infer', + 'ops', + 'optim', + 'pyro', + 'pyro_backend', +] diff --git a/pyro/generic.py b/pyro/generic/dispatch.py similarity index 90% rename from pyro/generic.py rename to pyro/generic/dispatch.py index 76b632be3b..246afc958f 100644 --- a/pyro/generic.py +++ b/pyro/generic/dispatch.py @@ -1,5 +1,4 @@ import importlib - from contextlib import contextmanager @@ -54,37 +53,43 @@ def pyro_backend(*aliases, **new_backends): _ALIASES = { 'pyro': { - 'pyro': 'pyro', 'distributions': 'pyro.distributions', + 'handlers': 'pyro.poutine', 'infer': 'pyro.infer', - 'optim': 'pyro.optim', 'ops': 'torch', + 'optim': 'pyro.optim', + 'pyro': 'pyro', }, 'minipyro': { - 'pyro': 'pyro.contrib.minipyro', + 'distributions': 'pyro.distributions', + 'handlers': 'pyro.poutine', 'infer': 'pyro.contrib.minipyro', - 'optim': 'pyro.contrib.minipyro', 'ops': 'torch', + 'optim': 'pyro.contrib.minipyro', + 'pyro': 'pyro.contrib.minipyro', }, 'funsor': { - 'pyro': 'funsor.minipyro', - 'infer': 'funsor.minipyro', - 'optim': 'funsor.minipyro', 'distributions': 'funsor.distributions', + 'handlers': 'funsor.minipyro', + 'infer': 'funsor.minipyro', 'ops': 'funsor.ops', + 'optim': 'funsor.minipyro', + 'pyro': 'funsor.minipyro', }, 'numpy': { - 'pyro': 'numpyro.compat.pyro', 'distributions': 'numpyro.compat.distributions', + 'handlers': 'numpyro.compat.handlers', 'infer': 'numpyro.compat.infer', + 'ops': 'numpyro.compat.ops', 'optim': 'numpyro.compat.optim', - 'ops': 'numpy', + 'pyro': 'numpyro.compat.pyro', }, } # These modules can be overridden. pyro = GenericModule('pyro', 'pyro') distributions = GenericModule('distributions', 'pyro.distributions') +handlers = GenericModule('handlers', 'pyro.poutine') infer = GenericModule('infer', 'pyro.infer') optim = GenericModule('optim', 'pyro.optim') ops = GenericModule('ops', 'torch') diff --git a/pyro/generic/testing.py b/pyro/generic/testing.py new file mode 100644 index 0000000000..5c11b4d087 --- /dev/null +++ b/pyro/generic/testing.py @@ -0,0 +1,111 @@ +""" +Models for testing the generic interface. + +For specifying the arguments to model functions, the convention followed is +that positional arguments are inputs to the model and keyword arguments denote +observed data. +""" + +import argparse +from collections import OrderedDict + +from pyro.generic import distributions as dist, handlers, ops, pyro, pyro_backend + +MODELS = OrderedDict() + + +def register(rng_seed=None): + def _register_fn(fn): + MODELS[fn.__name__] = handlers.seed(fn, rng_seed) + + return _register_fn + + +@register(rng_seed=1) +def logistic_regression(): + N, dim = 3000, 3 + # generic way to sample from distributions + data = pyro.sample('data', dist.Normal(0., 1.), sample_shape=(N, dim)) + true_coefs = ops.arange(1., dim + 1.) + logits = ops.sum(true_coefs * data, axis=-1) + labels = pyro.sample('labels', dist.Bernoulli(logits=logits)) + + def model(x, y=None): + coefs = pyro.sample('coefs', dist.Normal(ops.zeros(dim), ops.ones(dim))) + intercept = pyro.sample('intercept', dist.Normal(0., 1.)) + logits = ops.sum(coefs * x, axis=-1) + intercept + return pyro.sample('obs', dist.Bernoulli(logits=logits), obs=y) + + return {'model': model, 'model_args': (data,), 'model_kwargs': {'y': labels}} + + +@register(rng_seed=1) +def neals_funnel(): + def model(dim): + y = pyro.sample('y', dist.Normal(0, 3)) + pyro.sample('x', dist.TransformedDistribution( + dist.Normal(ops.zeros(dim - 1), 1), dist.transforms.AffineTransform(0, ops.exp(y / 2)))) + + return {'model': model, 'model_args': (10,)} + + +@register(rng_seed=1) +def eight_schools(): + J = 8 + y = ops.tensor([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) + sigma = ops.tensor([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) + + def model(J, sigma, y=None): + mu = pyro.sample('mu', dist.Normal(0, 5)) + tau = pyro.sample('tau', dist.HalfCauchy(5)) + with pyro.plate('J', J): + theta = pyro.sample('theta', dist.Normal(mu, tau)) + pyro.sample('obs', dist.Normal(theta, sigma), obs=y) + + return {'model': model, 'model_args': (J, sigma), 'model_kwargs': {'y': y}} + + +@register(rng_seed=1) +def beta_binomial(): + N, D1, D2 = 10, 2, 2 + true_probs = ops.tensor([[0.7, 0.4], [0.6, 0.4]]) + total_count = ops.tensor([[1000, 600], [400, 800]]) + + data = pyro.sample('data', dist.Binomial(total_count=total_count, probs=true_probs), + sample_shape=(N,)) + + def model(N, D1, D2, data=None): + with pyro.plate("plate_0", D1): + alpha = pyro.sample("alpha", dist.HalfCauchy(1.)) + beta = pyro.sample("beta", dist.HalfCauchy(1.)) + with pyro.plate("plate_1", D2): + probs = pyro.sample("probs", dist.Beta(alpha, beta)) + with pyro.plate("data", N): + pyro.sample("binomial", dist.Binomial(probs=probs, total_count=total_count), obs=data) + + return {'model': model, 'model_args': (N, D1, D2), 'model_kwargs': {'data': data}} + + +def check_model(backend, name): + get_model = MODELS[name] + print('Running model "{}" on backend "{}".'.format(name, args.backend)) + with pyro_backend(backend), handlers.seed(rng_seed=2): + f = get_model() + model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) + print('Sample from prior...') + model(*model_args) + print('Trace model...') + handlers.trace(model).get_trace(*model_args, **model_kwargs) + + +def main(args): + for name in MODELS: + check_model(args.backend, name) + + +if __name__ == '__main__': + assert pyro.__version__.startswith('0.4.1') + parser = argparse.ArgumentParser(description="Mini Pyro demo") + parser.add_argument("-b", "--backend", default="pyro") + args = parser.parse_args() + main(args) diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index cba6e18156..4e5e6571f5 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -4,6 +4,9 @@ from pyro.infer.elbo import ELBO from pyro.infer.enum import config_enumerate from pyro.infer.importance import Importance +from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc.hmc import HMC +from pyro.infer.mcmc.nuts import NUTS from pyro.infer.renyi_elbo import RenyiELBO from pyro.infer.smcfilter import SMCFilter from pyro.infer.svi import SVI @@ -23,6 +26,7 @@ "is_validation_enabled", "ELBO", "EmpiricalMarginal", + "HMC", "Importance", "IMQSteinKernel", "infer_discrete", @@ -30,6 +34,8 @@ "JitTraceGraph_ELBO", "JitTraceMeanField_ELBO", "JitTrace_ELBO", + "MCMC", + "NUTS", "RBFSteinKernel", "RenyiELBO", "SMCFilter", diff --git a/pyro/poutine/__init__.py b/pyro/poutine/__init__.py index b7bf1074d0..cfe87bbb57 100644 --- a/pyro/poutine/__init__.py +++ b/pyro/poutine/__init__.py @@ -1,5 +1,5 @@ from .handlers import (block, broadcast, condition, do, enum, escape, infer_config, lift, markov, mask, queue, replay, - scale, trace, uncondition) + scale, seed, trace, uncondition) from .runtime import NonlocalExit from .trace_struct import Trace from .util import enable_validation, is_validation_enabled @@ -21,6 +21,7 @@ "replay", "queue", "scale", + "seed", "trace", "Trace", "uncondition", diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index c80b75dfee..84cddce16d 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -49,7 +49,8 @@ import functools from pyro.poutine import util - +from pyro.poutine.messenger import Messenger +from pyro.util import set_rng_seed from .block_messenger import BlockMessenger from .broadcast_messenger import BroadcastMessenger from .condition_messenger import ConditionMessenger @@ -66,6 +67,7 @@ from .trace_messenger import TraceMessenger from .uncondition_messenger import UnconditionMessenger + ############################################ # Begin primitive operations ############################################ @@ -497,3 +499,29 @@ def markov(fn=None, history=1, keep=False): return MarkovMessenger(history=history, keep=keep).generator(iterable=fn) # Used as a decorator with bound args return MarkovMessenger(history=history, keep=keep)(fn) + + +class _SeedMessenger(Messenger): + def __init__(self, rng_seed): + assert isinstance(rng_seed, int) + self.rng_seed = rng_seed + super(_SeedMessenger, self).__init__() + + def __enter__(self): + set_rng_seed(self.rng_seed) + super(_SeedMessenger, self).__enter__() + + +def seed(fn=None, rng_seed=None): + """ + Handler to set the random number generator to a pre-defined state by setting its + seed. This is the same as calling :func:`pyro.set_rng_seed` before the + call to `fn`. This handler has no additional effect on primitive statements on the + standard Pyro backend, but it might intercept ``pyro.sample`` calls in other + backends. e.g. the NumPy backend. + + :param fn: a stochastic function (callable containing Pyro primitive calls). + :param int rng_seed: rng seed. + """ + msngr = _SeedMessenger(rng_seed) + return msngr(fn) if fn is not None else msngr diff --git a/pyro/util.py b/pyro/util.py index 5ec05a2272..b33604a2b0 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -15,7 +15,8 @@ def set_rng_seed(rng_seed): """ - Sets seeds of torch and torch.cuda (if available). + Sets seeds of `torch` and `torch.cuda` (if available). + :param int rng_seed: The seed value. """ torch.manual_seed(rng_seed) diff --git a/tests/test_generic.py b/tests/test_generic.py new file mode 100644 index 0000000000..70d08b92a1 --- /dev/null +++ b/tests/test_generic.py @@ -0,0 +1,19 @@ +import pytest + +from pyro.generic import infer, pyro_backend, handlers +from pyro.generic.testing import MODELS + + +pytestmark = pytest.mark.stage('unit') + + +@pytest.mark.parametrize('model', MODELS) +@pytest.mark.parametrize('backend', ['pyro']) +def test_mcmc_interface(model, backend): + with pyro_backend(backend), handlers.seed(rng_seed=20): + f = MODELS[model]() + model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) + nuts_kernel = infer.NUTS(model=model) + mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10) + mcmc.run(*args, **kwargs) + mcmc.summary()