diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 9c6bfd2164cf..0672521bd95f 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -14,6 +14,8 @@ Remember to align the itemized text with the first line of an item within a list ## Released with jax 0.4.32 * Changes + * The kernel function is not allowed to close over constants. Instead, all the needed arrays + must be passed as inputs, with proper block specs ({jax-issue}`#22746`). * Deprecations diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 7d8881635b5c..bb18dd353569 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -342,12 +342,11 @@ class GridMapping: Encodes the calling conventions of the pallas_call primitive, the kernel, and the index maps. - The pallas_call is invoked with: ``*dynamic_grid_sizes, *index, *consts, *inputs``. + The pallas_call is invoked with: ``*dynamic_grid_sizes, *index, *inputs``. The ``index`` operands are for the scalar prefetch. - The ``consts`` are constants captured by the kernel function. The kernel function is invoked with: - ``*index, *consts, *inputs, *scratch``. + ``*index, *inputs, *scratch``. The index map functions are invoked with: ``*program_ids, *index``. @@ -357,7 +356,7 @@ class GridMapping: grid: GridMappingGrid grid_names: tuple[Hashable, ...] | None - # Block mappings for: *consts, *inputs, *outputs + # Block mappings for: *inputs, *outputs block_mappings: tuple[BlockMapping, ...] # The inputs for tracing the index map: the tree and the flat avals index_map_tree: tree_util.PyTreeDef @@ -366,17 +365,14 @@ class GridMapping: vmapped_dims: tuple[int, ...] num_index_operands: int - # Number of captured constants hoisted to operands. - num_constant_operands: int num_inputs: int num_outputs: int num_scratch_operands: int def check_invariants(self) -> None: if not config.enable_checks.value: return - assert (len(self.block_mappings) == - self.num_constant_operands + self.num_inputs + self.num_outputs), ( - self.num_constant_operands, self.num_inputs, self.num_outputs, + assert (len(self.block_mappings) == self.num_inputs + self.num_outputs), ( + self.num_inputs, self.num_outputs, self.block_mappings ) # index_map_avals = int32[] * len(self.grid) + index_operands @@ -443,21 +439,21 @@ def trace_env(self): @property def slice_index_ops(self): """Returns a slice object to select the index operands to a kernel. - This works on a sequence that contains *index, *consts, *ins, *outs, *scratch. + This works on a sequence that contains *index, *ins, *outs, *scratch. """ return slice(0, self.num_index_operands) @property def slice_block_ops(self): """Returns a slice to select all but the index operands to a kernel. - This works on a sequence that contains *index, *consts, *ins, *outs, *scratch. + This works on a sequence that contains *index, *ins, *outs, *scratch. """ return slice(self.num_index_operands, None) @property def slice_scratch_ops(self): """Returns a slice object to select the scratch operands to a kernel. - This works on a sequence that contains *index, *consts, *ins, *outs, *scratch. + This works on a sequence that contains *index, *ins, *outs, *scratch. """ if self.num_scratch_operands: return slice(-self.num_scratch_operands, None) @@ -466,22 +462,21 @@ def slice_scratch_ops(self): @property def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: - """The shapes of *index, *consts, *inputs.""" + """The shapes of *index, *inputs.""" index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape, ia.inner_aval.dtype) for ia in self.index_map_avals[len(self.grid):]) - consts_inputs_shapes = ( + inputs_shapes = ( bm.array_shape_dtype - for bm in self.block_mappings[ - :self.num_constant_operands + self.num_inputs]) - return itertools.chain(index_shapes, consts_inputs_shapes) + for bm in self.block_mappings[:self.num_inputs]) + return itertools.chain(index_shapes, inputs_shapes) @property def block_mappings_output(self) -> Iterable[BlockMapping]: return itertools.islice( self.block_mappings, - self.num_constant_operands + self.num_inputs, - self.num_constant_operands + self.num_inputs + self.num_outputs) + self.num_inputs, + self.num_inputs + self.num_outputs) @property def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: @@ -742,7 +737,6 @@ def get_grid_mapping( index_map_tree=index_map_tree, vmapped_dims=(), num_index_operands=num_flat_scalar_prefetch, - num_constant_operands=0, # Fixed up later num_inputs=len(flat_in_specs), num_outputs=len(flat_out_specs), num_scratch_operands=num_flat_scratch_operands, diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 14ea00e624c4..421dab06be6e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -300,7 +300,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping, # TODO(necula): clean this using new grid_mapping helpers num_scalar_prefetch = grid_mapping.num_index_operands num_scratch = grid_mapping.num_scratch_operands - # jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch] + # jaxpr has signature [*scalar_prefetch, *in_ops, *out_ops, *scratch] num_operands = ( len(self.jaxpr.invars) - num_scalar_prefetch diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 2f117625beb1..24f26f53f213 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -50,10 +50,7 @@ def pallas_call_lowering( if debug: print(jaxpr) print(grid_mapping) - if grid_mapping.num_constant_operands: - raise NotImplementedError( - "captured consts not supported in the Mosaic GPU backend" - ) + lowering_result = lowering.lower_jaxpr_to_module( grid_mapping, jaxpr, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index c2d93fd77f23..d1b69f424e6c 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -37,7 +37,6 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas.primitives import uninitialized_value from jax._src.state import discharge as state_discharge -from jax._src.state import utils as state_utils from jax._src.util import ( safe_map, safe_zip, @@ -841,7 +840,8 @@ def _trace_kernel_to_jaxpr(fun: Callable, grid_mapping: GridMapping, kernel_avals: tuple[pallas_core.AbstractMemRef, ...], kernel_in_tree: tree_util.PyTreeDef, - interpret: bool): + interpret: bool + ) -> jax_core.ClosedJaxpr: if interpret: kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval, kernel_avals)) @@ -852,38 +852,20 @@ def _trace_kernel_to_jaxpr(fun: Callable, jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, kernel_avals, debug) if consts: - # Pad ``block_mappings`` to account for the hoisted constants. - # The constants will be right after the index operands and just before - # the real inputs and outputs. - jaxpr = state_utils.hoist_consts_to_refs( - jaxpr, - index=grid_mapping.num_index_operands, - make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None)) - num_constant_operands = len(consts) - const_block_mappings = [] - for c_idx, c in enumerate(consts): - const_block_mapping = pallas_core._convert_block_spec_to_block_mapping( - pallas_core.BlockSpec(None, None), - origin=f"consts[{c_idx}]", - array_aval=jax_core.ShapedArray(c.shape, c.dtype), - index_map_avals=grid_mapping.index_map_avals, - index_map_tree=grid_mapping.index_map_tree, - grid=grid_mapping.grid, - mapped_dims=(), - ) - const_block_mappings.append(const_block_mapping) + consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c)) + for c in consts] + raise ValueError( + f"The kernel function {fun_src_info} in a " + "pallas_call should not capture constants. You should pass them " + f"as inputs. It captures constants of shapes: {consts_avals}") - grid_mapping = grid_mapping.replace( - block_mappings=(*const_block_mappings, *grid_mapping.block_mappings), - num_constant_operands=num_constant_operands, - ) kernel_out_tree = out_tree_thunk() if kernel_out_tree != tree_util.tree_structure(None): raise ValueError( f"The kernel function {fun_src_info} in a " f"pallas_call should return None. " f"It returns a PyTree: {kernel_out_tree}") - return grid_mapping, jaxpr, consts + return jaxpr def _extract_function_name(f: Callable, name: str | None) -> str: if name is None: @@ -1095,7 +1077,7 @@ def wrapped(*args): flat_in_avals, in_tree, in_origins, flat_out_avals, out_tree, out_origins) flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals) - grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr( + jaxpr = _trace_kernel_to_jaxpr( kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals), kernel_in_tree, interpret=interpret) @@ -1122,7 +1104,7 @@ def wrapped(*args): index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) out_flat = pallas_call_p.bind( - *dynamic_grid_bounds, *index_args, *consts, *rest_args, + *dynamic_grid_bounds, *index_args, *rest_args, jaxpr=jaxpr, name=name, debug=debug, interpret=interpret, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index eb7915d12380..464bd451d70e 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -96,13 +96,13 @@ def test_weak_dtype(self, fn, dtype): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype), ) - def kernel(x_ref, o_ref): - o_ref[...] = fn(x_ref[...], y) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = fn(x_ref[...], y_ref[...]) x = jnp.full((8, 128), 4, dtype=dtype) y = jnp.full((8, 128), 2 if jnp.issubdtype(dtype, jnp.integer) else 2.0, dtype=dtype) - np.testing.assert_allclose(kernel(x), fn(x, y)) + np.testing.assert_allclose(kernel(x, y), fn(x, y)) @parameterized.named_parameters( ('integer_1_1', (1, 1)), diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 0601165f12e7..0ce703808abb 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -488,8 +488,10 @@ def test_hoisted_consts(self): def kernel(src, dst): dst[0:1] = to_store - res = kernel(x) - self.assertAllClose(res[0:1], to_store) + with self.assertRaisesRegex( + ValueError, + "The kernel function .* should not capture constants"): + kernel(x) def test_vector_slicing(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index 8006a78890b4..724285abbbca 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -131,7 +131,6 @@ def add_one(x_ref, o_ref): np.testing.assert_allclose(out, out_ref) def test_vmap_with_hoisted_consts(self): - # to_store will be hoisted as a constant. Choose distinct shapes from in/outs. to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) x = np.arange(4 * 16 * 128, dtype=np.float32).reshape((4, 16, 128)) @@ -146,9 +145,10 @@ def test_vmap_with_hoisted_consts(self): def kernel(src, dst): dst[0:1] = to_store - res = kernel(x) - for i in range(x.shape[0]): - self.assertAllClose(res[i, 0:1], to_store) + with self.assertRaisesRegex( + ValueError, + "The kernel function .* should not capture constants"): + kernel(x) def test_vmap_of_kernel_with_input_output_aliases(self): @functools.partial( diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index c5356abf110c..40aa30c3af49 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -130,7 +130,6 @@ def test_scalar_prefetch_calling_convention( # dynamic_grid_dims, index, inputs, outputs, scratch. if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64") - # to_store will be hoisted as constants. Choose distinct shapes from in/outs. to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) if vmap: x_shape = (4, 16, 128) @@ -138,7 +137,7 @@ def test_scalar_prefetch_calling_convention( x_shape = (16, 128) x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape) - def f(x, grid_size): + def f(x, grid_size, to_store): s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0 @functools.partial( self.pallas_call, @@ -147,29 +146,30 @@ def f(x, grid_size): num_scalar_prefetch=1, # 1 pytree grid=(grid_size,), in_specs=[pl.BlockSpec((8, 128), - lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0))], + lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0)), + pl.BlockSpec((1, 128), lambda i, s_ref: (0, 0))], out_specs=pl.BlockSpec((32, 128), lambda i, s_ref: (pl.load(s_ref[0], i), 0)), scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch else []), ), ) - def kernel(s_refs, src, dst, *scratch_refs): + def kernel(s_refs, src, to_store, dst, *scratch_refs): s_ref, s2, s3 = s_refs assert s_ref.shape == (2,) assert s2.shape == (3,) assert s3 is None store_idx = s_ref[pl.program_id(0)] - pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store) + pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...]) # Pass a pytree of scalar - return kernel((s, np.arange(3, dtype=np.int32), None), x) + return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store) if dyn_grid: f = jax.jit(f) if vmap: - res = jax.vmap(lambda x: f(x, 2))(x) + res = jax.vmap(lambda x: f(x, 2, to_store))(x) else: - res = f(x, 2) + res = f(x, 2, to_store) if vmap: for i in range(x.shape[0]):