-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.random.fold_in function is using static_argnum=(1,) #1033
Comments
Thanks for spotting this; I think it's almost just a bug, but it looks like we only handle 64bit Python integer Want to make a PR with your fix? |
Thanks for quick response @mattjj! Not sure I understand the 64bit vs 32bit integer part... You mean 64bit integer cannot be passed as an argument to jit-compiled function, hence it was made static? I tried running the code as in the gist above, but passing |
The 64bit issue is just that in the I'm not worried about that behavior, so it's better to merge #1039. Thanks! |
I verified that this issue is solved. Thanks, @joaogui1 ! |
Currently, the implementation of
jax.random.fold_in
function treats second argumentdata
as static:https://github.com/google/jax/blob/4c34541c00c02fa750b63a6ea9149909e6c4078f/jax/random.py#L193
Is that the intended behavior?
If I understand correctly, it means, for instance, that innocent-looking loops of sort:
cause fold_in to recompile at every iteration of the loop and take a lot of time.
On the other hand, if we jit-compile
fold_in
function, things seem to work correctly and fast; like in this gist (thanks @christopherhesse):https://gist.github.com/christopherhesse/f493e516b7786533d76c3ef689cb6a45
Should that be the default behavior?
The text was updated successfully, but these errors were encountered: