diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py index 563462feb472..e4430e7c9fc0 100644 --- a/docs/cuda_custom_call/cuda_custom_call_test.py +++ b/docs/cuda_custom_call/cuda_custom_call_test.py @@ -28,7 +28,6 @@ import jax.numpy as jnp from jax.extend import ffi from jax.lib import xla_client -from jax.interpreters import mlir # start test boilerplate from absl.testing import absltest @@ -53,109 +52,45 @@ XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd" XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd" -# independently, corresponding JAX primitives must also be named, -# names can be different from XLA targets, here they are the same -JAX_PRIMITIVE_FWD = "foo-fwd" -JAX_PRIMITIVE_BWD = "foo-bwd" - +# load the shared library with the FFI target definitions if jtu.is_running_under_pytest(): raise unittest.SkipTest("libfoo.so hasn't been built") SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so") - library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) -#-----------------------------------------------------------------------------# -# Forward pass # -#-----------------------------------------------------------------------------# - -# register the XLA FFI binding pointer with XLA +# register the custom calls targets with XLA xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD, fn=ffi.pycapsule(library.FooFwd), platform=XLA_PLATFORM, api_version=XLA_CUSTOM_CALL_API_VERSION) - - -# our forward primitive will also return the intermediate output b+1 -# so it can be reused in the backward pass computation -def _foo_fwd_abstract_eval(a, b): - assert a.shape == b.shape - assert a.dtype == b.dtype - shaped_array = jax.core.ShapedArray(a.shape, a.dtype) - return ( - shaped_array, # output c - shaped_array, # intermediate output b+1 - ) - - -def _foo_fwd_lowering(ctx, a, b): - # ffi.ffi_lowering does most of the heavy lifting building a lowering. - # Keyword arguments passed to the lowering constructed by ffi_lowering are - # turned into custom call backend_config entries, which we take advantage of - # here for the dynamically computed n. - n = np.prod(a.type.shape).astype(np.uint64) - return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_FWD)(ctx, a, b, n=n) - - -# construct a new JAX primitive -foo_fwd_p = jax.core.Primitive(JAX_PRIMITIVE_FWD) -# register the abstract evaluation rule for the forward primitive -foo_fwd_p.def_abstract_eval(_foo_fwd_abstract_eval) -foo_fwd_p.multiple_results = True -mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM) - -#-----------------------------------------------------------------------------# -# Backward pass # -#-----------------------------------------------------------------------------# - -# register the XLA FFI binding pointer with XLA xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD, fn=ffi.pycapsule(library.FooBwd), platform=XLA_PLATFORM, api_version=XLA_CUSTOM_CALL_API_VERSION) -def _foo_bwd_abstract_eval(c_grad, a, b_plus_1): - assert c_grad.shape == a.shape - assert a.shape == b_plus_1.shape - assert c_grad.dtype == a.dtype - assert a.dtype == b_plus_1.dtype - - shaped_array = jax.core.ShapedArray(a.shape, a.dtype) - return ( - shaped_array, # a_grad - shaped_array, # b_grad - ) - - -def _foo_bwd_lowering(ctx, c_grad, a, b_plus_1): - n = np.prod(a.type.shape).astype(np.uint64) - return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_BWD)(ctx, - c_grad, - a, - b_plus_1, - n=n) - - -# construct a new JAX primitive -foo_bwd_p = jax.core.Primitive(JAX_PRIMITIVE_BWD) -# register the abstract evaluation rule for the backward primitive -foo_bwd_p.def_abstract_eval(_foo_bwd_abstract_eval) -foo_bwd_p.multiple_results = True -mlir.register_lowering(foo_bwd_p, _foo_bwd_lowering, platform=JAX_PLATFORM) - -#-----------------------------------------------------------------------------# -# User facing API # -#-----------------------------------------------------------------------------# - - def foo_fwd(a, b): - c, b_plus_1 = foo_fwd_p.bind(a, b) + assert a.dtype == jnp.float32 + assert a.shape == b.shape + assert a.dtype == b.dtype + n = np.prod(a.shape).astype(np.uint64) + out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) + c, b_plus_1 = ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_FWD, (out_type, out_type), + a, b, n=n) return c, (a, b_plus_1) def foo_bwd(res, c_grad): a, b_plus_1 = res - return foo_bwd_p.bind(c_grad, a, b_plus_1) + assert c_grad.dtype == jnp.float32 + assert c_grad.shape == a.shape + assert a.shape == b_plus_1.shape + assert c_grad.dtype == a.dtype + assert a.dtype == b_plus_1.dtype + n = np.prod(a.shape).astype(np.uint64) + out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) + return ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_BWD, (out_type, out_type), + c_grad, a, b_plus_1, n=n) @jax.custom_vjp