diff --git a/aesara/link/jax/dispatch/random.py b/aesara/link/jax/dispatch/random.py index 015d295a94..8e9cd08059 100644 --- a/aesara/link/jax/dispatch/random.py +++ b/aesara/link/jax/dispatch/random.py @@ -346,3 +346,18 @@ def sample_fn(rng, size, dtype, *parameters): return (rng, sample_exp) return sample_fn + + +@jax_sample_fn.register(aer.ChiSquareRV) +def jax_sample_fn_chisquare(op): + """JAX implementation of `ChiSquareRV`""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + (df,) = parameters + sample = jax.random.gamma(sampling_key, df / 2, size, dtype) * 2 + rng["jax_state"] = rng_key + return (rng, sample) + + return sample_fn diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index d9710adfd5..e7c09e70f7 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -95,6 +95,18 @@ def test_random_updates(rng_ctor): "cauchy", lambda *args: args, ), + ( + aer.chisquare, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ) + ], + (2,), + "chi2", + lambda *args: args, + ), ( aer.exponential, [