Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] Fix the handling of captured consts #22550

Merged
merged 1 commit into from
Jul 22, 2024

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jul 21, 2024

There was an attempt to handle consts captured by the kernel, but it was incomplete and with errors: the calling convention was wrong, and the support for handling consts along with scalar prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py, one in pallas_vmap_test.py, and four tests in tpu_pallas_test.py (to handle scalar prefetch, with and without scratch inputs, with and without vmap).

The calling convention now: *scalar_refs, *consts, *ins, *outs, *scratch. This is different from before (*consts, *scalar_refs, *ins, ...) so that it keeps the block arguments (consts, ins, outs) together and makes it easier to write the lowering.

Fixes: #21557.

I will follow up with a cleanup PR for the handling of grid_mapping. Here I attempted to minimize the changes.

@gnecula gnecula self-assigned this Jul 21, 2024
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jul 21, 2024
@gnecula gnecula force-pushed the pallas_consts branch 10 times, most recently from bccf6d1 to 478704d Compare July 21, 2024 16:05
@gnecula gnecula requested a review from superbobry July 21, 2024 16:06
Copy link
Collaborator

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, George!

Maybe reference #21557?

tests/pallas/pallas_test.py Show resolved Hide resolved
@gnecula
Copy link
Collaborator Author

gnecula commented Jul 22, 2024

Thanks, George!

Maybe reference #21557?

Done, I think that this fixes that bug.

@gnecula gnecula force-pushed the pallas_consts branch 2 times, most recently from 1a47c62 to 9d9b69f Compare July 22, 2024 06:21
jax/_src/pallas/mosaic/lowering.py Outdated Show resolved Hide resolved
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).

The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.

I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
@copybara-service copybara-service bot merged commit 433f66a into jax-ml:main Jul 22, 2024
11 of 14 checks passed
@gnecula gnecula deleted the pallas_consts branch July 22, 2024 12:12
gnecula added a commit to gnecula/jax that referenced this pull request Jul 30, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
gnecula added a commit to gnecula/jax that referenced this pull request Jul 30, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
gnecula added a commit to gnecula/jax that referenced this pull request Jul 30, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
gnecula added a commit to gnecula/jax that referenced this pull request Jul 30, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
gnecula added a commit to gnecula/jax that referenced this pull request Jul 31, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
nitins17 pushed a commit to google-ml-infra/jax-fork that referenced this pull request Aug 27, 2024
Previously this was allowed, but until recently (jax-ml#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

constants in Pallas kernel causes ValueError: safe_zip() argument 2 is shorter than argument 1
3 participants