Skip to content

Commit

Permalink
Add HalfCauchyRV JAX implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
theorashid authored and rlouf committed Dec 15, 2022
1 parent 67ea5b6 commit 3418496
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
20 changes: 20 additions & 0 deletions aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,26 @@ def sample_fn(rng, size, dtype, *parameters):
return sample_fn


@jax_sample_fn.register(aer.HalfCauchyRV)
def jax_sample_fn_halfcauchy(op):
"""JAX implementation of `HalfCauchyRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(
loc,
scale,
) = parameters
sample = loc + jax.numpy.abs(
jax.random.cauchy(sampling_key, size, dtype) * scale
)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn


@jax_sample_fn.register(aer.ChoiceRV)
def jax_funcify_choice(op):
"""JAX implementation of `ChoiceRV`."""
Expand Down
16 changes: 16 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,22 @@ def test_random_updates(rng_ctor):
"halfnorm",
lambda *args: args,
),
(
aer.halfcauchy,
[
set_test_value(
at.dvector(),
np.array([-1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1000.0, dtype=np.float64),
),
],
(2,),
"halfcauchy",
lambda *args: args,
),
],
)
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
Expand Down

0 comments on commit 3418496

Please sign in to comment.