Skip to content

Commit

Permalink
[pallas] More simplification of grid mapping and calling convention
Browse files Browse the repository at this point in the history
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.
  • Loading branch information
gnecula committed Jul 24, 2024
1 parent 35e7360 commit 8b6d346
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 118 deletions.
39 changes: 18 additions & 21 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""Module for pallas-core functionality."""
from __future__ import annotations

from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
import contextlib
import copy
import dataclasses
import enum
import functools
import itertools
import threading
from typing import Any, Hashable, Union
import warnings
Expand Down Expand Up @@ -425,11 +426,6 @@ def slice_index_ops(self):
"""Returns a slice object to select the index operands to a kernel."""
return slice(0, self.num_index_operands)

@property
def slice_block_ops(self):
"""Returns a slice to select all but the index operands to a kernel."""
return slice(self.num_index_operands, None)

@property
def slice_scratch_ops(self):
"""Returns a slice object to select the scratch operands to a kernel."""
Expand All @@ -438,30 +434,31 @@ def slice_scratch_ops(self):
else:
return slice(0, 0)

# TODO(necula): this is used to recover the old `in_shapes`, but it probably
# is not needed anymore, with some cleanup.
@property
def in_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
"""The shapes of *index, *consts, *inputs."""
index_shapes = [jax.ShapeDtypeStruct(ia.inner_aval.shape,
index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape,
ia.inner_aval.dtype)
for ia in self.index_map_avals[len(self.grid):]]
consts_inputs_shapes = [
for ia in self.index_map_avals[len(self.grid):])
consts_inputs_shapes = (
bm.array_shape_dtype
for bm in self.block_mappings[
0:
self.num_constant_operands + self.num_inputs]]
return tuple(index_shapes + consts_inputs_shapes)
self.num_constant_operands + self.num_inputs])
return itertools.chain(index_shapes, consts_inputs_shapes)

@property
def block_mappings_output(self) -> Iterable[BlockMapping]:
return itertools.islice(
self.block_mappings,
self.num_constant_operands + self.num_inputs,
self.num_constant_operands + self.num_inputs + self.num_outputs)

# TODO(necula): this is used to recover the old `out_shapes`, but it probably
# is not needed anymore, with some cleanup.
@property
def out_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
return tuple(
bm.array_shape_dtype
for bm in self.block_mappings[
self.num_constant_operands + self.num_inputs:
self.num_constant_operands + self.num_inputs + self.num_outputs])
bm.array_shape_dtype for bm in self.block_mappings_output)


def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
if isinstance(dim, jax.Array):
Expand Down
30 changes: 5 additions & 25 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,6 @@ def lower_jaxpr_to_module(
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
# TODO(necula): cleanup
in_shapes = grid_mapping.in_shapes
out_shapes = grid_mapping.out_shapes

mosaic_grid_mapping = MosaicGridMapping(
jaxpr, grid_mapping, dimension_semantics, mesh)
mosaic_grid_mapping.maybe_compress_grid()
Expand All @@ -448,31 +444,15 @@ def lower_jaxpr_to_module(
window_params = []
grid = mosaic_grid_mapping.grid
if grid:
invars = jaxpr.invars
if grid_mapping.num_scratch_operands > 0:
invars = invars[
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
else:
invars = invars[grid_mapping.num_index_operands:]
# invars now = *consts, *ins, *outs
avals = tuple(v.aval for v in invars)
# TODO(necula): we should not need block_operand_shapes anymore
block_operand_shapes = (
*in_shapes[grid_mapping.num_index_operands:],
*out_shapes,
)
assert len(block_operand_shapes) == len(grid_mapping.block_mappings)
for i, (full_ty, bm, aval) in enumerate(
zip(block_operand_shapes, grid_mapping.block_mappings, avals)
):
for i, bm in enumerate(grid_mapping.block_mappings):
func_name = f"transform_{i}"
# ANY operands don't support windowing and require empty window_params.
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY:
# We may not require windowing if our block_shape matches the original
# shape or the dimensions are mapped.
requires_windowing = any(
b != s
for b, s in zip(bm.block_shape, full_ty.shape)
for b, s in zip(bm.block_shape, bm.array_shape_dtype.shape)
if not (b is pl_core.mapped and s == 1)
)
if np.prod(grid) != 1:
Expand All @@ -492,7 +472,7 @@ def lower_jaxpr_to_module(
mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
aval,
bm.block_aval,
name=func_name,
mosaic_grid_mapping=mosaic_grid_mapping,
for_verification=for_verification,
Expand All @@ -503,7 +483,7 @@ def lower_jaxpr_to_module(
]
# If we have an extended dtype, we need to add the block shape for the
# remaining physical dtype.
block_shape += list(_get_aval_physical_dtype_shape(aval.inner_aval))
block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval))
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
block_params = dict(
window_bounds=window_shape,
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def pallas_call_tpu_lowering_rule(
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
if debug:
print(jaxpr)
if "mosaic_params" in compiler_params:
Expand Down Expand Up @@ -117,7 +115,9 @@ def lower_module(for_verification: bool):
(a[0] + num_dyn_bounds + num_extra_args, a[1])
for a in input_output_aliases
)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype)
for bm in grid_mapping.block_mappings_output]

if promela_dump_path := _DUMP_PROMELA_TO.value:
num_devices = 1 if mesh is None else mesh.devices.size
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def lower_jaxpr_to_module(
name: str,
compiler_params: dict[str, Any],
) -> LoweringResult:
in_structs = grid_mapping.in_shapes
in_structs = tuple(grid_mapping.in_shapes)
out_structs = grid_mapping.out_shapes
assert len(jaxpr.outvars) == 0
assert not grid_mapping.vmapped_dims
Expand Down
Loading

0 comments on commit 8b6d346

Please sign in to comment.