Skip to content

Commit

Permalink
[pallas] Disallow capturing of consts by kernel functions.
Browse files Browse the repository at this point in the history
Previously this was allowed, but until recently (jax-ml#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
  • Loading branch information
gnecula committed Jul 31, 2024
1 parent 35ba6f7 commit 987bf33
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 71 deletions.
2 changes: 2 additions & 0 deletions docs/pallas/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Remember to align the itemized text with the first line of an item within a list
## Released with jax 0.4.32

* Changes
* The kernel function is not allowed to close over constants. Instead, all the needed arrays
must be passed as inputs, with proper block specs ({jax-issue}`#22746`).

* Deprecations

Expand Down
34 changes: 14 additions & 20 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,11 @@ class GridMapping:
Encodes the calling conventions of the pallas_call primitive, the kernel,
and the index maps.
The pallas_call is invoked with: ``*dynamic_grid_sizes, *index, *consts, *inputs``.
The pallas_call is invoked with: ``*dynamic_grid_sizes, *index, *inputs``.
The ``index`` operands are for the scalar prefetch.
The ``consts`` are constants captured by the kernel function.
The kernel function is invoked with:
``*index, *consts, *inputs, *scratch``.
``*index, *inputs, *scratch``.
The index map functions are invoked with:
``*program_ids, *index``.
Expand All @@ -357,7 +356,7 @@ class GridMapping:
grid: GridMappingGrid
grid_names: tuple[Hashable, ...] | None

# Block mappings for: *consts, *inputs, *outputs
# Block mappings for: *inputs, *outputs
block_mappings: tuple[BlockMapping, ...]
# The inputs for tracing the index map: the tree and the flat avals
index_map_tree: tree_util.PyTreeDef
Expand All @@ -366,17 +365,14 @@ class GridMapping:
vmapped_dims: tuple[int, ...]

num_index_operands: int
# Number of captured constants hoisted to operands.
num_constant_operands: int
num_inputs: int
num_outputs: int
num_scratch_operands: int

def check_invariants(self) -> None:
if not config.enable_checks.value: return
assert (len(self.block_mappings) ==
self.num_constant_operands + self.num_inputs + self.num_outputs), (
self.num_constant_operands, self.num_inputs, self.num_outputs,
assert (len(self.block_mappings) == self.num_inputs + self.num_outputs), (
self.num_inputs, self.num_outputs,
self.block_mappings
)
# index_map_avals = int32[] * len(self.grid) + index_operands
Expand Down Expand Up @@ -443,21 +439,21 @@ def trace_env(self):
@property
def slice_index_ops(self):
"""Returns a slice object to select the index operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
This works on a sequence that contains *index, *ins, *outs, *scratch.
"""
return slice(0, self.num_index_operands)

@property
def slice_block_ops(self):
"""Returns a slice to select all but the index operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
This works on a sequence that contains *index, *ins, *outs, *scratch.
"""
return slice(self.num_index_operands, None)

@property
def slice_scratch_ops(self):
"""Returns a slice object to select the scratch operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
This works on a sequence that contains *index, *ins, *outs, *scratch.
"""
if self.num_scratch_operands:
return slice(-self.num_scratch_operands, None)
Expand All @@ -466,22 +462,21 @@ def slice_scratch_ops(self):

@property
def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
"""The shapes of *index, *consts, *inputs."""
"""The shapes of *index, *inputs."""
index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape,
ia.inner_aval.dtype)
for ia in self.index_map_avals[len(self.grid):])
consts_inputs_shapes = (
inputs_shapes = (
bm.array_shape_dtype
for bm in self.block_mappings[
:self.num_constant_operands + self.num_inputs])
return itertools.chain(index_shapes, consts_inputs_shapes)
for bm in self.block_mappings[:self.num_inputs])
return itertools.chain(index_shapes, inputs_shapes)

@property
def block_mappings_output(self) -> Iterable[BlockMapping]:
return itertools.islice(
self.block_mappings,
self.num_constant_operands + self.num_inputs,
self.num_constant_operands + self.num_inputs + self.num_outputs)
self.num_inputs,
self.num_inputs + self.num_outputs)

@property
def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
Expand Down Expand Up @@ -742,7 +737,6 @@ def get_grid_mapping(
index_map_tree=index_map_tree,
vmapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
num_constant_operands=0, # Fixed up later
num_inputs=len(flat_in_specs),
num_outputs=len(flat_out_specs),
num_scratch_operands=num_flat_scratch_operands,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
# TODO(necula): clean this using new grid_mapping helpers
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
# jaxpr has signature [*scalar_prefetch, *in_ops, *out_ops, *scratch]
num_operands = (
len(self.jaxpr.invars)
- num_scalar_prefetch
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def pallas_call_lowering(
if debug:
print(jaxpr)
print(grid_mapping)
if grid_mapping.num_constant_operands:
raise NotImplementedError(
"captured consts not supported in the Mosaic GPU backend"
)

lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
jaxpr,
Expand Down
40 changes: 11 additions & 29 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from jax._src.pallas import core as pallas_core
from jax._src.pallas.primitives import uninitialized_value
from jax._src.state import discharge as state_discharge
from jax._src.state import utils as state_utils
from jax._src.util import (
safe_map,
safe_zip,
Expand Down Expand Up @@ -841,7 +840,8 @@ def _trace_kernel_to_jaxpr(fun: Callable,
grid_mapping: GridMapping,
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
kernel_in_tree: tree_util.PyTreeDef,
interpret: bool):
interpret: bool
) -> jax_core.ClosedJaxpr:
if interpret:
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
kernel_avals))
Expand All @@ -852,38 +852,20 @@ def _trace_kernel_to_jaxpr(fun: Callable,
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
if consts:
# Pad ``block_mappings`` to account for the hoisted constants.
# The constants will be right after the index operands and just before
# the real inputs and outputs.
jaxpr = state_utils.hoist_consts_to_refs(
jaxpr,
index=grid_mapping.num_index_operands,
make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None))
num_constant_operands = len(consts)
const_block_mappings = []
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
pallas_core.BlockSpec(None, None),
origin=f"consts[{c_idx}]",
array_aval=jax_core.ShapedArray(c.shape, c.dtype),
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=(),
)
const_block_mappings.append(const_block_mapping)
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
for c in consts]
raise ValueError(
f"The kernel function {fun_src_info} in a "
"pallas_call should not capture constants. You should pass them "
f"as inputs. It captures constants of shapes: {consts_avals}")

grid_mapping = grid_mapping.replace(
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
num_constant_operands=num_constant_operands,
)
kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
f"The kernel function {fun_src_info} in a "
f"pallas_call should return None. "
f"It returns a PyTree: {kernel_out_tree}")
return grid_mapping, jaxpr, consts
return jaxpr

def _extract_function_name(f: Callable, name: str | None) -> str:
if name is None:
Expand Down Expand Up @@ -1095,7 +1077,7 @@ def wrapped(*args):
flat_in_avals, in_tree, in_origins,
flat_out_avals, out_tree, out_origins)
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
jaxpr = _trace_kernel_to_jaxpr(
kernel, kernel_src_info,
grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
interpret=interpret)
Expand All @@ -1122,7 +1104,7 @@ def wrapped(*args):

index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
*dynamic_grid_bounds, *index_args, *rest_args,
jaxpr=jaxpr, name=name,
debug=debug,
interpret=interpret,
Expand Down
6 changes: 3 additions & 3 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def test_weak_dtype(self, fn, dtype):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
)
def kernel(x_ref, o_ref):
o_ref[...] = fn(x_ref[...], y)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = fn(x_ref[...], y_ref[...])

x = jnp.full((8, 128), 4, dtype=dtype)
y = jnp.full((8, 128), 2 if jnp.issubdtype(dtype, jnp.integer) else 2.0,
dtype=dtype)
np.testing.assert_allclose(kernel(x), fn(x, y))
np.testing.assert_allclose(kernel(x, y), fn(x, y))

@parameterized.named_parameters(
('integer_1_1', (1, 1)),
Expand Down
6 changes: 4 additions & 2 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,10 @@ def test_hoisted_consts(self):
def kernel(src, dst):
dst[0:1] = to_store

res = kernel(x)
self.assertAllClose(res[0:1], to_store)
with self.assertRaisesRegex(
ValueError,
"The kernel function .* should not capture constants"):
kernel(x)

def test_vector_slicing(self):
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
Expand Down
8 changes: 4 additions & 4 deletions tests/pallas/pallas_vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def add_one(x_ref, o_ref):
np.testing.assert_allclose(out, out_ref)

def test_vmap_with_hoisted_consts(self):
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(4 * 16 * 128, dtype=np.float32).reshape((4, 16, 128))

Expand All @@ -146,9 +145,10 @@ def test_vmap_with_hoisted_consts(self):
def kernel(src, dst):
dst[0:1] = to_store

res = kernel(x)
for i in range(x.shape[0]):
self.assertAllClose(res[i, 0:1], to_store)
with self.assertRaisesRegex(
ValueError,
"The kernel function .* should not capture constants"):
kernel(x)

def test_vmap_of_kernel_with_input_output_aliases(self):
@functools.partial(
Expand Down
16 changes: 8 additions & 8 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,14 @@ def test_scalar_prefetch_calling_convention(
# dynamic_grid_dims, index, inputs, outputs, scratch.
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64")
# to_store will be hoisted as constants. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
if vmap:
x_shape = (4, 16, 128)
else:
x_shape = (16, 128)
x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape)

def f(x, grid_size):
def f(x, grid_size, to_store):
s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0
@functools.partial(
self.pallas_call,
Expand All @@ -147,29 +146,30 @@ def f(x, grid_size):
num_scalar_prefetch=1, # 1 pytree
grid=(grid_size,),
in_specs=[pl.BlockSpec((8, 128),
lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0))],
lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0)),
pl.BlockSpec((1, 128), lambda i, s_ref: (0, 0))],
out_specs=pl.BlockSpec((32, 128),
lambda i, s_ref: (pl.load(s_ref[0], i), 0)),
scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch
else []),
),
)
def kernel(s_refs, src, dst, *scratch_refs):
def kernel(s_refs, src, to_store, dst, *scratch_refs):
s_ref, s2, s3 = s_refs
assert s_ref.shape == (2,)
assert s2.shape == (3,)
assert s3 is None
store_idx = s_ref[pl.program_id(0)]
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store)
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...])
# Pass a pytree of scalar
return kernel((s, np.arange(3, dtype=np.int32), None), x)
return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store)

if dyn_grid:
f = jax.jit(f)
if vmap:
res = jax.vmap(lambda x: f(x, 2))(x)
res = jax.vmap(lambda x: f(x, 2, to_store))(x)
else:
res = f(x, 2)
res = f(x, 2, to_store)

if vmap:
for i in range(x.shape[0]):
Expand Down

0 comments on commit 987bf33

Please sign in to comment.