-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pallas] More simplification of grid mapping and calling convention #22593
[pallas] More simplification of grid mapping and calling convention #22593
Conversation
8b6d346
to
abbe898
Compare
abbe898
to
dd177d7
Compare
48467be
to
40154f9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay. It's a big change, so I will do another pass later today.
class GridSpec: | ||
"""Encodes the parameters of the grid, as given through the API. | ||
"""Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Random thought: should we rename pallas_call
to call
since it will always be used with pl.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is a good idea
@@ -276,7 +276,7 @@ def block_shape(self): | |||
|
|||
@property | |||
def compute_index(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make this a method?
def compute_index(self, *args):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not yet gotten to cleanup in the mosaic
directory, I did only required changes. I'll leave this for later.
jax/_src/pallas/core.py
Outdated
dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int)) | ||
# We can't use dataclasses.replace, because our fields are incompatible | ||
# with __init__'s signature. | ||
static_self = copy.copy(grid_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should ideally find a way to make GridSpec
frozen to make it easier to reason about.
This is of course for a different PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, but we don't need to. Now it is used very little internally, and it is turned into GridMapping
very early, which is frozen. GridSpec
is part of the API, and today making it frozen would break, e.g., aqt2, because their code to create GridSpec relies on refining it. Of course, they could use replace
...
9c48945
to
7c83c51
Compare
(index_map_grid_aval,) * len(grid_spec.grid)) | ||
index_map_tree = tree_util.tree_structure((index_map_avals, {})) | ||
|
||
num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit unsatisfying that Pallas core has to know about num_scalar_prefetch
and scratch_shapes
, since both are TPU-specific.
I can't think of a better way, but just flagging in case you have ideas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am actually thinking that we can merge the functionality of PrefetchScalarGridSpec from the TPU into the core.
But even if we don't do that, I think it is worth having a single calling convention, with optional parts that may be applicable only to some platform.
In previous PR jax-ml#22552 I have expanded `GridMapping` to encode more parts of the calling convention. Here we use that new functionality and clean up some code. I have removed the internal methods from `BlockSpec` and `GridSpec` because these classes are part of the API. I added entries to pallas/CHANGELOG.
7c83c51
to
70a11ac
Compare
Uses the helper functions for the calling convention from jax-ml#22552 and jax-ml#22593. PiperOrigin-RevId: 658692284
In previous PR #22552 I have expanded
GridMapping
to encode moreparts of the calling convention. Here we use that new functionality
and clean up some code.
I have removed the internal methods from
BlockSpec
andGridSpec
becausethese classes are part of the API. I also removed dangerous
unsafe_hash
tagson
BlockSpec
andGridSpec
.I added entries to pallas/CHANGELOG.