Skip to content

Commit

Permalink
Update CUDA custom call example code to use ffi_call.
Browse files Browse the repository at this point in the history
Following up on jax-ml#21925, we can update the example code in
`docs/cuda_custom_call` to use `ffi_call` instead of manually
registering `core.Primitive`s. This removes quite a bit of boilerplate
and doesn't require direct use of MLIR.
  • Loading branch information
dfm committed Aug 2, 2024
1 parent e88887e commit 5474e0e
Showing 1 changed file with 18 additions and 83 deletions.
101 changes: 18 additions & 83 deletions docs/cuda_custom_call/cuda_custom_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import jax.numpy as jnp
from jax.extend import ffi
from jax.lib import xla_client
from jax.interpreters import mlir

# start test boilerplate
from absl.testing import absltest
Expand All @@ -53,109 +52,45 @@
XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd"
XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd"

# independently, corresponding JAX primitives must also be named,
# names can be different from XLA targets, here they are the same
JAX_PRIMITIVE_FWD = "foo-fwd"
JAX_PRIMITIVE_BWD = "foo-bwd"

# load the shared library with the FFI target definitions
if jtu.is_running_under_pytest():
raise unittest.SkipTest("libfoo.so hasn't been built")
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so")

library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)

#-----------------------------------------------------------------------------#
# Forward pass #
#-----------------------------------------------------------------------------#

# register the XLA FFI binding pointer with XLA
# register the custom calls targets with XLA
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
fn=ffi.pycapsule(library.FooFwd),
platform=XLA_PLATFORM,
api_version=XLA_CUSTOM_CALL_API_VERSION)


# our forward primitive will also return the intermediate output b+1
# so it can be reused in the backward pass computation
def _foo_fwd_abstract_eval(a, b):
assert a.shape == b.shape
assert a.dtype == b.dtype
shaped_array = jax.core.ShapedArray(a.shape, a.dtype)
return (
shaped_array, # output c
shaped_array, # intermediate output b+1
)


def _foo_fwd_lowering(ctx, a, b):
# ffi.ffi_lowering does most of the heavy lifting building a lowering.
# Keyword arguments passed to the lowering constructed by ffi_lowering are
# turned into custom call backend_config entries, which we take advantage of
# here for the dynamically computed n.
n = np.prod(a.type.shape).astype(np.uint64)
return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_FWD)(ctx, a, b, n=n)


# construct a new JAX primitive
foo_fwd_p = jax.core.Primitive(JAX_PRIMITIVE_FWD)
# register the abstract evaluation rule for the forward primitive
foo_fwd_p.def_abstract_eval(_foo_fwd_abstract_eval)
foo_fwd_p.multiple_results = True
mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM)

#-----------------------------------------------------------------------------#
# Backward pass #
#-----------------------------------------------------------------------------#

# register the XLA FFI binding pointer with XLA
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
fn=ffi.pycapsule(library.FooBwd),
platform=XLA_PLATFORM,
api_version=XLA_CUSTOM_CALL_API_VERSION)


def _foo_bwd_abstract_eval(c_grad, a, b_plus_1):
assert c_grad.shape == a.shape
assert a.shape == b_plus_1.shape
assert c_grad.dtype == a.dtype
assert a.dtype == b_plus_1.dtype

shaped_array = jax.core.ShapedArray(a.shape, a.dtype)
return (
shaped_array, # a_grad
shaped_array, # b_grad
)


def _foo_bwd_lowering(ctx, c_grad, a, b_plus_1):
n = np.prod(a.type.shape).astype(np.uint64)
return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_BWD)(ctx,
c_grad,
a,
b_plus_1,
n=n)


# construct a new JAX primitive
foo_bwd_p = jax.core.Primitive(JAX_PRIMITIVE_BWD)
# register the abstract evaluation rule for the backward primitive
foo_bwd_p.def_abstract_eval(_foo_bwd_abstract_eval)
foo_bwd_p.multiple_results = True
mlir.register_lowering(foo_bwd_p, _foo_bwd_lowering, platform=JAX_PLATFORM)

#-----------------------------------------------------------------------------#
# User facing API #
#-----------------------------------------------------------------------------#


def foo_fwd(a, b):
c, b_plus_1 = foo_fwd_p.bind(a, b)
assert a.dtype == jnp.float32
assert a.shape == b.shape
assert a.dtype == b.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
c, b_plus_1 = ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_FWD, (out_type, out_type),
a, b, n=n)
return c, (a, b_plus_1)


def foo_bwd(res, c_grad):
a, b_plus_1 = res
return foo_bwd_p.bind(c_grad, a, b_plus_1)
assert c_grad.dtype == jnp.float32
assert c_grad.shape == a.shape
assert a.shape == b_plus_1.shape
assert c_grad.dtype == a.dtype
assert a.dtype == b_plus_1.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
return ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_BWD, (out_type, out_type),
c_grad, a, b_plus_1, n=n)


@jax.custom_vjp
Expand Down

0 comments on commit 5474e0e

Please sign in to comment.