Skip to content

Commit

Permalink
Test seeding of sample_numpyro_nuts
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 19, 2022
1 parent 3f5b297 commit 9739855
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,16 @@ def test_get_batched_jittered_initial_points():
assert np.all(ips[0][0] != ips[0][1])


@pytest.mark.parametrize(
"sampler",
[
sample_blackjax_nuts,
sample_numpyro_nuts,
],
)
@pytest.mark.parametrize("random_seed", (None, 123))
@pytest.mark.parametrize("chains", (1, 2))
def test_seeding(chains, random_seed):
def test_seeding(chains, random_seed, sampler):
sample_kwargs = dict(
tune=100,
draws=5,
Expand All @@ -211,8 +218,8 @@ def test_seeding(chains, random_seed):

with pm.Model() as m:
pm.Normal("x", mu=0, sigma=1)
result1 = sample_numpyro_nuts(**sample_kwargs)
result2 = sample_numpyro_nuts(**sample_kwargs)
result1 = sampler(**sample_kwargs)
result2 = sampler(**sample_kwargs)

all_equal = np.all(result1.posterior["x"] == result2.posterior["x"])
if random_seed is None:
Expand Down

0 comments on commit 9739855

Please sign in to comment.