Skip to content

Commit

Permalink
Merge pull request jax-ml#22550 from gnecula:pallas_consts
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654686876
  • Loading branch information
jax authors committed Jul 22, 2024
2 parents b3469a6 + b7105cc commit 433f66a
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 32 deletions.
6 changes: 5 additions & 1 deletion jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class BlockMapping:
index_map_jaxpr: jax_core.ClosedJaxpr
indexing_mode: IndexingMode

def compute_start_indices(self, loop_idx, *args):
def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
)
Expand Down Expand Up @@ -344,6 +344,10 @@ def _convert_block_spec_to_block_mapping(
f"{len(aval.shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
)
if consts:
raise NotImplementedError(
f"Index map for {what}{tree_util.keystr(path)} captures constants: "
f"{consts}")
return BlockMapping(
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
)
Expand Down
10 changes: 6 additions & 4 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
self.mapped_dims = grid_mapping.mapped_dims
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch]
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
num_operands = (
len(self.jaxpr.invars)
- num_scalar_prefetch
Expand Down Expand Up @@ -411,9 +411,13 @@ def lower_jaxpr_to_module(
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
else:
invars = invars[grid_mapping.num_index_operands:]
# invars now = *consts, *ins, *outs
avals = tuple(v.aval for v in invars)
block_operand_shapes = (
*in_shapes[grid_mapping.num_index_operands :],
*[jax.ShapeDtypeStruct(v.aval.shape,
v.aval.dtype)
for v in invars[:grid_mapping.num_constant_operands]],
*in_shapes[grid_mapping.num_index_operands:],
*out_shapes,
)
assert len(block_operand_shapes) == len(grid_mapping.block_mappings)
Expand All @@ -425,8 +429,6 @@ def lower_jaxpr_to_module(
raise NotImplementedError(
"BlockSpecs are required on TPU when grid is specified"
)
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
# ANY operands don't support windowing and require empty window_params.
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
# We may not require windowing if our block_shape matches the original
Expand Down
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def pallas_call_tpu_lowering_rule(
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*args,
Expand All @@ -68,7 +69,10 @@ def pallas_call_lowering(
if debug:
print(jaxpr)
print(grid_mapping)

if grid_mapping.num_constant_operands:
raise NotImplementedError(
"captured consts not supported in the Mosaic GPU backend"
)
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
in_shapes,
Expand Down
61 changes: 50 additions & 11 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,13 @@ def _pallas_call_impl_interpret(
)
assert next(dynamic_grid_args_iter, None) is None
with grid_mapping.trace_env():
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
# args now contains: *consts, *inputs, *outputs
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
Expand Down Expand Up @@ -243,6 +244,10 @@ def _pallas_call_impl_interpret(
else:
# Base case is always one iteration when grid is ()
num_iterations = 1

# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
# i:int32 is the interation index
# loop_idx: tuple[int32] are the program ids for each grid axis
def cond(carry):
i, *_ = carry
return i < num_iterations
Expand All @@ -256,7 +261,7 @@ def body(carry):
carry, scratch = split_list(carry, [num_inout])
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
Expand All @@ -269,13 +274,14 @@ def body(carry):
len(blocks),
len(scratch_values),
)
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)

(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
Expand Down Expand Up @@ -604,13 +610,13 @@ def _maybe_squeeze_out_bdim(
# Ordinarily, adding support for scalar prefetch in vmap would involve
# modifying the block specs in a nontrivial way. However, if we are only
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
# and pretend we were never vmapping over them at all.
# and pretend we were never vmapped over them at all.
if all(
bdim is batching.not_mapped or arg.shape[bdim] == 1
for arg, bdim in zip(scalar_args, scalar_bdims)
):
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
scalar_bdims = [None] * len(scalar_args)
scalar_bdims = [batching.not_mapped] * len(scalar_args)
args = (*scalar_args, *args)
dims = (*scalar_bdims, *bdims)
else:
Expand Down Expand Up @@ -648,6 +654,7 @@ def _maybe_squeeze_out_bdim(
all_dims = list(dims) + [0] * len(out_shapes)

num_index_operands = grid_mapping.num_index_operands
num_constant_operands = grid_mapping.num_constant_operands
num_scratch_operands = grid_mapping.num_scratch_operands

# Only add a batch dimension for the avals that actually have a grid mapping.
Expand All @@ -661,11 +668,16 @@ def _maybe_squeeze_out_bdim(
block_mappings,
)

# TODO(necula): should fix in_shapes to include the consts
dims_no_consts = (
dims[:num_index_operands] +
dims[num_index_operands + num_constant_operands:]
)
batched_in_shapes = tuple(
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
tuple_insert(x.shape, dim, axis_size),
x.dtype)
for x, dim in zip(in_shapes, dims))
for x, dim in zip(in_shapes, dims_no_consts))
batched_out_shapes = tuple(
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
for x in out_shapes)
Expand Down Expand Up @@ -900,11 +912,37 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
if consts:
jaxpr = state_utils.hoist_consts_to_refs(jaxpr)
# Pad ``block_mappings`` to account for the hoisted constants.
# The constants will be right after the index operands and just before
# the real inputs and outputs.
jaxpr = state_utils.hoist_consts_to_refs(
jaxpr,
index=grid_mapping.num_index_operands,
make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None))
num_constant_operands = len(consts)
# TODO(necula): refactor grid_mapping to remove this code duplication
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
if grid_mapping.num_index_operands:
grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
const_block_mappings = []
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
grid_avals,
pallas_core.BlockSpec(None, None),
path=(tree_util.SequenceKey(c_idx),),
aval=jax_core.ShapedArray(c.shape, c.dtype),
in_tree=grid_tree,
grid=grid_mapping.grid,
mapped_dims=(),
what="consts",
)
const_block_mappings.append(const_block_mapping)

grid_mapping = grid_mapping.replace(
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
num_constant_operands=len(consts),
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
num_constant_operands=num_constant_operands,
)
return grid_mapping, jaxpr, consts, out_tree_thunk()

Expand Down Expand Up @@ -1105,8 +1143,9 @@ def wrapped(*args):
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
f"a different abstract value {out_aval}.")

index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
jaxpr=jaxpr, name=name,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _new_ir_context() -> ir.Context:

def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr,
in_shapes,
in_out_shapes,
grid_mapping: GridMapping,
name: str,
platform: str
Expand Down Expand Up @@ -301,6 +301,10 @@ def lower_jaxpr_to_triton_module(
functools.partial(_eval_index_map, ctx, program_ids),
grid_mapping.block_mappings,
)
consts_shapes = [
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype)
for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands]
]
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
Expand All @@ -310,7 +314,7 @@ def lower_jaxpr_to_triton_module(
if block_mapping is not None
else None
for shape_dtype, block_mapping, start_idx in zip(
(*in_shapes, *[()] * grid_mapping.num_constant_operands),
(*consts_shapes, *in_out_shapes),
grid_mapping.block_mappings,
start_indices,
)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*in_nodes,
Expand All @@ -72,6 +73,10 @@ def pallas_call_lowering(
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.pop("num_warps", 4)
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
Expand Down
13 changes: 11 additions & 2 deletions jax/_src/state/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Utilities for tracing stateful functions."""

from typing import Callable

from jax._src.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
Expand All @@ -24,13 +26,20 @@
zip, unsafe_zip = safe_zip, zip


def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
def hoist_consts_to_refs(
jaxpr: core.Jaxpr,
*,
index: int = 0,
make_abstract_ref: Callable[[core.AbstractValue], AbstractRef] = lambda aval: AbstractRef(aval)
) -> core.Jaxpr:
"""Hoists the constants in the given jaxpr into invars.
Args:
jaxpr: The jaxpr.
index: The index where the invars for the constants should be inserted.
By default, the new invars are inserted *before* any existing invars.
make_abstract_ref: a callable to construct an AbstractRef, or subtype
thereof, from a constant AbstractValue.
Returns:
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
Expand All @@ -42,7 +51,7 @@ def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
isinstance(var.aval, AbstractRef) for var in jaxpr.constvars
]
const_avals = [
var.aval if is_ref else AbstractRef(var.aval)
var.aval if is_ref else make_abstract_ref(var.aval)
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
]
in_avals = [var.aval for var in jaxpr.invars]
Expand Down
29 changes: 21 additions & 8 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def pallas_call(self, *args, **kwargs):

class PallasCallTest(PallasBaseTest):


def test_add_one(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
Expand Down Expand Up @@ -468,19 +467,22 @@ def kernel(o_ref):

def test_hoisted_consts(self):
# See https://github.com/google/jax/issues/21557.
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
x = jnp.zeros(32)
indices = jnp.arange(4).reshape((2, 2))
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid=(2,),
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
)
def kernel(src, dst):
dst[indices] = src[indices]
dst[0:1] = to_store

jax.block_until_ready(kernel(x))
res = kernel(x)
self.assertAllClose(res[0:1], to_store)

def test_vector_slicing(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
Expand Down Expand Up @@ -744,6 +746,17 @@ def test_pallas_call_index_map_wrong_number_of_results(self):
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
f(a)

def test_pallas_call_index_map_captures_consts(self):
a = np.arange(256, dtype=np.int32)
index_map_result = np.array([0], dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
with self.assertRaisesRegex(
NotImplementedError,
"Index map for input\\[0\\] captures constants"):
f(a)

def test_pallas_call_out_specs_mismatch_shape(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
Expand Down
Loading

0 comments on commit 433f66a

Please sign in to comment.