Skip to content

Commit

Permalink
Add an InvGamma JAX implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pscemama-mitre authored and brandonwillard committed Mar 23, 2023
1 parent 4a687c0 commit 4e01892
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
21 changes: 21 additions & 0 deletions aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,24 @@ def sample_fn(rng, size, dtype, *parameters):
return (rng, samples)

return sample_fn


@jax_sample_fn.register(aer.InvGammaRV)
def jax_sample_fn_invgamma(op):
"""JAX implementation of `InvGammaRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

(
shape,
scale,
) = parameters
# InvGamma[shape, scale] <-> 1 / Gamma[shape, 1 / scale]
samples = 1 / (jax.random.gamma(sampling_key, shape, size, dtype) / scale)

rng["jax_state"] = rng_key
return (rng, samples)

return sample_fn
13 changes: 13 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,19 @@ def test_random_dirichlet(parameter, size):
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)


@pytest.mark.parametrize(
"shape, scale",
[(3, 3), (2, 1), (2, 5)],
)
def test_random_invgamma(shape, scale):
rng = shared(np.random.RandomState(123))
g = at.random.invgamma(shape, scale, size=(100000,), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
# mean = scale / (shape - 1) only exists for shape > 1
np.testing.assert_allclose(samples.mean(), scale / (shape - 1), rtol=1e-01)


def test_random_choice():
# Elements are picked at equal frequency
num_samples = 10000
Expand Down

0 comments on commit 4e01892

Please sign in to comment.