Skip to content

Commit

Permalink
[pallas] Improve the error localization
Browse files Browse the repository at this point in the history
  * Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
  • Loading branch information
gnecula committed Jul 30, 2024
1 parent cc21245 commit 20379e8
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 122 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ see {ref}`pallas-changelog`.

<!--
Remember to align the itemized text with the first line of an item within a list.
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.32
Expand Down
12 changes: 11 additions & 1 deletion docs/pallas/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
Remember to align the itemized text with the first line of an item within a list.
-->

## Released with JAX 0.4.31
## Released with jax 0.4.32

* Changes

* Deprecations

* New functionality:
* Improved error messages for mistakes in the signature of the index map functions,
to include the name and source location of the index map.

## Released with jax 0.4.31 (July 29, 2024)

* Changes
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2301,8 +2301,7 @@ class DebugInfo(NamedTuple):
def debug_info(fn: Callable, in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
has_kwargs: bool, traced_for: str) -> DebugInfo:
try: sig = inspect.signature(fn)
except (ValueError, TypeError): sig = None
sig = api_util.fun_signature(fn)
src_info = fun_sourceinfo(fn)
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
traced_for)
Expand Down
73 changes: 38 additions & 35 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,6 @@ def __init__(
self.indexing_mode = indexing_mode


def compute_index(bs: BlockSpec, *args):
assert bs.index_map is not None
out = bs.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
return out


class NoBlockSpec:
def __repr__(self):
return "NoBlockSpec"
Expand All @@ -274,6 +266,7 @@ class BlockMapping:
block_shape: tuple[Mapped | int, ...]
block_aval: AbstractMemoryRef # The block ref aval
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_src_info: str # function_name at filename:linenumber
indexing_mode: IndexingMode
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: str # The origin, e.g. input[2]["field"]
Expand Down Expand Up @@ -501,32 +494,31 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool:

def _convert_block_spec_to_block_mapping(
block_spec: BlockSpec,
path: tree_util.KeyPath,
origin: str, # for error messages, e.g., x["field"]
array_aval: jax_core.ShapedArray,
*,
# Inputs for the index_map
index_map_avals: Sequence[jax_core.AbstractValue],
index_map_tree: tree_util.PyTreeDef,
grid: GridMappingGrid,
mapped_dims: tuple[int, ...],
what: str, # Used to localize error messages, e.g., {what}{path}
) -> BlockMapping:
origin = f"{what}{tree_util.keystr(path)}"
if block_spec is no_block_spec:
block_spec = BlockSpec(None, None)
if block_spec.index_map is None:
index_map_func = lambda *args: (0,) * len(array_aval.shape)
else:
index_map_func = functools.partial(compute_index, block_spec)
index_map_func = block_spec.index_map
if block_spec.block_shape is None:
block_shape = array_aval.shape
else:
block_shape = block_spec.block_shape
if len(array_aval.shape) != len(block_shape):
raise ValueError(
f"Block shape for {origin} (= {block_shape}) "
f"must have the same number of dimensions as the array shape {array_aval.shape}"
)
"must have the same number of dimensions as the "
f"array shape {array_aval.shape}.")

unmapped_block_shape = tuple(s for s in block_shape if s is not None)
block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape),
block_spec.memory_space)
Expand All @@ -535,31 +527,44 @@ def _convert_block_spec_to_block_mapping(
raise ValueError(
"shape polymorphism for Pallas does not support "
"dynamically-shaped blocks. "
f"{origin} has block_shape: {block_aval.shape}")
f"Block spec for {origin} has block_shape: {block_aval.shape}")

flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(index_map_func),
index_map_tree)
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func), index_map_tree)
debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk,
False, "pallas_call index_map")
index_map_src_info = debug.func_src_info or "<unknown>"
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
index_map_avals)
index_map_avals,
debug_info=debug)
mapped_block_shape = tuple(
mapped if s is None else s for s in block_shape)
if len(out_avals) != len(mapped_block_shape):
mapped if s is None else s for s in block_shape)
if len(out_avals) != len(block_shape):
raise ValueError(
# TODO(necula): show the name and location of the index map function
f"Index map for {origin} must return "
f"{len(block_aval.shape)} values to match block shape {mapped_block_shape}. "
f"Currently returning {len(out_avals)} values."
)
f"Index map function {index_map_src_info} for "
f"{origin} must return "
f"{len(block_shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values.")
for i, ov in enumerate(out_avals):
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must return integer scalars. Output[{i}] has type "
f"{ov}.")


if consts:
raise NotImplementedError(
# TODO(necula): show the name and location of the index map function
f"Index map for {origin} captures constants: "
f"{consts}")
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must not capture constants: {consts}")


mapping = BlockMapping(
block_shape=mapped_block_shape,
block_aval=block_aval,
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=block_spec.indexing_mode,
array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype),
origin=origin,
Expand Down Expand Up @@ -632,10 +637,10 @@ def get_grid_mapping(
grid_spec: GridSpec,
in_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
in_origins: Sequence[str],
out_avals: Sequence[jax_core.AbstractValue],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
out_origins: Sequence[str],
) -> tuple[tuple[jax_core.AbstractValue, ...],
GridMapping]:
assert all(i is None or isinstance(i, int) for i in grid_spec.grid)
Expand Down Expand Up @@ -700,10 +705,9 @@ def get_grid_mapping(
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="inputs",
),
flat_in_specs,
in_paths[num_flat_scalar_prefetch:],
in_origins[num_flat_scalar_prefetch:],
in_avals,
)

Expand All @@ -723,10 +727,9 @@ def get_grid_mapping(
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="outputs",
),
flat_out_specs,
out_paths,
out_origins,
out_avals,
)
grid_mapping = GridMapping(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def block_shape(self):

@property
def compute_index(self):
return lambda *args: pallas_core.compute_index(self.spec, *args)
return self.spec.index_map

@property
def memory_space(self):
Expand Down
60 changes: 46 additions & 14 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from typing import Any

import jax
from jax import api_util
from jax import lax
from jax._src import ad_util
from jax._src import api_util
from jax._src import checkify
from jax._src import config
from jax._src import core as jax_core
Expand Down Expand Up @@ -797,15 +797,15 @@ def _ensure_2d_error_shape(arg):
# for the new error inputs and outputs.
error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals)
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
error_origins = tuple(f"errrors[{tree_util.keystr(p)}" for p in error_paths)
error_block_mappings = map(
partial(
pallas_core._convert_block_spec_to_block_mapping,
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=grid_mapping.vmapped_dims,
what="error"),
error_block_specs, error_paths, shaped_err_avals)
mapped_dims=grid_mapping.vmapped_dims),
error_block_specs, error_origins, shaped_err_avals)
input_block_mappings, output_block_mappings = split_list(
grid_mapping.block_mappings, [num_kernel_inputs,])
grid_mapping_with_error = grid_mapping.replace(
Expand Down Expand Up @@ -837,6 +837,7 @@ def _ensure_2d_error_shape(arg):

@weakref_lru_cache
def _trace_kernel_to_jaxpr(fun: Callable,
fun_src_info: str, # <func> at <file>:<line>
grid_mapping: GridMapping,
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
kernel_in_tree: tree_util.PyTreeDef,
Expand All @@ -863,13 +864,12 @@ def _trace_kernel_to_jaxpr(fun: Callable,
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
pallas_core.BlockSpec(None, None),
path=(tree_util.SequenceKey(c_idx),),
origin=f"consts[{c_idx}]",
array_aval=jax_core.ShapedArray(c.shape, c.dtype),
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=(),
what="consts",
)
const_block_mappings.append(const_block_mapping)

Expand All @@ -880,8 +880,9 @@ def _trace_kernel_to_jaxpr(fun: Callable,
kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
"The kernel function in a pallas_call should return None. "
f"Found a PyTree: {kernel_out_tree}")
f"The kernel function {fun_src_info} in a "
f"pallas_call should return None. "
f"It returns a PyTree: {kernel_out_tree}")
return grid_mapping, jaxpr, consts

def _extract_function_name(f: Callable, name: str | None) -> str:
Expand Down Expand Up @@ -979,7 +980,7 @@ def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params):


def pallas_call(
f: Callable[..., None],
kernel: Callable[..., None],
out_shape: Any,
*,
grid_spec: GridSpec | None = None,
Expand All @@ -997,7 +998,7 @@ def pallas_call(
See `Pallas Quickstart <https://jax.readthedocs.io/en/latest/pallas/quickstart.html>`_.
Args:
f: the kernel function, that receives a Ref for each input and output.
kernel: the kernel function, that receives a Ref for each input and output.
The shape of the Refs are given by the ``block_shape`` in the
corresponding ``in_specs`` and ``out_specs``.
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
Expand Down Expand Up @@ -1034,7 +1035,7 @@ def pallas_call(
invoke the Pallas kernel.
"""
name = _extract_function_name(f, name)
name = _extract_function_name(kernel, name)
if compiler_params is None:
compiler_params = {}

Expand Down Expand Up @@ -1072,14 +1073,31 @@ def wrapped(*args):
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)

kernel_fun_sig = api_util.fun_signature(kernel)
arg_names = None
kernel_src_info = "<unknown>"
if kernel_fun_sig:
kernel_debug_info = api_util.debug_info(
"pallas_call kernel",
api_util.fun_sourceinfo(kernel),
kernel_fun_sig,
[1] * len(kernel_fun_sig.parameters), {}, (), ())
if kernel_debug_info:
arg_names = kernel_debug_info.arg_names
kernel_src_info = kernel_debug_info.func_src_info
in_origins = tuple(in_path_to_input_origin(p, arg_names)
for p in in_paths)
out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths)
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
kernel_avals, grid_mapping = pallas_core.get_grid_mapping(
grid_spec,
flat_in_avals, in_tree, in_paths,
flat_out_avals, out_tree, out_paths)
flat_in_avals, in_tree, in_origins,
flat_out_avals, out_tree, out_origins)
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
f, grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
kernel, kernel_src_info,
grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
interpret=interpret)
for i_idx, o_idx in input_output_aliases.items():
if i_idx not in range(len(flat_in_avals)):
Expand Down Expand Up @@ -1116,6 +1134,20 @@ def wrapped(*args):
return wrapped


def in_path_to_input_origin(in_path: tree_util.KeyPath,
arg_names: tuple[str, ...] | None) -> str:
"""Converts `args[k]<rest>` into `arg_k_name<rest>`."""
if arg_names is None:
return f"args{tree_util.keystr(in_path)}"
if len(in_path) == 0:
return "args"
arg_idx, *rest_path = in_path
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names):
return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path))
else:
return f"args{tree_util.keystr(tuple(in_path))}"


# We import the TPU backend at the top level because it defines flags. Note that
# we can only do that at the bottom of this file, beacuse it also depends on
# this module already being initialized.
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any:
"""Call the function with allocated references.
Args:
f: The function that generatest the jaxpr.
f: The function that generates the jaxpr.
*types: The types of the function's positional arguments.
**kw_types: The types of the function's keyword arguments.
"""
Expand Down
Loading

0 comments on commit 20379e8

Please sign in to comment.