From 9739855c809f67fef9956c0a1a2764b019b515c0 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 19 May 2022 20:44:15 +0200 Subject: [PATCH] Test seeding of sample_numpyro_nuts --- pymc/tests/test_sampling_jax.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index b7e9dc0d4c7..dc3e4ad4c84 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -199,9 +199,16 @@ def test_get_batched_jittered_initial_points(): assert np.all(ips[0][0] != ips[0][1]) +@pytest.mark.parametrize( + "sampler", + [ + sample_blackjax_nuts, + sample_numpyro_nuts, + ], +) @pytest.mark.parametrize("random_seed", (None, 123)) @pytest.mark.parametrize("chains", (1, 2)) -def test_seeding(chains, random_seed): +def test_seeding(chains, random_seed, sampler): sample_kwargs = dict( tune=100, draws=5, @@ -211,8 +218,8 @@ def test_seeding(chains, random_seed): with pm.Model() as m: pm.Normal("x", mu=0, sigma=1) - result1 = sample_numpyro_nuts(**sample_kwargs) - result2 = sample_numpyro_nuts(**sample_kwargs) + result1 = sampler(**sample_kwargs) + result2 = sampler(**sample_kwargs) all_equal = np.all(result1.posterior["x"] == result2.posterior["x"]) if random_seed is None: