-
On a regular training loop I usually split keys using this pattern which seems to be the most common: key = jax.random.PRNGKey(0)
for x, y in data:
key, step_key = jax.random.split(key, 2)
model = train_step(step_key, model, x, y) However, I've recently realized you can also do something like this: key = jax.random.PRNGKey(0)
for x, y in data:
key = jax.random.split(key, 1)[0]
model = train_step(key, model, x, y) Since |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Thanks for the question! The first example is better, because in it the caller guarantees random values can't be accidentally reused. In the latter example the key is used twice: once as an argument to import jax
key = jax.random.PRNGKey(0)
def train_step(key):
key, = jax.random.split(key, 1)
print(jax.random.normal(key, ()))
key, = jax.random.split(key, 1)
print(jax.random.normal(key, ()))
for i in range(2):
print(f'iter {i}')
key, = jax.random.split(key, 1)
train_step(key)
|
Beta Was this translation helpful? Give feedback.
-
from itertools import accumulate, chain, repeat
from jax.random import PRNGKey, split
from operator import itemgetter
seed = 42
keys = map(itemgetter(0), accumulate(chain((split(PRNGKey(seed)),), repeat(None)), lambda acc, _: split(acc[1])))
print(next(keys)) # [2465931498 3679230171]
print(next(keys)) # [1224796891 3487907634]
print(next(keys)) # [252818381 516428635]
print(next(keys)) # [3126834917 176854060]
print(next(keys)) # [ 207868211 2198494137]
print(next(keys)) # [358328843 191880770]
print(next(keys)) # [2134635248 4286345057]
print(next(keys)) # [2517085557 1569996244] |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
The first example is better, because in it the caller guarantees random values can't be accidentally reused.
In the latter example the key is used twice: once as an argument to
train_step
and again as an argument tospilt
on the next loop iteration. That's dangerous in the sense that the different calls totrain_step
might end up using the same random values. Here's an example definition oftrain_step
where that happens: