Skip to content

Commit

Permalink
Add ChiSquareRV JAX implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama authored and brandonwillard committed Mar 9, 2023
1 parent 83ce997 commit be7e3ea
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
16 changes: 15 additions & 1 deletion aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def jax_sample_fn_wald(op):
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

mean, scale = parameters

key1, key2 = jax.random.split(sampling_key, 2)
Expand All @@ -391,6 +390,21 @@ def sample_fn(rng, size, dtype, *parameters):
return sample_fn


@jax_sample_fn.register(aer.ChiSquareRV)
def jax_sample_fn_chisquare(op):
"""JAX implementation of `ChiSquareRV`"""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(df,) = parameters
sample = jax.random.gamma(sampling_key, df / 2, size, dtype) * 2
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn


@jax_sample_fn.register(aer.GeometricRV)
def jax_sample_fn_geometric(op):
"""JAX implementation of `GeometricRV`."""
Expand Down
21 changes: 16 additions & 5 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ def test_random_updates(rng_ctor):
lambda *args: args,
None,
),
(
aer.chisquare,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
"chi2",
lambda *args: args,
None,
),
(
aer.exponential,
[
Expand Down Expand Up @@ -587,12 +600,10 @@ def test_random_concrete_shape_subtensor_tuple():
assert jax_fn(np.ones((2, 3))).shape == (2,)


@pytest.mark.xfail(
reason="`size_at` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input():
"""JAX cannot JIT-compile random variables whose `size` argument is not static."""
rng = shared(np.random.RandomState(123))
size_at = at.scalar()
out = at.random.normal(0, 1, size=size_at, rng=rng)
jax_fn = function([size_at], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)
with pytest.raises(NotImplementedError, match=r".* concrete values .*"):
function([size_at], out, mode=jax_mode)

0 comments on commit be7e3ea

Please sign in to comment.