From 6ea0f631d4493f823dd16c276495884f8da0c64e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 8 Jul 2024 03:11:46 -0700 Subject: [PATCH] Moved the implementation of ``custom_partitioning`` into jax/_src This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in #21371. PiperOrigin-RevId: 650183480 --- jax/BUILD | 1 + jax/_src/custom_partitioning.py | 548 ++++++++++++++++++++++++ jax/experimental/custom_partitioning.py | 543 +---------------------- 3 files changed, 555 insertions(+), 537 deletions(-) create mode 100644 jax/_src/custom_partitioning.py diff --git a/jax/BUILD b/jax/BUILD index 7fcf5cfa9884..1e8e19dd5162 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -209,6 +209,7 @@ py_library_providing_imports_info( "_src/checkify.py", "_src/custom_batching.py", "_src/custom_derivatives.py", + "_src/custom_partitioning.py", "_src/custom_transpose.py", "_src/debugging.py", "_src/dispatch.py", diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py new file mode 100644 index 000000000000..c5dfd130fe76 --- /dev/null +++ b/jax/_src/custom_partitioning.py @@ -0,0 +1,548 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import partial +import inspect +import weakref + +import jax +from jax import tree_util +from jax._src import api_util +from jax._src import core +from jax._src import custom_api_util +from jax._src import dispatch +from jax._src import linear_util as lu +from jax._src import mesh as mesh_lib +from jax._src import sharding_impls +from jax._src import xla_bridge as xb +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.lib import xla_client as xc +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +from jax.errors import UnexpectedTracerError + + +def _resolve_kwargs(fun, args, kwargs): + ba = inspect.signature(fun).bind(*args, **kwargs) + ba.apply_defaults() + if ba.kwargs: + raise TypeError("keyword arguments could not be resolved to positions") + else: + return ba.args + + +class _ShardingCallbackInfo: + + def __init__(self, propagate_user_sharding, partition, to_mesh_pspec_sharding, + in_tree, out_tree, infer_sharding_from_operands, module_context, mesh, + static_args): + self.propagate_user_sharding = propagate_user_sharding + self.partition = partition + self.to_mesh_pspec_sharding = to_mesh_pspec_sharding + self.in_tree = in_tree + self.out_tree = out_tree + self.infer_sharding_from_operands = infer_sharding_from_operands + self.module_context = module_context + self.mesh = mesh + self.static_args = static_args + + def unflatten_arg_shape(self, s, sharding): + return _to_jax_sharded_shape( + s, self.to_mesh_pspec_sharding(sharding, len(s.dimensions())) + ) + + def unflatten_arg_shapes(self, arg_shapes, arg_shardings): + return self.in_tree.unflatten( + [ + self.unflatten_arg_shape(s, sharding) + for s, sharding in zip(arg_shapes, arg_shardings) + ] + ) + + +_sharding_callbacks = weakref.WeakValueDictionary() # type: ignore + +_CUSTOM_PARTITIONING_CALL_NAME = "CustomSPMDPartitioning" + + +def _to_jax_shape(s): + return core.ShapedArray(s.dimensions(), s.numpy_dtype()) + + +def _to_jax_sharded_shape(s, sharding): + return jax.ShapeDtypeStruct( + s.dimensions(), s.numpy_dtype(), sharding=sharding + ) + + +def _pack_result_sharding(shape, result_shardings): + if shape.is_tuple(): + return xc.HloSharding.tuple_sharding(shape, result_shardings) + else: + return result_shardings[0] + + +def _flatten_sharding(tree, shardings, shapes): + return [ + _to_hlo_sharding(sharding, len(shape.dimensions())) + for sharding, shape in zip( + tree.flatten_up_to(shardings), shapes + ) + ] + + +def _custom_partitioning_propagate_user_sharding(user_sharding, shape, + backend_string): + info = _sharding_callbacks[backend_string] + if info.propagate_user_sharding is None: + return user_sharding + if shape.is_tuple(): + user_shapes = shape.tuple_shapes() + user_shardings = user_sharding.tuple_elements() + else: + user_shapes = (shape,) + user_shardings = (user_sharding,) + user_shape = info.out_tree.unflatten( + [ + info.unflatten_arg_shape(s, sharding) + for s, sharding in zip(user_shapes, user_shardings) + ] + ) + result_sharding = info.propagate_user_sharding( + *info.static_args, info.mesh, user_shape + ) + result_shardings = _flatten_sharding( + info.out_tree, result_sharding, user_shapes) + return _pack_result_sharding(shape, result_shardings) + + +def _to_hlo_sharding(sharding, num_dimensions): + if not isinstance(sharding, jax.sharding.Sharding): + raise ValueError("Custom Partitioning rules must return Sharding.") + return sharding._to_xla_hlo_sharding(num_dimensions) + + +def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, + result_sharding, backend_string): + info = _sharding_callbacks[backend_string] + if result_shape.is_tuple(): + result_shapes = result_shape.tuple_shapes() + result_shardings = result_sharding.tuple_elements() + else: + result_shapes = (result_shape,) + result_shardings = (result_sharding,) + mesh, lower_fn, result_sharding, arg_shardings = info.partition( + *info.static_args, + info.mesh, + info.unflatten_arg_shapes(arg_shapes, arg_shardings), + info.out_tree.unflatten( + [ + info.unflatten_arg_shape(s, sharding) + for s, sharding in zip(result_shapes, result_shardings) + ] + ), + ) + module_context = info.module_context + + result_shardings = _flatten_sharding( + info.out_tree, result_sharding, result_shapes) + arg_shardings = _flatten_sharding(info.in_tree, arg_shardings, arg_shapes) + tiled_args = [ + _to_jax_shape(sharding.tile(s)) + for sharding, s in zip(arg_shardings, arg_shapes) + ] + tiled_results = [ + _to_jax_shape(sharding.tile(s)) + for sharding, s in zip(result_shardings, result_shapes) + ] + closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( + *tiled_args + ) + if closed_jaxpr.out_avals != tiled_results: + raise ValueError( + "Mismatch in result shapes. %s vs %s" + % (repr(closed_jaxpr.out_avals), repr(tiled_results)) + ) + axis_context = sharding_impls.SPMDAxisContext(mesh) + with core.extend_axis_env_nd(mesh.shape.items()): + module = mlir.build_mlir_module_helper( + closed_jaxpr, + name="tmp_xla_computation", + platforms=module_context.platforms, + backend_or_name=module_context.backend_or_name, + axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)), + ) + result_sharding = _pack_result_sharding(result_shape, result_shardings) + return mlir.module_to_bytecode(module), arg_shardings, result_sharding + + +def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings, + result_shape, + backend_string): + info = _sharding_callbacks[backend_string] + if result_shape.is_tuple(): + result_shapes = result_shape.tuple_shapes() + else: + result_shapes = (result_shape,) + result_sharding = info.infer_sharding_from_operands( + *info.static_args, + info.mesh, + info.unflatten_arg_shapes(arg_shapes, arg_shardings), + info.out_tree.unflatten([_to_jax_shape(s) for s in result_shapes]), + ) + result_shardings = _flatten_sharding( + info.out_tree, result_sharding, result_shapes) + return _pack_result_sharding(result_shape, result_shardings) + + +custom_partitioning_p = core.Primitive("custom_partitioning") +custom_partitioning_p.multiple_results = True +dispatch.prim_requires_devices_during_lowering.add(custom_partitioning_p) + + +def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree, + propagate_user_sharding, partition, + infer_sharding_from_operands, + decode_shardings, + static_args): + del in_tree, out_tree, propagate_user_sharding, partition + del infer_sharding_from_operands, decode_shardings, static_args + return call.out_avals + + +def _custom_partitioning_impl(*args, call, in_tree, out_tree, + propagate_user_sharding, + partition, infer_sharding_from_operands, + decode_shardings, static_args): + del in_tree, out_tree, propagate_user_sharding, partition + del infer_sharding_from_operands, decode_shardings, static_args + return core.jaxpr_as_fun(call)(*args) + + +custom_partitioning_p.def_abstract_eval(_custom_partitioning_abstract_eval) +custom_partitioning_p.def_impl(_custom_partitioning_impl) + + +def _check_for_tracers(x): + if any(isinstance(leaf, core.Tracer) for leaf in tree_util.tree_leaves(x)): + raise UnexpectedTracerError( + "Found a JAX Tracer object passed as an argument to a" + "custom_partitioning function in a position indicated as static by" + "static_argnums. " + ) + + +@custom_api_util.register_custom_decorator_type +class custom_partitioning: + """Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules. + + .. code-block:: python + + @custom_partitioning + def f(*args): + return ... + + def propagate_user_sharding(mesh, user_shape): + '''Update the sharding of the op from a user's shape.sharding.''' + user_sharding = jax.tree.map(lambda x: x.sharding, user_shape) + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(*args): + ... builds computation on per-device shapes ... + result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + # result_sharding and arg_shardings may optionally be modified and the + # partitioner will insert collectives to reshape. + return mesh, lower_fn, result_sharding, arg_shardings + + def infer_sharding_from_operands(mesh, arg_shapes, shape): + '''Compute the result sharding from the sharding of the operands.''' + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + + + f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands) + + The args to ``def_partition`` are as follows: + + * ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag) + and returns a suggestion for a new `NamedSharding`. The default + implementation is just to return the suggested sharding. + * ``partition``: Callable which takes the SPMD suggested partition shapes and + partition specs and returns the mesh, a per-shard lowering function, and the final + input and output sharding specs (the SPMD partitioner will repartition the + inputs to match). The mesh is returned to allow configuring axis_names for + collectives when no mesh is provided. + * ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding`` + from the ``NamedSharding`` chosen for each argument. + * ``decode_shardings``: When set to True, convert input ``GSPMDSharding``s to + ``NamedSharding`` if possible. This may not be possible if the user does not + provide a contextual mesh. + + Positional arguments can be specified as static using static_argnums. JAX uses + :code:`inspect.signature(fun)` to resolve these positional arguments. + + Examples: + + As an example, assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes + the discrete Fourier transform of an N-dimensional input along the last dimension, and is batched + along the first N-1 dimensions. + By default, however, it will ignore the sharding of the input and gather the input on all devices. + However, since ``jax.numpy.fft.fft`` is batched along the first N-1 dimensions, + this is unnecessary. We will create a new ``my_fft`` op that, instead, does not alter the sharding + along the first `N-1` dimensions, and only gathers the input along the last dimension if needed. + + .. code-block:: python + + import jax + from jax.sharding import NamedSharding + from jax.experimental.custom_partitioning import custom_partitioning + from jax.experimental.pjit import pjit + from jax.sharding import PartitionSpec as P + from jax.sharding import Mesh + from jax.numpy.fft import fft + import regex as re + import numpy as np + + # Pattern to detect all-gather or dynamic-slice in the generated HLO + _PATTERN = '(dynamic-slice|all-gather)' + + # For an N-D input, keeps sharding along the first N-1 dimensions + # but replicate along the last dimension + def supported_sharding(sharding, shape): + rank = len(shape.shape) + max_shared_dims = min(len(sharding.spec), rank-1) + names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims)) + return NamedSharding(sharding.mesh, P(*names)) + + def partition(mesh, arg_shapes, result_shape): + result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + return mesh, fft, \ + supported_sharding(arg_shardings[0], arg_shapes[0]), \ + (supported_sharding(arg_shardings[0], arg_shapes[0]),) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + return supported_sharding(arg_shardings[0], arg_shapes[0]) + + @custom_partitioning + def my_fft(x): + return fft(x) + + my_fft.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition) + + Now create a 2D array sharded along the first axis, pass it through ``my_fft`` + and notice how it is still sharded as expected, and identical to the output + of ``fft``. However, inspecting the HLO + (using ``lower(x).compile().runtime_executable().hlo_modules()``) reveals that + ``my_fft`` does not create any all-gather or dynamic-slice, while ``fft`` does. + + .. code-block:: + + with Mesh(np.array(jax.devices()), ('x',)): + x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64) + y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) + pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) + pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) + print(pjit_my_fft(y)) + print(pjit_fft(y)) + # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array + assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) + # dynamic-slice or all-gather are present in the HLO for fft + assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None) + + .. code-block:: + + # my_fft + [[-38.840824 +0.j -40.649452 +11.845365j + ... + -1.6937828 +0.8402481j 15.999859 -4.0156755j]] + + # jax.numpy.fft.fft + [[-38.840824 +0.j -40.649452 +11.845365j + ... + -1.6937828 +0.8402481j 15.999859 -4.0156755j]] + + Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays. + However, in this case, the HLO of ``my_fft`` does show a dynamic-slice, since the last dimension + is the dimension along which FFTs are calculated and needs to be replicated on all devices before + the computation can be done. + + .. code-block:: + + with Mesh(np.array(jax.devices()), ('x',)): + x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64) + y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) + pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) + pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) + print(pjit_my_fft(y)) + print(pjit_fft(y)) + # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array + assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) + # dynamic-slice or all-gather are present in the HLO for fft + assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None) + + .. code-block:: + + # my_fft + [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j + ... 1422.4502 +7271.4297j -405.84033 -3042.983j + -3012.4963 -4287.6343j] + + # jax.numpy.fft.fft + [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j + ... 1422.4502 +7271.4297j -405.84033 -3042.983j + -3012.4963 -4287.6343j] + + """ + + def __init__(self, fun, static_argnums=()): + self.fun = fun + self.partition = None + self.static_argnums = static_argnums + self.propagate_user_sharding = None + self.infer_sharding_from_operands = None + + __getattr__ = custom_api_util.forward_attr + + def def_partition(self, partition, infer_sharding_from_operands, + propagate_user_sharding=None, decode_shardings=True): + self.partition = partition + self.propagate_user_sharding = propagate_user_sharding + self.infer_sharding_from_operands = infer_sharding_from_operands + self.decode_shardings = decode_shardings + return partition + + def __call__(self, *args, **kwargs): + args = _resolve_kwargs(self.fun, args, kwargs) + if self.static_argnums: + static_argnums = set(self.static_argnums) + args = tuple(x if i in static_argnums else x for i, x in enumerate(args)) + dyn_argnums = [i for i in range(len(args)) if i not in static_argnums] + f_, dyn_args = api_util.argnums_partial( + lu.wrap_init(self.fun), + dyn_argnums, + args, + require_static_args_hashable=False, + ) + static_args = [args[i] for i in self.static_argnums] + _check_for_tracers(static_args) + else: + static_args = [] + f_, dyn_args = lu.wrap_init(self.fun), args + args_flat, in_tree = tree_util.tree_flatten(dyn_args) + flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree) + in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + debug = pe.debug_info(self.fun, in_tree, out_tree, False, + "custom_partitioning") + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + assert not len(consts) + closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) + out_flat = custom_partitioning_p.bind( + *consts, + *args_flat, + call=closed_call, + partition=self.partition, + propagate_user_sharding=self.propagate_user_sharding, + infer_sharding_from_operands=self.infer_sharding_from_operands, + decode_shardings=self.decode_shardings, + in_tree=in_tree, + out_tree=out_tree(), + static_args=static_args + ) + return tree_util.tree_unflatten(out_tree(), out_flat) + + +def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, + call, in_tree, out_tree, + propagate_user_sharding, partition, + infer_sharding_from_operands, + decode_shardings, + static_args): + mesh = mesh_lib.thread_resources.env.physical_mesh + axis_context = ctx.module_context.axis_context + if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and + set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): + return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) + + if isinstance(axis_context, sharding_impls.ShardingContext): + devices = axis_context.device_assignment + if devices is None: + raise AssertionError( + 'Please file a bug at https://github.com/google/jax/issues') + elif isinstance(axis_context, sharding_impls.SPMDAxisContext): + devices = axis_context.mesh._flat_devices_tuple + else: + devices = None + + if not devices or len(devices) == 1: + return mlir.lower_fun( + core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) + + def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): + if hlo_sharding is None: + return hlo_sharding + if mesh.empty or not decode_shardings: + assert devices is not None + return sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices) + pspec = sharding_impls.parse_flatten_op_sharding( + hlo_sharding, mesh)[0].get_partition_spec() + pspec = jax.sharding.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) + return jax.sharding.NamedSharding(mesh, pspec) + + sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, + partition, to_mesh_pspec_sharding, in_tree, out_tree, + infer_sharding_from_operands, ctx.module_context, mesh, static_args) + key = str(id(sharding_callback_info)) + _sharding_callbacks[bytes(key, 'utf8')] = sharding_callback_info + # We need to make sure `sharding_callback_info` is still alive when the SPMD + # partitioner runs so we keep it alive by attaching it to the executable. + ctx.module_context.add_keepalive(sharding_callback_info) + + result_types = [mlir.aval_to_ir_type(s) for s in call.out_avals] + out = hlo.CustomCallOp( + result_types, + list(values), + call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME), + has_side_effect=ir.BoolAttr.get(False), + api_version=mlir.i32_attr(2), + called_computations=ir.ArrayAttr.get([]), + backend_config=ir.StringAttr.get(key), + operand_layouts=None, + result_layouts=None) + return out.results + +mlir.register_lowering(custom_partitioning_p, + _custom_partitioning_lowering_rule) + +xc.register_custom_call_partitioner( + _CUSTOM_PARTITIONING_CALL_NAME, + _custom_partitioning_propagate_user_sharding, + _custom_partitioning_partition, + _custom_partitioning_infer_sharding_from_operands, True) # type: ignore +xb.register_plugin_callbacks( + partial( + xc.register_custom_call_partitioner, + name=_CUSTOM_PARTITIONING_CALL_NAME, + prop_user_sharding=_custom_partitioning_propagate_user_sharding, + partition=_custom_partitioning_partition, + infer_sharding_from_operands=_custom_partitioning_infer_sharding_from_operands, + can_side_effecting_have_replicated_sharding=True, + ) +) diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index aa138fe88993..3c7bfac40061 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -1,4 +1,4 @@ -# Copyright 2018 The JAX Authors. +# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,541 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 -from functools import partial -import inspect -from typing import Optional -import weakref - -import jax -from jax._src import core -from jax import tree_util -from jax._src import linear_util as lu -from jax._src import sharding_impls -from jax.errors import UnexpectedTracerError -from jax._src import mesh as mesh_lib -from jax._src import dispatch -from jax._src.lib.mlir.dialects import hlo -from jax._src.lib.mlir import ir -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax._src.sharding_impls import _op_sharding_to_pos_sharding -from jax._src import custom_api_util -from jax._src import xla_bridge as xb -from jax._src.lib import xla_client as xc -from jax._src.api_util import flatten_fun_nokwargs, argnums_partial - - -def _resolve_kwargs(fun, args, kwargs): - ba = inspect.signature(fun).bind(*args, **kwargs) - ba.apply_defaults() - if ba.kwargs: - raise TypeError("keyword arguments could not be resolved to positions") - else: - return ba.args - - -class _ShardingCallbackInfo: - - def __init__(self, propagate_user_sharding, partition, to_mesh_pspec_sharding, - in_tree, out_tree, infer_sharding_from_operands, module_context, mesh, - static_args): - self.propagate_user_sharding = propagate_user_sharding - self.partition = partition - self.to_mesh_pspec_sharding = to_mesh_pspec_sharding - self.in_tree = in_tree - self.out_tree = out_tree - self.infer_sharding_from_operands = infer_sharding_from_operands - self.module_context = module_context - self.mesh = mesh - self.static_args = static_args - - def unflatten_arg_shape(self, s, sharding): - return _to_jax_sharded_shape( - s, self.to_mesh_pspec_sharding(sharding, len(s.dimensions())) - ) - - def unflatten_arg_shapes(self, arg_shapes, arg_shardings): - return self.in_tree.unflatten( - [ - self.unflatten_arg_shape(s, sharding) - for s, sharding in zip(arg_shapes, arg_shardings) - ] - ) - - -_sharding_callbacks = weakref.WeakValueDictionary() # type: ignore - -_CUSTOM_PARTITIONING_CALL_NAME = "CustomSPMDPartitioning" - - -def _to_jax_shape(s): - return core.ShapedArray(s.dimensions(), s.numpy_dtype()) - - -def _to_jax_sharded_shape(s, sharding): - return jax.ShapeDtypeStruct( - s.dimensions(), s.numpy_dtype(), sharding=sharding - ) - - -def _pack_result_sharding(shape, result_shardings): - if shape.is_tuple(): - return xc.HloSharding.tuple_sharding(shape, result_shardings) - else: - return result_shardings[0] - - -def _flatten_sharding(tree, shardings, shapes): - return [ - _to_hlo_sharding(sharding, len(shape.dimensions())) - for sharding, shape in zip( - tree.flatten_up_to(shardings), shapes - ) - ] - - -def _custom_partitioning_propagate_user_sharding(user_sharding, shape, - backend_string): - info = _sharding_callbacks[backend_string] - if info.propagate_user_sharding is None: - return user_sharding - if shape.is_tuple(): - user_shapes = shape.tuple_shapes() - user_shardings = user_sharding.tuple_elements() - else: - user_shapes = (shape,) - user_shardings = (user_sharding,) - user_shape = info.out_tree.unflatten( - [ - info.unflatten_arg_shape(s, sharding) - for s, sharding in zip(user_shapes, user_shardings) - ] - ) - result_sharding = info.propagate_user_sharding( - *info.static_args, info.mesh, user_shape - ) - result_shardings = _flatten_sharding( - info.out_tree, result_sharding, user_shapes) - return _pack_result_sharding(shape, result_shardings) - - -def _to_hlo_sharding(sharding, num_dimensions): - if not isinstance(sharding, jax.sharding.Sharding): - raise ValueError("Custom Partitioning rules must return Sharding.") - return sharding._to_xla_hlo_sharding(num_dimensions) - - -def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, - result_sharding, backend_string): - info = _sharding_callbacks[backend_string] - if result_shape.is_tuple(): - result_shapes = result_shape.tuple_shapes() - result_shardings = result_sharding.tuple_elements() - else: - result_shapes = (result_shape,) - result_shardings = (result_sharding,) - mesh, lower_fn, result_sharding, arg_shardings = info.partition( - *info.static_args, - info.mesh, - info.unflatten_arg_shapes(arg_shapes, arg_shardings), - info.out_tree.unflatten( - [ - info.unflatten_arg_shape(s, sharding) - for s, sharding in zip(result_shapes, result_shardings) - ] - ), - ) - module_context = info.module_context - - result_shardings = _flatten_sharding( - info.out_tree, result_sharding, result_shapes) - arg_shardings = _flatten_sharding(info.in_tree, arg_shardings, arg_shapes) - tiled_args = [ - _to_jax_shape(sharding.tile(s)) - for sharding, s in zip(arg_shardings, arg_shapes) - ] - tiled_results = [ - _to_jax_shape(sharding.tile(s)) - for sharding, s in zip(result_shardings, result_shapes) - ] - closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( - *tiled_args - ) - if closed_jaxpr.out_avals != tiled_results: - raise ValueError( - "Mismatch in result shapes. %s vs %s" - % (repr(closed_jaxpr.out_avals), repr(tiled_results)) - ) - axis_context = sharding_impls.SPMDAxisContext(mesh) - with core.extend_axis_env_nd(mesh.shape.items()): - module = mlir.build_mlir_module_helper( - closed_jaxpr, - name="tmp_xla_computation", - platforms=module_context.platforms, - backend_or_name=module_context.backend_or_name, - axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)), - ) - result_sharding = _pack_result_sharding(result_shape, result_shardings) - return mlir.module_to_bytecode(module), arg_shardings, result_sharding - - -def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings, - result_shape, - backend_string): - info = _sharding_callbacks[backend_string] - if result_shape.is_tuple(): - result_shapes = result_shape.tuple_shapes() - else: - result_shapes = (result_shape,) - result_sharding = info.infer_sharding_from_operands( - *info.static_args, - info.mesh, - info.unflatten_arg_shapes(arg_shapes, arg_shardings), - info.out_tree.unflatten([_to_jax_shape(s) for s in result_shapes]), - ) - result_shardings = _flatten_sharding( - info.out_tree, result_sharding, result_shapes) - return _pack_result_sharding(result_shape, result_shardings) - - -custom_partitioning_p = core.Primitive("custom_partitioning") -custom_partitioning_p.multiple_results = True -dispatch.prim_requires_devices_during_lowering.add(custom_partitioning_p) - - -def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree, - propagate_user_sharding, partition, - infer_sharding_from_operands, - decode_shardings, - static_args): - del in_tree, out_tree, propagate_user_sharding, partition - del infer_sharding_from_operands, decode_shardings, static_args - return call.out_avals - - -def _custom_partitioning_impl(*args, call, in_tree, out_tree, - propagate_user_sharding, - partition, infer_sharding_from_operands, - decode_shardings, static_args): - del in_tree, out_tree, propagate_user_sharding, partition - del infer_sharding_from_operands, decode_shardings, static_args - return core.jaxpr_as_fun(call)(*args) - - -custom_partitioning_p.def_abstract_eval(_custom_partitioning_abstract_eval) -custom_partitioning_p.def_impl(_custom_partitioning_impl) - - -def _check_for_tracers(x): - for leaf in tree_util.tree_leaves(x): - if isinstance(x, core.Tracer): - msg = ( - "Found a JAX Tracer object passed as an argument to a" - "custom_partitioning function in a position indicated as static by" - "static_argnums. " - ) - raise UnexpectedTracerError(msg) - - -@custom_api_util.register_custom_decorator_type -class custom_partitioning: - """Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules. - - .. code-block:: python - - @custom_partitioning - def f(*args): - return ... - - def propagate_user_sharding(mesh, user_shape): - '''Update the sharding of the op from a user's shape.sharding.''' - user_sharding = jax.tree.map(lambda x: x.sharding, user_shape) - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(*args): - ... builds computation on per-device shapes ... - result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) - arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) - # result_sharding and arg_shardings may optionally be modified and the - # partitioner will insert collectives to reshape. - return mesh, lower_fn, result_sharding, arg_shardings - - def infer_sharding_from_operands(mesh, arg_shapes, shape): - '''Compute the result sharding from the sharding of the operands.''' - arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) - - - f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands) - - The args to ``def_partition`` are as follows: - - * ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag) - and returns a suggestion for a new `NamedSharding`. The default - implementation is just to return the suggested sharding. - * ``partition``: Callable which takes the SPMD suggested partition shapes and - partition specs and returns the mesh, a per-shard lowering function, and the final - input and output sharding specs (the SPMD partitioner will repartition the - inputs to match). The mesh is returned to allow configuring axis_names for - collectives when no mesh is provided. - * ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding`` - from the ``NamedSharding`` chosen for each argument. - * ``decode_shardings``: When set to True, convert input ``GSPMDSharding``s to - ``NamedSharding`` if possible. This may not be possible if the user does not - provide a contextual mesh. - - Positional arguments can be specified as static using static_argnums. JAX uses - :code:`inspect.signature(fun)` to resolve these positional arguments. - - Examples: - - As an example, assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes - the discrete Fourier transform of an N-dimensional input along the last dimension, and is batched - along the first N-1 dimensions. - By default, however, it will ignore the sharding of the input and gather the input on all devices. - However, since ``jax.numpy.fft.fft`` is batched along the first N-1 dimensions, - this is unnecessary. We will create a new ``my_fft`` op that, instead, does not alter the sharding - along the first `N-1` dimensions, and only gathers the input along the last dimension if needed. - - .. code-block:: python - - import jax - from jax.sharding import NamedSharding - from jax.experimental.custom_partitioning import custom_partitioning - from jax.experimental.pjit import pjit - from jax.sharding import PartitionSpec as P - from jax.sharding import Mesh - from jax.numpy.fft import fft - import regex as re - import numpy as np - - # Pattern to detect all-gather or dynamic-slice in the generated HLO - _PATTERN = '(dynamic-slice|all-gather)' - - # For an N-D input, keeps sharding along the first N-1 dimensions - # but replicate along the last dimension - def supported_sharding(sharding, shape): - rank = len(shape.shape) - max_shared_dims = min(len(sharding.spec), rank-1) - names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims)) - return NamedSharding(sharding.mesh, P(*names)) - - def partition(mesh, arg_shapes, result_shape): - result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) - arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) - return mesh, fft, \ - supported_sharding(arg_shardings[0], arg_shapes[0]), \ - (supported_sharding(arg_shardings[0], arg_shapes[0]),) - - def infer_sharding_from_operands(mesh, arg_shapes, result_shape): - arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) - return supported_sharding(arg_shardings[0], arg_shapes[0]) - - @custom_partitioning - def my_fft(x): - return fft(x) - - my_fft.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition) - - Now create a 2D array sharded along the first axis, pass it through ``my_fft`` - and notice how it is still sharded as expected, and identical to the output - of ``fft``. However, inspecting the HLO - (using ``lower(x).compile().runtime_executable().hlo_modules()``) reveals that - ``my_fft`` does not create any all-gather or dynamic-slice, while ``fft`` does. - - .. code-block:: - - with Mesh(np.array(jax.devices()), ('x',)): - x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64) - y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) - pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) - pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) - print(pjit_my_fft(y)) - print(pjit_fft(y)) - # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array - assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) - # dynamic-slice or all-gather are present in the HLO for fft - assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None) - - .. code-block:: - - # my_fft - [[-38.840824 +0.j -40.649452 +11.845365j - ... - -1.6937828 +0.8402481j 15.999859 -4.0156755j]] - - # jax.numpy.fft.fft - [[-38.840824 +0.j -40.649452 +11.845365j - ... - -1.6937828 +0.8402481j 15.999859 -4.0156755j]] - - Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays. - However, in this case, the HLO of ``my_fft`` does show a dynamic-slice, since the last dimension - is the dimension along which FFTs are calculated and needs to be replicated on all devices before - the computation can be done. - - .. code-block:: - - with Mesh(np.array(jax.devices()), ('x',)): - x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64) - y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) - pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) - pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) - print(pjit_my_fft(y)) - print(pjit_fft(y)) - # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array - assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) - # dynamic-slice or all-gather are present in the HLO for fft - assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None) - - .. code-block:: - - # my_fft - [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j - ... 1422.4502 +7271.4297j -405.84033 -3042.983j - -3012.4963 -4287.6343j] - - # jax.numpy.fft.fft - [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j - ... 1422.4502 +7271.4297j -405.84033 -3042.983j - -3012.4963 -4287.6343j] - - """ - - def __init__(self, fun, static_argnums=()): - self.fun = fun - self.partition = None - self.static_argnums = static_argnums - self.propagate_user_sharding = None - self.infer_sharding_from_operands = None - - __getattr__ = custom_api_util.forward_attr - - def def_partition(self, partition, infer_sharding_from_operands, - propagate_user_sharding=None, decode_shardings=True): - self.partition = partition - self.propagate_user_sharding = propagate_user_sharding - self.infer_sharding_from_operands = infer_sharding_from_operands - self.decode_shardings = decode_shardings - return partition - - def __call__(self, *args, **kwargs): - args = _resolve_kwargs(self.fun, args, kwargs) - if self.static_argnums: - static_argnums = set(self.static_argnums) - args = tuple(x if i in static_argnums else x for i, x in enumerate(args)) - dyn_argnums = [i for i in range(len(args)) if i not in static_argnums] - f_, dyn_args = argnums_partial( - lu.wrap_init(self.fun), - dyn_argnums, - args, - require_static_args_hashable=False, - ) - static_args = [args[i] for i in self.static_argnums] - _check_for_tracers(static_args) - else: - static_args = [] - f_, dyn_args = lu.wrap_init(self.fun), args - args_flat, in_tree = tree_util.tree_flatten(dyn_args) - flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] - debug = pe.debug_info(self.fun, in_tree, out_tree, False, - "custom_partitioning") - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) - assert not len(consts) - closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) - out_flat = custom_partitioning_p.bind( - *consts, - *args_flat, - call=closed_call, - partition=self.partition, - propagate_user_sharding=self.propagate_user_sharding, - infer_sharding_from_operands=self.infer_sharding_from_operands, - decode_shardings=self.decode_shardings, - in_tree=in_tree, - out_tree=out_tree(), - static_args=static_args - ) - return tree_util.tree_unflatten(out_tree(), out_flat) - - -def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, - call, in_tree, out_tree, - propagate_user_sharding, partition, - infer_sharding_from_operands, - decode_shardings, - static_args): - mesh = mesh_lib.thread_resources.env.physical_mesh - axis_context = ctx.module_context.axis_context - if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and - set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): - return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) - - if isinstance(axis_context, sharding_impls.ShardingContext): - devices = axis_context.device_assignment - if devices is None: - raise AssertionError( - 'Please file a bug at https://github.com/google/jax/issues') - elif isinstance(axis_context, sharding_impls.SPMDAxisContext): - devices = axis_context.mesh._flat_devices_tuple - else: - devices = None - - if not devices or len(devices) == 1: - return mlir.lower_fun( - core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) - - def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): - if hlo_sharding is None: - return hlo_sharding - if mesh.empty or not decode_shardings: - assert devices is not None - return _op_sharding_to_pos_sharding(hlo_sharding, devices) - pspec = sharding_impls.parse_flatten_op_sharding( - hlo_sharding, mesh)[0].get_partition_spec() - pspec = jax.sharding.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) - return jax.sharding.NamedSharding(mesh, pspec) - - sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, - partition, to_mesh_pspec_sharding, in_tree, out_tree, - infer_sharding_from_operands, ctx.module_context, mesh, static_args) - key = str(id(sharding_callback_info)) - _sharding_callbacks[bytes(key, 'utf8')] = sharding_callback_info - # We need to make sure `sharding_callback_info` is still alive when the SPMD - # partitioner runs so we keep it alive by attaching it to the executable. - ctx.module_context.add_keepalive(sharding_callback_info) - - result_types = [mlir.aval_to_ir_type(s) for s in call.out_avals] - out = hlo.CustomCallOp( - result_types, - list(values), - call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME), - has_side_effect=ir.BoolAttr.get(False), - api_version=mlir.i32_attr(2), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(key), - operand_layouts=None, - result_layouts=None) - return out.results - -mlir.register_lowering(custom_partitioning_p, - _custom_partitioning_lowering_rule) - -xc.register_custom_call_partitioner( - _CUSTOM_PARTITIONING_CALL_NAME, - _custom_partitioning_propagate_user_sharding, - _custom_partitioning_partition, - _custom_partitioning_infer_sharding_from_operands, True) # type: ignore -xb.register_plugin_callbacks( - partial( - xc.register_custom_call_partitioner, - name=_CUSTOM_PARTITIONING_CALL_NAME, - prop_user_sharding=_custom_partitioning_propagate_user_sharding, - partition=_custom_partitioning_partition, - infer_sharding_from_operands=_custom_partitioning_infer_sharding_from_operands, - can_side_effecting_have_replicated_sharding=True, - ) +from jax._src.custom_partitioning import ( + custom_partitioning as custom_partitioning, + custom_partitioning_p as custom_partitioning_p, )