From 8b6d346b65ab5812d0916ef23305d3ac4b26f3e7 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 23 Jul 2024 15:25:14 +0300 Subject: [PATCH] [pallas] More simplification of grid mapping and calling convention In previous PR #22552 I have expanded `GridMapping` to encode more parts of the calling convention. Here we use that new functionality and clean up some code. --- jax/_src/pallas/core.py | 39 ++++--- jax/_src/pallas/mosaic/lowering.py | 30 +----- .../pallas/mosaic/pallas_call_registration.py | 6 +- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/pallas_call.py | 101 +++++++++--------- jax/_src/pallas/triton/lowering.py | 16 ++- .../pallas/triton/pallas_call_registration.py | 10 +- 7 files changed, 86 insertions(+), 118 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index d5133f4a47c7..010713f2f07c 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -15,12 +15,13 @@ """Module for pallas-core functionality.""" from __future__ import annotations -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence import contextlib import copy import dataclasses import enum import functools +import itertools import threading from typing import Any, Hashable, Union import warnings @@ -425,11 +426,6 @@ def slice_index_ops(self): """Returns a slice object to select the index operands to a kernel.""" return slice(0, self.num_index_operands) - @property - def slice_block_ops(self): - """Returns a slice to select all but the index operands to a kernel.""" - return slice(self.num_index_operands, None) - @property def slice_scratch_ops(self): """Returns a slice object to select the scratch operands to a kernel.""" @@ -438,30 +434,31 @@ def slice_scratch_ops(self): else: return slice(0, 0) - # TODO(necula): this is used to recover the old `in_shapes`, but it probably - # is not needed anymore, with some cleanup. @property - def in_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]: + def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: """The shapes of *index, *consts, *inputs.""" - index_shapes = [jax.ShapeDtypeStruct(ia.inner_aval.shape, + index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape, ia.inner_aval.dtype) - for ia in self.index_map_avals[len(self.grid):]] - consts_inputs_shapes = [ + for ia in self.index_map_avals[len(self.grid):]) + consts_inputs_shapes = ( bm.array_shape_dtype for bm in self.block_mappings[ 0: - self.num_constant_operands + self.num_inputs]] - return tuple(index_shapes + consts_inputs_shapes) + self.num_constant_operands + self.num_inputs]) + return itertools.chain(index_shapes, consts_inputs_shapes) + + @property + def block_mappings_output(self) -> Iterable[BlockMapping]: + return itertools.islice( + self.block_mappings, + self.num_constant_operands + self.num_inputs, + self.num_constant_operands + self.num_inputs + self.num_outputs) - # TODO(necula): this is used to recover the old `out_shapes`, but it probably - # is not needed anymore, with some cleanup. @property - def out_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]: + def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: return tuple( - bm.array_shape_dtype - for bm in self.block_mappings[ - self.num_constant_operands + self.num_inputs: - self.num_constant_operands + self.num_inputs + self.num_outputs]) + bm.array_shape_dtype for bm in self.block_mappings_output) + def _is_valid_grid_dim(dim: int | jax.Array) -> bool: if isinstance(dim, jax.Array): diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index d67ad6838bc1..391fde228981 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -430,10 +430,6 @@ def lower_jaxpr_to_module( mesh: mesh_lib.Mesh | None = None, for_verification: bool = False, ) -> tuple[Module, tuple[Any, ...]]: - # TODO(necula): cleanup - in_shapes = grid_mapping.in_shapes - out_shapes = grid_mapping.out_shapes - mosaic_grid_mapping = MosaicGridMapping( jaxpr, grid_mapping, dimension_semantics, mesh) mosaic_grid_mapping.maybe_compress_grid() @@ -448,31 +444,15 @@ def lower_jaxpr_to_module( window_params = [] grid = mosaic_grid_mapping.grid if grid: - invars = jaxpr.invars - if grid_mapping.num_scratch_operands > 0: - invars = invars[ - grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands] - else: - invars = invars[grid_mapping.num_index_operands:] - # invars now = *consts, *ins, *outs - avals = tuple(v.aval for v in invars) - # TODO(necula): we should not need block_operand_shapes anymore - block_operand_shapes = ( - *in_shapes[grid_mapping.num_index_operands:], - *out_shapes, - ) - assert len(block_operand_shapes) == len(grid_mapping.block_mappings) - for i, (full_ty, bm, aval) in enumerate( - zip(block_operand_shapes, grid_mapping.block_mappings, avals) - ): + for i, bm in enumerate(grid_mapping.block_mappings): func_name = f"transform_{i}" # ANY operands don't support windowing and require empty window_params. - if aval.memory_space == tpu_core.TPUMemorySpace.ANY: + if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY: # We may not require windowing if our block_shape matches the original # shape or the dimensions are mapped. requires_windowing = any( b != s - for b, s in zip(bm.block_shape, full_ty.shape) + for b, s in zip(bm.block_shape, bm.array_shape_dtype.shape) if not (b is pl_core.mapped and s == 1) ) if np.prod(grid) != 1: @@ -492,7 +472,7 @@ def lower_jaxpr_to_module( mlir_func = lower_jaxpr_to_transform_func( ctx, bm.index_map_jaxpr.jaxpr, - aval, + bm.block_aval, name=func_name, mosaic_grid_mapping=mosaic_grid_mapping, for_verification=for_verification, @@ -503,7 +483,7 @@ def lower_jaxpr_to_module( ] # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. - block_shape += list(_get_aval_physical_dtype_shape(aval.inner_aval)) + block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval)) window_shape = ir.DenseI64ArrayAttr.get(block_shape) block_params = dict( window_bounds=window_shape, diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 951a3f158c3f..4945d50bd0a3 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -73,8 +73,6 @@ def pallas_call_tpu_lowering_rule( compiler_params: dict[str, Any]): """Lowers a pallas_call to a Mosaic TPU custom call.""" del interpret - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes if debug: print(jaxpr) if "mosaic_params" in compiler_params: @@ -117,7 +115,9 @@ def lower_module(for_verification: bool): (a[0] + num_dyn_bounds + num_extra_args, a[1]) for a in input_output_aliases ) - out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] + out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype) + for bm in grid_mapping.block_mappings_output] if promela_dump_path := _DUMP_PROMELA_TO.value: num_devices = 1 if mesh is None else mesh.devices.size diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index d90a92ab055d..02a52851af0d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -160,7 +160,7 @@ def lower_jaxpr_to_module( name: str, compiler_params: dict[str, Any], ) -> LoweringResult: - in_structs = grid_mapping.in_shapes + in_structs = tuple(grid_mapping.in_shapes) out_structs = grid_mapping.out_shapes assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index e9f897f46737..53c1b3a51c9e 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,7 +15,7 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial, reduce import itertools from typing import Any @@ -118,14 +118,16 @@ def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]: return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals) def _initialize_output_vals( - out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]: + block_mappings_output: Iterable[BlockMapping], + input_args, input_output_aliases) -> Sequence[jax.Array]: oi_map = {v: k for k, v in input_output_aliases} output_vals = [] - for i, out_shape in enumerate(out_shapes): + for i, bm in enumerate(block_mappings_output): if i in oi_map: output_vals.append(input_args[oi_map[i]]) else: - output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype)) + output_vals.append(uninitialized_value(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype)) return output_vals def _logical_to_interpret_mode_dtype(dtype): @@ -173,8 +175,6 @@ def _pallas_call_impl_interpret( grid_mapping: GridMapping, compiler_params: Any): del compiler_params, name - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes # If we're in interpreter mode, we *scan* over the grid and eval the # discharged jaxpr. dynamic_grid_args, args = split_list( # type: ignore @@ -191,16 +191,18 @@ def _pallas_call_impl_interpret( discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if debug: print(discharged_jaxpr) - out = _initialize_output_vals(out_shapes, args, input_output_aliases) - scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore + out = _initialize_output_vals(grid_mapping.block_mappings_output, + args, input_output_aliases) + scalars = args[grid_mapping.slice_index_ops] + block_args = args[len(scalars):] # invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch] - # args now contains: *consts, *inputs, *outputs + # block_args now contains: *consts, *inputs, *outputs scratch_invars = jaxpr.invars[grid_mapping.slice_scratch_ops] scratch_avals = [v.aval for v in scratch_invars] scratch_values = _initialize_scratch_vals(scratch_avals) carry = [] - for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings): + for x, bm in zip(itertools.chain(block_args, out), grid_mapping.block_mappings): if isinstance(bm.indexing_mode, pallas_core.Unblocked): padding = bm.indexing_mode.padding if padding is not None and any(p != (0, 0) for p in padding): @@ -226,7 +228,7 @@ def _pallas_call_impl_interpret( carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) - num_inout = len(args) + len(out) + num_inout_blocks = len(block_args) + len(out) grid_start_indices = (jnp.int32(0),) * len(grid) if grid: num_iterations = reduce(jnp.multiply, grid) @@ -241,19 +243,19 @@ def cond(carry): i, *_ = carry return i < num_iterations def body(carry): - i, loop_idx, *carry = carry + i, loop_idx, *carry_blocks = carry local_grid_env = tuple( pallas_core.GridAxis(idx, b) for dim, (idx, b) in enumerate(zip(loop_idx, grid)) if dim not in grid_mapping.vmapped_dims ) - carry, scratch = split_list(carry, [num_inout]) + carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): start_indices = [ None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) for bm in grid_mapping.block_mappings] - blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry, - is_indexing_dim) + blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, + carry_consts_ins, is_indexing_dim) with pallas_core.grid_env(local_grid_env): assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len( scratch_values @@ -265,20 +267,22 @@ def body(carry): ) blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch) - blocks = blocks[grid_mapping.num_index_operands:] - blocks, out_scratch = split_list(blocks, [num_inout]) - carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes, - carry, blocks, is_indexing_dim) - return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch) + + _, out_inout, out_scratch = split_list( + blocks, [grid_mapping.num_index_operands, num_inout_blocks]) + out_carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes, + carry_consts_ins, out_inout, is_indexing_dim) + return (i + 1, _get_next_indices(grid, loop_idx), + *out_carry, *out_scratch) (_, _, *carry) = lax.while_loop( cond, body, (jnp.int32(0), grid_start_indices, *carry) ) - _, out, _ = split_list(carry, [len(args), len(out)]) - assert len(grid_mapping.block_mappings) == len(args) + len(out) - out_block_mappings = grid_mapping.block_mappings[len(args):] + + out_out = carry[len(block_args):len(block_args) + len(out)] + out_block_mappings = grid_mapping.block_mappings_output out_nopad = [] - for o, expected_o_shape, bm in zip(out, out_shapes, out_block_mappings): + for o, bm in zip(out_out, out_block_mappings): if isinstance(bm.indexing_mode, pallas_core.Unblocked): padding = bm.indexing_mode.padding if padding is not None and any(p != (0, 0) for p in padding): @@ -287,23 +291,22 @@ def body(carry): pad_low, pad_high = zip(*padding) limit_indices = [s - p for s, p in zip(o.shape, pad_high)] o = lax.slice(o, pad_low, limit_indices) - if o.shape != expected_o_shape.shape: - o = lax.slice(o, (0,) * o.ndim, expected_o_shape.shape) + if o.shape != bm.array_shape_dtype.shape: + o = lax.slice(o, (0,) * o.ndim, bm.array_shape_dtype.shape) out_nopad.append(o) return out_nopad pallas_call_p.def_impl(_pallas_call_impl) -def _pallas_call_abstract_eval(*avals, grid_mapping, **_): - out_shapes = grid_mapping.out_shapes - return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes) +def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_): + return tuple(jax_core.ShapedArray(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype) + for bm in grid_mapping.block_mappings_output) pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping, debug, interpret, compiler_params: Any): - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError("interpret with dynamic grid bounds unsupported") if grid_mapping.num_index_operands: @@ -312,7 +315,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, raise NotImplementedError("JVP with aliasing not supported.") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] tangents = [t for t in tangents if type(t) is not ad_util.Zero] - nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes) + nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts @@ -324,7 +327,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, # compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around # the jaxpr's invars. primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list( - jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)] + jvp_jaxpr.invars, [len(primals), grid_mapping.num_outputs, len(tangents)] ) invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs) effs = [] @@ -337,6 +340,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs) if debug: print(jvp_jaxpr) + # TODO(necula): does this work with consts? in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)]) jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms) jvp_grid_mapping = grid_mapping.replace( @@ -447,8 +451,6 @@ def _batch_with_explicit_loop( to the current iteration index and dynamic_updates an (initially empty) output allocation. """ - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes if not dims: raise NotImplementedError("vmapping pallas_call with no arguments.") @@ -468,10 +470,9 @@ def _batch_with_explicit_loop( # The output arrays are completelly overwritten, so we can just initialize # empty arrays. initial_state = [ - jnp.empty( - tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype - ) - for out_shape in out_shapes + jnp.empty(tuple_insert(bm.array_shape_dtype.shape, 0, axis_size), + dtype=bm.array_shape_dtype.dtype) + for bm in grid_mapping.block_mappings_output ] def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: @@ -531,8 +532,6 @@ def _pallas_call_batching_rule( interpret: bool, compiler_params: Any, ): - # TODO(necula): cleanup - out_shapes = grid_mapping.out_shapes def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped ) -> jax.Array: @@ -634,7 +633,7 @@ def _maybe_squeeze_out_bdim( args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size ) - all_dims = list(dims) + [0] * len(out_shapes) + all_dims = list(dims) + [0] * grid_mapping.num_outputs num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands @@ -650,10 +649,12 @@ def _maybe_squeeze_out_bdim( block_mappings, ) - index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(grid_mapping.index_map_avals) + index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten( + grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (jax_core.ShapedArray((), jnp.int32),) + index_map_tree_args - batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten((batched_index_map_args, {})) + batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( + (batched_index_map_args, {})) batched_grid_mapping = grid_mapping.replace( grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), @@ -709,18 +710,12 @@ def pallas_call_checkify_rule(error: checkify.Error, # 4) Create block specs for the error state and call pallas_call with # the new kernel. dynamic_grid_bounds, scalars, args = split_list( # type: ignore - args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands] + args, [grid_mapping.num_dynamic_grid_bounds, + grid_mapping.num_index_operands] ) num_scalars = len(scalars) - num_invars = len(jaxpr.invars) - num_inputs_outputs = ( - num_invars - - grid_mapping.num_index_operands - - grid_mapping.num_scratch_operands - ) num_kernel_inputs = len(args) - num_scratch = num_invars - num_inputs_outputs - num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs + num_kernel_outputs = grid_mapping.num_outputs # Trace the jaxpr to get an initial error value so the kernel jaxpr has all of # the required inputs. diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 65300ec5f9d2..9760a8c82f0c 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -250,7 +250,6 @@ def _new_ir_context() -> ir.Context: def lower_jaxpr_to_triton_module( jaxpr: jax_core.Jaxpr, - in_out_shapes, grid_mapping: GridMapping, name: str, platform: str @@ -294,23 +293,22 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "Scalar prefetch not supported in Triton lowering." ) - for bm in grid_mapping.block_mappings: - if not isinstance(bm.indexing_mode, Blocked): - raise NotImplementedError( - "Only Blocked indexing mode is supported in Triton lowering." - ) + if not all(isinstance(bm.indexing_mode, Blocked) + for bm in grid_mapping.block_mappings): + raise NotImplementedError( + "Only Blocked indexing mode is supported in Triton lowering." + ) start_indices = map( functools.partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings, ) block_infos = [ BlockInfo( - jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype), + block_mapping.array_shape_dtype, start_idx, block_mapping.block_shape, ) - for shape_dtype, block_mapping, start_idx in zip( - in_out_shapes, + for block_mapping, start_idx in zip( grid_mapping.block_mappings, start_indices, ) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 052e27aae72d..6959a7936060 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -50,9 +50,6 @@ def pallas_call_lowering( compiler_params: dict[str, Any], ): del interpret - # TODO(necula): cleanup - in_shapes = grid_mapping.in_shapes - out_shapes = grid_mapping.out_shapes if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" @@ -74,7 +71,7 @@ def pallas_call_lowering( print(grid_mapping) lowering_result = lowering.lower_jaxpr_to_triton_module( - jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, lowering_platform + jaxpr, grid_mapping, name, lowering_platform ) module_op = lowering_result.module.operation if debug: @@ -82,8 +79,9 @@ def pallas_call_lowering( grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) out_types = [ - ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) - for shape in out_shapes + ir.RankedTensorType.get(bm.array_shape_dtype.shape, + mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype)) + for bm in grid_mapping.block_mappings_output ] buf = io.BytesIO() module_op.write_bytecode(buf)