From 02bb94f40b721891c51391f29a7cc36251c0f536 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Mon, 20 Feb 2023 17:57:21 +0530 Subject: [PATCH] Implement JAX geometric sampling --- aesara/link/jax/dispatch/random.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/aesara/link/jax/dispatch/random.py b/aesara/link/jax/dispatch/random.py index 08552f232a..6fb9c28c84 100644 --- a/aesara/link/jax/dispatch/random.py +++ b/aesara/link/jax/dispatch/random.py @@ -389,3 +389,20 @@ def sample_fn(rng, size, dtype, *parameters): return (rng, samples) return sample_fn + + +@jax_sample_fn.register(aer.GeometricRV) +def jax_sample_fn_geometric(op): + """JAX implementation of `GeometricRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + p = parameters[0] + sample_num = jax.numpy.log(jax.random.uniform(sampling_key, size)) + sample = sample_num / jax.numpy.log1p(-p) + sample_ceil = jax.numpy.ceil(sample) + rng["jax_state"] = rng_key + return (rng, sample_ceil) + + return sample_fn