Skip to content

Commit

Permalink
[pallas] Fix the handling of captured consts
Browse files Browse the repository at this point in the history
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).

The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.

I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
  • Loading branch information
gnecula committed Jul 21, 2024
1 parent 9632a2d commit bccf6d1
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 25 deletions.
6 changes: 5 additions & 1 deletion jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class BlockMapping:
index_map_jaxpr: jax_core.ClosedJaxpr
indexing_mode: IndexingMode

def compute_start_indices(self, loop_idx, *args):
def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
)
Expand Down Expand Up @@ -344,6 +344,10 @@ def _convert_block_spec_to_block_mapping(
f"{len(aval.shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
)
if consts:
raise NotImplementedError(
f"Index map for {what}{tree_util.keystr(path)} captures constants: "
f"{consts}")
return BlockMapping(
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
)
Expand Down
13 changes: 8 additions & 5 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
self.mapped_dims = grid_mapping.mapped_dims
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch]
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
num_operands = (
len(self.jaxpr.invars)
- num_scalar_prefetch
Expand Down Expand Up @@ -411,9 +411,13 @@ def lower_jaxpr_to_module(
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
else:
invars = invars[grid_mapping.num_index_operands:]
# invars now = *consts, *ins, *outs
avals = tuple(v.aval for v in invars)
block_operand_shapes = (
*in_shapes[grid_mapping.num_index_operands :],
*[jax.ShapeDtypeStruct(v.aval.shape,
v.aval.dtype)
for v in invars[:grid_mapping.num_constant_operands]],
*in_shapes[grid_mapping.num_index_operands:],
*out_shapes,
)
assert len(block_operand_shapes) == len(grid_mapping.block_mappings)
Expand All @@ -425,10 +429,9 @@ def lower_jaxpr_to_module(
raise NotImplementedError(
"BlockSpecs are required on TPU when grid is specified"
)
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
# ANY operands don't support windowing and require empty window_params.
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
if (hasattr(aval, "memory_space") and # happens when AbstractRef are created for consts
aval.memory_space == tpu_core.TPUMemorySpace.ANY):
# We may not require windowing if our block_shape matches the original
# shape or the dimensions are mapped.
requires_windowing = any(
Expand Down
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def pallas_call_tpu_lowering_rule(
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*args,
Expand All @@ -68,7 +69,10 @@ def pallas_call_lowering(
if debug:
print(jaxpr)
print(grid_mapping)

if grid_mapping.num_constant_operands:
raise NotImplementedError(
"captured consts not supported in the Mosaic GPU backend"
)
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
in_shapes,
Expand Down
46 changes: 38 additions & 8 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,13 @@ def _pallas_call_impl_interpret(
)
assert next(dynamic_grid_args_iter, None) is None
with grid_mapping.trace_env():
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
# args now contains: *consts, *inputs, *outputs
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
Expand Down Expand Up @@ -243,6 +244,10 @@ def _pallas_call_impl_interpret(
else:
# Base case is always one iteration when grid is ()
num_iterations = 1

# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
# i:int32 is the interation index
# loop_idx: tuple[int32] are the program ids for each grid axis
def cond(carry):
i, *_ = carry
return i < num_iterations
Expand All @@ -256,7 +261,7 @@ def body(carry):
carry, scratch = split_list(carry, [num_inout])
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
Expand All @@ -269,13 +274,14 @@ def body(carry):
len(blocks),
len(scratch_values),
)
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)

(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
Expand Down Expand Up @@ -900,11 +906,34 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
if consts:
jaxpr = state_utils.hoist_consts_to_refs(jaxpr)
# Pad ``block_mappings`` to account for the hoisted constants.
# The constants will be right after the index operands and just before
# the real inputs and outputs.
jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=grid_mapping.num_index_operands)
num_constant_operands = len(consts)
# TODO(necula): refactor grid_mapping to remove this code duplication
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
if grid_mapping.num_index_operands:
grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
const_block_mappings = []
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
grid_avals,
pallas_core.BlockSpec(None, None),
path=(tree_util.SequenceKey(c_idx),),
aval=jax_core.ShapedArray(c.shape, c.dtype),
in_tree=grid_tree,
grid=grid_mapping.grid,
mapped_dims=(),
what="consts",
)
const_block_mappings.append(const_block_mapping)

grid_mapping = grid_mapping.replace(
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
num_constant_operands=len(consts),
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
num_constant_operands=num_constant_operands,
)
return grid_mapping, jaxpr, consts, out_tree_thunk()

Expand Down Expand Up @@ -1105,8 +1134,9 @@ def wrapped(*args):
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
f"a different abstract value {out_aval}.")

index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
jaxpr=jaxpr, name=name,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _new_ir_context() -> ir.Context:

def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr,
in_shapes,
in_out_shapes,
grid_mapping: GridMapping,
name: str,
platform: str
Expand Down Expand Up @@ -301,6 +301,10 @@ def lower_jaxpr_to_triton_module(
functools.partial(_eval_index_map, ctx, program_ids),
grid_mapping.block_mappings,
)
consts_shapes = [
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype)
for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands]
]
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
Expand All @@ -310,7 +314,7 @@ def lower_jaxpr_to_triton_module(
if block_mapping is not None
else None
for shape_dtype, block_mapping, start_idx in zip(
(*in_shapes, *[()] * grid_mapping.num_constant_operands),
(*consts_shapes, *in_out_shapes),
grid_mapping.block_mappings,
start_indices,
)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*in_nodes,
Expand All @@ -72,6 +73,10 @@ def pallas_call_lowering(
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.pop("num_warps", 4)
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
Expand Down
29 changes: 21 additions & 8 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def pallas_call(self, *args, **kwargs):

class PallasCallTest(PallasBaseTest):


def test_add_one(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
Expand Down Expand Up @@ -468,19 +467,22 @@ def kernel(o_ref):

def test_hoisted_consts(self):
# See https://github.com/google/jax/issues/21557.
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
x = jnp.zeros(32)
indices = jnp.arange(4).reshape((2, 2))
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid=(2,),
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
)
def kernel(src, dst):
dst[indices] = src[indices]
dst[0:1] = to_store

jax.block_until_ready(kernel(x))
res = kernel(x)
self.assertAllClose(res[0:1], to_store)

def test_vector_slicing(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
Expand Down Expand Up @@ -744,6 +746,17 @@ def test_pallas_call_index_map_wrong_number_of_results(self):
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
f(a)

def test_pallas_call_index_map_captures_consts(self):
a = np.arange(256, dtype=np.int32)
index_map_result = np.array([0], dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
with self.assertRaisesRegex(
NotImplementedError,
"Index map for input\\[0\\] captures constants"):
f(a)

def test_pallas_call_out_specs_mismatch_shape(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
Expand Down
34 changes: 34 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,40 @@ def body(_, x_ref, o_ref):
)(s, x)
np.testing.assert_array_equal(out, x)

@jtu.parameterized_filterable(
kwargs=[
dict(with_scratch=with_scratch)
for with_scratch in [True, False]
]
)
def test_scalar_prefetch_with_scratch_and_hoisted_const(self, *, with_scratch):
# to_store will be hoisted as constants. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))

s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(2,),
in_specs=[pl.BlockSpec((8, 128),
lambda i, s_ref: (pl.load(s_ref, (i,)), 0))],
out_specs=pl.BlockSpec((32, 128),
lambda i, s_ref: (pl.load(s_ref, i), 0)),
scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if with_scratch
else []),
),
)
def kernel(s_ref, src, dst, *scratch_refs):
store_idx = s_ref[pl.program_id(0)]
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store)

res = kernel(s, x)
self.assertAllClose(res[0:1], to_store)
self.assertAllClose(res[33:34], to_store)

def test_block_spec_with_wrong_block_shape_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]
Expand Down

0 comments on commit bccf6d1

Please sign in to comment.