From 341849645ebeb1b0cc67d475b44c4909680a88f2 Mon Sep 17 00:00:00 2001 From: theorashid Date: Wed, 14 Dec 2022 16:20:47 +0000 Subject: [PATCH] Add HalfCauchyRV JAX implementation --- aesara/link/jax/dispatch/random.py | 20 ++++++++++++++++++++ tests/link/jax/test_random.py | 16 ++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/aesara/link/jax/dispatch/random.py b/aesara/link/jax/dispatch/random.py index 16934676c8..015d295a94 100644 --- a/aesara/link/jax/dispatch/random.py +++ b/aesara/link/jax/dispatch/random.py @@ -282,6 +282,26 @@ def sample_fn(rng, size, dtype, *parameters): return sample_fn +@jax_sample_fn.register(aer.HalfCauchyRV) +def jax_sample_fn_halfcauchy(op): + """JAX implementation of `HalfCauchyRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + ( + loc, + scale, + ) = parameters + sample = loc + jax.numpy.abs( + jax.random.cauchy(sampling_key, size, dtype) * scale + ) + rng["jax_state"] = rng_key + return (rng, sample) + + return sample_fn + + @jax_sample_fn.register(aer.ChoiceRV) def jax_funcify_choice(op): """JAX implementation of `ChoiceRV`.""" diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 32f980f7dd..d9710adfd5 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -296,6 +296,22 @@ def test_random_updates(rng_ctor): "halfnorm", lambda *args: args, ), + ( + aer.halfcauchy, + [ + set_test_value( + at.dvector(), + np.array([-1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1000.0, dtype=np.float64), + ), + ], + (2,), + "halfcauchy", + lambda *args: args, + ), ], ) def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):