diff --git a/docs/jax.extend.ffi.rst b/docs/jax.extend.ffi.rst index 070778b8f065..5928189eb647 100644 --- a/docs/jax.extend.ffi.rst +++ b/docs/jax.extend.ffi.rst @@ -6,5 +6,6 @@ .. autosummary:: :toctree: _autosummary + ffi_call ffi_lowering pycapsule diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 8e9ab0b62593..8ebd9767466d 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -124,14 +124,14 @@ def pure_callback_transpose_rule(*args, **kwargs): ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule -def pure_callback_batching_rule( +def callback_batching_rule( + prim, args, dims, *, - callback: _FlatCallback, - sharding: SingleDeviceSharding | None, vectorized: bool, result_avals: Sequence[core.ShapedArray], + **kwargs: Any, ): axis_size = next(a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped) @@ -141,30 +141,30 @@ def pure_callback_batching_rule( result_avals = tuple( core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore for aval in result_avals) - outvals = pure_callback_p.bind( + outvals = prim.bind( *new_args, - callback=callback, - sharding=sharding, vectorized=vectorized, result_avals=result_avals, + **kwargs, ) else: is_batched = [d is not batching.not_mapped for d in dims] unbatched_args, batched_args = util.partition_list(is_batched, new_args) def _batch_fun(batched_args): merged_args = util.merge_lists(is_batched, unbatched_args, batched_args) - return pure_callback_p.bind( + return prim.bind( *merged_args, - callback=callback, - sharding=sharding, result_avals=result_avals, vectorized=vectorized, + **kwargs, ) outvals = lax_map(_batch_fun, batched_args) return tuple(outvals), (0,) * len(outvals) -batching.primitive_batchers[pure_callback_p] = pure_callback_batching_rule +batching.primitive_batchers[pure_callback_p] = functools.partial( + callback_batching_rule, pure_callback_p +) def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index aec124549e1e..9a18e66a6dd7 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -14,18 +14,25 @@ from __future__ import annotations -import os -import ctypes from collections.abc import Iterable, Mapping, Sequence +import ctypes +import functools +import os from typing import Any -import numpy as np - +from jax._src import core +from jax._src import dispatch from jax._src import dtypes +from jax._src import util +from jax._src.callback import _check_shape_dtype, callback_batching_rule +from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lib import jaxlib from jax._src.lib.mlir import ir -from jax._src.typing import DimSize +from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray +import numpy as np + +map, unsafe_map = util.safe_map, map def pycapsule(funcptr): @@ -140,3 +147,102 @@ def _ir_attribute(obj: Any) -> ir.Attribute: elif isinstance(mlir_type, ir.FloatType): return ir.FloatAttr.get(mlir_type, obj) raise TypeError(f"Unsupported attribute type: {type(obj)}") + + +ffi_call_p = core.Primitive("ffi_call") +ffi_call_p.multiple_results = True +ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p)) + + +@ffi_call_p.def_abstract_eval +def ffi_call_abstract_eval( + *avals_in, + result_avals: tuple[core.ShapedArray, ...], + ffi_target_name: str, + vectorized: bool, + **kwargs: Any, +): + del avals_in, ffi_target_name, vectorized, kwargs + return result_avals + + +batching.primitive_batchers[ffi_call_p] = functools.partial( + callback_batching_rule, ffi_call_p +) + + +def ffi_call_lowering( + ctx: mlir.LoweringRuleContext, + *operands: ir.Value, + result_avals: tuple[core.ShapedArray, ...], + ffi_target_name: str, + vectorized: bool, + **kwargs: Any, +) -> Sequence[ir.Value]: + del result_avals, vectorized + return ffi_lowering(ffi_target_name)(ctx, *operands, **kwargs) + + +mlir.register_lowering(ffi_call_p, ffi_call_lowering) + + +def ffi_call( + ffi_target_name: str, + result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray], + *args: ArrayLike, + vectorized: bool = False, + **kwargs: Any, +) -> Array | Sequence[Array]: + """Call a foreign function interface (FFI) target. + + Like :func:`jax.pure_callback`, the behavior of ``ffi_call`` under ``vmap`` + depends on the value of ``vectorized``. When ``vectorized`` is ``True``, the + FFI target is assumed to satisfy: + ``ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])``. In other words, + calling the FFI target with an extra leading dimension should return the same + result as calling it within a loop and stacking along the zeroth axis. + Therefore, the FFI target will be called directly on batched inputs (where the + batch axes are the leading dimensions). Additionally, the callbacks should + return outputs that have corresponding leading batch axes. If ``vectorized`` + is ``False`` (the default behavior), transforming this ``ffi_call`` under + ``vmap`` will result in a :func:`jax.lax.scan` with the ``ffi_call`` in the + body. + + Args: + ffi_target_name: the name of the XLA FFI custom call target that was + registered using :func:`xla_client.register_custom_call_target`. + result_shape_dtypes: an object, or sequence of objects, with ``shape`` and + ``dtype`` attributes which are expected to match the shape and dtype of + the custom call output or outputs. :class:`jax.ShapeDtypeStruct` is often + used to define the elements of ``result_shape_dtypes``. + *args: the arguments passed to the custom call. + vectorized: boolean specifying whether the callback function can operate in + a vectorized manner, as described above. + **kwargs: keyword arguments that are passed as named attributes to the + custom call using XLA's FFI interface. + + Returns: + One or more :class:`jax.Array` objects whose shapes and dtypes match + ``result_shape_dtypes``. + """ + if isinstance(result_shape_dtypes, Sequence): + multiple_results = True + result_shape_dtypes_ = result_shape_dtypes + else: + multiple_results = False + result_shape_dtypes_ = (result_shape_dtypes,) + map(_check_shape_dtype, result_shape_dtypes_) + result_avals = [ + core.ShapedArray(x.shape, x.dtype) for x in result_shape_dtypes_ + ] + results = ffi_call_p.bind( + *args, + result_avals=tuple(result_avals), + vectorized=vectorized, + ffi_target_name=ffi_target_name, + **kwargs, + ) + if multiple_results: + return results + else: + return results[0] diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py index 565b37cbb542..5862d0ecebde 100644 --- a/jax/extend/ffi.py +++ b/jax/extend/ffi.py @@ -16,6 +16,7 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.extend.ffi import ( + ffi_call as ffi_call, ffi_lowering as ffi_lowering, include_dir as include_dir, pycapsule as pycapsule, diff --git a/tests/extend_test.py b/tests/extend_test.py index 790b4fa2d774..2d68a1fe7e33 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os import numpy as np from absl.testing import absltest, parameterized import jax +from jax import lax import jax.extend as jex import jax.numpy as jnp @@ -99,7 +101,7 @@ def testHeadersExist(self): @parameterized.parameters( [True, int(1), float(5.0), np.int32(-5), np.float32(0.5)]) - def testIrAttribute(sel, value): + def testIrAttribute(self, value): with mlir.make_ir_context(), ir.Location.unknown(): const = mlir.ir_constant(value) attr = ffi._ir_attribute(value) @@ -116,6 +118,59 @@ def testParams(self, param): func = jax.jit(lambda *args: prim.bind(*args, param=param)) func.lower(jnp.linspace(0, 5, 10)) + @jtu.sample_product( + shape=[(1,), (4,), (5,)], + dtype=(np.int32,), + ) + @jtu.run_on_devices("gpu") + def testFfiCall(self, shape, dtype): + pivots_size = shape[-1] + permutation_size = 2 * pivots_size + pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) + pivots = jnp.broadcast_to(pivots, shape) + expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) + actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size) + self.assertArraysEqual(actual, expected) + + @jtu.sample_product( + shape=[(1,), (4,), (5,)], + dtype=(np.int32,), + ) + @jtu.run_on_devices("gpu") + def testFfiCallBatching(self, shape, dtype): + # TODO(dfm): Add a test for `vectorized = True`. + shape = (10,) + shape + pivots_size = shape[-1] + permutation_size = 2 * pivots_size + pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) + pivots = jnp.broadcast_to(pivots, shape) + expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) + actual = jax.vmap( + lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size) + )(pivots) + self.assertArraysEqual(actual, expected) + + +# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation` +# custom call target because that's the only one in jaxlib that uses the +# new FFI interface. Once more are available, consider using something that +# can be run on multiple platforms. +def ffi_call_lu_pivots_to_permutation(pivots, permutation_size): + dims = pivots.shape + batch_size = math.prod(dims[:-1]) + pivot_size = dims[-1] + return jex.ffi.ffi_call( + "cu_lu_pivots_to_permutation", + jax.ShapeDtypeStruct( + shape=dims[:-1] + (permutation_size,), + dtype=pivots.dtype, + ), + pivots, + batch_size=np.int64(batch_size), + pivot_size=np.int32(pivot_size), + permutation_size=np.int32(permutation_size), + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())