Skip to content

Commit

Permalink
Add ffi_call function with a similar signature to pure_callback.
Browse files Browse the repository at this point in the history
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
  • Loading branch information
dfm committed Jun 27, 2024
1 parent 2408212 commit ed56df0
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/jax.extend.ffi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
.. autosummary::
:toctree: _autosummary

ffi_call
ffi_lowering
pycapsule
20 changes: 10 additions & 10 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def pure_callback_transpose_rule(*args, **kwargs):
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule


def pure_callback_batching_rule(
def callback_batching_rule(
prim,
args,
dims,
*,
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
vectorized: bool,
result_avals: Sequence[core.ShapedArray],
**kwargs: Any,
):
axis_size = next(a.shape[d] for a, d in zip(args, dims)
if d is not batching.not_mapped)
Expand All @@ -141,30 +141,30 @@ def pure_callback_batching_rule(
result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
for aval in result_avals)
outvals = pure_callback_p.bind(
outvals = prim.bind(
*new_args,
callback=callback,
sharding=sharding,
vectorized=vectorized,
result_avals=result_avals,
**kwargs,
)
else:
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
return pure_callback_p.bind(
return prim.bind(
*merged_args,
callback=callback,
sharding=sharding,
result_avals=result_avals,
vectorized=vectorized,
**kwargs,
)
outvals = lax_map(_batch_fun, batched_args)
return tuple(outvals), (0,) * len(outvals)


batching.primitive_batchers[pure_callback_p] = pure_callback_batching_rule
batching.primitive_batchers[pure_callback_p] = functools.partial(
callback_batching_rule, pure_callback_p
)


def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
Expand Down
116 changes: 111 additions & 5 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,25 @@

from __future__ import annotations

import os
import ctypes
from collections.abc import Iterable, Mapping, Sequence
import ctypes
import functools
import os
from typing import Any

import numpy as np

from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import util
from jax._src.callback import _check_shape_dtype, callback_batching_rule
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lib import jaxlib
from jax._src.lib.mlir import ir
from jax._src.typing import DimSize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray
import numpy as np

map, unsafe_map = util.safe_map, map


def pycapsule(funcptr):
Expand Down Expand Up @@ -140,3 +147,102 @@ def _ir_attribute(obj: Any) -> ir.Attribute:
elif isinstance(mlir_type, ir.FloatType):
return ir.FloatAttr.get(mlir_type, obj)
raise TypeError(f"Unsupported attribute type: {type(obj)}")


ffi_call_p = core.Primitive("ffi_call")
ffi_call_p.multiple_results = True
ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p))


@ffi_call_p.def_abstract_eval
def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.ShapedArray, ...],
ffi_target_name: str,
vectorized: bool,
**kwargs: Any,
):
del avals_in, ffi_target_name, vectorized, kwargs
return result_avals


batching.primitive_batchers[ffi_call_p] = functools.partial(
callback_batching_rule, ffi_call_p
)


def ffi_call_lowering(
ctx: mlir.LoweringRuleContext,
*operands: ir.Value,
result_avals: tuple[core.ShapedArray, ...],
ffi_target_name: str,
vectorized: bool,
**kwargs: Any,
) -> Sequence[ir.Value]:
del result_avals, vectorized
return ffi_lowering(ffi_target_name)(ctx, *operands, **kwargs)


mlir.register_lowering(ffi_call_p, ffi_call_lowering)


def ffi_call(
ffi_target_name: str,
result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray],
*args: ArrayLike,
vectorized: bool = False,
**kwargs: Any,
) -> Array | Sequence[Array]:
"""Call a foreign function interface (FFI) target.
Like :func:`jax.pure_callback`, the behavior of ``ffi_call`` under ``vmap``
depends on the value of ``vectorized``. When ``vectorized`` is ``True``, the
FFI target is assumed to satisfy:
``ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])``. In other words,
calling the FFI target with an extra leading dimension should return the same
result as calling it within a loop and stacking along the zeroth axis.
Therefore, the FFI target will be called directly on batched inputs (where the
batch axes are the leading dimensions). Additionally, the callbacks should
return outputs that have corresponding leading batch axes. If ``vectorized``
is ``False`` (the default behavior), transforming this ``ffi_call`` under
``vmap`` will result in a :func:`jax.lax.scan` with the ``ffi_call`` in the
body.
Args:
ffi_target_name: the name of the XLA FFI custom call target that was
registered using :func:`xla_client.register_custom_call_target`.
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
``dtype`` attributes which are expected to match the shape and dtype of
the custom call output or outputs. :class:`jax.ShapeDtypeStruct` is often
used to define the elements of ``result_shape_dtypes``.
*args: the arguments passed to the custom call.
vectorized: boolean specifying whether the callback function can operate in
a vectorized manner, as described above.
**kwargs: keyword arguments that are passed as named attributes to the
custom call using XLA's FFI interface.
Returns:
One or more :class:`jax.Array` objects whose shapes and dtypes match
``result_shape_dtypes``.
"""
if isinstance(result_shape_dtypes, Sequence):
multiple_results = True
result_shape_dtypes_ = result_shape_dtypes
else:
multiple_results = False
result_shape_dtypes_ = (result_shape_dtypes,)
map(_check_shape_dtype, result_shape_dtypes_)
result_avals = [
core.ShapedArray(x.shape, x.dtype) for x in result_shape_dtypes_
]
results = ffi_call_p.bind(
*args,
result_avals=tuple(result_avals),
vectorized=vectorized,
ffi_target_name=ffi_target_name,
**kwargs,
)
if multiple_results:
return results
else:
return results[0]
1 change: 1 addition & 0 deletions jax/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.extend.ffi import (
ffi_call as ffi_call,
ffi_lowering as ffi_lowering,
include_dir as include_dir,
pycapsule as pycapsule,
Expand Down
57 changes: 56 additions & 1 deletion tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os

import numpy as np
from absl.testing import absltest, parameterized

import jax
from jax import lax
import jax.extend as jex
import jax.numpy as jnp

Expand Down Expand Up @@ -99,7 +101,7 @@ def testHeadersExist(self):
@parameterized.parameters(
[True, int(1), float(5.0),
np.int32(-5), np.float32(0.5)])
def testIrAttribute(sel, value):
def testIrAttribute(self, value):
with mlir.make_ir_context(), ir.Location.unknown():
const = mlir.ir_constant(value)
attr = ffi._ir_attribute(value)
Expand All @@ -116,6 +118,59 @@ def testParams(self, param):
func = jax.jit(lambda *args: prim.bind(*args, param=param))
func.lower(jnp.linspace(0, 5, 10))

@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
)
@jtu.run_on_devices("gpu")
def testFfiCall(self, shape, dtype):
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size)
self.assertArraysEqual(actual, expected)

@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
)
@jtu.run_on_devices("gpu")
def testFfiCallBatching(self, shape, dtype):
# TODO(dfm): Add a test for `vectorized = True`.
shape = (10,) + shape
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = jax.vmap(
lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size)
)(pivots)
self.assertArraysEqual(actual, expected)


# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation`
# custom call target because that's the only one in jaxlib that uses the
# new FFI interface. Once more are available, consider using something that
# can be run on multiple platforms.
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size):
dims = pivots.shape
batch_size = math.prod(dims[:-1])
pivot_size = dims[-1]
return jex.ffi.ffi_call(
"cu_lu_pivots_to_permutation",
jax.ShapeDtypeStruct(
shape=dims[:-1] + (permutation_size,),
dtype=pivots.dtype,
),
pivots,
batch_size=np.int64(batch_size),
pivot_size=np.int32(pivot_size),
permutation_size=np.int32(permutation_size),
)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit ed56df0

Please sign in to comment.