diff --git a/aesara/link/jax/dispatch/random.py b/aesara/link/jax/dispatch/random.py index 08552f232a..6fb9c28c84 100644 --- a/aesara/link/jax/dispatch/random.py +++ b/aesara/link/jax/dispatch/random.py @@ -389,3 +389,20 @@ def sample_fn(rng, size, dtype, *parameters): return (rng, samples) return sample_fn + + +@jax_sample_fn.register(aer.GeometricRV) +def jax_sample_fn_geometric(op): + """JAX implementation of `GeometricRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + p = parameters[0] + sample_num = jax.numpy.log(jax.random.uniform(sampling_key, size)) + sample = sample_num / jax.numpy.log1p(-p) + sample_ceil = jax.numpy.ceil(sample) + rng["jax_state"] = rng_key + return (rng, sample_ceil) + + return sample_fn