Skip to content

Commit

Permalink
Add some models for generic testing of MCMC (#2049)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fritzo committed Oct 3, 2019
1 parent 716737d commit 8a2b87f
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 14 deletions.
2 changes: 2 additions & 0 deletions docs/source/poutine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Handlers

.. autofunction:: pyro.poutine.scale

.. autofunction:: pyro.poutine.seed

.. autofunction:: pyro.poutine.trace

.. autofunction:: pyro.infer.enum.config_enumerate
Expand Down
2 changes: 2 additions & 0 deletions docs/source/primitives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ Primitives
.. autofunction:: pyro.enable_validation

.. autofunction:: pyro.ops.jit.trace

.. autofunction:: pyro.set_rng_seed
2 changes: 1 addition & 1 deletion docs/source/pyro.generic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Generic Interface
The ``pyro.generic`` module provides an interface to dynamically dispatch Pyro code
to custom backends.

.. automodule:: pyro.generic
.. automodule:: pyro.generic.dispatch
:members:
:undoc-members:
:show-inheritance:
Expand Down
11 changes: 11 additions & 0 deletions pyro/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pyro.generic.dispatch import distributions, handlers, infer, ops, optim, pyro, pyro_backend

__all__ = [
'distributions',
'handlers',
'infer',
'ops',
'optim',
'pyro',
'pyro_backend',
]
25 changes: 15 additions & 10 deletions pyro/generic.py → pyro/generic/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import importlib

from contextlib import contextmanager


Expand Down Expand Up @@ -54,37 +53,43 @@ def pyro_backend(*aliases, **new_backends):

_ALIASES = {
'pyro': {
'pyro': 'pyro',
'distributions': 'pyro.distributions',
'handlers': 'pyro.poutine',
'infer': 'pyro.infer',
'optim': 'pyro.optim',
'ops': 'torch',
'optim': 'pyro.optim',
'pyro': 'pyro',
},
'minipyro': {
'pyro': 'pyro.contrib.minipyro',
'distributions': 'pyro.distributions',
'handlers': 'pyro.poutine',
'infer': 'pyro.contrib.minipyro',
'optim': 'pyro.contrib.minipyro',
'ops': 'torch',
'optim': 'pyro.contrib.minipyro',
'pyro': 'pyro.contrib.minipyro',
},
'funsor': {
'pyro': 'funsor.minipyro',
'infer': 'funsor.minipyro',
'optim': 'funsor.minipyro',
'distributions': 'funsor.distributions',
'handlers': 'funsor.minipyro',
'infer': 'funsor.minipyro',
'ops': 'funsor.ops',
'optim': 'funsor.minipyro',
'pyro': 'funsor.minipyro',
},
'numpy': {
'pyro': 'numpyro.compat.pyro',
'distributions': 'numpyro.compat.distributions',
'handlers': 'numpyro.compat.handlers',
'infer': 'numpyro.compat.infer',
'ops': 'numpyro.compat.ops',
'optim': 'numpyro.compat.optim',
'ops': 'numpy',
'pyro': 'numpyro.compat.pyro',
},
}

# These modules can be overridden.
pyro = GenericModule('pyro', 'pyro')
distributions = GenericModule('distributions', 'pyro.distributions')
handlers = GenericModule('handlers', 'pyro.poutine')
infer = GenericModule('infer', 'pyro.infer')
optim = GenericModule('optim', 'pyro.optim')
ops = GenericModule('ops', 'torch')
111 changes: 111 additions & 0 deletions pyro/generic/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Models for testing the generic interface.
For specifying the arguments to model functions, the convention followed is
that positional arguments are inputs to the model and keyword arguments denote
observed data.
"""

import argparse
from collections import OrderedDict

from pyro.generic import distributions as dist, handlers, ops, pyro, pyro_backend

MODELS = OrderedDict()


def register(rng_seed=None):
def _register_fn(fn):
MODELS[fn.__name__] = handlers.seed(fn, rng_seed)

return _register_fn


@register(rng_seed=1)
def logistic_regression():
N, dim = 3000, 3
# generic way to sample from distributions
data = pyro.sample('data', dist.Normal(0., 1.), sample_shape=(N, dim))
true_coefs = ops.arange(1., dim + 1.)
logits = ops.sum(true_coefs * data, axis=-1)
labels = pyro.sample('labels', dist.Bernoulli(logits=logits))

def model(x, y=None):
coefs = pyro.sample('coefs', dist.Normal(ops.zeros(dim), ops.ones(dim)))
intercept = pyro.sample('intercept', dist.Normal(0., 1.))
logits = ops.sum(coefs * x, axis=-1) + intercept
return pyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

return {'model': model, 'model_args': (data,), 'model_kwargs': {'y': labels}}


@register(rng_seed=1)
def neals_funnel():
def model(dim):
y = pyro.sample('y', dist.Normal(0, 3))
pyro.sample('x', dist.TransformedDistribution(
dist.Normal(ops.zeros(dim - 1), 1), dist.transforms.AffineTransform(0, ops.exp(y / 2))))

return {'model': model, 'model_args': (10,)}


@register(rng_seed=1)
def eight_schools():
J = 8
y = ops.tensor([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = ops.tensor([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

def model(J, sigma, y=None):
mu = pyro.sample('mu', dist.Normal(0, 5))
tau = pyro.sample('tau', dist.HalfCauchy(5))
with pyro.plate('J', J):
theta = pyro.sample('theta', dist.Normal(mu, tau))
pyro.sample('obs', dist.Normal(theta, sigma), obs=y)

return {'model': model, 'model_args': (J, sigma), 'model_kwargs': {'y': y}}


@register(rng_seed=1)
def beta_binomial():
N, D1, D2 = 10, 2, 2
true_probs = ops.tensor([[0.7, 0.4], [0.6, 0.4]])
total_count = ops.tensor([[1000, 600], [400, 800]])

data = pyro.sample('data', dist.Binomial(total_count=total_count, probs=true_probs),
sample_shape=(N,))

def model(N, D1, D2, data=None):
with pyro.plate("plate_0", D1):
alpha = pyro.sample("alpha", dist.HalfCauchy(1.))
beta = pyro.sample("beta", dist.HalfCauchy(1.))
with pyro.plate("plate_1", D2):
probs = pyro.sample("probs", dist.Beta(alpha, beta))
with pyro.plate("data", N):
pyro.sample("binomial", dist.Binomial(probs=probs, total_count=total_count), obs=data)

return {'model': model, 'model_args': (N, D1, D2), 'model_kwargs': {'data': data}}


def check_model(backend, name):
get_model = MODELS[name]
print('Running model "{}" on backend "{}".'.format(name, args.backend))
with pyro_backend(backend), handlers.seed(rng_seed=2):
f = get_model()
model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
print('Sample from prior...')
model(*model_args)
print('Trace model...')
handlers.trace(model).get_trace(*model_args, **model_kwargs)


def main(args):
for name in MODELS:
check_model(args.backend, name)


if __name__ == '__main__':
assert pyro.__version__.startswith('0.4.1')
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-b", "--backend", default="pyro")
args = parser.parse_args()
main(args)
6 changes: 6 additions & 0 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from pyro.infer.elbo import ELBO
from pyro.infer.enum import config_enumerate
from pyro.infer.importance import Importance
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.renyi_elbo import RenyiELBO
from pyro.infer.smcfilter import SMCFilter
from pyro.infer.svi import SVI
Expand All @@ -23,13 +26,16 @@
"is_validation_enabled",
"ELBO",
"EmpiricalMarginal",
"HMC",
"Importance",
"IMQSteinKernel",
"infer_discrete",
"JitTraceEnum_ELBO",
"JitTraceGraph_ELBO",
"JitTraceMeanField_ELBO",
"JitTrace_ELBO",
"MCMC",
"NUTS",
"RBFSteinKernel",
"RenyiELBO",
"SMCFilter",
Expand Down
3 changes: 2 additions & 1 deletion pyro/poutine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .handlers import (block, broadcast, condition, do, enum, escape, infer_config, lift, markov, mask, queue, replay,
scale, trace, uncondition)
scale, seed, trace, uncondition)
from .runtime import NonlocalExit
from .trace_struct import Trace
from .util import enable_validation, is_validation_enabled
Expand All @@ -21,6 +21,7 @@
"replay",
"queue",
"scale",
"seed",
"trace",
"Trace",
"uncondition",
Expand Down
30 changes: 29 additions & 1 deletion pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
import functools

from pyro.poutine import util

from pyro.poutine.messenger import Messenger
from pyro.util import set_rng_seed
from .block_messenger import BlockMessenger
from .broadcast_messenger import BroadcastMessenger
from .condition_messenger import ConditionMessenger
Expand All @@ -66,6 +67,7 @@
from .trace_messenger import TraceMessenger
from .uncondition_messenger import UnconditionMessenger


############################################
# Begin primitive operations
############################################
Expand Down Expand Up @@ -497,3 +499,29 @@ def markov(fn=None, history=1, keep=False):
return MarkovMessenger(history=history, keep=keep).generator(iterable=fn)
# Used as a decorator with bound args
return MarkovMessenger(history=history, keep=keep)(fn)


class _SeedMessenger(Messenger):
def __init__(self, rng_seed):
assert isinstance(rng_seed, int)
self.rng_seed = rng_seed
super(_SeedMessenger, self).__init__()

def __enter__(self):
set_rng_seed(self.rng_seed)
super(_SeedMessenger, self).__enter__()


def seed(fn=None, rng_seed=None):
"""
Handler to set the random number generator to a pre-defined state by setting its
seed. This is the same as calling :func:`pyro.set_rng_seed` before the
call to `fn`. This handler has no additional effect on primitive statements on the
standard Pyro backend, but it might intercept ``pyro.sample`` calls in other
backends. e.g. the NumPy backend.
:param fn: a stochastic function (callable containing Pyro primitive calls).
:param int rng_seed: rng seed.
"""
msngr = _SeedMessenger(rng_seed)
return msngr(fn) if fn is not None else msngr
3 changes: 2 additions & 1 deletion pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

def set_rng_seed(rng_seed):
"""
Sets seeds of torch and torch.cuda (if available).
Sets seeds of `torch` and `torch.cuda` (if available).
:param int rng_seed: The seed value.
"""
torch.manual_seed(rng_seed)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest

from pyro.generic import infer, pyro_backend, handlers
from pyro.generic.testing import MODELS


pytestmark = pytest.mark.stage('unit')


@pytest.mark.parametrize('model', MODELS)
@pytest.mark.parametrize('backend', ['pyro'])
def test_mcmc_interface(model, backend):
with pyro_backend(backend), handlers.seed(rng_seed=20):
f = MODELS[model]()
model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
nuts_kernel = infer.NUTS(model=model)
mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10)
mcmc.run(*args, **kwargs)
mcmc.summary()

0 comments on commit 8a2b87f

Please sign in to comment.