Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.random.fold_in function is using static_argnum=(1,) #1033

Closed
pzhokhov opened this issue Jul 19, 2019 · 4 comments
Closed

jax.random.fold_in function is using static_argnum=(1,) #1033

pzhokhov opened this issue Jul 19, 2019 · 4 comments
Labels
bug Something isn't working

Comments

@pzhokhov
Copy link
Contributor

pzhokhov commented Jul 19, 2019

Currently, the implementation of jax.random.fold_in function treats second argument data as static:
https://github.com/google/jax/blob/4c34541c00c02fa750b63a6ea9149909e6c4078f/jax/random.py#L193
Is that the intended behavior?
If I understand correctly, it means, for instance, that innocent-looking loops of sort:

base_key = jax.random.PRNGKey(0)
for i in range(N): 
   do_something_random(jax.random.fold_in(base_key, i))

cause fold_in to recompile at every iteration of the loop and take a lot of time.
On the other hand, if we jit-compile fold_in function, things seem to work correctly and fast; like in this gist (thanks @christopherhesse):
https://gist.github.com/christopherhesse/f493e516b7786533d76c3ef689cb6a45
Should that be the default behavior?

@mattjj
Copy link
Collaborator

mattjj commented Jul 19, 2019

Thanks for spotting this; I think it's almost just a bug, but it looks like we only handle 64bit Python integer data correctly by using static_argnums=1. Yet there's no test for that, and moreover we could just define fold_in only to handle 32bit integers (updating the docstring).

Want to make a PR with your fix?

@pzhokhov
Copy link
Contributor Author

Thanks for quick response @mattjj! Not sure I understand the 64bit vs 32bit integer part... You mean 64bit integer cannot be passed as an argument to jit-compiled function, hence it was made static? I tried running the code as in the gist above, but passing np.int64(i) as a data argument; and both versions (with and without static argument) seem to work correctly.
As for making a PR - sure thing, will do!

@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2019

The 64bit issue is just that in the jax.random.PRNGKey function we handle 64bit values by checking isinstance and that kind of code won't work when data is traced as an argument to a jit function, just because when JAX_ENABLE_X64=0 (the default) we'll cast it down to a 32bit int. (Using static_argnums avoids that because it delays when data is effectively staged into JAX until after that isinstance check.)

I'm not worried about that behavior, so it's better to merge #1039. Thanks!

@hawkinsp hawkinsp added the bug Something isn't working label Aug 4, 2019
@mattjj
Copy link
Collaborator

mattjj commented Dec 17, 2019

I verified that this issue is solved. Thanks, @joaogui1 !

@mattjj mattjj closed this as completed Dec 17, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants