Skip to content

Commit

Permalink
[pallas] Fix the handling of captured consts
Browse files Browse the repository at this point in the history
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).

The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.

I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
  • Loading branch information
gnecula committed Jul 22, 2024
1 parent b3469a6 commit b7105cc
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 32 deletions.
6 changes: 5 additions & 1 deletion jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class BlockMapping:
index_map_jaxpr: jax_core.ClosedJaxpr
indexing_mode: IndexingMode

def compute_start_indices(self, loop_idx, *args):
def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
)
Expand Down Expand Up @@ -344,6 +344,10 @@ def _convert_block_spec_to_block_mapping(
f"{len(aval.shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
)
if consts:
raise NotImplementedError(
f"Index map for {what}{tree_util.keystr(path)} captures constants: "
f"{consts}")
return BlockMapping(
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
)
Expand Down
10 changes: 6 additions & 4 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
self.mapped_dims = grid_mapping.mapped_dims
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch]
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
num_operands = (
len(self.jaxpr.invars)
- num_scalar_prefetch
Expand Down Expand Up @@ -411,9 +411,13 @@ def lower_jaxpr_to_module(
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
else:
invars = invars[grid_mapping.num_index_operands:]
# invars now = *consts, *ins, *outs
avals = tuple(v.aval for v in invars)
block_operand_shapes = (
*in_shapes[grid_mapping.num_index_operands :],
*[jax.ShapeDtypeStruct(v.aval.shape,
v.aval.dtype)
for v in invars[:grid_mapping.num_constant_operands]],
*in_shapes[grid_mapping.num_index_operands:],
*out_shapes,
)
assert len(block_operand_shapes) == len(grid_mapping.block_mappings)
Expand All @@ -425,8 +429,6 @@ def lower_jaxpr_to_module(
raise NotImplementedError(
"BlockSpecs are required on TPU when grid is specified"
)
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
# ANY operands don't support windowing and require empty window_params.
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
# We may not require windowing if our block_shape matches the original
Expand Down
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def pallas_call_tpu_lowering_rule(
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*args,
Expand All @@ -68,7 +69,10 @@ def pallas_call_lowering(
if debug:
print(jaxpr)
print(grid_mapping)

if grid_mapping.num_constant_operands:
raise NotImplementedError(
"captured consts not supported in the Mosaic GPU backend"
)
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
in_shapes,
Expand Down
61 changes: 50 additions & 11 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,13 @@ def _pallas_call_impl_interpret(
)
assert next(dynamic_grid_args_iter, None) is None
with grid_mapping.trace_env():
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
# args now contains: *consts, *inputs, *outputs
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
Expand Down Expand Up @@ -243,6 +244,10 @@ def _pallas_call_impl_interpret(
else:
# Base case is always one iteration when grid is ()
num_iterations = 1

# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
# i:int32 is the interation index
# loop_idx: tuple[int32] are the program ids for each grid axis
def cond(carry):
i, *_ = carry
return i < num_iterations
Expand All @@ -256,7 +261,7 @@ def body(carry):
carry, scratch = split_list(carry, [num_inout])
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
Expand All @@ -269,13 +274,14 @@ def body(carry):
len(blocks),
len(scratch_values),
)
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)

(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
Expand Down Expand Up @@ -604,13 +610,13 @@ def _maybe_squeeze_out_bdim(
# Ordinarily, adding support for scalar prefetch in vmap would involve
# modifying the block specs in a nontrivial way. However, if we are only
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
# and pretend we were never vmapping over them at all.
# and pretend we were never vmapped over them at all.
if all(
bdim is batching.not_mapped or arg.shape[bdim] == 1
for arg, bdim in zip(scalar_args, scalar_bdims)
):
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
scalar_bdims = [None] * len(scalar_args)
scalar_bdims = [batching.not_mapped] * len(scalar_args)
args = (*scalar_args, *args)
dims = (*scalar_bdims, *bdims)
else:
Expand Down Expand Up @@ -648,6 +654,7 @@ def _maybe_squeeze_out_bdim(
all_dims = list(dims) + [0] * len(out_shapes)

num_index_operands = grid_mapping.num_index_operands
num_constant_operands = grid_mapping.num_constant_operands
num_scratch_operands = grid_mapping.num_scratch_operands

# Only add a batch dimension for the avals that actually have a grid mapping.
Expand All @@ -661,11 +668,16 @@ def _maybe_squeeze_out_bdim(
block_mappings,
)

# TODO(necula): should fix in_shapes to include the consts
dims_no_consts = (
dims[:num_index_operands] +
dims[num_index_operands + num_constant_operands:]
)
batched_in_shapes = tuple(
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
tuple_insert(x.shape, dim, axis_size),
x.dtype)
for x, dim in zip(in_shapes, dims))
for x, dim in zip(in_shapes, dims_no_consts))
batched_out_shapes = tuple(
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
for x in out_shapes)
Expand Down Expand Up @@ -900,11 +912,37 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
if consts:
jaxpr = state_utils.hoist_consts_to_refs(jaxpr)
# Pad ``block_mappings`` to account for the hoisted constants.
# The constants will be right after the index operands and just before
# the real inputs and outputs.
jaxpr = state_utils.hoist_consts_to_refs(
jaxpr,
index=grid_mapping.num_index_operands,
make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None))
num_constant_operands = len(consts)
# TODO(necula): refactor grid_mapping to remove this code duplication
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
if grid_mapping.num_index_operands:
grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
const_block_mappings = []
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
grid_avals,
pallas_core.BlockSpec(None, None),
path=(tree_util.SequenceKey(c_idx),),
aval=jax_core.ShapedArray(c.shape, c.dtype),
in_tree=grid_tree,
grid=grid_mapping.grid,
mapped_dims=(),
what="consts",
)
const_block_mappings.append(const_block_mapping)

grid_mapping = grid_mapping.replace(
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
num_constant_operands=len(consts),
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
num_constant_operands=num_constant_operands,
)
return grid_mapping, jaxpr, consts, out_tree_thunk()

Expand Down Expand Up @@ -1105,8 +1143,9 @@ def wrapped(*args):
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
f"a different abstract value {out_aval}.")

index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
jaxpr=jaxpr, name=name,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _new_ir_context() -> ir.Context:

def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr,
in_shapes,
in_out_shapes,
grid_mapping: GridMapping,
name: str,
platform: str
Expand Down Expand Up @@ -301,6 +301,10 @@ def lower_jaxpr_to_triton_module(
functools.partial(_eval_index_map, ctx, program_ids),
grid_mapping.block_mappings,
)
consts_shapes = [
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype)
for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands]
]
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
Expand All @@ -310,7 +314,7 @@ def lower_jaxpr_to_triton_module(
if block_mapping is not None
else None
for shape_dtype, block_mapping, start_idx in zip(
(*in_shapes, *[()] * grid_mapping.num_constant_operands),
(*consts_shapes, *in_out_shapes),
grid_mapping.block_mappings,
start_indices,
)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*in_nodes,
Expand All @@ -72,6 +73,10 @@ def pallas_call_lowering(
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.pop("num_warps", 4)
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
Expand Down
13 changes: 11 additions & 2 deletions jax/_src/state/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Utilities for tracing stateful functions."""

from typing import Callable

from jax._src.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
Expand All @@ -24,13 +26,20 @@
zip, unsafe_zip = safe_zip, zip


def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
def hoist_consts_to_refs(
jaxpr: core.Jaxpr,
*,
index: int = 0,
make_abstract_ref: Callable[[core.AbstractValue], AbstractRef] = lambda aval: AbstractRef(aval)
) -> core.Jaxpr:
"""Hoists the constants in the given jaxpr into invars.
Args:
jaxpr: The jaxpr.
index: The index where the invars for the constants should be inserted.
By default, the new invars are inserted *before* any existing invars.
make_abstract_ref: a callable to construct an AbstractRef, or subtype
thereof, from a constant AbstractValue.
Returns:
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
Expand All @@ -42,7 +51,7 @@ def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
isinstance(var.aval, AbstractRef) for var in jaxpr.constvars
]
const_avals = [
var.aval if is_ref else AbstractRef(var.aval)
var.aval if is_ref else make_abstract_ref(var.aval)
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
]
in_avals = [var.aval for var in jaxpr.invars]
Expand Down
29 changes: 21 additions & 8 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def pallas_call(self, *args, **kwargs):

class PallasCallTest(PallasBaseTest):


def test_add_one(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
Expand Down Expand Up @@ -468,19 +467,22 @@ def kernel(o_ref):

def test_hoisted_consts(self):
# See https://github.com/google/jax/issues/21557.
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
x = jnp.zeros(32)
indices = jnp.arange(4).reshape((2, 2))
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid=(2,),
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
)
def kernel(src, dst):
dst[indices] = src[indices]
dst[0:1] = to_store

jax.block_until_ready(kernel(x))
res = kernel(x)
self.assertAllClose(res[0:1], to_store)

def test_vector_slicing(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
Expand Down Expand Up @@ -744,6 +746,17 @@ def test_pallas_call_index_map_wrong_number_of_results(self):
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
f(a)

def test_pallas_call_index_map_captures_consts(self):
a = np.arange(256, dtype=np.int32)
index_map_result = np.array([0], dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
with self.assertRaisesRegex(
NotImplementedError,
"Index map for input\\[0\\] captures constants"):
f(a)

def test_pallas_call_out_specs_mismatch_shape(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
Expand Down
Loading

0 comments on commit b7105cc

Please sign in to comment.