Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple svi tests to pyroapi.tests #5

Merged
merged 7 commits into from
Oct 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
python: 3.6
script:
- flake8
- pytest -vs --tb=short
- pytest -vs --tb=short test
25 changes: 0 additions & 25 deletions pyroapi/tests.py

This file was deleted.

1 change: 1 addition & 0 deletions pyroapi/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .test_svi import * # noqa F401
196 changes: 196 additions & 0 deletions pyroapi/tests/test_svi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import pytest

from pyroapi.dispatch import distributions as dist
from pyroapi.dispatch import infer, ops, optim, pyro

# This file tests a variety of model,guide pairs with valid and invalid structure.
# See https://github.com/pyro-ppl/pyro/blob/0.3.1/tests/infer/test_valid_models.py
#
# Note that the backend arg to these tests must be provided as a
# user-defined fixture that sets the pyro_backend. For demonstration,
# see test/conftest.py.


def assert_ok(model, guide, elbo, *args, **kwargs):
"""
Assert that inference works without warnings or errors.
"""
pyro.get_param_store().clear()
adam = optim.Adam({"lr": 1e-6})
inference = infer.SVI(model, guide, adam, elbo)
for i in range(2):
inference.step(*args, **kwargs)


def test_generate_data(backend):

def model(data=None):
loc = pyro.param("loc", ops.tensor(2.0))
scale = pyro.param("scale", ops.tensor(1.0))
x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
return x

data = model()
assert data.shape == ()


def test_generate_data_plate(backend):
num_points = 1000

def model(data=None):
loc = pyro.param("loc", ops.tensor(2.0))
scale = pyro.param("scale", ops.tensor(1.0))
with pyro.plate("data", 1000, dim=-1):
x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
return x

data = model()
if type(data).__module__.startswith('funsor'):
pytest.xfail(reason='plate is an input, and does not appear in .shape')
assert data.shape == (num_points,)
mean = data.sum().item() / num_points
assert 1.9 <= mean <= 2.1


@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
@pytest.mark.parametrize("optim_name", ["Adam", "ClippedAdam"])
def test_optimizer(backend, optim_name, jit):

def model(data):
p = pyro.param("p", ops.tensor(0.5))
pyro.sample("x", dist.Bernoulli(p), obs=data)

def guide(data):
pass

data = ops.tensor(0.)
pyro.get_param_store().clear()
Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
elbo = Elbo(ignore_jit_warnings=True)
optimizer = getattr(optim, optim_name)({"lr": 1e-6})
inference = infer.SVI(model, guide, optimizer, elbo)
for i in range(2):
inference.step(data)


@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
def test_nonempty_model_empty_guide_ok(backend, jit):

def model(data):
loc = pyro.param("loc", ops.tensor(0.0))
pyro.sample("x", dist.Normal(loc, 1.), obs=data)

def guide(data):
pass

data = ops.tensor(2.)
Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
elbo = Elbo(ignore_jit_warnings=True)
assert_ok(model, guide, elbo, data)


@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
def test_plate_ok(backend, jit):
data = ops.randn(10)

def model():
locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5]))
p = ops.tensor([0.2, 0.3, 0.5])
with pyro.plate("plate", len(data), dim=-1):
x = pyro.sample("x", dist.Categorical(p))
pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)

def guide():
p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2]))
with pyro.plate("plate", len(data), dim=-1):
pyro.sample("x", dist.Categorical(p))

Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
elbo = Elbo(ignore_jit_warnings=True)
assert_ok(model, guide, elbo)


@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
def test_nested_plate_plate_ok(backend, jit):
data = ops.randn(2, 3)

def model():
loc = ops.tensor(3.0)
with pyro.plate("plate_outer", data.size(-1), dim=-1):
x = pyro.sample("x", dist.Normal(loc, 1.))
with pyro.plate("plate_inner", data.size(-2), dim=-2):
pyro.sample("y", dist.Normal(x, 1.), obs=data)

def guide():
loc = pyro.param("loc", ops.tensor(0.))
scale = pyro.param("scale", ops.tensor(1.))
with pyro.plate("plate_outer", data.size(-1), dim=-1):
pyro.sample("x", dist.Normal(loc, scale))

Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
elbo = Elbo(ignore_jit_warnings=True)
assert_ok(model, guide, elbo)


@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
def test_local_param_ok(backend, jit):
data = ops.randn(10)

def model():
locs = pyro.param("locs", ops.tensor([-1., 0., 1.]))
with pyro.plate("plate", len(data), dim=-1):
x = pyro.sample("x", dist.Categorical(ops.ones(3) / 3))
pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)

def guide():
with pyro.plate("plate", len(data), dim=-1):
p = pyro.param("p", ops.ones(len(data), 3) / 3, event_dim=1)
pyro.sample("x", dist.Categorical(p))
return p

Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
elbo = Elbo(ignore_jit_warnings=True)
assert_ok(model, guide, elbo)

# Check that pyro.param() can be called without init_value.
expected = guide()
actual = pyro.param("p")
assert ops.allclose(actual, expected)


@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
def test_constraints(backend, jit):
data = ops.tensor(0.5)

def model():
locs = pyro.param("locs", ops.randn(3),
constraint=dist.constraints.real)
scales = pyro.param("scales", ops.randn(3).exp(),
constraint=dist.constraints.positive)
p = ops.tensor([0.5, 0.3, 0.2])
x = pyro.sample("x", dist.Categorical(p))
pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)

def guide():
q = pyro.param("q", ops.randn(3).exp(),
constraint=dist.constraints.simplex)
pyro.sample("x", dist.Categorical(q))

Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
elbo = Elbo(ignore_jit_warnings=True)
assert_ok(model, guide, elbo)


def test_mean_field_ok(backend):

def model():
x = pyro.sample("x", dist.Normal(0., 1.))
pyro.sample("y", dist.Normal(x, 1.))

def guide():
loc = pyro.param("loc", ops.tensor(0.))
x = pyro.sample("x", dist.Normal(loc, 1.))
pyro.sample("y", dist.Normal(x, 1.))

elbo = infer.TraceMeanField_ELBO()
assert_ok(model, guide, elbo)
2 changes: 2 additions & 0 deletions test/test_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pyroapi import pyro_backend
from pyroapi.tests import * # noqa F401

pytestmark = pytest.mark.filterwarnings("ignore::numpyro.compat.util.UnsupportedAPIWarning")
Copy link
Member

@neerajprad neerajprad Oct 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it will makes more sense to have this in pyro-api instead.



@pytest.fixture(params=["pyro", "minipyro", "numpy", "funsor"])
def backend(request):
Expand Down