Skip to content

Commit

Permalink
Merge pull request #22746 from gnecula:pallas_consts
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658050734
  • Loading branch information
jax authors committed Jul 31, 2024
2 parents 858dc54 + 987bf33 commit d696813
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 71 deletions.
2 changes: 2 additions & 0 deletions docs/pallas/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Remember to align the itemized text with the first line of an item within a list
## Released with jax 0.4.32

* Changes
* The kernel function is not allowed to close over constants. Instead, all the needed arrays
must be passed as inputs, with proper block specs ({jax-issue}`#22746`).

* Deprecations

Expand Down
34 changes: 14 additions & 20 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,11 @@ class GridMapping:
Encodes the calling conventions of the pallas_call primitive, the kernel,
and the index maps.
The pallas_call is invoked with: ``*dynamic_grid_sizes, *index, *consts, *inputs``.
The pallas_call is invoked with: ``*dynamic_grid_sizes, *index, *inputs``.
The ``index`` operands are for the scalar prefetch.
The ``consts`` are constants captured by the kernel function.
The kernel function is invoked with:
``*index, *consts, *inputs, *scratch``.
``*index, *inputs, *scratch``.
The index map functions are invoked with:
``*program_ids, *index``.
Expand All @@ -357,7 +356,7 @@ class GridMapping:
grid: GridMappingGrid
grid_names: tuple[Hashable, ...] | None

# Block mappings for: *consts, *inputs, *outputs
# Block mappings for: *inputs, *outputs
block_mappings: tuple[BlockMapping, ...]
# The inputs for tracing the index map: the tree and the flat avals
index_map_tree: tree_util.PyTreeDef
Expand All @@ -366,17 +365,14 @@ class GridMapping:
vmapped_dims: tuple[int, ...]

num_index_operands: int
# Number of captured constants hoisted to operands.
num_constant_operands: int
num_inputs: int
num_outputs: int
num_scratch_operands: int

def check_invariants(self) -> None:
if not config.enable_checks.value: return
assert (len(self.block_mappings) ==
self.num_constant_operands + self.num_inputs + self.num_outputs), (
self.num_constant_operands, self.num_inputs, self.num_outputs,
assert (len(self.block_mappings) == self.num_inputs + self.num_outputs), (
self.num_inputs, self.num_outputs,
self.block_mappings
)
# index_map_avals = int32[] * len(self.grid) + index_operands
Expand Down Expand Up @@ -443,21 +439,21 @@ def trace_env(self):
@property
def slice_index_ops(self):
"""Returns a slice object to select the index operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
This works on a sequence that contains *index, *ins, *outs, *scratch.
"""
return slice(0, self.num_index_operands)

@property
def slice_block_ops(self):
"""Returns a slice to select all but the index operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
This works on a sequence that contains *index, *ins, *outs, *scratch.
"""
return slice(self.num_index_operands, None)

@property
def slice_scratch_ops(self):
"""Returns a slice object to select the scratch operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
This works on a sequence that contains *index, *ins, *outs, *scratch.
"""
if self.num_scratch_operands:
return slice(-self.num_scratch_operands, None)
Expand All @@ -466,22 +462,21 @@ def slice_scratch_ops(self):

@property
def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
"""The shapes of *index, *consts, *inputs."""
"""The shapes of *index, *inputs."""
index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape,
ia.inner_aval.dtype)
for ia in self.index_map_avals[len(self.grid):])
consts_inputs_shapes = (
inputs_shapes = (
bm.array_shape_dtype
for bm in self.block_mappings[
:self.num_constant_operands + self.num_inputs])
return itertools.chain(index_shapes, consts_inputs_shapes)
for bm in self.block_mappings[:self.num_inputs])
return itertools.chain(index_shapes, inputs_shapes)

@property
def block_mappings_output(self) -> Iterable[BlockMapping]:
return itertools.islice(
self.block_mappings,
self.num_constant_operands + self.num_inputs,
self.num_constant_operands + self.num_inputs + self.num_outputs)
self.num_inputs,
self.num_inputs + self.num_outputs)

@property
def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
Expand Down Expand Up @@ -742,7 +737,6 @@ def get_grid_mapping(
index_map_tree=index_map_tree,
vmapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
num_constant_operands=0, # Fixed up later
num_inputs=len(flat_in_specs),
num_outputs=len(flat_out_specs),
num_scratch_operands=num_flat_scratch_operands,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
# TODO(necula): clean this using new grid_mapping helpers
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
# jaxpr has signature [*scalar_prefetch, *in_ops, *out_ops, *scratch]
num_operands = (
len(self.jaxpr.invars)
- num_scalar_prefetch
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def pallas_call_lowering(
if debug:
print(jaxpr)
print(grid_mapping)
if grid_mapping.num_constant_operands:
raise NotImplementedError(
"captured consts not supported in the Mosaic GPU backend"
)

lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
jaxpr,
Expand Down
40 changes: 11 additions & 29 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from jax._src.pallas import core as pallas_core
from jax._src.pallas.primitives import uninitialized_value
from jax._src.state import discharge as state_discharge
from jax._src.state import utils as state_utils
from jax._src.util import (
safe_map,
safe_zip,
Expand Down Expand Up @@ -841,7 +840,8 @@ def _trace_kernel_to_jaxpr(fun: Callable,
grid_mapping: GridMapping,
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
kernel_in_tree: tree_util.PyTreeDef,
interpret: bool):
interpret: bool
) -> jax_core.ClosedJaxpr:
if interpret:
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
kernel_avals))
Expand All @@ -852,38 +852,20 @@ def _trace_kernel_to_jaxpr(fun: Callable,
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
if consts:
# Pad ``block_mappings`` to account for the hoisted constants.
# The constants will be right after the index operands and just before
# the real inputs and outputs.
jaxpr = state_utils.hoist_consts_to_refs(
jaxpr,
index=grid_mapping.num_index_operands,
make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None))
num_constant_operands = len(consts)
const_block_mappings = []
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
pallas_core.BlockSpec(None, None),
origin=f"consts[{c_idx}]",
array_aval=jax_core.ShapedArray(c.shape, c.dtype),
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=(),
)
const_block_mappings.append(const_block_mapping)
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
for c in consts]
raise ValueError(
f"The kernel function {fun_src_info} in a "
"pallas_call should not capture constants. You should pass them "
f"as inputs. It captures constants of shapes: {consts_avals}")

grid_mapping = grid_mapping.replace(
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
num_constant_operands=num_constant_operands,
)
kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
f"The kernel function {fun_src_info} in a "
f"pallas_call should return None. "
f"It returns a PyTree: {kernel_out_tree}")
return grid_mapping, jaxpr, consts
return jaxpr

def _extract_function_name(f: Callable, name: str | None) -> str:
if name is None:
Expand Down Expand Up @@ -1095,7 +1077,7 @@ def wrapped(*args):
flat_in_avals, in_tree, in_origins,
flat_out_avals, out_tree, out_origins)
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
jaxpr = _trace_kernel_to_jaxpr(
kernel, kernel_src_info,
grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
interpret=interpret)
Expand All @@ -1122,7 +1104,7 @@ def wrapped(*args):

index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
*dynamic_grid_bounds, *index_args, *rest_args,
jaxpr=jaxpr, name=name,
debug=debug,
interpret=interpret,
Expand Down
6 changes: 3 additions & 3 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def test_weak_dtype(self, fn, dtype):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
)
def kernel(x_ref, o_ref):
o_ref[...] = fn(x_ref[...], y)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = fn(x_ref[...], y_ref[...])

x = jnp.full((8, 128), 4, dtype=dtype)
y = jnp.full((8, 128), 2 if jnp.issubdtype(dtype, jnp.integer) else 2.0,
dtype=dtype)
np.testing.assert_allclose(kernel(x), fn(x, y))
np.testing.assert_allclose(kernel(x, y), fn(x, y))

@parameterized.named_parameters(
('integer_1_1', (1, 1)),
Expand Down
6 changes: 4 additions & 2 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,10 @@ def test_hoisted_consts(self):
def kernel(src, dst):
dst[0:1] = to_store

res = kernel(x)
self.assertAllClose(res[0:1], to_store)
with self.assertRaisesRegex(
ValueError,
"The kernel function .* should not capture constants"):
kernel(x)

def test_vector_slicing(self):
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
Expand Down
8 changes: 4 additions & 4 deletions tests/pallas/pallas_vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def add_one(x_ref, o_ref):
np.testing.assert_allclose(out, out_ref)

def test_vmap_with_hoisted_consts(self):
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(4 * 16 * 128, dtype=np.float32).reshape((4, 16, 128))

Expand All @@ -146,9 +145,10 @@ def test_vmap_with_hoisted_consts(self):
def kernel(src, dst):
dst[0:1] = to_store

res = kernel(x)
for i in range(x.shape[0]):
self.assertAllClose(res[i, 0:1], to_store)
with self.assertRaisesRegex(
ValueError,
"The kernel function .* should not capture constants"):
kernel(x)

def test_vmap_of_kernel_with_input_output_aliases(self):
@functools.partial(
Expand Down
16 changes: 8 additions & 8 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,14 @@ def test_scalar_prefetch_calling_convention(
# dynamic_grid_dims, index, inputs, outputs, scratch.
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64")
# to_store will be hoisted as constants. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
if vmap:
x_shape = (4, 16, 128)
else:
x_shape = (16, 128)
x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape)

def f(x, grid_size):
def f(x, grid_size, to_store):
s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0
@functools.partial(
self.pallas_call,
Expand All @@ -147,29 +146,30 @@ def f(x, grid_size):
num_scalar_prefetch=1, # 1 pytree
grid=(grid_size,),
in_specs=[pl.BlockSpec((8, 128),
lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0))],
lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0)),
pl.BlockSpec((1, 128), lambda i, s_ref: (0, 0))],
out_specs=pl.BlockSpec((32, 128),
lambda i, s_ref: (pl.load(s_ref[0], i), 0)),
scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch
else []),
),
)
def kernel(s_refs, src, dst, *scratch_refs):
def kernel(s_refs, src, to_store, dst, *scratch_refs):
s_ref, s2, s3 = s_refs
assert s_ref.shape == (2,)
assert s2.shape == (3,)
assert s3 is None
store_idx = s_ref[pl.program_id(0)]
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store)
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...])
# Pass a pytree of scalar
return kernel((s, np.arange(3, dtype=np.int32), None), x)
return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store)

if dyn_grid:
f = jax.jit(f)
if vmap:
res = jax.vmap(lambda x: f(x, 2))(x)
res = jax.vmap(lambda x: f(x, 2, to_store))(x)
else:
res = f(x, 2)
res = f(x, 2, to_store)

if vmap:
for i in range(x.shape[0]):
Expand Down

0 comments on commit d696813

Please sign in to comment.