diff --git a/pymc3/tests/test_smc.py b/pymc3/tests/test_smc.py index 09688dd9959..882387b3fe3 100644 --- a/pymc3/tests/test_smc.py +++ b/pymc3/tests/test_smc.py @@ -229,32 +229,29 @@ def abs_diff(eps, obs_data, sim_data): def test_one_gaussian(self): with self.SMABC_test: - trace = pm.sample_smc(draws=1000) + trace = pm.sample_smc(draws=1000, chains=2, cores=1, return_inferencedata=False) + pr_p = pm.sample_prior_predictive(1000) + po_p = pm.sample_posterior_predictive(trace, 1000) assert abs(self.data.mean() - trace["a"].mean()) < 0.05 assert abs(self.data.std() - trace["b"].mean()) < 0.05 - def test_sim_data_ppc(self): - with self.SMABC_test: - trace = pm.sample_smc(draws=1000, chains=1) - pr_p = pm.sample_prior_predictive(1000) - po_p = pm.sample_posterior_predictive(trace, 1000) - assert pr_p["s"].shape == (1000, 1000) - assert abs(0 - pr_p["s"].mean()) < 0.05 - assert abs(1.4 - pr_p["s"].std()) < 0.05 + assert abs(0 - pr_p["s"].mean()) < 0.10 + assert abs(1.4 - pr_p["s"].std()) < 0.10 + assert po_p["s"].shape == (1000, 1000) - assert abs(0 - po_p["s"].mean()) < 0.05 - assert abs(1 - po_p["s"].std()) < 0.05 + assert abs(self.data.mean() - po_p["s"].mean()) < 0.10 + assert abs(self.data.std() - po_p["s"].std()) < 0.10 def test_custom_dist_sum(self): with self.SMABC_test2: - trace = pm.sample_smc(draws=100) + trace = pm.sample_smc(draws=100, chains=1) - @pytest.mark.xfail(reason="standard SMC is failing with Potentials") + @pytest.mark.xfail(reason="Potential is failing in SMC") def test_potential(self): with self.SMABC_potential: - trace = pm.sample_smc(draws=1000) + trace = pm.sample_smc(draws=1000, chains=1) assert np.all(trace["a"] >= 0) @pytest.mark.xfail(reason="KL not refactored") @@ -281,7 +278,7 @@ def test_repr_latex(self): def test_name_is_string_type(self): with self.SMABC_potential: assert not self.SMABC_potential.name - trace = pm.sample_smc(draws=10, kernel="SMC") + trace = pm.sample_smc(draws=10, cores=1) assert isinstance(trace._straces[0].name, str) def test_named_models_are_unsupported(self): @@ -339,7 +336,7 @@ def fn(rng, a, size): observed=data2, ) - trace = pm.sample_smc(chains=1) + trace = pm.sample_smc(chains=1, return_inferencedata=False) assert abs(true_a - trace["a"].mean()) < 0.05 assert abs(true_b - trace["b"].mean()) < 0.05