Skip to content

Commit

Permalink
Merge pull request #21925 from dfm:ffi-call
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648532673
  • Loading branch information
jax authors committed Jul 2, 2024
2 parents 1eaaa10 + e9b087d commit b669ab7
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/jax.extend.ffi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
.. autosummary::
:toctree: _autosummary

ffi_call
ffi_lowering
pycapsule
20 changes: 10 additions & 10 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def pure_callback_transpose_rule(*args, **kwargs):
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule


def pure_callback_batching_rule(
def callback_batching_rule(
prim,
args,
dims,
*,
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
vectorized: bool,
result_avals: Sequence[core.ShapedArray],
**kwargs: Any,
):
axis_size = next(a.shape[d] for a, d in zip(args, dims)
if d is not batching.not_mapped)
Expand All @@ -141,30 +141,30 @@ def pure_callback_batching_rule(
result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
for aval in result_avals)
outvals = pure_callback_p.bind(
outvals = prim.bind(
*new_args,
callback=callback,
sharding=sharding,
vectorized=vectorized,
result_avals=result_avals,
**kwargs,
)
else:
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
return pure_callback_p.bind(
return prim.bind(
*merged_args,
callback=callback,
sharding=sharding,
result_avals=result_avals,
vectorized=vectorized,
**kwargs,
)
outvals = lax_map(_batch_fun, batched_args)
return tuple(outvals), (0,) * len(outvals)


batching.primitive_batchers[pure_callback_p] = pure_callback_batching_rule
batching.primitive_batchers[pure_callback_p] = functools.partial(
callback_batching_rule, pure_callback_p
)


def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
Expand Down
125 changes: 120 additions & 5 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,26 @@

from __future__ import annotations

import os
import ctypes
from collections.abc import Iterable, Mapping, Sequence
import ctypes
import functools
import os
from typing import Any

import numpy as np

from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import util
from jax._src.callback import _check_shape_dtype, callback_batching_rule
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lib import jaxlib
from jax._src.lib.mlir import ir
from jax._src.typing import DimSize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray
import numpy as np

map, unsafe_map = util.safe_map, map


def pycapsule(funcptr):
Expand Down Expand Up @@ -140,3 +148,110 @@ def _ir_attribute(obj: Any) -> ir.Attribute:
elif isinstance(mlir_type, ir.FloatType):
return ir.FloatAttr.get(mlir_type, obj)
raise TypeError(f"Unsupported attribute type: {type(obj)}")


def ffi_call(
target_name: str,
result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray],
*args: ArrayLike,
vectorized: bool = False,
**kwargs: Any,
) -> Array | list[Array]:
"""Call a foreign function interface (FFI) target.
Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under
:func:`~jax.vmap` depends on the value of ``vectorized``. When ``vectorized``
is ``True``, the FFI target is assumed to satisfy: ``ffi_call(xs) ==
jnp.stack([ffi_call(x) for x in xs])``. In other words, calling the FFI target
with an extra leading dimension should return the same result as calling it
within a loop and stacking along the zeroth axis. Therefore, the FFI target
will be called directly on batched inputs (where the batch axes are the
leading dimensions). Additionally, the callbacks should return outputs that
have corresponding leading batch axes. If ``vectorized`` is ``False`` (the
default behavior), transforming this ``ffi_call`` under :func:`~jax.vmap` will
result in a :func:`~jax.lax.scan` with the ``ffi_call`` in the body.
Args:
target_name: the name of the XLA FFI custom call target that was registered
using :func:`~jaxlib.xla_client.register_custom_call_target`.
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
``dtype`` attributes which are expected to match the shape and dtype of
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
used to define the elements of ``result_shape_dtypes``.
*args: the arguments passed to the custom call.
vectorized: boolean specifying whether the callback function can operate in
a vectorized manner, as described above.
**kwargs: keyword arguments that are passed as named attributes to the
custom call using XLA's FFI interface.
Returns:
One or more :class:`~jax.Array` objects whose shapes and dtypes match
``result_shape_dtypes``.
"""
if isinstance(result_shape_dtypes, Sequence):
multiple_results = True
result_types = result_shape_dtypes
else:
multiple_results = False
result_types = (result_shape_dtypes,)
map(_check_shape_dtype, result_types)
result_avals = tuple(core.ShapedArray(x.shape, x.dtype) for x in result_types)
results = ffi_call_p.bind(
*args,
result_avals=result_avals,
vectorized=vectorized,
target_name=target_name,
**kwargs,
)
if multiple_results:
return results
else:
return results[0]


def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.ShapedArray, ...],
target_name: str,
vectorized: bool,
**kwargs: Any,
):
del avals_in, target_name, vectorized, kwargs
return result_avals


def ffi_call_jvp(*args, target_name, **kwargs):
del args, kwargs
raise ValueError(
f"The FFI call to `{target_name}` cannot be differentiated. "
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")


def ffi_call_transpose(*args, target_name, **kwargs):
del args, kwargs
raise ValueError(
f"The FFI call to `{target_name}` cannot be differentiated. "
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")


def ffi_call_lowering(
ctx: mlir.LoweringRuleContext,
*operands: ir.Value,
result_avals: tuple[core.ShapedArray, ...],
target_name: str,
vectorized: bool,
**kwargs: Any,
) -> Sequence[ir.Value]:
del result_avals, vectorized
return ffi_lowering(target_name)(ctx, *operands, **kwargs)


ffi_call_p = core.Primitive("ffi_call")
ffi_call_p.multiple_results = True
ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p))
ffi_call_p.def_abstract_eval(ffi_call_abstract_eval)
ad.primitive_jvps[ffi_call_p] = ffi_call_jvp
ad.primitive_transposes[ffi_call_p] = ffi_call_transpose
batching.primitive_batchers[ffi_call_p] = functools.partial(
callback_batching_rule, ffi_call_p)
mlir.register_lowering(ffi_call_p, ffi_call_lowering)
1 change: 1 addition & 0 deletions jax/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.extend.ffi import (
ffi_call as ffi_call,
ffi_lowering as ffi_lowering,
include_dir as include_dir,
pycapsule as pycapsule,
Expand Down
55 changes: 54 additions & 1 deletion tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import os

import numpy as np
import unittest
from absl.testing import absltest, parameterized

import jax
from jax import lax
import jax.extend as jex
import jax.numpy as jnp

Expand All @@ -28,6 +30,7 @@
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.extend import ffi

Expand Down Expand Up @@ -99,7 +102,7 @@ def testHeadersExist(self):
@parameterized.parameters(
[True, int(1), float(5.0),
np.int32(-5), np.float32(0.5)])
def testIrAttribute(sel, value):
def testIrAttribute(self, value):
with mlir.make_ir_context(), ir.Location.unknown():
const = mlir.ir_constant(value)
attr = ffi._ir_attribute(value)
Expand All @@ -116,6 +119,56 @@ def testParams(self, param):
func = jax.jit(lambda *args: prim.bind(*args, param=param))
func.lower(jnp.linspace(0, 5, 10))

@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
)
@jtu.run_on_devices("gpu")
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def testFfiCall(self, shape, dtype):
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size)
self.assertArraysEqual(actual, expected)

@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
vectorized=(False, True),
)
@jtu.run_on_devices("gpu")
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def testFfiCallBatching(self, shape, dtype, vectorized):
shape = (10,) + shape
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation(
x, permutation_size, vectorized=vectorized))(pivots)
self.assertArraysEqual(actual, expected)


# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation`
# custom call target because that's the only one in jaxlib that uses the
# new FFI interface. Once more are available, consider using something that
# can be run on multiple platforms.
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True):
return jex.ffi.ffi_call(
"cu_lu_pivots_to_permutation",
jax.ShapeDtypeStruct(
shape=pivots.shape[:-1] + (permutation_size,),
dtype=pivots.dtype,
),
pivots,
permutation_size=np.int32(permutation_size),
vectorized=vectorized,
)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b669ab7

Please sign in to comment.