diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index c088fbcfe1d3..ae5f267d2fbb 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -245,7 +245,7 @@ class BlockMapping: index_map_jaxpr: jax_core.ClosedJaxpr indexing_mode: IndexingMode - def compute_start_indices(self, loop_idx, *args): + def compute_start_indices_interpret(self, loop_idx, *args): discharged_jaxpr, discharged_consts = state_discharge.discharge_state( self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts ) @@ -344,6 +344,10 @@ def _convert_block_spec_to_block_mapping( f"{len(aval.shape)} values to match {block_shape=}. " f"Currently returning {len(out_avals)} values." ) + if consts: + raise NotImplementedError( + f"Index map for {what}{tree_util.keystr(path)} captures constants: " + f"{consts}") return BlockMapping( block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode ) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 8177c7f81a39..5bbeff2f1302 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -274,7 +274,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping, self.mapped_dims = grid_mapping.mapped_dims num_scalar_prefetch = grid_mapping.num_index_operands num_scratch = grid_mapping.num_scratch_operands - # jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch] + # jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch] num_operands = ( len(self.jaxpr.invars) - num_scalar_prefetch @@ -411,9 +411,13 @@ def lower_jaxpr_to_module( grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands] else: invars = invars[grid_mapping.num_index_operands:] + # invars now = *consts, *ins, *outs avals = tuple(v.aval for v in invars) block_operand_shapes = ( - *in_shapes[grid_mapping.num_index_operands :], + *[jax.ShapeDtypeStruct(v.aval.shape, + v.aval.dtype) + for v in invars[:grid_mapping.num_constant_operands]], + *in_shapes[grid_mapping.num_index_operands:], *out_shapes, ) assert len(block_operand_shapes) == len(grid_mapping.block_mappings) @@ -425,10 +429,9 @@ def lower_jaxpr_to_module( raise NotImplementedError( "BlockSpecs are required on TPU when grid is specified" ) - if bm.index_map_jaxpr.consts: - raise NotImplementedError("Index map jaxpr with consts not supported.") # ANY operands don't support windowing and require empty window_params. - if aval.memory_space == tpu_core.TPUMemorySpace.ANY: + if (hasattr(aval, "memory_space") and # happens when AbstractRef are created for consts + aval.memory_space == tpu_core.TPUMemorySpace.ANY): # We may not require windowing if our block_shape matches the original # shape or the dimensions are mapped. requires_windowing = any( diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 70d9cc469008..9e2486eb1581 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -76,6 +76,7 @@ def pallas_call_tpu_lowering_rule( compiler_params: dict[str, Any]): """Lowers a pallas_call to a Mosaic TPU custom call.""" if interpret: + # TODO(necula): is this branch still needed? return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, in_shapes=in_shapes, diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 740f0c31ebb7..6aca2b74ef04 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -42,6 +42,7 @@ def pallas_call_lowering( compiler_params: dict[str, Any], ): if interpret: + # TODO(necula): is this still needed? return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *args, @@ -68,7 +69,10 @@ def pallas_call_lowering( if debug: print(jaxpr) print(grid_mapping) - + if grid_mapping.num_constant_operands: + raise NotImplementedError( + "captured consts not supported in the Mosaic GPU backend" + ) lowering_result = lowering.lower_jaxpr_to_module( grid_mapping, in_shapes, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 359a42e8647b..d28e857525fc 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -187,12 +187,13 @@ def _pallas_call_impl_interpret( ) assert next(dynamic_grid_args_iter, None) is None with grid_mapping.trace_env(): - discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ()) + discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if debug: print(discharged_jaxpr) out = _initialize_output_vals(out_shapes, args, input_output_aliases) scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore - # invars: [*scalar_prefetch, *inputs, *outputs, *scratch] + # invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch] + # args now contains: *consts, *inputs, *outputs num_invars = len(jaxpr.invars) num_inputs_outputs = ( num_invars @@ -243,6 +244,10 @@ def _pallas_call_impl_interpret( else: # Base case is always one iteration when grid is () num_iterations = 1 + + # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) + # i:int32 is the interation index + # loop_idx: tuple[int32] are the program ids for each grid axis def cond(carry): i, *_ = carry return i < num_iterations @@ -256,7 +261,7 @@ def body(carry): carry, scratch = split_list(carry, [num_inout]) with pallas_core.grid_env(local_grid_env): start_indices = [ - None if bm is None else bm.compute_start_indices(loop_idx, *scalars) + None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) for bm in grid_mapping.block_mappings] blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry, is_indexing_dim) @@ -269,13 +274,14 @@ def body(carry): len(blocks), len(scratch_values), ) - blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars, + blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch) blocks = blocks[grid_mapping.num_index_operands:] blocks, out_scratch = split_list(blocks, [num_inout]) carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes, carry, blocks, is_indexing_dim) return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch) + (_, _, *carry) = lax.while_loop( cond, body, (jnp.int32(0), grid_start_indices, *carry) ) @@ -900,11 +906,34 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug) if consts: - jaxpr = state_utils.hoist_consts_to_refs(jaxpr) # Pad ``block_mappings`` to account for the hoisted constants. + # The constants will be right after the index operands and just before + # the real inputs and outputs. + jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=grid_mapping.num_index_operands) + num_constant_operands = len(consts) + # TODO(necula): refactor grid_mapping to remove this code duplication + grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid) + if grid_mapping.num_index_operands: + grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore + # Create args, kwargs pytree def + grid_tree = tree_util.tree_structure((tuple(grid_avals), {})) + const_block_mappings = [] + for c_idx, c in enumerate(consts): + const_block_mapping = pallas_core._convert_block_spec_to_block_mapping( + grid_avals, + pallas_core.BlockSpec(None, None), + path=(tree_util.SequenceKey(c_idx),), + aval=jax_core.ShapedArray(c.shape, c.dtype), + in_tree=grid_tree, + grid=grid_mapping.grid, + mapped_dims=(), + what="consts", + ) + const_block_mappings.append(const_block_mapping) + grid_mapping = grid_mapping.replace( - block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)), - num_constant_operands=len(consts), + block_mappings=(*const_block_mappings, *grid_mapping.block_mappings), + num_constant_operands=num_constant_operands, ) return grid_mapping, jaxpr, consts, out_tree_thunk() @@ -1105,8 +1134,9 @@ def wrapped(*args): f"and to output{tree_util.keystr(out_paths[o_idx])} with " f"a different abstract value {out_aval}.") + index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) out_flat = pallas_call_p.bind( - *dynamic_grid_bounds, *consts, *flat_args, + *dynamic_grid_bounds, *index_args, *consts, *rest_args, jaxpr=jaxpr, name=name, in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in flat_args), diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index d8eaa8d28348..5653dcf6a6b7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -251,7 +251,7 @@ def _new_ir_context() -> ir.Context: def lower_jaxpr_to_triton_module( jaxpr: jax_core.Jaxpr, - in_shapes, + in_out_shapes, grid_mapping: GridMapping, name: str, platform: str @@ -301,6 +301,10 @@ def lower_jaxpr_to_triton_module( functools.partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings, ) + consts_shapes = [ + jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype) + for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands] + ] block_infos = [ BlockInfo( jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype), @@ -310,7 +314,7 @@ def lower_jaxpr_to_triton_module( if block_mapping is not None else None for shape_dtype, block_mapping, start_idx in zip( - (*in_shapes, *[()] * grid_mapping.num_constant_operands), + (*consts_shapes, *in_out_shapes), grid_mapping.block_mappings, start_indices, ) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index e6d521692ec2..800b328f28e9 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -54,6 +54,7 @@ def pallas_call_lowering( compiler_params: dict[str, Any], ): if interpret: + # TODO(necula): is this branch still needed? return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *in_nodes, @@ -72,6 +73,10 @@ def pallas_call_lowering( raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" ) + if grid_mapping.num_index_operands: + raise NotImplementedError( + "scalar prefetch not implemented in the Triton backend" + ) triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.pop("num_warps", 4) [lowering_platform] = ctx.platforms or ctx.module_context.platforms diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index ac40d1358f85..10f3c9ee7b00 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -150,7 +150,6 @@ def pallas_call(self, *args, **kwargs): class PallasCallTest(PallasBaseTest): - def test_add_one(self): if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: # TODO: assertion failures on CPU in 64-bit mode @@ -468,19 +467,22 @@ def kernel(o_ref): def test_hoisted_consts(self): # See https://github.com/google/jax/issues/21557. - if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - self.skipTest("On TPU the test works only in interpret mode") - x = jnp.zeros(32) - indices = jnp.arange(4).reshape((2, 2)) + # to_store will be hoisted as a constant. Choose distinct shapes from in/outs. + to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) + x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype), + grid=(2,), + in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))], + out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)), ) def kernel(src, dst): - dst[indices] = src[indices] + dst[0:1] = to_store - jax.block_until_ready(kernel(x)) + res = kernel(x) + self.assertAllClose(res[0:1], to_store) def test_vector_slicing(self): if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: @@ -744,6 +746,17 @@ def test_pallas_call_index_map_wrong_number_of_results(self): "Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."): f(a) + def test_pallas_call_index_map_captures_consts(self): + a = np.arange(256, dtype=np.int32) + index_map_result = np.array([0], dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=a, + in_specs=[pl.BlockSpec((4,), lambda: index_map_result)]) + with self.assertRaisesRegex( + NotImplementedError, + "Index map for input\\[0\\] captures constants"): + f(a) + def test_pallas_call_out_specs_mismatch_shape(self): a = np.arange(256, dtype=np.int32) f = self.pallas_call(lambda x_ref, o1_ref: None, diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index b428705427e8..1241b8d6cdd9 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -115,6 +115,40 @@ def body(_, x_ref, o_ref): )(s, x) np.testing.assert_array_equal(out, x) + @jtu.parameterized_filterable( + kwargs=[ + dict(with_scratch=with_scratch) + for with_scratch in [True, False] + ] + ) + def test_scalar_prefetch_with_scratch_and_hoisted_const(self, *, with_scratch): + # to_store will be hoisted as constants. Choose distinct shapes from in/outs. + to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) + x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) + + s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0 + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(2,), + in_specs=[pl.BlockSpec((8, 128), + lambda i, s_ref: (pl.load(s_ref, (i,)), 0))], + out_specs=pl.BlockSpec((32, 128), + lambda i, s_ref: (pl.load(s_ref, i), 0)), + scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if with_scratch + else []), + ), + ) + def kernel(s_ref, src, dst, *scratch_refs): + store_idx = s_ref[pl.program_id(0)] + pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store) + + res = kernel(s, x) + self.assertAllClose(res[0:1], to_store) + self.assertAllClose(res[33:34], to_store) + def test_block_spec_with_wrong_block_shape_errors(self): def body(x_ref, o_ref): o_ref[...] = x_ref[...]