-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pallas] Fix the handling of captured consts #22550
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
gnecula
force-pushed
the
pallas_consts
branch
10 times, most recently
from
July 21, 2024 16:05
bccf6d1
to
478704d
Compare
superbobry
approved these changes
Jul 21, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, George!
Maybe reference #21557?
Done, I think that this fixes that bug. |
gnecula
force-pushed
the
pallas_consts
branch
2 times, most recently
from
July 22, 2024 06:21
1a47c62
to
9d9b69f
Compare
superbobry
approved these changes
Jul 22, 2024
There was an attempt to handle consts captured by the kernel, but it was incomplete and with errors: the calling convention was wrong, and the support for handling consts along with scalar prefetch and scratch values was incomplete. I expanded the tests: one in pallas_tests.py and two tests in tpu_pallas_test.py (to handle scalar prefetch, with and without scratch inputs). The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`. This is different from before (`*consts, *scalar_refs, *ins, ...`) so that it keeps the block arguments (consts, ins, outs) together and makes it easier to write the lowering. I will follow up with a cleanup PR for the handling of grid_mapping. Here I attempted to minimize the changes.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Jul 30, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was not working correctly in many cases. Now we disallow const capturing because it can lead to surprises. Instead, the kernel function must receive all the arrays it needs as explicit inputs, with proper block specs.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Jul 30, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was not working correctly in many cases. Now we disallow const capturing because it can lead to surprises. Instead, the kernel function must receive all the arrays it needs as explicit inputs, with proper block specs.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Jul 30, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was not working correctly in many cases. Now we disallow const capturing because it can lead to surprises. Instead, the kernel function must receive all the arrays it needs as explicit inputs, with proper block specs.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Jul 30, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was not working correctly in many cases. Now we disallow const capturing because it can lead to surprises. Instead, the kernel function must receive all the arrays it needs as explicit inputs, with proper block specs.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Jul 31, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was not working correctly in many cases. Now we disallow const capturing because it can lead to surprises. Instead, the kernel function must receive all the arrays it needs as explicit inputs, with proper block specs.
nitins17
pushed a commit
to google-ml-infra/jax-fork
that referenced
this pull request
Aug 27, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was not working correctly in many cases. Now we disallow const capturing because it can lead to surprises. Instead, the kernel function must receive all the arrays it needs as explicit inputs, with proper block specs.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was an attempt to handle consts captured by the kernel, but it was incomplete and with errors: the calling convention was wrong, and the support for handling consts along with scalar prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py, one in pallas_vmap_test.py, and four tests in tpu_pallas_test.py (to handle scalar prefetch, with and without scratch inputs, with and without vmap).
The calling convention now:
*scalar_refs, *consts, *ins, *outs, *scratch
. This is different from before (*consts, *scalar_refs, *ins, ...
) so that it keeps the block arguments (consts, ins, outs) together and makes it easier to write the lowering.Fixes: #21557.
I will follow up with a cleanup PR for the handling of grid_mapping. Here I attempted to minimize the changes.