Skip to content

Commit

Permalink
Put abstract_mesh on every eqn so that we can preserve it during `eva…
Browse files Browse the repository at this point in the history
…l_jaxpr` and `check_jaxpr` roundtrip.

Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 17, 2024
1 parent e1f037b commit 473e2bf
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 43 deletions.
2 changes: 0 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,8 +2187,6 @@ def _infer_src_sharding(src, x) -> Sharding | None:
return src # pytype: disable=bad-return-type
if isinstance(x, array.ArrayImpl):
return x.sharding
if config.sharding_in_types.value and hasattr(x, 'sharding'):
return x.sharding
if isinstance(x, core.Tracer):
val = x.to_concrete_value()
if val is not None and isinstance(val, array.ArrayImpl):
Expand Down
40 changes: 29 additions & 11 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,12 @@ def __init__(self, compute_type: str | None, threefry_partitionable: bool,
xla_metadata=None):
self.compute_type = compute_type
self.threefry_partitionable = threefry_partitionable
self.cur_abstract_mesh = mesh_lib.get_abstract_mesh()
self.xla_metadata = xla_metadata
self._managers = [
(compute_on.extend_compute_type, self.compute_type),
(config.threefry_partitionable.__call__, self.threefry_partitionable),
(mesh_lib.set_abstract_mesh, self.cur_abstract_mesh),
(xla_metadata_lib.set_xla_metadata, self.xla_metadata),
]

Expand All @@ -292,6 +294,7 @@ def __repr__(self):
return (
f"JaxprEqnContext(compute_type={self.compute_type}, "
f"threefry_partitionable={self.threefry_partitionable}, "
f"cur_abstract_mesh={self.cur_abstract_mesh}, "
f"xla_metadata={self.xla_metadata})"
)

Expand Down Expand Up @@ -535,6 +538,17 @@ def write(v: Var, val: Any) -> None:
clean_up_dead_vars(eqn, env, lu)
return map(read, jaxpr.outvars)

def check_avals_context_mesh(avals, prim_name):
if config.sharding_in_types.value:
for a in avals:
cur_mesh = mesh_lib.get_abstract_mesh()
if a.sharding.mesh != cur_mesh:
raise ValueError(
f"For primitive {prim_name}, context mesh {cur_mesh} should match"
f" the aval mesh {a.sharding.mesh} for shape {a.str_short()}. This"
" error occurs at source: "
f" {source_info_util.summarize(source_info_util.current())}")


# -------------------- tracing --------------------

Expand Down Expand Up @@ -1622,7 +1636,10 @@ def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding # type: ignore

if sharding is not None:
assert len(sharding.spec) == ndim
if len(sharding.spec) != ndim:
raise ValueError(
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
return _maybe_modify_sharding(sharding)

context_mesh = mesh_lib.get_abstract_mesh()
Expand Down Expand Up @@ -2518,17 +2535,18 @@ def write(v: Var, a: AbstractValue) -> None:
in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes

# Compute the type of the primitive application.
if prim in custom_typechecks:
out_type, eqn_effects = custom_typechecks[prim](
ctx_factory, *in_atoms, **eqn.params)
elif prim.call_primitive:
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
with eqn.ctx.manager:
if prim in custom_typechecks:
out_type, eqn_effects = custom_typechecks[prim](
ctx_factory, *in_atoms, **eqn.params)
elif prim.call_primitive:
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
eqn.params)
elif prim.map_primitive:
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
eqn.params)
elif prim.map_primitive:
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
eqn.params)
else:
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)
else:
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)

# Check the computed effect type matches the eqn's annotation, and is
# included in the jaxpr's annotation.
Expand Down
14 changes: 0 additions & 14 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,6 @@ def _batched_device_put_impl(
device_put_p.def_impl(_batched_device_put_impl)

def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics):
if config.sharding_in_types.value:
return [x.update(sharding=s) for x, s in zip(xs, devices)]
return xs
device_put_p.def_abstract_eval(_device_put_abstract_eval)

Expand Down Expand Up @@ -566,12 +564,6 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
# TODO(yashkatariya): Maybe we should add the custom calls anyways if it's
# being used inside jit? Atleast for now, this preserves the old behavior.
if ctx.module_context.all_default_mem_kind:
if config.sharding_in_types.value:
return [
mlir.wrap_with_sharding_op(
ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto())
for x, a in zip(xs, ctx.avals_out)
]
return xs
def lower(x, device, aval, out_aval):
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
Expand All @@ -597,12 +589,6 @@ def lower(x, device, aval, out_aval):


def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
if config.sharding_in_types.value:
return [
mlir.wrap_with_sharding_op(
ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto())
for x, a in zip(xs, ctx.avals_out)
]
return xs
mlir.register_lowering(device_put_p, _common_device_put_lowering)

Expand Down
2 changes: 2 additions & 0 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,8 @@ def _abstract_to_concrete_mesh(abstract_mesh):

out = []
for s, a in zip(shardings, avals):
# Remove the `UnconstrainedSingleton` logic after UNCONSTRAINED is supported
# in out_shardings at top level jit.
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
if config.sharding_in_types.value:
core.check_avals_context_mesh(args, 'all_reduce')
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,
sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes))
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
weak_type = weak_type_rule(*avals, **kwargs)
least_specialized = type(max(avals, key=_get_array_abstraction_level))
if least_specialized is core.ShapedArray:
core.check_avals_context_mesh(avals, prim.name)
return core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
Expand All @@ -78,6 +79,7 @@ def standard_multi_result_abstract_eval(
if least_specialized is core.ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
core.check_avals_context_mesh(avals, prim.name)
out_shardings = (sharding_rule(*avals, **kwargs)
if config.sharding_in_types.value else
[None] * len(out_shapes))
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ def local_devices(self):
def abstract_mesh(self):
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)

def with_axis_types(self, new_axis_types) -> Mesh:
return Mesh(self.devices, self.axis_names, axis_types=new_axis_types)


EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))

Expand Down Expand Up @@ -396,8 +399,9 @@ def __eq__(self, other):
self._axis_types_tuple == other._axis_types_tuple)

def __repr__(self):
mesh_repr = ", ".join(f"'{n}': {v}" for n, v in self.shape_tuple)
atr = f", axis_types={self.axis_types}"
return f"AbstractMesh({self.shape_tuple}{atr})"
return f"AbstractMesh({mesh_repr}{atr})"

@property
def axis_names(self):
Expand Down Expand Up @@ -427,6 +431,9 @@ def _internal_device_list(self):
def empty(self):
return self.size == 0

def with_axis_types(self, new_axis_types) -> AbstractMesh:
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)

@functools.cached_property
def _are_all_axes_collective(self) -> bool:
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
Expand Down
82 changes: 75 additions & 7 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src import sharding
from jax._src.mesh import AbstractMesh
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
Expand All @@ -73,7 +73,7 @@
from jax._src.tree_util import (
tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves,
treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr,
PyTreeDef, none_leaf_registry as none_lr)
PyTreeDef, none_leaf_registry as none_lr, tree_map)
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, weakref_lru_cache,
Expand Down Expand Up @@ -1027,7 +1027,7 @@ def hashable_pytree(pytree):
def _create_sharding_for_array(mesh, x, name, api_name):
if x is None and (mesh is None or mesh.empty):
return UNSPECIFIED
if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)):
if isinstance(x, (AUTO, UnspecifiedValue, Sharding)):
return x
if mesh is None:
msg = ('jax.jit only supports `Sharding`s being passed to'
Expand Down Expand Up @@ -1339,7 +1339,7 @@ def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)):
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
out_shardings_flat = (orig_out_shardings,) * len(out_avals)
else:
out_shardings_flat = flatten_axis_resources(
Expand Down Expand Up @@ -1571,7 +1571,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
else:
resolved_in_shardings.append(arg_s)
else:
assert isinstance(arg_s, sharding.Sharding)
assert isinstance(arg_s, Sharding)
if dispatch.is_single_device_sharding(arg_s):
resolved_in_shardings.append(UNSPECIFIED)
else:
Expand Down Expand Up @@ -1903,7 +1903,7 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
core.custom_typechecks[pjit_p] = _pjit_typecheck


def _pjit_abstract_eval(*args, jaxpr, **_):
def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
return jaxpr.out_avals, jaxpr.effects
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)

Expand Down Expand Up @@ -2016,7 +2016,7 @@ def _pjit_batcher(axis_data, vals_in, dims_in,
batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule

def _pjit_batcher_for_sharding(
s: sharding.Sharding | UnspecifiedValue,
s: Sharding | UnspecifiedValue,
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
if isinstance(s, UnspecifiedValue):
return s
Expand Down Expand Up @@ -2673,6 +2673,74 @@ def _sharding_constraint_batcher(
batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher
batching.skippable_batchers[sharding_constraint_p] = lambda _: ()

# -------------------- sharding_cast ---------------------------

def _check_mesh_shape_same(src_sharding, dst_sharding, aval):
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
raise ValueError(
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
' match the mesh shape of the target sharding'
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')

def sharding_cast(xs, shardings):
if isinstance(shardings, NamedSharding):
return tree_map(lambda x: sharding_cast_p.bind(
x, src_sharding=x.sharding, dst_sharding=shardings), xs)

x_flat, treedef = tree_flatten(xs)
shardings_flat = flatten_axes("sharding_cast shardings", treedef, shardings)
out_flat = [sharding_cast_p.bind(x, src_sharding=x.sharding, dst_sharding=s)
for x, s in safe_zip(x_flat, shardings_flat)]
return tree_unflatten(treedef, out_flat)

sharding_cast_p = core.Primitive('sharding_cast')
def _sharding_cast_abstract_eval(aval, src_sharding, dst_sharding):
_check_mesh_shape_same(src_sharding, dst_sharding, aval)
return aval.update(sharding=dst_sharding)
sharding_cast_p.def_abstract_eval(_sharding_cast_abstract_eval)

def _sharding_cast_impl(x, src_sharding, dst_sharding):
aval = shaped_abstractify(x)
_check_mesh_shape_same(x.sharding, dst_sharding, aval)
new_mesh = x.sharding.mesh.with_axis_types(dst_sharding.mesh.axis_types)
concrete_dst_sharding = NamedSharding(new_mesh, dst_sharding.spec)
# TODO(yashkatariya): Replace this with `dispatch.apply_primitive(...)`
return api.jit(_identity_fn, out_shardings=concrete_dst_sharding)(x)
sharding_cast_p.def_impl(_sharding_cast_impl)

def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding):
return [sharding_cast_p.bind(ct, src_sharding=dst_sharding,
dst_sharding=src_sharding)]
ad.deflinear2(sharding_cast_p, _sharding_cast_transpose_rule)

def _sharding_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
aval, = ctx.avals_in
aval_out, = ctx.avals_out
proto = dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
mlir.register_lowering(sharding_cast_p, _sharding_cast_hlo_lowering)

# TODO(yashkatariya): Comment this in after vmap ShiT tests are added.
# def _sharding_cast_batcher(axis_data, vals_in, dims_in, src_sharding,
# dst_sharding):
# if axis_data.spmd_name is not None:
# used = {n for ns in dst_sharding.spec
# for n in (ns if isinstance(ns, tuple) else (ns,))}
# if set(axis_data.spmd_name) & used:
# raise ValueError(
# f'vmap spmd_axis_name {axis_data.spmd_name} cannot '
# f'appear in sharding_cast spec, but got spec {dst_sharding.spec}')
# x, = vals_in
# d, = dims_in

# val = None if axis_data.spmd_name is None else axis_data.spmd_name
# new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val))
# vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec)
# y = sharding_cast_p.bind(x, src_sharding=src_sharding,
# dst_sharding=vmapped_dst_sharding)
# return y, d
# batching.fancy_primitive_batchers[sharding_cast_p] = _sharding_cast_batcher
# batching.skippable_batchers[sharding_cast_p] = lambda _: ()

# -------------------- helpers --------------------

Expand Down
18 changes: 14 additions & 4 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,18 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):

@util.cache(max_size=128, trace_context_in_key=False)
def _check_axis_type_consistency(mesh, parsed_pspec):
if mesh.axis_types is None:
return
for p in parsed_pspec:
if p is not None:
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
raise ValueError(
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
if mesh_lib.AxisTypes.Auto not in mesh.axis_types and None in parsed_pspec:
raise ValueError(
f'PartitionSpec {parsed_pspec.get_partition_spec()} cannot contain'
' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh'
f' axis_types: {mesh.axis_types}')


def hashed_index(x) -> int:
Expand Down Expand Up @@ -271,11 +274,15 @@ def __init__(
self._parsed_pspec = preprocess(self.mesh, self.spec, _parsed_pspec)

def __repr__(self):
mesh_repr = ", ".join(f"'{k}': {v}" for k, v in self.mesh.shape.items())
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
ldi = ('' if self._logical_device_ids is None else
f', logical_device_ids={self._logical_device_ids}')
return f'NamedSharding(mesh=Mesh({mesh_repr}), spec={self.spec}{mem}{ldi})'
if isinstance(self.mesh, mesh_lib.AbstractMesh):
mesh_repr = f"{self.mesh}"
else:
nv_str = ", ".join(f"'{n}': {v}" for n, v in self.mesh.shape.items())
mesh_repr = f"Mesh({nv_str})"
return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})'

def __reduce__(self):
return (type(self), (self.mesh, self.spec),
Expand Down Expand Up @@ -381,6 +388,9 @@ def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
spec = PartitionSpec(*spec)
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)

def with_mesh(self, new_mesh: mesh_lib.Mesh) -> NamedSharding:
return NamedSharding(new_mesh, self.spec, memory_kind=self.memory_kind)

def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)

Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def test_primitive_coverage(self):
continue
if p.name == "sharding_constraint":
continue
if p.name == "sharding_cast":
continue
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if p.name == "optimization_barrier":
continue
Expand Down
Loading

0 comments on commit 473e2bf

Please sign in to comment.