Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] More simplification of grid mapping and calling convention #22593

Merged
merged 1 commit into from
Jul 29, 2024

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jul 23, 2024

In previous PR #22552 I have expanded GridMapping to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from BlockSpec and GridSpec because
these classes are part of the API. I also removed dangerous unsafe_hash tags
on BlockSpec and GridSpec.

I added entries to pallas/CHANGELOG.

@gnecula gnecula self-assigned this Jul 23, 2024
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jul 23, 2024
@gnecula gnecula changed the title Pallas more simplification [pallas] More simplification of grid mapping and calling convention Jul 23, 2024
@gnecula gnecula force-pushed the pallas_more_simplification branch 7 times, most recently from 8b6d346 to abbe898 Compare July 24, 2024 12:28
@gnecula gnecula requested a review from superbobry July 24, 2024 12:28
@gnecula gnecula force-pushed the pallas_more_simplification branch from abbe898 to dd177d7 Compare July 24, 2024 12:33
@gnecula gnecula force-pushed the pallas_more_simplification branch 8 times, most recently from 48467be to 40154f9 Compare July 26, 2024 09:19
Copy link
Collaborator

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. It's a big change, so I will do another pass later today.

class GridSpec:
"""Encodes the parameters of the grid, as given through the API.
"""Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random thought: should we rename pallas_call to call since it will always be used with pl.?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good idea

tests/pallas/pallas_test.py Outdated Show resolved Hide resolved
@@ -276,7 +276,7 @@ def block_shape(self):

@property
def compute_index(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this a method?

def compute_index(self, *args):
  ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not yet gotten to cleanup in the mosaic directory, I did only required changes. I'll leave this for later.

dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int))
# We can't use dataclasses.replace, because our fields are incompatible
# with __init__'s signature.
static_self = copy.copy(grid_spec)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should ideally find a way to make GridSpec frozen to make it easier to reason about.

This is of course for a different PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but we don't need to. Now it is used very little internally, and it is turned into GridMapping very early, which is frozen. GridSpec is part of the API, and today making it frozen would break, e.g., aqt2, because their code to create GridSpec relies on refining it. Of course, they could use replace ...

@gnecula gnecula force-pushed the pallas_more_simplification branch 4 times, most recently from 9c48945 to 7c83c51 Compare July 29, 2024 05:23
(index_map_grid_aval,) * len(grid_spec.grid))
index_map_tree = tree_util.tree_structure((index_map_avals, {}))

num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unsatisfying that Pallas core has to know about num_scalar_prefetch and scratch_shapes, since both are TPU-specific.

I can't think of a better way, but just flagging in case you have ideas.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually thinking that we can merge the functionality of PrefetchScalarGridSpec from the TPU into the core.

But even if we don't do that, I think it is worth having a single calling convention, with optional parts that may be applicable only to some platform.

In previous PR jax-ml#22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
@gnecula gnecula force-pushed the pallas_more_simplification branch from 7c83c51 to 70a11ac Compare July 29, 2024 13:54
@copybara-service copybara-service bot merged commit e78e643 into jax-ml:main Jul 29, 2024
13 of 14 checks passed
@gnecula gnecula deleted the pallas_more_simplification branch July 29, 2024 17:28
copybara-service bot pushed a commit that referenced this pull request Jul 30, 2024
Uses the helper functions for the calling convention from #22552 and #22593.

PiperOrigin-RevId: 657524854
copybara-service bot pushed a commit that referenced this pull request Jul 30, 2024
Uses the helper functions for the calling convention from #22552 and #22593.

PiperOrigin-RevId: 657524854
copybara-service bot pushed a commit that referenced this pull request Jul 30, 2024
Uses the helper functions for the calling convention from #22552 and #22593.

PiperOrigin-RevId: 657524854
copybara-service bot pushed a commit that referenced this pull request Aug 2, 2024
Uses the helper functions for the calling convention from #22552 and #22593.

PiperOrigin-RevId: 657524854
copybara-service bot pushed a commit that referenced this pull request Aug 2, 2024
Uses the helper functions for the calling convention from #22552 and #22593.

PiperOrigin-RevId: 658692284
nitins17 pushed a commit to google-ml-infra/jax-fork that referenced this pull request Aug 27, 2024
Uses the helper functions for the calling convention from jax-ml#22552 and jax-ml#22593.

PiperOrigin-RevId: 658692284
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants