Skip to content

Commit

Permalink
Merge pull request #3779 from pymc-devs/disable-broken-test
Browse files Browse the repository at this point in the history
Fix broken VI test (pytest issue)
  • Loading branch information
lucianopaz authored Jan 20, 2020
2 parents 636ef8a + af75d8b commit 03a64aa
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
12 changes: 6 additions & 6 deletions pymc3/tests/test_shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@
test_to_shapes = [None, tuple(), (10, 5, 4), (10, 1, 1, 5, 1)]


@pytest.fixture(scope="module", params=test_sizes, ids=str)
@pytest.fixture(params=test_sizes, ids=str)
def fixture_sizes(request):
return request.param


@pytest.fixture(scope="module", params=test_shapes, ids=str)
@pytest.fixture(params=test_shapes, ids=str)
def fixture_shapes(request):
return request.param


@pytest.fixture(scope="module", params=[False, True], ids=str)
@pytest.fixture(params=[False, True], ids=str)
def fixture_exception_handling(request):
return request.param


@pytest.fixture(scope="module")
@pytest.fixture()
def samples_to_broadcast(fixture_sizes, fixture_shapes):
samples = [np.empty(s) for s in fixture_shapes]
try:
Expand All @@ -68,7 +68,7 @@ def samples_to_broadcast(fixture_sizes, fixture_shapes):
return fixture_sizes, samples, broadcast_shape


@pytest.fixture(scope="module", params=test_to_shapes, ids=str)
@pytest.fixture(params=test_to_shapes, ids=str)
def samples_to_broadcast_to(request, samples_to_broadcast):
to_shape = request.param
size, samples, broadcast_shape = samples_to_broadcast
Expand All @@ -82,7 +82,7 @@ def samples_to_broadcast_to(request, samples_to_broadcast):
return to_shape, size, samples, broadcast_shape


@pytest.fixture(scope="module")
@pytest.fixture
def fixture_model():
with pm.Model() as model:
n = 5
Expand Down
5 changes: 3 additions & 2 deletions pymc3/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def three_var_approx(three_var_model, three_var_groups):
def three_var_approx_single_group_mf(three_var_model):
return MeanField(model=three_var_model)


@pytest.fixture(
params = [
('ndarray', None),
Expand Down Expand Up @@ -566,7 +567,7 @@ def use_minibatch(request):
return request.param


@pytest.fixture('module')
@pytest.fixture
def simple_model_data(use_minibatch):
n = 1000
sigma0 = 2.
Expand All @@ -590,7 +591,7 @@ def simple_model_data(use_minibatch):
)


@pytest.fixture(scope='module')
@pytest.fixture
def simple_model(simple_model_data):
with pm.Model() as model:
mu_ = pm.Normal(
Expand Down

0 comments on commit 03a64aa

Please sign in to comment.