Skip to content

Commit

Permalink
[pallas] More simplification of grid mapping and calling convention
Browse files Browse the repository at this point in the history
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
  • Loading branch information
gnecula committed Jul 26, 2024
1 parent 8ed94bc commit 40154f9
Show file tree
Hide file tree
Showing 14 changed files with 284 additions and 186 deletions.
3 changes: 2 additions & 1 deletion docs/jax.experimental.pallas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Classes
:toctree: _autosummary

BlockSpec
GridSpec
Slice

Functions
Expand All @@ -34,4 +35,4 @@ Functions
atomic_or
atomic_xchg

debug_print
debug_print
7 changes: 7 additions & 0 deletions docs/pallas/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ Remember to align the itemized text with the first line of an item within a list
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
be passed *before* `index_map`. The old argument order is deprecated and
will be removed in a future release.
* {class}`jax.experimental.pallas.GridSpec` does not have the `in_specs_tree`,
and the `out_specs_tree` fields and the `in_specs` and `out_specs` tree now
store the values as pytrees of BlockSpec. Previously, `in_specs` and
`out_specs` were flattened ({jax-issue}`#22552`).
* The method `compute_index` of {class}`jax.experimental.pallas.GridSpec` has
been removed because it is private. Simmilarly, the `get_grid_mapping` and
`unzip_dynamic_bounds` have been removed from `BlockSpec` ({jax-issue}`#22593`).
* Fixed the interpreter mode to work with BlockSpec that involve padding
({jax-issue}`#22275`).
Padding in interpreter mode will be with NaN, to help debug out-of-bounds
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,7 +1704,7 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
f: a Python value that represents a dimension.
d: a Python value that represents a dimension.
Returns:
A canonical dimension value.
Expand Down
244 changes: 192 additions & 52 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""Module for pallas-core functionality."""
from __future__ import annotations

from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
import contextlib
import copy
import dataclasses
import enum
import functools
import itertools
import threading
from typing import Any, Hashable, Union
import warnings
Expand Down Expand Up @@ -55,6 +56,10 @@ def __repr__(self):
Grid = Union[NamedGrid, TupleGrid]
StaticGrid = tuple[int, ...]
GridMappingGrid = tuple[int | DynamicGridDim, ...]

# Pytrees of jax.ShapeDtypeStruct
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]

split_list = util.split_list

map, unsafe_map = util.safe_map, map
Expand Down Expand Up @@ -202,14 +207,13 @@ def __repr__(self):
IndexingMode = Union[Blocked, Unblocked]


@dataclasses.dataclass(unsafe_hash=True)
@dataclasses.dataclass
class BlockSpec:
"""Specifies how an array should be sliced for each iteration of a kernel.
See :ref:`pallas_blockspec` for more details.
This object contains the parameters passed through the API.
An internal canonicalized version is in BlockMapping.
"""
# An internal canonicalized version is in BlockMapping.
block_shape: tuple[int | None, ...] | None = None
index_map: Callable[..., Any] | None = None
memory_space: Any | None = dataclasses.field(kw_only=True, default=None)
Expand Down Expand Up @@ -242,22 +246,25 @@ def __init__(
self.memory_space = memory_space
self.indexing_mode = indexing_mode

def compute_index(self, *args):
assert self.index_map is not None
out = self.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
return out

def compute_index(bs: BlockSpec, *args):
assert bs.index_map is not None
out = bs.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
return out


class NoBlockSpec:
pass
def __repr__(self):
return "NoBlockSpec"
no_block_spec = NoBlockSpec()


# A PyTree of BlockSpec | NoBlockSpec.
# BlockSpecTree = Sequence[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
BlockSpecTree = Any


@dataclasses.dataclass(frozen=True)
class BlockMapping:
"""An internal canonicalized version of BlockSpec.
Expand Down Expand Up @@ -311,6 +318,14 @@ def compute_start_indices_interpret(self, loop_idx, *args):
else:
raise RuntimeError(f"Unknown indexing mode: {self.indexing_mode}")

def has_non_trivial_window(self):
for b, s in zip(self.block_shape, self.array_shape_dtype.shape):
if b != s and not (b is mapped and s == 1):
return True
for atom in self.index_map_jaxpr.jaxpr.outvars:
if not (isinstance(atom, jax_core.Literal) and atom.val == 0):
return True
return False

@contextlib.contextmanager
def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
Expand Down Expand Up @@ -432,11 +447,6 @@ def slice_index_ops(self):
"""Returns a slice object to select the index operands to a kernel."""
return slice(0, self.num_index_operands)

@property
def slice_block_ops(self):
"""Returns a slice to select all but the index operands to a kernel."""
return slice(self.num_index_operands, None)

@property
def slice_scratch_ops(self):
"""Returns a slice object to select the scratch operands to a kernel."""
Expand All @@ -445,29 +455,30 @@ def slice_scratch_ops(self):
else:
return slice(0, 0)

# TODO(necula): this is used to recover the old `in_shapes`, but it probably
# is not needed anymore, with some cleanup.
@property
def in_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
"""The shapes of *index, *consts, *inputs."""
index_shapes = [jax.ShapeDtypeStruct(ia.inner_aval.shape,
index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape,
ia.inner_aval.dtype)
for ia in self.index_map_avals[len(self.grid):]]
consts_inputs_shapes = [
for ia in self.index_map_avals[len(self.grid):])
consts_inputs_shapes = (
bm.array_shape_dtype
for bm in self.block_mappings[
:self.num_constant_operands + self.num_inputs]]
return tuple(index_shapes + consts_inputs_shapes)
:self.num_constant_operands + self.num_inputs])
return itertools.chain(index_shapes, consts_inputs_shapes)

@property
def block_mappings_output(self) -> Iterable[BlockMapping]:
return itertools.islice(
self.block_mappings,
self.num_constant_operands + self.num_inputs,
self.num_constant_operands + self.num_inputs + self.num_outputs)

# TODO(necula): this is used to recover the old `out_shapes`, but it probably
# is not needed anymore, with some cleanup.
@property
def out_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
return tuple(
bm.array_shape_dtype
for bm in self.block_mappings[
self.num_constant_operands + self.num_inputs:
self.num_constant_operands + self.num_inputs + self.num_outputs])
bm.array_shape_dtype for bm in self.block_mappings_output)


def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
if isinstance(dim, jax.Array):
Expand All @@ -490,9 +501,9 @@ def _convert_block_spec_to_block_mapping(
if block_spec is no_block_spec:
block_spec = BlockSpec(None, None)
if block_spec.index_map is None:
compute_index = lambda *args: (0,) * len(array_aval.shape)
index_map_func = lambda *args: (0,) * len(array_aval.shape)
else:
compute_index = block_spec.compute_index
index_map_func = functools.partial(compute_index, block_spec)
if block_spec.block_shape is None:
block_shape = array_aval.shape
else:
Expand All @@ -512,7 +523,7 @@ def _convert_block_spec_to_block_mapping(
"dynamically-shaped blocks. "
f"{origin} has block_shape: {block_aval.shape}")

flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index),
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(index_map_func),
index_map_tree)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
Expand Down Expand Up @@ -549,16 +560,19 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
shape = tuple(s for s in block_shape if s is not None)
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))

@dataclasses.dataclass(init=False, unsafe_hash=True)
@dataclasses.dataclass(init=False)
class GridSpec:
"""Encodes the parameters of the grid, as given through the API.
"""Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`.
An internal sanitized version is in GridMapping.
See the documentation for :func:`jax.experimental.pallas.pallas_call`,
and also :ref:`pallas_grids_and_blockspecs` for a more detailed
description of the parameters.
"""
# A canonicalized internal version is in GridMapping.
grid: TupleGrid
grid_names: tuple[Hashable, ...] | None
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
in_specs: BlockSpecTree
out_specs: BlockSpecTree

def __init__(
self,
Expand Down Expand Up @@ -722,18 +736,144 @@ def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
def _make_scalar_ref_aval(self, aval):
assert False # Not needed in GridSpec

def unzip_dynamic_grid_bounds(
self,
) -> tuple[GridSpec, tuple[Any, ...]]:
static_grid = tuple(
d if isinstance(d, int) else None for d in self.grid
)
dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int))
# We can't use dataclasses.replace, because our fields are incompatible
# with __init__'s signature.
static_self = copy.copy(self)
static_self.grid = static_grid # type: ignore
return static_self, dynamic_bounds
def unzip_dynamic_grid_bounds(
grid_spec: GridSpec) -> tuple[GridSpec, tuple[Any, ...]]:
static_grid = tuple(
d if isinstance(d, int) else None for d in grid_spec.grid
)
dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int))
# We can't use dataclasses.replace, because our fields are incompatible
# with __init__'s signature.
static_self = copy.copy(grid_spec)
static_self.grid = static_grid # type: ignore
return static_self, dynamic_bounds


def get_grid_mapping(
grid_spec: GridSpec,
in_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
out_avals: Sequence[jax_core.AbstractValue],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
) -> tuple[tuple[jax_core.AbstractValue, ...],
GridMapping]:
assert all(i is None or isinstance(i, int) for i in grid_spec.grid)
grid_mapping_grid = tuple(
dynamic_grid_dim if d is None else d for d in grid_spec.grid
)
# The inputs for the index maps
index_map_avals = (
(jax_core.ShapedArray((), jnp.dtype("int32")),) * len(grid_spec.grid))
index_map_tree = tree_util.tree_structure((index_map_avals, {}))

num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0)
if num_scalar_prefetch:
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
scalar_avals, unflat_in_avals = split_list(
all_avals, [num_scalar_prefetch])
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
num_flat_scalar_prefetch = len(flat_scalar_avals)
scalar_ref_avals = [
grid_spec._make_scalar_ref_aval(aval)
for aval in flat_scalar_avals]
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
scalar_tree, scalar_ref_avals)
in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
index_map_tree = tree_util.tree_structure(((*index_map_avals,
*scalar_avals), {}))
index_map_avals = (*index_map_avals, *scalar_ref_avals)
del scalar_ref_avals, flat_scalar_avals, scalar_tree
del scalar_avals, unflat_in_avals, all_avals
else:
num_flat_scalar_prefetch = 0
jaxpr_scalar_ref_avals = ()

scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ())
if scratch_shapes:
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
scratch_shapes)
flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes)
num_flat_scratch_operands = len(flat_scratch_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
del flat_scratch_avals, flat_scratch_shapes, scratch_tree
else:
num_flat_scratch_operands = 0
jaxpr_scratch_avals = ()

if grid_spec.in_specs is not no_block_spec:
flat_in_specs, in_specs_tree = tree_util.tree_flatten(grid_spec.in_specs)
if in_specs_tree != in_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree,
"inputs", in_tree))
else:
flat_in_specs = [no_block_spec] * len(in_avals)

in_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="inputs",
),
flat_in_specs,
in_paths[num_flat_scalar_prefetch:],
in_avals,
)

if grid_spec.out_specs is not no_block_spec:
flat_out_specs, out_specs_tree = tree_util.tree_flatten(grid_spec.out_specs)
if out_specs_tree != out_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree,
"`out_shape`", out_tree))
else:
flat_out_specs = [no_block_spec] * len(out_avals)

out_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="outputs",
),
flat_out_specs,
out_paths,
out_avals,
)
grid_mapping = GridMapping(
grid=grid_mapping_grid, # type: ignore[arg-type]
grid_names=grid_spec.grid_names,
block_mappings=(*in_block_mappings, *out_block_mappings),
index_map_avals=index_map_avals, # type: ignore[arg-type]
index_map_tree=index_map_tree,
vmapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
num_constant_operands=0, # Fixed up later
num_inputs=len(flat_in_specs),
num_outputs=len(flat_out_specs),
num_scratch_operands=num_flat_scratch_operands,
)
grid_mapping.check_invariants()
in_ref_avals = [bm.block_aval for bm in in_block_mappings]
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals,
*jaxpr_in_ref_avals)
out_ref_avals = [bm.block_aval for bm in out_block_mappings]
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals,
*jaxpr_scratch_avals), grid_mapping


def pytreedef_mismatch_err_msg(
Expand Down
Loading

0 comments on commit 40154f9

Please sign in to comment.