Skip to content

Commit

Permalink
Add ChiSquareRV JAX implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama authored and rlouf committed Dec 15, 2022
1 parent 3418496 commit 870e1c3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
15 changes: 15 additions & 0 deletions aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,18 @@ def sample_fn(rng, size, dtype, *parameters):
return (rng, sample_exp)

return sample_fn


@jax_sample_fn.register(aer.ChiSquareRV)
def jax_sample_fn_chisquare(op):
"""JAX implementation of `ChiSquareRV`"""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(df,) = parameters
sample = jax.random.gamma(sampling_key, df / 2, size, dtype) * 2
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
12 changes: 12 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def test_random_updates(rng_ctor):
"cauchy",
lambda *args: args,
),
(
aer.chisquare,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
"chi2",
lambda *args: args,
),
(
aer.exponential,
[
Expand Down

0 comments on commit 870e1c3

Please sign in to comment.