Skip to content

Commit

Permalink
jax lognormal
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 9, 2022
1 parent fe3e76d commit 8d3089d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
15 changes: 15 additions & 0 deletions aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,18 @@ def sample_fn(rng, size, dtype, *parameters):
return (rng, sample)

return sample_fn


@jax_sample_fn.register(aer.LogNormalRV)
def jax_sample_fn_lognormal(op):
"""JAX implementation of `LogNormalRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
loc, scale = parameters
sample = loc + jax.random.normal(rng_key, size, dtype) * scale
sample_exp = jax.numpy.exp(sample)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample_exp)

return sample_fn
16 changes: 16 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,22 @@ def test_random_updates(rng_ctor):
"logistic",
lambda *args: args,
),
(
aer.lognormal,
[
set_test_value(
at.lvector(),
np.array([0, 0], dtype=np.int64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"lognorm",
lambda *args: args,
),
(
aer.normal,
[
Expand Down

0 comments on commit 8d3089d

Please sign in to comment.