diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst index d63e748cdce9..d250d682d32a 100644 --- a/docs/jax.experimental.pallas.rst +++ b/docs/jax.experimental.pallas.rst @@ -10,6 +10,7 @@ Classes :toctree: _autosummary BlockSpec + GridSpec Slice Functions @@ -34,4 +35,4 @@ Functions atomic_or atomic_xchg - debug_print \ No newline at end of file + debug_print diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 30106a5a7e43..9d075e585cd0 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -17,6 +17,13 @@ Remember to align the itemized text with the first line of an item within a list * {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to be passed *before* `index_map`. The old argument order is deprecated and will be removed in a future release. + * {class}`jax.experimental.pallas.GridSpec` does not have anymore the `in_specs_tree`, + and the `out_specs_tree` fields, and the `in_specs` and `out_specs` tree now + store the values as pytrees of BlockSpec. Previously, `in_specs` and + `out_specs` were flattened ({jax-issue}`#22552`). + * The method `compute_index` of {class}`jax.experimental.pallas.GridSpec` has + been removed because it is private. Similarly, the `get_grid_mapping` and + `unzip_dynamic_bounds` have been removed from `BlockSpec` ({jax-issue}`#22593`). * Fixed the interpreter mode to work with BlockSpec that involve padding ({jax-issue}`#22275`). Padding in interpreter mode will be with NaN, to help debug out-of-bounds diff --git a/jax/_src/core.py b/jax/_src/core.py index 87ede1e54862..b09f51c88cc4 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1705,7 +1705,7 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize: """Canonicalizes and checks for errors in a user-provided shape dimension value. Args: - f: a Python value that represents a dimension. + d: a Python value that represents a dimension. Returns: A canonical dimension value. diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index a44052d720bc..4f0b07fb6ffb 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -15,12 +15,13 @@ """Module for pallas-core functionality.""" from __future__ import annotations -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence import contextlib import copy import dataclasses import enum import functools +import itertools import threading from typing import Any, Hashable, Union import warnings @@ -55,6 +56,10 @@ def __repr__(self): Grid = Union[NamedGrid, TupleGrid] StaticGrid = tuple[int, ...] GridMappingGrid = tuple[int | DynamicGridDim, ...] + +# Pytrees of jax.ShapeDtypeStruct +ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...] + split_list = util.split_list map, unsafe_map = util.safe_map, map @@ -202,14 +207,13 @@ def __repr__(self): IndexingMode = Union[Blocked, Unblocked] -@dataclasses.dataclass(unsafe_hash=True) +@dataclasses.dataclass class BlockSpec: - """Specifies how an array should be sliced for each iteration of a kernel. + """Specifies how an array should be sliced for each invocation of a kernel. See :ref:`pallas_blockspec` for more details. - This object contains the parameters passed through the API. - An internal canonicalized version is in BlockMapping. """ + # An internal canonicalized version is in BlockMapping. block_shape: tuple[int | None, ...] | None = None index_map: Callable[..., Any] | None = None memory_space: Any | None = dataclasses.field(kw_only=True, default=None) @@ -242,22 +246,25 @@ def __init__( self.memory_space = memory_space self.indexing_mode = indexing_mode - def compute_index(self, *args): - assert self.index_map is not None - out = self.index_map(*args) - if not isinstance(out, tuple): - out = (out,) - return out + +def compute_index(bs: BlockSpec, *args): + assert bs.index_map is not None + out = bs.index_map(*args) + if not isinstance(out, tuple): + out = (out,) + return out + class NoBlockSpec: - pass + def __repr__(self): + return "NoBlockSpec" no_block_spec = NoBlockSpec() # A PyTree of BlockSpec | NoBlockSpec. +# BlockSpecTree = Sequence[BlockSpec | NoBlockSpec, ...] | NoBlockSpec BlockSpecTree = Any - @dataclasses.dataclass(frozen=True) class BlockMapping: """An internal canonicalized version of BlockSpec. @@ -399,6 +406,7 @@ def check_invariants(self) -> None: if self.grid_names is not None: assert len(self.grid) == len(self.grid_names), (self.grid, self.grid_names) + for bm in self.block_mappings: bm.check_invariants() assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), ( @@ -439,45 +447,52 @@ def trace_env(self): @property def slice_index_ops(self): - """Returns a slice object to select the index operands to a kernel.""" + """Returns a slice object to select the index operands to a kernel. + This works on a sequence that contains *index, *consts, *ins, *outs, *scratch. + """ return slice(0, self.num_index_operands) @property def slice_block_ops(self): - """Returns a slice to select all but the index operands to a kernel.""" + """Returns a slice to select all but the index operands to a kernel. + This works on a sequence that contains *index, *consts, *ins, *outs, *scratch. + """ return slice(self.num_index_operands, None) @property def slice_scratch_ops(self): - """Returns a slice object to select the scratch operands to a kernel.""" + """Returns a slice object to select the scratch operands to a kernel. + This works on a sequence that contains *index, *consts, *ins, *outs, *scratch. + """ if self.num_scratch_operands: return slice(-self.num_scratch_operands, None) else: return slice(0, 0) - # TODO(necula): this is used to recover the old `in_shapes`, but it probably - # is not needed anymore, with some cleanup. @property - def in_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]: + def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: """The shapes of *index, *consts, *inputs.""" - index_shapes = [jax.ShapeDtypeStruct(ia.inner_aval.shape, + index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape, ia.inner_aval.dtype) - for ia in self.index_map_avals[len(self.grid):]] - consts_inputs_shapes = [ + for ia in self.index_map_avals[len(self.grid):]) + consts_inputs_shapes = ( bm.array_shape_dtype for bm in self.block_mappings[ - :self.num_constant_operands + self.num_inputs]] - return tuple(index_shapes + consts_inputs_shapes) + :self.num_constant_operands + self.num_inputs]) + return itertools.chain(index_shapes, consts_inputs_shapes) - # TODO(necula): this is used to recover the old `out_shapes`, but it probably - # is not needed anymore, with some cleanup. @property - def out_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]: + def block_mappings_output(self) -> Iterable[BlockMapping]: + return itertools.islice( + self.block_mappings, + self.num_constant_operands + self.num_inputs, + self.num_constant_operands + self.num_inputs + self.num_outputs) + + @property + def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: return tuple( - bm.array_shape_dtype - for bm in self.block_mappings[ - self.num_constant_operands + self.num_inputs: - self.num_constant_operands + self.num_inputs + self.num_outputs]) + bm.array_shape_dtype for bm in self.block_mappings_output) + def _is_valid_grid_dim(dim: int | jax.Array) -> bool: if isinstance(dim, jax.Array): @@ -500,9 +515,9 @@ def _convert_block_spec_to_block_mapping( if block_spec is no_block_spec: block_spec = BlockSpec(None, None) if block_spec.index_map is None: - compute_index = lambda *args: (0,) * len(array_aval.shape) + index_map_func = lambda *args: (0,) * len(array_aval.shape) else: - compute_index = block_spec.compute_index + index_map_func = functools.partial(compute_index, block_spec) if block_spec.block_shape is None: block_shape = array_aval.shape else: @@ -522,7 +537,7 @@ def _convert_block_spec_to_block_mapping( "dynamically-shaped blocks. " f"{origin} has block_shape: {block_aval.shape}") - flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), + flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(index_map_func), index_map_tree) with tracing_grid_env(grid, mapped_dims): jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun, @@ -559,16 +574,21 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None shape = tuple(s for s in block_shape if s is not None) return ref.update(inner_aval=ref.inner_aval.update(shape=shape)) -@dataclasses.dataclass(init=False, unsafe_hash=True) +index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) + +@dataclasses.dataclass(init=False) class GridSpec: - """Encodes the parameters of the grid, as given through the API. + """Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`. - An internal sanitized version is in GridMapping. + See the documentation for :func:`jax.experimental.pallas.pallas_call`, + and also :ref:`pallas_grids_and_blockspecs` for a more detailed + description of the parameters. """ + # A canonicalized internal version is in GridMapping. grid: TupleGrid grid_names: tuple[Hashable, ...] | None - in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec - out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec + in_specs: BlockSpecTree + out_specs: BlockSpecTree def __init__( self, @@ -601,149 +621,151 @@ def __init__( self.grid = grid # type: ignore self.grid_names = grid_names - def get_grid_mapping( - self, - in_avals: Sequence[jax_core.AbstractValue], - in_tree: tree_util.PyTreeDef, - in_paths: Sequence[tree_util.KeyPath], - out_avals: Sequence[jax_core.AbstractValue], - out_tree: tree_util.PyTreeDef, - out_paths: Sequence[tree_util.KeyPath], - num_scalar_prefetch: int = 0, - scratch_shapes: Sequence[Any] = (), - ) -> tuple[tuple[AbstractMemoryRef, ...], - GridMapping]: - assert all(i is None or isinstance(i, int) for i in self.grid) - grid_mapping_grid = tuple( - dynamic_grid_dim if d is None else d for d in self.grid - ) - # The inputs for the index maps - index_map_avals = ( - (jax_core.ShapedArray((), jnp.dtype("int32")),) * len(self.grid)) - index_map_tree = tree_util.tree_structure((index_map_avals, {})) - - if num_scalar_prefetch: - all_avals = tree_util.tree_unflatten(in_tree, in_avals) - scalar_avals, unflat_in_avals = split_list( - all_avals, [num_scalar_prefetch]) - flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals) - num_flat_scalar_prefetch = len(flat_scalar_avals) - scalar_ref_avals = [ - self._make_scalar_ref_aval(aval) - for aval in flat_scalar_avals] - jaxpr_scalar_ref_avals = tree_util.tree_unflatten( - scalar_tree, scalar_ref_avals) - in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals)) - index_map_tree = tree_util.tree_structure(((*index_map_avals, - *scalar_avals), {})) - index_map_avals = (*index_map_avals, *scalar_ref_avals) - del scalar_ref_avals, flat_scalar_avals, scalar_tree - del scalar_avals, unflat_in_avals, all_avals - else: - num_flat_scalar_prefetch = 0 - jaxpr_scalar_ref_avals = () - - if scratch_shapes: - flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( - scratch_shapes) - flat_scratch_avals = map(self._make_scratch_aval, flat_scratch_shapes) - num_flat_scratch_operands = len(flat_scratch_avals) - jaxpr_scratch_avals = tree_util.tree_unflatten( - scratch_tree, flat_scratch_avals) - if not isinstance(jaxpr_scratch_avals, (tuple, list)): - jaxpr_scratch_avals = (jaxpr_scratch_avals,) - del flat_scratch_avals, flat_scratch_shapes, scratch_tree - else: - num_flat_scratch_operands = 0 - jaxpr_scratch_avals = () - - if self.in_specs is not no_block_spec: - flat_in_specs, in_specs_tree = tree_util.tree_flatten(self.in_specs) - if in_specs_tree != in_tree: - raise ValueError( - pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree, - "inputs", in_tree)) - else: - flat_in_specs = [no_block_spec] * len(in_avals) - - in_block_mappings = map( - partial( - _convert_block_spec_to_block_mapping, - index_map_avals=index_map_avals, - index_map_tree=index_map_tree, - grid=grid_mapping_grid, - mapped_dims=(), - what="inputs", - ), - flat_in_specs, - in_paths[num_flat_scalar_prefetch:], - in_avals, - ) - - if self.out_specs is not no_block_spec: - flat_out_specs, out_specs_tree = tree_util.tree_flatten(self.out_specs) - if out_specs_tree != out_tree: - raise ValueError( - pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree, - "`out_shape`", out_tree)) - else: - flat_out_specs = [no_block_spec] * len(out_avals) - - out_block_mappings = map( - partial( - _convert_block_spec_to_block_mapping, - index_map_avals=index_map_avals, - index_map_tree=index_map_tree, - grid=grid_mapping_grid, - mapped_dims=(), - what="outputs", - ), - flat_out_specs, - out_paths, - out_avals, - ) - grid_mapping = GridMapping( - grid=grid_mapping_grid, # type: ignore[arg-type] - grid_names=self.grid_names, - block_mappings=(*in_block_mappings, *out_block_mappings), - index_map_avals=index_map_avals, # type: ignore[arg-type] - index_map_tree=index_map_tree, - vmapped_dims=(), - num_index_operands=num_flat_scalar_prefetch, - num_constant_operands=0, # Fixed up later - num_inputs=len(flat_in_specs), - num_outputs=len(flat_out_specs), - num_scratch_operands=num_flat_scratch_operands, - ) - grid_mapping.check_invariants() - in_ref_avals = [bm.block_aval for bm in in_block_mappings] - jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) - jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals) - out_ref_avals = [bm.block_aval for bm in out_block_mappings] - jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) - if not isinstance(jaxpr_out_avals, (tuple, list)): - jaxpr_out_avals = (jaxpr_out_avals,) - return (*jaxpr_in_avals, *jaxpr_out_avals, - *jaxpr_scratch_avals), grid_mapping - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: assert False # Not needed in GridSpec def _make_scalar_ref_aval(self, aval): assert False # Not needed in GridSpec - def unzip_dynamic_grid_bounds( - self, - ) -> tuple[GridSpec, tuple[Any, ...]]: - static_grid = tuple( - d if isinstance(d, int) else None for d in self.grid - ) - dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int)) - # We can't use dataclasses.replace, because our fields are incompatible - # with __init__'s signature. - static_self = copy.copy(self) - static_self.grid = static_grid # type: ignore - return static_self, dynamic_bounds + +def get_grid_mapping( + grid_spec: GridSpec, + in_avals: Sequence[jax_core.AbstractValue], + in_tree: tree_util.PyTreeDef, + in_paths: Sequence[tree_util.KeyPath], + out_avals: Sequence[jax_core.AbstractValue], + out_tree: tree_util.PyTreeDef, + out_paths: Sequence[tree_util.KeyPath], +) -> tuple[tuple[jax_core.AbstractValue, ...], + GridMapping]: + assert all(i is None or isinstance(i, int) for i in grid_spec.grid) + grid_mapping_grid = tuple( + dynamic_grid_dim if d is None else d for d in grid_spec.grid + ) + # The inputs for the index maps + index_map_avals = ( + (index_map_grid_aval,) * len(grid_spec.grid)) + index_map_tree = tree_util.tree_structure((index_map_avals, {})) + + num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0) + if num_scalar_prefetch: + all_avals = tree_util.tree_unflatten(in_tree, in_avals) + scalar_avals, unflat_in_avals = split_list( + all_avals, [num_scalar_prefetch]) + flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals) + num_flat_scalar_prefetch = len(flat_scalar_avals) + scalar_ref_avals = [ + grid_spec._make_scalar_ref_aval(aval) + for aval in flat_scalar_avals] + jaxpr_scalar_ref_avals = tree_util.tree_unflatten( + scalar_tree, scalar_ref_avals) + in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals)) + index_map_tree = tree_util.tree_structure(((*index_map_avals, + *scalar_avals), {})) + index_map_avals = (*index_map_avals, *scalar_ref_avals) + del scalar_ref_avals, flat_scalar_avals, scalar_tree + del scalar_avals, unflat_in_avals, all_avals + else: + num_flat_scalar_prefetch = 0 + jaxpr_scalar_ref_avals = () + + scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ()) + if scratch_shapes: + flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( + scratch_shapes) + flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes) + num_flat_scratch_operands = len(flat_scratch_avals) + jaxpr_scratch_avals = tree_util.tree_unflatten( + scratch_tree, flat_scratch_avals) + if not isinstance(jaxpr_scratch_avals, (tuple, list)): + jaxpr_scratch_avals = (jaxpr_scratch_avals,) + del flat_scratch_avals, flat_scratch_shapes, scratch_tree + else: + num_flat_scratch_operands = 0 + jaxpr_scratch_avals = () + + if grid_spec.in_specs is not no_block_spec: + flat_in_specs, in_specs_tree = tree_util.tree_flatten(grid_spec.in_specs) + if in_specs_tree != in_tree: + raise ValueError( + pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree, + "inputs", in_tree)) + else: + flat_in_specs = [no_block_spec] * len(in_avals) + + in_block_mappings = map( + partial( + _convert_block_spec_to_block_mapping, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + grid=grid_mapping_grid, + mapped_dims=(), + what="inputs", + ), + flat_in_specs, + in_paths[num_flat_scalar_prefetch:], + in_avals, + ) + + if grid_spec.out_specs is not no_block_spec: + flat_out_specs, out_specs_tree = tree_util.tree_flatten(grid_spec.out_specs) + if out_specs_tree != out_tree: + raise ValueError( + pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree, + "`out_shape`", out_tree)) + else: + flat_out_specs = [no_block_spec] * len(out_avals) + + out_block_mappings = map( + partial( + _convert_block_spec_to_block_mapping, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + grid=grid_mapping_grid, + mapped_dims=(), + what="outputs", + ), + flat_out_specs, + out_paths, + out_avals, + ) + grid_mapping = GridMapping( + grid=grid_mapping_grid, # type: ignore[arg-type] + grid_names=grid_spec.grid_names, + block_mappings=(*in_block_mappings, *out_block_mappings), + index_map_avals=index_map_avals, # type: ignore[arg-type] + index_map_tree=index_map_tree, + vmapped_dims=(), + num_index_operands=num_flat_scalar_prefetch, + num_constant_operands=0, # Fixed up later + num_inputs=len(flat_in_specs), + num_outputs=len(flat_out_specs), + num_scratch_operands=num_flat_scratch_operands, + ) + grid_mapping.check_invariants() + in_ref_avals = [bm.block_aval for bm in in_block_mappings] + jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) + jaxpr_in_avals = (*jaxpr_scalar_ref_avals, + *jaxpr_in_ref_avals) + out_ref_avals = [bm.block_aval for bm in out_block_mappings] + jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) + if not isinstance(jaxpr_out_avals, (tuple, list)): + jaxpr_out_avals = (jaxpr_out_avals,) + return (*jaxpr_in_avals, *jaxpr_out_avals, + *jaxpr_scratch_avals), grid_mapping + + +def unzip_dynamic_grid_bounds( + grid_spec: GridSpec) -> tuple[GridSpec, tuple[Any, ...]]: + static_grid = tuple( + d if isinstance(d, int) else None for d in grid_spec.grid + ) + dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int)) + # We can't use dataclasses.replace, because our fields are incompatible + # with __init__'s signature. + static_self = copy.copy(grid_spec) + static_self.grid = static_grid # type: ignore + return static_self, dynamic_bounds def pytreedef_mismatch_err_msg( diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 2f02eba5e5be..75e5101de142 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -67,7 +67,7 @@ class barrier_semaphore(semaphore_dtype): pass class AbstractSemaphoreTyRules: @staticmethod def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), jnp.dtype('int32')) + return pallas_core.index_map_grid_aval class AbstractSemaphoreTy(dtypes.ExtendedDType): name: str @@ -145,8 +145,8 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): grid: TupleGrid grid_names: tuple[Hashable, ...] | None num_scalar_prefetch: int - in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec - out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec + in_specs: pallas_core.BlockSpecTree + out_specs: pallas_core.BlockSpecTree scratch_shapes: tuple[Any, ...] def __init__( @@ -173,14 +173,6 @@ def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: raise ValueError(f"No registered conversion for {type(obj)}. " "Only VMEM and SemaphoreType are supported.") - def get_grid_mapping( # type: ignore[override] - self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths - ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: - return super().get_grid_mapping(in_avals, in_tree, in_paths, - out_avals, out_tree, out_paths, - num_scalar_prefetch=self.num_scalar_prefetch, - scratch_shapes=self.scratch_shapes) - @dataclasses.dataclass(frozen=True) class TensorCore: diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b55afa4b8eab..14ea00e624c4 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -262,6 +262,7 @@ def _get_arg_type( memory_space = TPUMemorySpace.VMEM if isinstance(aval, tpu_core.AbstractSemaphore): return aval_to_ir_type(aval), None + # TODO(necula): clean this None block_mapping if block_mapping is None: return aval_to_ir_type(aval, memory_space=memory_space), aval.shape shape = tuple(1 if b is pl_core.mapped else b for b in block_mapping.block_shape) @@ -296,6 +297,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping, self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings self.mapped_dims = grid_mapping.vmapped_dims + # TODO(necula): clean this using new grid_mapping helpers num_scalar_prefetch = grid_mapping.num_index_operands num_scratch = grid_mapping.num_scratch_operands # jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch] @@ -348,7 +350,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping, for aval in scratch_avals ) self.grid_types, _ = unzip2([ - _get_arg_type(jax_core.ShapedArray((), jnp.int32), None) + _get_arg_type(pl_core.index_map_grid_aval, None) for _ in range(len(self.grid)) ]) self._prepare_mesh_info(mesh) @@ -432,9 +434,6 @@ def lower_jaxpr_to_module( mesh: mesh_lib.Mesh | None = None, for_verification: bool = False, ) -> tuple[Module, tuple[Any, ...]]: - # TODO(necula): cleanup - in_shapes = grid_mapping.in_shapes - out_shapes = grid_mapping.out_shapes for bm in grid_mapping.block_mappings: def err_details(): return (f"Block spec for {bm.origin} has block shape " @@ -510,33 +509,17 @@ def err_details(): window_params = [] grid = mosaic_grid_mapping.grid if grid: - invars = jaxpr.invars - if grid_mapping.num_scratch_operands > 0: - invars = invars[ - grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands] - else: - invars = invars[grid_mapping.num_index_operands:] - # invars now = *consts, *ins, *outs - avals = tuple(v.aval for v in invars) - # TODO(necula): we should not need block_operand_shapes anymore - block_operand_shapes = ( - *in_shapes[grid_mapping.num_index_operands:], - *out_shapes, - ) - assert len(block_operand_shapes) == len(grid_mapping.block_mappings) - for i, (full_ty, bm, aval) in enumerate( - zip(block_operand_shapes, grid_mapping.block_mappings, avals) - ): + for i, bm in enumerate(grid_mapping.block_mappings): func_name = f"transform_{i}" # ANY operands don't support windowing and require empty window_params. - if aval.memory_space == tpu_core.TPUMemorySpace.ANY: + if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY: # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) continue mlir_func = lower_jaxpr_to_transform_func( ctx, bm.index_map_jaxpr.jaxpr, - aval, + bm.block_aval, name=func_name, mosaic_grid_mapping=mosaic_grid_mapping, for_verification=for_verification, @@ -547,7 +530,7 @@ def err_details(): ] # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. - block_shape += list(_get_aval_physical_dtype_shape(aval.inner_aval)) + block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval)) window_shape = ir.DenseI64ArrayAttr.get(block_shape) block_params = dict( window_bounds=window_shape, @@ -941,7 +924,7 @@ def _make_index(s): def _maybe_cast_to_index(cast_to_index, x): if cast_to_index: return _make_index(x) - return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32)) + return _ensure_mlir_value(x, aval=pl_core.index_map_grid_aval) def _index_to_start_size_stride( @@ -2156,9 +2139,8 @@ def _run_body(i, args): if unroll != 1: raise NotImplementedError( f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.") - i32 = jax_core.ShapedArray((), jnp.int32) - lbd = _ensure_mlir_value(start, i32) - ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, i32)) + lbd = _ensure_mlir_value(start, pl_core.index_map_grid_aval) + ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pl_core.index_map_grid_aval)) step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))) for_op = scf.ForOp(lbd, ubd, step, args) with ir.InsertionPoint(for_op.body): @@ -2626,8 +2608,8 @@ def _linearize_mesh_indices(*indices): return sum(a * b for a, b in zip(indices, mesh_strides)) lower_ctx = LoweringRuleContext( lowering_context=ctx.lowering_context, - avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids), - avals_out=[jax_core.ShapedArray((), jnp.int32)], + avals_in=[pl_core.index_map_grid_aval] * len(device_ids), + avals_out=[pl_core.index_map_grid_aval], block_shapes=(None,) * len(device_ids), ) return lower_fun(_linearize_mesh_indices, multiple_results=False)( diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index b4c52c49026d..862ab5ff6463 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -74,8 +74,6 @@ def pallas_call_tpu_lowering_rule( compiler_params: dict[str, Any]): """Lowers a pallas_call to a Mosaic TPU custom call.""" del interpret - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes if debug: print(jaxpr) if "mosaic_params" in compiler_params: @@ -118,7 +116,9 @@ def lower_module(for_verification: bool): (a[0] + num_dyn_bounds + num_extra_args, a[1]) for a in input_output_aliases ) - out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] + out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype) + for bm in grid_mapping.block_mappings_output] if promela_dump_path := _DUMP_PROMELA_TO.value: num_devices = 1 if mesh is None else mesh.devices.size diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 1935f89f1699..5cee5d2779d7 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -276,7 +276,7 @@ def block_shape(self): @property def compute_index(self): - return self.spec.compute_index + return lambda *args: pallas_core.compute_index(self.spec, *args) @property def memory_space(self): diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 4a349c8f71fa..860dc2d6951b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -147,7 +147,7 @@ def lower_jaxpr_to_module( name: str, compiler_params: dict[str, Any], ) -> LoweringResult: - in_structs = grid_mapping.in_shapes + in_structs = tuple(grid_mapping.in_shapes) out_structs = grid_mapping.out_shapes assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index a52c8ec6a862..38523e7c66cf 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,7 +15,7 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial, reduce import itertools from typing import Any @@ -53,6 +53,7 @@ zip, unsafe_zip = safe_zip, zip Grid = pallas_core.Grid +TupleGrid = pallas_core.TupleGrid GridSpec = pallas_core.GridSpec BlockMapping = pallas_core.BlockMapping GridMapping = pallas_core.GridMapping @@ -118,14 +119,16 @@ def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]: return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals) def _initialize_output_vals( - out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]: + block_mappings_output: Iterable[BlockMapping], + input_args, input_output_aliases) -> Sequence[jax.Array]: oi_map = {v: k for k, v in input_output_aliases} output_vals = [] - for i, out_shape in enumerate(out_shapes): + for i, bm in enumerate(block_mappings_output): if i in oi_map: output_vals.append(input_args[oi_map[i]]) else: - output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype)) + output_vals.append(uninitialized_value(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype)) return output_vals def _logical_to_interpret_mode_dtype(dtype): @@ -171,8 +174,6 @@ def _pallas_call_impl_interpret( grid_mapping: GridMapping, compiler_params: Any): del compiler_params, name - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes # If we're in interpreter mode, we *scan* over the grid and eval the # discharged jaxpr. dynamic_grid_args, args = split_list( # type: ignore @@ -189,16 +190,18 @@ def _pallas_call_impl_interpret( discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if debug: print(discharged_jaxpr) - out = _initialize_output_vals(out_shapes, args, input_output_aliases) - scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore + out = _initialize_output_vals(grid_mapping.block_mappings_output, + args, input_output_aliases) + scalars = args[grid_mapping.slice_index_ops] + block_args = args[len(scalars):] # invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch] - # args now contains: *consts, *inputs, *outputs + # block_args now contains: *consts, *inputs, *outputs scratch_invars = jaxpr.invars[grid_mapping.slice_scratch_ops] scratch_avals = [v.aval for v in scratch_invars] scratch_values = _initialize_scratch_vals(scratch_avals) carry = [] - for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings): + for x, bm in zip(itertools.chain(block_args, out), grid_mapping.block_mappings): if isinstance(bm.indexing_mode, pallas_core.Unblocked): padding = bm.indexing_mode.padding if padding is not None and any(p != (0, 0) for p in padding): @@ -224,7 +227,7 @@ def _pallas_call_impl_interpret( carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) - num_inout = len(args) + len(out) + num_inout_blocks = len(block_args) + len(out) grid_start_indices = (jnp.int32(0),) * len(grid) if grid: num_iterations = reduce(jnp.multiply, grid) @@ -239,19 +242,19 @@ def cond(carry): i, *_ = carry return i < num_iterations def body(carry): - i, loop_idx, *carry = carry + i, loop_idx, *carry_blocks = carry local_grid_env = tuple( pallas_core.GridAxis(idx, b) for dim, (idx, b) in enumerate(zip(loop_idx, grid)) if dim not in grid_mapping.vmapped_dims ) - carry, scratch = split_list(carry, [num_inout]) + carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): start_indices = [ None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) for bm in grid_mapping.block_mappings] - blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry, - is_indexing_dim) + blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, + carry_consts_ins, is_indexing_dim) with pallas_core.grid_env(local_grid_env): assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len( scratch_values @@ -263,20 +266,21 @@ def body(carry): ) blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch) - blocks = blocks[grid_mapping.num_index_operands:] - blocks, out_scratch = split_list(blocks, [num_inout]) - carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes, - carry, blocks, is_indexing_dim) - return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch) + + _, out_inout, out_scratch = split_list( + blocks, [grid_mapping.num_index_operands, num_inout_blocks]) + out_carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes, + carry_consts_ins, out_inout, is_indexing_dim) + return (i + 1, _get_next_indices(grid, loop_idx), + *out_carry, *out_scratch) (_, _, *carry) = lax.while_loop( cond, body, (jnp.int32(0), grid_start_indices, *carry) ) - _, out, _ = split_list(carry, [len(args), len(out)]) - assert len(grid_mapping.block_mappings) == len(args) + len(out) - out_block_mappings = grid_mapping.block_mappings[len(args):] + + out_out = carry[len(block_args):len(block_args) + len(out)] out_nopad = [] - for o, expected_o_shape, bm in zip(out, out_shapes, out_block_mappings): + for o, bm in zip(out_out, grid_mapping.block_mappings_output): if isinstance(bm.indexing_mode, pallas_core.Unblocked): padding = bm.indexing_mode.padding if padding is not None and any(p != (0, 0) for p in padding): @@ -285,23 +289,22 @@ def body(carry): pad_low, pad_high = zip(*padding) limit_indices = [s - p for s, p in zip(o.shape, pad_high)] o = lax.slice(o, pad_low, limit_indices) - if o.shape != expected_o_shape.shape: - o = lax.slice(o, (0,) * o.ndim, expected_o_shape.shape) + if o.shape != bm.array_shape_dtype.shape: + o = lax.slice(o, (0,) * o.ndim, bm.array_shape_dtype.shape) out_nopad.append(o) return out_nopad pallas_call_p.def_impl(_pallas_call_impl) -def _pallas_call_abstract_eval(*avals, grid_mapping, **_): - out_shapes = grid_mapping.out_shapes - return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes) +def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_): + return tuple(jax_core.ShapedArray(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype) + for bm in grid_mapping.block_mappings_output) pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping, debug, interpret, compiler_params: Any): - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError("interpret with dynamic grid bounds unsupported") if grid_mapping.num_index_operands: @@ -310,7 +313,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, raise NotImplementedError("JVP with aliasing not supported.") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] tangents = [t for t in tangents if type(t) is not ad_util.Zero] - nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes) + nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts @@ -322,7 +325,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, # compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around # the jaxpr's invars. primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list( - jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)] + jvp_jaxpr.invars, [len(primals), grid_mapping.num_outputs, len(tangents)] ) invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs) effs = [] @@ -335,6 +338,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs) if debug: print(jvp_jaxpr) + # TODO(necula): does this work with consts? in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)]) jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms) jvp_grid_mapping = grid_mapping.replace( @@ -369,8 +373,7 @@ def _block_map_function(new_idx, *args): if dim is not batching.not_mapped: indices.insert(dim, new_idx) return tuple(indices) - i32_aval = jax_core.ShapedArray((), jnp.int32) - idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals] + idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) @@ -444,8 +447,6 @@ def _batch_with_explicit_loop( to the current iteration index and dynamic_updates an (initially empty) output allocation. """ - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes if not dims: raise NotImplementedError("vmapping pallas_call with no arguments.") @@ -465,10 +466,9 @@ def _batch_with_explicit_loop( # The output arrays are completelly overwritten, so we can just initialize # empty arrays. initial_state = [ - jnp.empty( - tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype - ) - for out_shape in out_shapes + jnp.empty(tuple_insert(bm.array_shape_dtype.shape, 0, axis_size), + dtype=bm.array_shape_dtype.dtype) + for bm in grid_mapping.block_mappings_output ] def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: @@ -528,8 +528,6 @@ def _pallas_call_batching_rule( interpret: bool, compiler_params: Any, ): - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped ) -> jax.Array: @@ -631,7 +629,7 @@ def _maybe_squeeze_out_bdim( args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size ) - all_dims = list(dims) + [0] * len(out_shapes) + all_dims = list(dims) + [0] * grid_mapping.num_outputs num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands @@ -647,10 +645,12 @@ def _maybe_squeeze_out_bdim( block_mappings, ) - index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(grid_mapping.index_map_avals) + index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten( + grid_mapping.index_map_avals) assert not index_map_tree_kwargs - batched_index_map_args = (jax_core.ShapedArray((), jnp.int32),) + index_map_tree_args - batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten((batched_index_map_args, {})) + batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args + batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( + (batched_index_map_args, {})) batched_grid_mapping = grid_mapping.replace( grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), @@ -706,18 +706,12 @@ def pallas_call_checkify_rule(error: checkify.Error, # 4) Create block specs for the error state and call pallas_call with # the new kernel. dynamic_grid_bounds, scalars, args = split_list( # type: ignore - args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands] + args, [grid_mapping.num_dynamic_grid_bounds, + grid_mapping.num_index_operands] ) num_scalars = len(scalars) - num_invars = len(jaxpr.invars) - num_inputs_outputs = ( - num_invars - - grid_mapping.num_index_operands - - grid_mapping.num_scratch_operands - ) num_kernel_inputs = len(args) - num_scratch = num_invars - num_inputs_outputs - num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs + num_kernel_outputs = grid_mapping.num_outputs # Trace the jaxpr to get an initial error value so the kernel jaxpr has all of # the required inputs. @@ -989,11 +983,11 @@ def pallas_call( out_shape: Any, *, grid_spec: GridSpec | None = None, - debug: bool = False, - grid: Grid = (), + grid: TupleGrid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, input_output_aliases: dict[int, int] = {}, + debug: bool = False, interpret: bool = False, name: str | None = None, compiler_params: dict[str, Any] | None = None, @@ -1008,9 +1002,8 @@ def pallas_call( corresponding ``in_specs`` and ``out_specs``. out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape and dtypes of the outputs. - grid_spec: TO BE DOCUMENTED. - debug: if True, Pallas prints various intermediate forms of the kernel - as it is being processed. + grid_spec: An alternative way to specify ``grid``, ``in_specs``, and + ``out_specs``. If given, those other parameters must not be also given. grid: the iteration space, as a tuple of integers. The kernel is executed as many times as ``prod(grid)``. See details at :ref:`pallas_grid`. @@ -1027,6 +1020,8 @@ def pallas_call( input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs. + debug: if True, Pallas prints various intermediate forms of the kernel + as it is being processed. interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the grid whose body is the kernel lowered as a JAX function. This does not require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. @@ -1059,7 +1054,7 @@ def pallas_call( "If `grid_spec` is specified, then `out_specs` must " f"be `no_block_spec`. It is {out_specs}") del grid, in_specs, out_specs - grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds() + grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage # but it is lossy, because it prevents expressing functions that return # lists. @@ -1078,7 +1073,8 @@ def wrapped(*args): flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) for v in flat_out_shapes) # TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc. - kernel_avals, grid_mapping = grid_spec.get_grid_mapping( + kernel_avals, grid_mapping = pallas_core.get_grid_mapping( + grid_spec, flat_in_avals, in_tree, in_paths, flat_out_avals, out_tree, out_paths) flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 173bffdff2e4..440e46ac3fee 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -250,7 +250,6 @@ def _new_ir_context() -> ir.Context: def lower_jaxpr_to_triton_module( jaxpr: jax_core.Jaxpr, - in_out_shapes, grid_mapping: GridMapping, name: str, platform: str @@ -313,23 +312,22 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "Scalar prefetch not supported in Triton lowering." ) - for bm in grid_mapping.block_mappings: - if not isinstance(bm.indexing_mode, Blocked): - raise NotImplementedError( - "Only Blocked indexing mode is supported in Triton lowering." - ) + if not all(isinstance(bm.indexing_mode, Blocked) + for bm in grid_mapping.block_mappings): + raise NotImplementedError( + "Only Blocked indexing mode is supported in Triton lowering." + ) start_indices = map( functools.partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings, ) block_infos = [ BlockInfo( - jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype), + block_mapping.array_shape_dtype, start_idx, block_mapping.block_shape, ) - for shape_dtype, block_mapping, start_idx in zip( - in_out_shapes, + for block_mapping, start_idx in zip( grid_mapping.block_mappings, start_indices, ) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index c0b12b549b28..4bc71a0441ae 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -50,9 +50,14 @@ def pallas_call_lowering( compiler_params: dict[str, Any], ): del interpret - # TODO(necula): cleanup - in_shapes = grid_mapping.in_shapes - out_shapes = grid_mapping.out_shapes + if grid_mapping.num_dynamic_grid_bounds: + raise NotImplementedError( + "dynamic grid bounds not supported in the Triton backend" + ) + if grid_mapping.num_index_operands: + raise NotImplementedError( + "scalar prefetch not implemented in the Triton backend" + ) triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.pop("num_warps", 4) [lowering_platform] = ctx.platforms or ctx.module_context.platforms @@ -66,7 +71,7 @@ def pallas_call_lowering( print(grid_mapping) lowering_result = lowering.lower_jaxpr_to_triton_module( - jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, lowering_platform + jaxpr, grid_mapping, name, lowering_platform ) module_op = lowering_result.module.operation if debug: @@ -74,8 +79,9 @@ def pallas_call_lowering( grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) out_types = [ - ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) - for shape in out_shapes + ir.RankedTensorType.get(bm.array_shape_dtype.shape, + mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype)) + for bm in grid_mapping.block_mappings_output ] buf = io.BytesIO() module_op.write_bytecode(buf) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 4b25476362fa..e6bde1924e61 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -25,6 +25,7 @@ from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import Unblocked from jax._src.pallas.core import unblocked +from jax._src.pallas.core import GridSpec from jax._src.pallas.pallas_call import pallas_call from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.primitives import atomic_add