Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ffi_call function with a similar signature to pure_callback. #21743

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def pure_callback_transpose_rule(*args, **kwargs):
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule


def pure_callback_batching_rule(
def callback_batching_rule(
prim,
args,
dims,
*,
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
vectorized: bool,
result_avals: Sequence[core.ShapedArray],
**kwargs: Any,
):
axis_size = next(a.shape[d] for a, d in zip(args, dims)
if d is not batching.not_mapped)
Expand All @@ -141,30 +141,30 @@ def pure_callback_batching_rule(
result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
for aval in result_avals)
outvals = pure_callback_p.bind(
outvals = prim.bind(
*new_args,
callback=callback,
sharding=sharding,
vectorized=vectorized,
result_avals=result_avals,
**kwargs,
)
else:
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
return pure_callback_p.bind(
return prim.bind(
*merged_args,
callback=callback,
sharding=sharding,
result_avals=result_avals,
vectorized=vectorized,
**kwargs,
)
outvals = lax_map(_batch_fun, batched_args)
return tuple(outvals), (0,) * len(outvals)


batching.primitive_batchers[pure_callback_p] = pure_callback_batching_rule
batching.primitive_batchers[pure_callback_p] = functools.partial(
callback_batching_rule, pure_callback_p
)


def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
Expand Down
125 changes: 120 additions & 5 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,26 @@

from __future__ import annotations

import os
import ctypes
from collections.abc import Iterable, Mapping, Sequence
import ctypes
import functools
import os
from typing import Any

import numpy as np

from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import util
from jax._src.callback import _check_shape_dtype, callback_batching_rule
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lib import jaxlib
from jax._src.lib.mlir import ir
from jax._src.typing import DimSize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray
import numpy as np

map, unsafe_map = util.safe_map, map


def pycapsule(funcptr):
Expand Down Expand Up @@ -140,3 +148,110 @@ def _ir_attribute(obj: Any) -> ir.Attribute:
elif isinstance(mlir_type, ir.FloatType):
return ir.FloatAttr.get(mlir_type, obj)
raise TypeError(f"Unsupported attribute type: {type(obj)}")


ffi_call_p = core.Primitive("ffi_call")
ffi_call_p.multiple_results = True
ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p))


@ffi_call_p.def_abstract_eval
def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.ShapedArray, ...],
platforms: tuple[str, ...],
target_names: tuple[str, ...],
vectorized: bool,
**kwargs: Any,
):
del avals_in, platforms, target_names, vectorized, kwargs
return result_avals


batching.primitive_batchers[ffi_call_p] = functools.partial(
callback_batching_rule, ffi_call_p
)


def ffi_call_lowering(
ctx: mlir.LoweringRuleContext,
*operands: ir.Value,
result_avals: tuple[core.ShapedArray, ...],
platforms: tuple[str, ...],
target_names: tuple[str, ...],
vectorized: bool,
**kwargs: Any,
) -> Sequence[ir.Value]:
del result_avals, vectorized
return mlir.lower_per_platform(
ctx,
"ffi_call",
{
platform: ffi_lowering(target_name)
for platform, target_name in util.safe_zip(platforms, target_names)
},
None,
effects.no_effects,
*operands,
**kwargs,
)


mlir.register_lowering(ffi_call_p, ffi_call_lowering)


def ffi_call(
platform_target_names: Mapping[str, str],
result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray],
*args: ArrayLike,
vectorized: bool = False,
**kwargs: Any,
) -> Array | Sequence[Array]:
"""Calls a foreign function interface (FFI) target.

TODO(dfm): Explain what vectorized does.

Args:
platform_target_names: a dictionary where the key is the platform name and
the value is the name of the XLA FFI custom call target that was
previously registered for that platform using
:func:`xla_client.register_custom_call_target`.
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
``dtype`` attributes which are expected to match the shape and dtype of
the custom call output or outputs. :class:`jax.ShapeDtypeStruct` is often
used to define the elements of ``result_shape_dtypes``.
*args: the arguments passed to the custom call.
vectorized: boolean specifying whether the callback function can operate in
a vectorized manner, as described above.
**kwargs: keyword arguments that are passed as named attributes to the
custom call using XLA's FFI interface.

Returns:
One or more :class:`jax.Array` objects whose shapes and dtypes match
``result_shape_dtypes``.
"""
if "gpu" in platform_target_names:
raise ValueError("Use 'cuda' or 'rocm' instead of 'gpu' for ffi_call")
if isinstance(result_shape_dtypes, Sequence):
multiple_results = True
result_shape_dtypes_ = result_shape_dtypes
else:
multiple_results = False
result_shape_dtypes_ = (result_shape_dtypes,)
map(_check_shape_dtype, result_shape_dtypes_)
result_avals = [
core.ShapedArray(x.shape, x.dtype) for x in result_shape_dtypes_
]
platforms, target_names = util.unzip2(platform_target_names.items())
results = ffi_call_p.bind(
*args,
result_avals=tuple(result_avals),
vectorized=vectorized,
platforms=platforms,
target_names=target_names,
**kwargs,
)
if multiple_results:
return results
else:
return results[0]
1 change: 1 addition & 0 deletions jax/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.extend.ffi import (
ffi_call as ffi_call,
ffi_lowering as ffi_lowering,
include_dir as include_dir,
pycapsule as pycapsule,
Expand Down
62 changes: 60 additions & 2 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
import unittest

import numpy as np
from absl.testing import absltest, parameterized

import jax
from jax import lax
import jax.extend as jex
import jax.numpy as jnp

Expand All @@ -31,7 +33,7 @@
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.extend import ffi
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_extension_version, xla_client

jax.config.parse_flags_with_absl()

Expand Down Expand Up @@ -102,7 +104,7 @@ def testHeadersExist(self):
@parameterized.parameters(
[True, int(1), float(5.0),
np.int32(-5), np.float32(0.5)])
def testIrAttribute(sel, value):
def testIrAttribute(self, value):
with mlir.make_ir_context(), ir.Location.unknown():
const = mlir.ir_constant(value)
attr = ffi._ir_attribute(value)
Expand All @@ -119,6 +121,62 @@ def testParams(self, param):
func = jax.jit(lambda *args: prim.bind(*args, param=param))
func.lower(jnp.linspace(0, 5, 10))

@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
)
@jtu.run_on_devices("gpu")
def testFfiCall(self, shape, dtype):
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size)
self.assertArraysEqual(actual, expected)

@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
)
@jtu.run_on_devices("gpu")
def testFfiCallBatching(self, shape, dtype):
# TODO(dfm): Add a test for `vectorized = True`.
shape = (10,) + shape
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = jax.vmap(
lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size)
)(pivots)
self.assertArraysEqual(actual, expected)


# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation`
# custom call target because that's the only one in jaxlib that uses the
# new FFI interface. Once more are available, consider using something that
# can be run on multiple platforms.
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size):
dims = pivots.shape
batch_size = math.prod(dims[:-1])
pivot_size = dims[-1]
return jex.ffi.ffi_call(
{
"cuda": "cu_lu_pivots_to_permutation",
"rocm": "hip_lu_pivots_to_permutation",
},
jax.ShapeDtypeStruct(
shape=dims[:-1] + (permutation_size,),
dtype=pivots.dtype,
),
pivots,
batch_size=np.int64(batch_size),
pivot_size=np.int32(pivot_size),
permutation_size=np.int32(permutation_size),
)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Loading