-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an ffi_call
function with a similar signature to pure_callback
#21925
Conversation
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
812e2d1
to
ed56df0
Compare
Following up on jax-ml#21925, we can update the example code in `docs/cuda_custom_call` to use `ffi_call` instead of manually registering `core.Primitive`s. This removes quite a bit of boilerplate and doesn't require direct use of MLIR.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
ff80960
to
6bb5325
Compare
return ffi_lowering(ffi_target_name)(ctx, *operands, **kwargs) | ||
|
||
|
||
ffi_call_p = core.Primitive("ffi_call") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the potential for reducing boilerplate!
How does this work when ffi_call
is called multiple times over the life of a process? There seems to be only one ffi_call_p
declared globally, does that mean ffi_call
will overwrite the abstract eval rule, impl, etc.? Or are you planning to have a global ffi_call_p_registry
or similar that adds a new primitive every time ffi_call
is called? Or is that not necessary somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! There's actually only ever need for one primitive because all of the relevant parameters are passed at bind time. You'll see that all the rules take the target name as a parameter, for example. This means that the same primitive can be reused for any number of calls to the same (or many different) FFI targets.
This is actually the same pattern which is used for the callback primitives (here's the primitive for pure_callback
, for example), and all other higher order primitives (custom_vjp
, etc.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see thanks, so if I'm understanding this correctly, ffi_call_p
assumes a new "persona" every time ffi_call
is invoked and that works fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly! If you inspect the jaxpr with an ffi_call
in it, you'll see something like:
ffi_call[
target_name=my_ffi_call_target_name
result_avals=(ShapedArray(float32[5,3]), ...)
vectorized=True
a_parameter_that_i_set=1.5
]
b8f6d91
to
e296c4d
Compare
|
||
def ffi_call( | ||
target_name: str, | ||
result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC why not jax.ShapeDtypeStruct as in pure_callback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pure_callback
is typed to take Any
for result_shape_dtypes
because it accepts a pytree, but the leaves are implicitly expected to satisfy DuckTypedArray
. In both cases, the docstrings say that ShapeDtypeStruct
is useful for constructing this argument, but the only requirement in each case is that the leaves have .shape
and .dtype
properties. With this in mind, this seemed right to use this annotation here, rather than the stricter ShapeDtypeStruct
!
One or more :class:`~jax.Array` objects whose shapes and dtypes match | ||
``result_shape_dtypes``. | ||
""" | ||
if isinstance(result_shape_dtypes, Sequence): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you considered always returning multiple results: result_shape_dtypes
should always be a sequence, and the restult is always a sequence. The typing declarations would be simpler and the usage also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! I did consider that, but decided that it was probably worth supporting single outputs since a large fraction of the uses of custom_call
that I could find in Google code, and across GitHub only produce a single output. Furthermore, pure_callback
also supports both single and multiple outputs (Note: that's not really a fair comparison since it also supports arbitrary pytree output), so I aimed to support as much of that interface as we can for an FFI. So, since the implementation is fairly straightforward, I came down on the side of supporting single or multiple outputs transparently, but I don't feel strongly if there is pressure otherwise!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
This could be useful for supporting the most common use cases for FFI custom calls. It has several benefits over using the `Primitive` based approach, but the biggest one (in my opinion) is that it doesn't require interacting with `mlir` at all. It does have the limitation that transforms would need to be registered using interfaces like `custom_vjp`, but many users of custom calls already do that. ~~The easiest to-do item (I think) is to implement batching using a `vectorized` parameter like `pure_callback`, but we could also think about more sophisticated vmapping interfaces in the future.~~ Done. The more difficult to-do is to think about how to support sharding, and we might actually want to expose an interface similar to the one from `custom_partitioning`. I have less experience with this part so I'll have to think some more about it, and feedback would be appreciated!
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Following up on jax-ml#21925, we can update the example code in `docs/cuda_custom_call` to use `ffi_call` instead of manually registering `core.Primitive`s. This removes quite a bit of boilerplate and doesn't require direct use of MLIR.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
Following up on jax-ml#21925, we can update the example code in `docs/cuda_custom_call` to use `ffi_call` instead of manually registering `core.Primitive`s. This removes quite a bit of boilerplate and doesn't require direct use of MLIR.
Currently, JAX users who want to use XLA custom calls must interact with private APIs (e.g.
core.Primitive
) and MLIR. This doesn’t provide a great developer experience, and it would be useful to provide some sort of public API. This has been previously discussed in several different contexts (including #12632), and this PR builds on ideas from this previous work to present a simple API that covers some core use cases.There are more advanced use cases which would require finer-grained customization, and these would continue to rely on the private API. But, there do appear to be examples of use cases that would be satisfied by this simpler interface.
Example
The general idea is to provide a function called (something like)
jax.extend.ffi.ffi_call
with a signature that is similar tojax.pure_callback
, that lowers to a custom call. For example, the existing implementation of lu_pivots_to_permutation on GPU (the only FFI custom call currently injaxlib
), could (to first approximation) be written as:Platform-dependent behavior should be handled in user code with the help of
lax.platform_dependent
. (Currently this doesn't work, but @gnecula is looking into it.) Likejax.pure_callback
, this could be combined withcustom_jvp
orcustom_vjp
to support autodiff.vmap
is discussed below.Batching
This proof-of-concept implementation includes a
vectorized
parameter which has the same behavior as thevectorized
parameter tojax.pure_callback
(in fact it uses exactly the same batching rule). The tl;dr is that whenvectorized is False
, the base custom call is executed in a while loop, but whenvectorized is True
, thevmap
ped primitive calls the same custom call with an extra batch dimension on the left. This behavior could potentially work with the FFI interface since the input buffers include dimension metadata, but it’s a restrictive interface. Is there a better approach (don’t say custom_vmap! Or do...)?Alternatives
If we’re not totally wedded to aligning with the
jax.pure_callback
interface, it’s possible that a "builder" interface would be more future proof. For example, the syntax for the demo from above would be something like:This has the potential benefit that do_call could include metadata like a reference to the underlying
core.Primitive
so that users could use that for further customization.