From 99d8d62498d051824590e8db8e09cec2957afac9 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 7 Jun 2024 11:47:04 -0700 Subject: [PATCH] Add `ffi_call` function with a similar signature to `pure_callback`. This could be useful for supporting the most common use cases for FFI custom calls. It has several benefits over using the `Primitive` based approach, but the biggest one (in my opinion) is that it doesn't require interacting with `mlir` at all. It does have the limitation that transforms would need to be registered using interfaces like `custom_vjp`, but many users of custom calls already do that. ~~The easiest to-do item (I think) is to implement batching using a `vectorized` parameter like `pure_callback`, but we could also think about more sophisticated vmapping interfaces in the future.~~ Done. The more difficult to-do is to think about how to support sharding, and we might actually want to expose an interface similar to the one from `custom_partitioning`. I have less experience with this part so I'll have to think some more about it, and feedback would be appreciated! PiperOrigin-RevId: 641313727 --- jax/_src/callback.py | 20 +++---- jax/_src/extend/ffi.py | 125 +++++++++++++++++++++++++++++++++++++++-- jax/extend/ffi.py | 1 + tests/extend_test.py | 62 +++++++++++++++++++- 4 files changed, 191 insertions(+), 17 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 8e9ab0b62593..8ebd9767466d 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -124,14 +124,14 @@ def pure_callback_transpose_rule(*args, **kwargs): ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule -def pure_callback_batching_rule( +def callback_batching_rule( + prim, args, dims, *, - callback: _FlatCallback, - sharding: SingleDeviceSharding | None, vectorized: bool, result_avals: Sequence[core.ShapedArray], + **kwargs: Any, ): axis_size = next(a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped) @@ -141,30 +141,30 @@ def pure_callback_batching_rule( result_avals = tuple( core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore for aval in result_avals) - outvals = pure_callback_p.bind( + outvals = prim.bind( *new_args, - callback=callback, - sharding=sharding, vectorized=vectorized, result_avals=result_avals, + **kwargs, ) else: is_batched = [d is not batching.not_mapped for d in dims] unbatched_args, batched_args = util.partition_list(is_batched, new_args) def _batch_fun(batched_args): merged_args = util.merge_lists(is_batched, unbatched_args, batched_args) - return pure_callback_p.bind( + return prim.bind( *merged_args, - callback=callback, - sharding=sharding, result_avals=result_avals, vectorized=vectorized, + **kwargs, ) outvals = lax_map(_batch_fun, batched_args) return tuple(outvals), (0,) * len(outvals) -batching.primitive_batchers[pure_callback_p] = pure_callback_batching_rule +batching.primitive_batchers[pure_callback_p] = functools.partial( + callback_batching_rule, pure_callback_p +) def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index aec124549e1e..7c007c141c70 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -14,18 +14,26 @@ from __future__ import annotations -import os -import ctypes from collections.abc import Iterable, Mapping, Sequence +import ctypes +import functools +import os from typing import Any -import numpy as np - +from jax._src import core +from jax._src import dispatch from jax._src import dtypes +from jax._src import effects +from jax._src import util +from jax._src.callback import _check_shape_dtype, callback_batching_rule +from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lib import jaxlib from jax._src.lib.mlir import ir -from jax._src.typing import DimSize +from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray +import numpy as np + +map, unsafe_map = util.safe_map, map def pycapsule(funcptr): @@ -140,3 +148,110 @@ def _ir_attribute(obj: Any) -> ir.Attribute: elif isinstance(mlir_type, ir.FloatType): return ir.FloatAttr.get(mlir_type, obj) raise TypeError(f"Unsupported attribute type: {type(obj)}") + + +ffi_call_p = core.Primitive("ffi_call") +ffi_call_p.multiple_results = True +ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p)) + + +@ffi_call_p.def_abstract_eval +def ffi_call_abstract_eval( + *avals_in, + result_avals: tuple[core.ShapedArray, ...], + platforms: tuple[str, ...], + target_names: tuple[str, ...], + vectorized: bool, + **kwargs: Any, +): + del avals_in, platforms, target_names, vectorized, kwargs + return result_avals + + +batching.primitive_batchers[ffi_call_p] = functools.partial( + callback_batching_rule, ffi_call_p +) + + +def ffi_call_lowering( + ctx: mlir.LoweringRuleContext, + *operands: ir.Value, + result_avals: tuple[core.ShapedArray, ...], + platforms: tuple[str, ...], + target_names: tuple[str, ...], + vectorized: bool, + **kwargs: Any, +) -> Sequence[ir.Value]: + del result_avals, vectorized + return mlir.lower_per_platform( + ctx, + "ffi_call", + { + platform: ffi_lowering(target_name) + for platform, target_name in util.safe_zip(platforms, target_names) + }, + None, + effects.no_effects, + *operands, + **kwargs, + ) + + +mlir.register_lowering(ffi_call_p, ffi_call_lowering) + + +def ffi_call( + platform_target_names: Mapping[str, str], + result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray], + *args: ArrayLike, + vectorized: bool = False, + **kwargs: Any, +) -> Array | Sequence[Array]: + """Calls a foreign function interface (FFI) target. + + TODO(dfm): Explain what vectorized does. + + Args: + platform_target_names: a dictionary where the key is the platform name and + the value is the name of the XLA FFI custom call target that was + previously registered for that platform using + :func:`xla_client.register_custom_call_target`. + result_shape_dtypes: an object, or sequence of objects, with ``shape`` and + ``dtype`` attributes which are expected to match the shape and dtype of + the custom call output or outputs. :class:`jax.ShapeDtypeStruct` is often + used to define the elements of ``result_shape_dtypes``. + *args: the arguments passed to the custom call. + vectorized: boolean specifying whether the callback function can operate in + a vectorized manner, as described above. + **kwargs: keyword arguments that are passed as named attributes to the + custom call using XLA's FFI interface. + + Returns: + One or more :class:`jax.Array` objects whose shapes and dtypes match + ``result_shape_dtypes``. + """ + if "gpu" in platform_target_names: + raise ValueError("Use 'cuda' or 'rocm' instead of 'gpu' for ffi_call") + if isinstance(result_shape_dtypes, Sequence): + multiple_results = True + result_shape_dtypes_ = result_shape_dtypes + else: + multiple_results = False + result_shape_dtypes_ = (result_shape_dtypes,) + map(_check_shape_dtype, result_shape_dtypes_) + result_avals = [ + core.ShapedArray(x.shape, x.dtype) for x in result_shape_dtypes_ + ] + platforms, target_names = util.unzip2(platform_target_names.items()) + results = ffi_call_p.bind( + *args, + result_avals=tuple(result_avals), + vectorized=vectorized, + platforms=platforms, + target_names=target_names, + **kwargs, + ) + if multiple_results: + return results + else: + return results[0] diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py index 565b37cbb542..5862d0ecebde 100644 --- a/jax/extend/ffi.py +++ b/jax/extend/ffi.py @@ -16,6 +16,7 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.extend.ffi import ( + ffi_call as ffi_call, ffi_lowering as ffi_lowering, include_dir as include_dir, pycapsule as pycapsule, diff --git a/tests/extend_test.py b/tests/extend_test.py index 8aee242242cf..8ffc5b4bee2a 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os import unittest @@ -19,6 +20,7 @@ from absl.testing import absltest, parameterized import jax +from jax import lax import jax.extend as jex import jax.numpy as jnp @@ -31,7 +33,7 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.extend import ffi -from jax._src.lib import xla_extension_version +from jax._src.lib import xla_extension_version, xla_client jax.config.parse_flags_with_absl() @@ -102,7 +104,7 @@ def testHeadersExist(self): @parameterized.parameters( [True, int(1), float(5.0), np.int32(-5), np.float32(0.5)]) - def testIrAttribute(sel, value): + def testIrAttribute(self, value): with mlir.make_ir_context(), ir.Location.unknown(): const = mlir.ir_constant(value) attr = ffi._ir_attribute(value) @@ -119,6 +121,62 @@ def testParams(self, param): func = jax.jit(lambda *args: prim.bind(*args, param=param)) func.lower(jnp.linspace(0, 5, 10)) + @jtu.sample_product( + shape=[(1,), (4,), (5,)], + dtype=(np.int32,), + ) + @jtu.run_on_devices("gpu") + def testFfiCall(self, shape, dtype): + pivots_size = shape[-1] + permutation_size = 2 * pivots_size + pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) + pivots = jnp.broadcast_to(pivots, shape) + expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) + actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size) + self.assertArraysEqual(actual, expected) + + @jtu.sample_product( + shape=[(1,), (4,), (5,)], + dtype=(np.int32,), + ) + @jtu.run_on_devices("gpu") + def testFfiCallBatching(self, shape, dtype): + # TODO(dfm): Add a test for `vectorized = True`. + shape = (10,) + shape + pivots_size = shape[-1] + permutation_size = 2 * pivots_size + pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) + pivots = jnp.broadcast_to(pivots, shape) + expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) + actual = jax.vmap( + lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size) + )(pivots) + self.assertArraysEqual(actual, expected) + + +# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation` +# custom call target because that's the only one in jaxlib that uses the +# new FFI interface. Once more are available, consider using something that +# can be run on multiple platforms. +def ffi_call_lu_pivots_to_permutation(pivots, permutation_size): + dims = pivots.shape + batch_size = math.prod(dims[:-1]) + pivot_size = dims[-1] + return jex.ffi.ffi_call( + { + "cuda": "cu_lu_pivots_to_permutation", + "rocm": "hip_lu_pivots_to_permutation", + }, + jax.ShapeDtypeStruct( + shape=dims[:-1] + (permutation_size,), + dtype=pivots.dtype, + ), + pivots, + batch_size=np.int64(batch_size), + pivot_size=np.int32(pivot_size), + permutation_size=np.int32(permutation_size), + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())