Skip to content

Commit

Permalink
[pallas] Small cleanup in the Mosaic lowering
Browse files Browse the repository at this point in the history
Uses the helper functions for the calling convention from #22552 and #22593.

PiperOrigin-RevId: 657524854
  • Loading branch information
gnecula authored and jax authors committed Jul 30, 2024
1 parent cc21245 commit 57ba7bc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
8 changes: 6 additions & 2 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,14 @@ def slice_index_ops(self):

@property
def slice_block_ops(self):
"""Returns a slice to select all but the index operands to a kernel.
"""Returns a slice to select the block operands to a kernel.
The block operands are: *consts, *ins, *outs, the same for which we
have `self.block_mappings`.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
"""
return slice(self.num_index_operands, None)
return slice(self.num_index_operands,
self.num_index_operands + len(self.block_mappings))

@property
def slice_scratch_ops(self):
Expand Down
18 changes: 4 additions & 14 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,6 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
self.jaxpr = jaxpr
self.block_mappings = grid_mapping.block_mappings
self.mapped_dims = grid_mapping.vmapped_dims
# TODO(necula): clean this using new grid_mapping helpers
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
num_operands = (
len(self.jaxpr.invars)
- num_scalar_prefetch
- num_scratch
)
user_grid = tuple(
g for i, g in enumerate(self.grid) if i not in self.mapped_dims
)
Expand All @@ -315,8 +306,6 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
raise ValueError(
"Must have dimension semantics for each dimension of the grid."
)
if num_operands != len(self.block_mappings):
raise ValueError("Must have block mappings for each operand.")
assert len(self.mapped_dims) + len(dimension_semantics) == len(
self.grid
), (
Expand All @@ -332,9 +321,10 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
)

in_avals = [invar.aval for invar in self.jaxpr.invars]
scalar_prefetch_avals, operand_avals, scratch_avals = split_list(
in_avals, [num_scalar_prefetch, num_operands]
)
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
scalar_prefetch_avals = in_avals[grid_mapping.slice_index_ops]
operand_avals = in_avals[grid_mapping.slice_block_ops]
scratch_avals = in_avals[grid_mapping.slice_scratch_ops]
self.scalar_prefetch_types, _ = unzip2([
_get_arg_type(aval, None)
for aval in scalar_prefetch_avals])
Expand Down

0 comments on commit 57ba7bc

Please sign in to comment.