Skip to content

Commit

Permalink
Implement JAX geometric sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create authored and rlouf committed Feb 21, 2023
1 parent 4c74650 commit 02bb94f
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,20 @@ def sample_fn(rng, size, dtype, *parameters):
return (rng, samples)

return sample_fn


@jax_sample_fn.register(aer.GeometricRV)
def jax_sample_fn_geometric(op):
"""JAX implementation of `GeometricRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
p = parameters[0]
sample_num = jax.numpy.log(jax.random.uniform(sampling_key, size))
sample = sample_num / jax.numpy.log1p(-p)
sample_ceil = jax.numpy.ceil(sample)
rng["jax_state"] = rng_key
return (rng, sample_ceil)

return sample_fn

0 comments on commit 02bb94f

Please sign in to comment.