You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If you want to use random.seed, random.unform, and so on in jax, you have to use from jax.random import PRNGKey and how to use the random module is different from the numpy one.
So I would like to implement functions for random module in jax, make it so simpler and enable us to use like numpy random when registering jax modules.
I would like to use like
jax.random.seed(42)
jax.random.normal() # this produce the same results as long as a certain random seed is specified.# As a default, jax.random.normal is provided but this requires PRNG key argument.
The text was updated successfully, but these errors were encountered:
to24toro
changed the title
Registering a module like random module in jax as default
Registering a random module in jax when registering jax modules
May 24, 2023
to24toro
changed the title
Registering a random module in jax when registering jax modules
Register a random module in jax when registering jax modules
May 24, 2023
What is the expected behavior?
If you want to use random.seed, random.unform, and so on in jax, you have to use
from jax.random import PRNGKey
and how to use the random module is different from the numpy one.So I would like to implement functions for random module in jax, make it so simpler and enable us to use like numpy random when registering jax modules.
I would like to use like
The text was updated successfully, but these errors were encountered: