Skip to content

Commit

Permalink
Run all ABC tests in single process
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 5, 2021
1 parent a8c041f commit 7e01fd2
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,32 +229,29 @@ def abs_diff(eps, obs_data, sim_data):

def test_one_gaussian(self):
with self.SMABC_test:
trace = pm.sample_smc(draws=1000)
trace = pm.sample_smc(draws=1000, chains=2, cores=1, return_inferencedata=False)
pr_p = pm.sample_prior_predictive(1000)
po_p = pm.sample_posterior_predictive(trace, 1000)

assert abs(self.data.mean() - trace["a"].mean()) < 0.05
assert abs(self.data.std() - trace["b"].mean()) < 0.05

def test_sim_data_ppc(self):
with self.SMABC_test:
trace = pm.sample_smc(draws=1000, chains=1)
pr_p = pm.sample_prior_predictive(1000)
po_p = pm.sample_posterior_predictive(trace, 1000)

assert pr_p["s"].shape == (1000, 1000)
assert abs(0 - pr_p["s"].mean()) < 0.05
assert abs(1.4 - pr_p["s"].std()) < 0.05
assert abs(0 - pr_p["s"].mean()) < 0.10
assert abs(1.4 - pr_p["s"].std()) < 0.10

assert po_p["s"].shape == (1000, 1000)
assert abs(0 - po_p["s"].mean()) < 0.05
assert abs(1 - po_p["s"].std()) < 0.05
assert abs(self.data.mean() - po_p["s"].mean()) < 0.10
assert abs(self.data.std() - po_p["s"].std()) < 0.10

def test_custom_dist_sum(self):
with self.SMABC_test2:
trace = pm.sample_smc(draws=100)
trace = pm.sample_smc(draws=100, chains=1)

@pytest.mark.xfail(reason="standard SMC is failing with Potentials")
@pytest.mark.xfail(reason="Potential is failing in SMC")
def test_potential(self):
with self.SMABC_potential:
trace = pm.sample_smc(draws=1000)
trace = pm.sample_smc(draws=1000, chains=1)
assert np.all(trace["a"] >= 0)

@pytest.mark.xfail(reason="KL not refactored")
Expand All @@ -281,7 +278,7 @@ def test_repr_latex(self):
def test_name_is_string_type(self):
with self.SMABC_potential:
assert not self.SMABC_potential.name
trace = pm.sample_smc(draws=10, kernel="SMC")
trace = pm.sample_smc(draws=10, cores=1)
assert isinstance(trace._straces[0].name, str)

def test_named_models_are_unsupported(self):
Expand Down Expand Up @@ -339,7 +336,7 @@ def fn(rng, a, size):
observed=data2,
)

trace = pm.sample_smc(chains=1)
trace = pm.sample_smc(chains=1, return_inferencedata=False)

assert abs(true_a - trace["a"].mean()) < 0.05
assert abs(true_b - trace["b"].mean()) < 0.05
Expand Down

0 comments on commit 7e01fd2

Please sign in to comment.