Skip to content

Commit

Permalink
[pallas] Fix the handling of captured consts
Browse files Browse the repository at this point in the history
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).

The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.

I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
  • Loading branch information
gnecula committed Jul 22, 2024
1 parent b3469a6 commit 9d9b69f
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 31 deletions.
6 changes: 5 additions & 1 deletion jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class BlockMapping:
index_map_jaxpr: jax_core.ClosedJaxpr
indexing_mode: IndexingMode

def compute_start_indices(self, loop_idx, *args):
def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
)
Expand Down Expand Up @@ -344,6 +344,10 @@ def _convert_block_spec_to_block_mapping(
f"{len(aval.shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
)
if consts:
raise NotImplementedError(
f"Index map for {what}{tree_util.keystr(path)} captures constants: "
f"{consts}")
return BlockMapping(
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
)
Expand Down
13 changes: 8 additions & 5 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
self.mapped_dims = grid_mapping.mapped_dims
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch]
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
num_operands = (
len(self.jaxpr.invars)
- num_scalar_prefetch
Expand Down Expand Up @@ -411,9 +411,13 @@ def lower_jaxpr_to_module(
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
else:
invars = invars[grid_mapping.num_index_operands:]
# invars now = *consts, *ins, *outs
avals = tuple(v.aval for v in invars)
block_operand_shapes = (
*in_shapes[grid_mapping.num_index_operands :],
*[jax.ShapeDtypeStruct(v.aval.shape,
v.aval.dtype)
for v in invars[:grid_mapping.num_constant_operands]],
*in_shapes[grid_mapping.num_index_operands:],
*out_shapes,
)
assert len(block_operand_shapes) == len(grid_mapping.block_mappings)
Expand All @@ -425,10 +429,9 @@ def lower_jaxpr_to_module(
raise NotImplementedError(
"BlockSpecs are required on TPU when grid is specified"
)
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
# ANY operands don't support windowing and require empty window_params.
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
if (hasattr(aval, "memory_space") and # happens when AbstractRef are created for consts
aval.memory_space == tpu_core.TPUMemorySpace.ANY):
# We may not require windowing if our block_shape matches the original
# shape or the dimensions are mapped.
requires_windowing = any(
Expand Down
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def pallas_call_tpu_lowering_rule(
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*args,
Expand All @@ -68,7 +69,10 @@ def pallas_call_lowering(
if debug:
print(jaxpr)
print(grid_mapping)

if grid_mapping.num_constant_operands:
raise NotImplementedError(
"captured consts not supported in the Mosaic GPU backend"
)
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
in_shapes,
Expand Down
58 changes: 47 additions & 11 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,13 @@ def _pallas_call_impl_interpret(
)
assert next(dynamic_grid_args_iter, None) is None
with grid_mapping.trace_env():
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
# args now contains: *consts, *inputs, *outputs
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
Expand Down Expand Up @@ -243,6 +244,10 @@ def _pallas_call_impl_interpret(
else:
# Base case is always one iteration when grid is ()
num_iterations = 1

# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
# i:int32 is the interation index
# loop_idx: tuple[int32] are the program ids for each grid axis
def cond(carry):
i, *_ = carry
return i < num_iterations
Expand All @@ -256,7 +261,7 @@ def body(carry):
carry, scratch = split_list(carry, [num_inout])
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
Expand All @@ -269,13 +274,14 @@ def body(carry):
len(blocks),
len(scratch_values),
)
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)

(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
Expand Down Expand Up @@ -604,13 +610,13 @@ def _maybe_squeeze_out_bdim(
# Ordinarily, adding support for scalar prefetch in vmap would involve
# modifying the block specs in a nontrivial way. However, if we are only
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
# and pretend we were never vmapping over them at all.
# and pretend we were never vmapped over them at all.
if all(
bdim is batching.not_mapped or arg.shape[bdim] == 1
for arg, bdim in zip(scalar_args, scalar_bdims)
):
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
scalar_bdims = [None] * len(scalar_args)
scalar_bdims = [batching.not_mapped] * len(scalar_args)
args = (*scalar_args, *args)
dims = (*scalar_bdims, *bdims)
else:
Expand Down Expand Up @@ -648,6 +654,7 @@ def _maybe_squeeze_out_bdim(
all_dims = list(dims) + [0] * len(out_shapes)

num_index_operands = grid_mapping.num_index_operands
num_constant_operands = grid_mapping.num_constant_operands
num_scratch_operands = grid_mapping.num_scratch_operands

# Only add a batch dimension for the avals that actually have a grid mapping.
Expand All @@ -661,11 +668,16 @@ def _maybe_squeeze_out_bdim(
block_mappings,
)

# TODO(necula): should fix in_shapes to include the consts
dims_no_consts = (
dims[:num_index_operands] +
dims[num_index_operands + num_constant_operands:]
)
batched_in_shapes = tuple(
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
tuple_insert(x.shape, dim, axis_size),
x.dtype)
for x, dim in zip(in_shapes, dims))
for x, dim in zip(in_shapes, dims_no_consts))
batched_out_shapes = tuple(
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
for x in out_shapes)
Expand Down Expand Up @@ -900,11 +912,34 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
if consts:
jaxpr = state_utils.hoist_consts_to_refs(jaxpr)
# Pad ``block_mappings`` to account for the hoisted constants.
# The constants will be right after the index operands and just before
# the real inputs and outputs.
jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=grid_mapping.num_index_operands)
num_constant_operands = len(consts)
# TODO(necula): refactor grid_mapping to remove this code duplication
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
if grid_mapping.num_index_operands:
grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
const_block_mappings = []
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
grid_avals,
pallas_core.BlockSpec(None, None),
path=(tree_util.SequenceKey(c_idx),),
aval=jax_core.ShapedArray(c.shape, c.dtype),
in_tree=grid_tree,
grid=grid_mapping.grid,
mapped_dims=(),
what="consts",
)
const_block_mappings.append(const_block_mapping)

grid_mapping = grid_mapping.replace(
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
num_constant_operands=len(consts),
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
num_constant_operands=num_constant_operands,
)
return grid_mapping, jaxpr, consts, out_tree_thunk()

Expand Down Expand Up @@ -1105,8 +1140,9 @@ def wrapped(*args):
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
f"a different abstract value {out_aval}.")

index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
jaxpr=jaxpr, name=name,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _new_ir_context() -> ir.Context:

def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr,
in_shapes,
in_out_shapes,
grid_mapping: GridMapping,
name: str,
platform: str
Expand Down Expand Up @@ -301,6 +301,10 @@ def lower_jaxpr_to_triton_module(
functools.partial(_eval_index_map, ctx, program_ids),
grid_mapping.block_mappings,
)
consts_shapes = [
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype)
for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands]
]
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
Expand All @@ -310,7 +314,7 @@ def lower_jaxpr_to_triton_module(
if block_mapping is not None
else None
for shape_dtype, block_mapping, start_idx in zip(
(*in_shapes, *[()] * grid_mapping.num_constant_operands),
(*consts_shapes, *in_out_shapes),
grid_mapping.block_mappings,
start_indices,
)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*in_nodes,
Expand All @@ -72,6 +73,10 @@ def pallas_call_lowering(
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.pop("num_warps", 4)
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
Expand Down
29 changes: 21 additions & 8 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def pallas_call(self, *args, **kwargs):

class PallasCallTest(PallasBaseTest):


def test_add_one(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
Expand Down Expand Up @@ -468,19 +467,22 @@ def kernel(o_ref):

def test_hoisted_consts(self):
# See https://github.com/google/jax/issues/21557.
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
x = jnp.zeros(32)
indices = jnp.arange(4).reshape((2, 2))
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid=(2,),
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
)
def kernel(src, dst):
dst[indices] = src[indices]
dst[0:1] = to_store

jax.block_until_ready(kernel(x))
res = kernel(x)
self.assertAllClose(res[0:1], to_store)

def test_vector_slicing(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
Expand Down Expand Up @@ -744,6 +746,17 @@ def test_pallas_call_index_map_wrong_number_of_results(self):
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
f(a)

def test_pallas_call_index_map_captures_consts(self):
a = np.arange(256, dtype=np.int32)
index_map_result = np.array([0], dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
with self.assertRaisesRegex(
NotImplementedError,
"Index map for input\\[0\\] captures constants"):
f(a)

def test_pallas_call_out_specs_mismatch_shape(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
Expand Down
24 changes: 22 additions & 2 deletions tests/pallas/pallas_vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


@jtu.with_config(jax_traceback_filtering="off")
class PallasTest(jtu.JaxTestCase):
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False

def setUp(self):
Expand All @@ -58,7 +58,7 @@ def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)


class PallasCallVmapTest(PallasTest):
class PallasCallVmapTest(PallasBaseTest):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -130,6 +130,26 @@ def add_one(x_ref, o_ref):
out_ref = jnp.arange(1, 9).reshape((4, 2))
np.testing.assert_allclose(out, out_ref)

def test_vmap_with_hoisted_consts(self):
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(4 * 16 * 128, dtype=np.float32).reshape((4, 16, 128))

@jax.vmap
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid=(2,),
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
)
def kernel(src, dst):
dst[0:1] = to_store

res = kernel(x)
for i in range(x.shape[0]):
self.assertAllClose(res[i, 0:1], to_store)

def test_vmap_of_kernel_with_input_output_aliases(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
Expand Down
Loading

0 comments on commit 9d9b69f

Please sign in to comment.