Skip to content

Commit

Permalink
Add tests for geometric JAX samples
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create authored and rlouf committed Feb 21, 2023
1 parent 02bb94f commit 311c901
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,21 @@ def test_random_bernoulli(size):
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)


@pytest.mark.parametrize(
"p, size",
[
(0.6, ()),
(0.2, (4,)),
],
)
def test_random_geometric(p, size):
rng = shared(np.random.RandomState(123))
g = at.random.geometric(p, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(), 1 / p, atol=0.1)


def test_random_mvnormal():
rng = shared(np.random.RandomState(123))

Expand Down

0 comments on commit 311c901

Please sign in to comment.