Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

define a loop-free untrue batching rule for rng_bit_generator #20094

Merged
merged 1 commit into from
Mar 8, 2024

Conversation

froystig
Copy link
Member

@froystig froystig commented Mar 6, 2024

fixes #19085, #16792

See #19085 for details.

Note that this is a random-bits-altering change, though only for vmapped random generation.

@froystig froystig requested a review from mattjj March 6, 2024 04:16
@froystig froystig self-assigned this Mar 6, 2024
@froystig froystig linked an issue Mar 6, 2024 that may be closed by this pull request
@froystig froystig added the pull ready Ready for copybara import and testing label Mar 6, 2024
@froystig froystig force-pushed the vmap-rbg branch 4 times, most recently from f89b528 to 2cf3d97 Compare March 6, 2024 22:43
@copybara-service copybara-service bot merged commit c4cf265 into jax-ml:main Mar 8, 2024
13 checks passed
@froystig froystig deleted the vmap-rbg branch March 8, 2024 21:59
@vikmary
Copy link

vikmary commented Mar 17, 2024

This solves my issue with per-example independent sampling (before it was 100X slower compared to dependent sampling), thanks for the fix!

@froystig froystig mentioned this pull request Apr 4, 2024
ruomingp added a commit to ruomingp/axlearn that referenced this pull request May 10, 2024
jax-ml/jax#20094 changes the behavior of RNG in vmap, so we can no longer rely on identical layer param initialization when using vmap vs. not. This affects RepeatedTransformerLayer and fused QKV layers.

The fix is to convert layer params from the reference layer to the test layer instead of relying on identical initialization.
github-merge-queue bot pushed a commit to apple/axlearn that referenced this pull request May 11, 2024
* Upgrades jax from 0.4.25 to 0.4.27.

* Fixes attention_test.

jax-ml/jax#20094 changes the behavior of RNG in vmap, so we can no longer rely on identical layer param initialization when using vmap vs. not. This affects RepeatedTransformerLayer and fused QKV layers.

The fix is to convert layer params from the reference layer to the test layer instead of relying on identical initialization.

* Fixes rnn_test.

* Fixes test_split_prng_key.

* Fixes test_split_prng_key.

* Fixes test_parent_children.

* Upgrades to jax 0.4.28.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

efficient untrue batching of random_bit_generator unsafe_rbg + vmap --> 10x slow down
4 participants