Skip to content

PRNGKey split patterns #10030

Answered by mattjj
cgarciae asked this question in General
Mar 24, 2022 · 2 comments · 2 replies
Discussion options

You must be logged in to vote

Thanks for the question!

The first example is better, because in it the caller guarantees random values can't be accidentally reused.

In the latter example the key is used twice: once as an argument to train_step and again as an argument to spilt on the next loop iteration. That's dangerous in the sense that the different calls to train_step might end up using the same random values. Here's an example definition of train_step where that happens:

import jax
key = jax.random.PRNGKey(0)

def train_step(key):
  key, = jax.random.split(key, 1)
  print(jax.random.normal(key, ()))
  key, = jax.random.split(key, 1)
  print(jax.random.normal(key, ()))

for i in range(2):
  print(f'iter {i}')
  key, =

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@cgarciae
Comment options

cgarciae Mar 25, 2022
Collaborator Author

Answer selected by cgarciae
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
4 participants