Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an ffi_call function with a similar signature to pure_callback #21925

Merged
merged 1 commit into from
Jul 2, 2024

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Jun 17, 2024

Currently, JAX users who want to use XLA custom calls must interact with private APIs (e.g. core.Primitive) and MLIR. This doesn’t provide a great developer experience, and it would be useful to provide some sort of public API. This has been previously discussed in several different contexts (including #12632), and this PR builds on ideas from this previous work to present a simple API that covers some core use cases.

There are more advanced use cases which would require finer-grained customization, and these would continue to rely on the private API. But, there do appear to be examples of use cases that would be satisfied by this simpler interface.

Example

The general idea is to provide a function called (something like) jax.extend.ffi.ffi_call with a signature that is similar to jax.pure_callback, that lowers to a custom call. For example, the existing implementation of lu_pivots_to_permutation on GPU (the only FFI custom call currently in jaxlib), could (to first approximation) be written as:

def ffi_call_lu_pivots_to_permutation(pivots, permutation_size):
  return jex.ffi.ffi_call(
      "cu_lu_pivots_to_permutation",

      # Output types are specified without reference to MLIR
      jax.ShapeDtypeStruct(
          shape=dims[:-1] + (permutation_size,),
          dtype=pivots.dtype,
      ),

      # Input arguments
      pivots,

      # Keyword arguments are passed to the FFI custom call as attributes
      permutation_size=np.int32(permutation_size),  # Note: np not jnp
  )

from jax.lib import xla_client
xla_client.register_custom_call_target(
    "cu_lu_pivots_to_permutation", ..., platform="CUDA", api_version=1)

Platform-dependent behavior should be handled in user code with the help of lax.platform_dependent. (Currently this doesn't work, but @gnecula is looking into it.) Like jax.pure_callback, this could be combined with custom_jvp or custom_vjp to support autodiff. vmap is discussed below.

Batching

This proof-of-concept implementation includes a vectorized parameter which has the same behavior as the vectorized parameter to jax.pure_callback (in fact it uses exactly the same batching rule). The tl;dr is that when vectorized is False, the base custom call is executed in a while loop, but when vectorized is True, the vmapped primitive calls the same custom call with an extra batch dimension on the left. This behavior could potentially work with the FFI interface since the input buffers include dimension metadata, but it’s a restrictive interface. Is there a better approach (don’t say custom_vmap! Or do...)?

Alternatives

If we’re not totally wedded to aligning with the jax.pure_callback interface, it’s possible that a "builder" interface would be more future proof. For example, the syntax for the demo from above would be something like:

do_call = jex.ffi.make_ffi_call("cu_lu_pivots_to_permutation")
do_call(
    jax.ShapeDtypeStruct(
        shape=dims[:-1] + (permutation_size,),
        dtype=pivots.dtype,
    ),
    pivots,
    batch_size=np.int64(batch_size),
    pivot_size=np.int32(pivot_size),
    permutation_size=np.int32(permutation_size),
)

This has the potential benefit that do_call could include metadata like a reference to the underlying core.Primitive so that users could use that for further customization.

@dfm dfm added the pull ready Ready for copybara import and testing label Jun 20, 2024
@dfm dfm self-assigned this Jun 20, 2024
dfm added a commit to dfm/jax that referenced this pull request Jun 25, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
@dfm dfm mentioned this pull request Jun 25, 2024
dfm added a commit to dfm/jax that referenced this pull request Jun 25, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
@dfm dfm force-pushed the ffi-call branch 2 times, most recently from 812e2d1 to ed56df0 Compare June 27, 2024 13:08
dfm added a commit to dfm/jax that referenced this pull request Jun 27, 2024
Following up on jax-ml#21925, we can update the example code in
`docs/cuda_custom_call` to use `ffi_call` instead of manually
registering `core.Primitive`s. This removes quite a bit of boilerplate
and doesn't require direct use of MLIR.
dfm added a commit to dfm/jax that referenced this pull request Jun 27, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
@dfm dfm force-pushed the ffi-call branch 3 times, most recently from ff80960 to 6bb5325 Compare June 27, 2024 20:49
return ffi_lowering(ffi_target_name)(ctx, *operands, **kwargs)


ffi_call_p = core.Primitive("ffi_call")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the potential for reducing boilerplate!
How does this work when ffi_call is called multiple times over the life of a process? There seems to be only one ffi_call_p declared globally, does that mean ffi_call will overwrite the abstract eval rule, impl, etc.? Or are you planning to have a global ffi_call_p_registry or similar that adds a new primitive every time ffi_call is called? Or is that not necessary somehow?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! There's actually only ever need for one primitive because all of the relevant parameters are passed at bind time. You'll see that all the rules take the target name as a parameter, for example. This means that the same primitive can be reused for any number of calls to the same (or many different) FFI targets.

This is actually the same pattern which is used for the callback primitives (here's the primitive for pure_callback, for example), and all other higher order primitives (custom_vjp, etc.).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks, so if I'm understanding this correctly, ffi_call_p assumes a new "persona" every time ffi_call is invoked and that works fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly! If you inspect the jaxpr with an ffi_call in it, you'll see something like:

ffi_call[
  target_name=my_ffi_call_target_name
  result_avals=(ShapedArray(float32[5,3]), ...)
  vectorized=True
  a_parameter_that_i_set=1.5
]

@dfm dfm marked this pull request as ready for review June 27, 2024 23:31
@dfm dfm force-pushed the ffi-call branch 3 times, most recently from b8f6d91 to e296c4d Compare June 28, 2024 15:47
@dfm dfm requested a review from superbobry June 28, 2024 15:48

def ffi_call(
target_name: str,
result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC why not jax.ShapeDtypeStruct as in pure_callback?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pure_callback is typed to take Any for result_shape_dtypes because it accepts a pytree, but the leaves are implicitly expected to satisfy DuckTypedArray. In both cases, the docstrings say that ShapeDtypeStruct is useful for constructing this argument, but the only requirement in each case is that the leaves have .shape and .dtype properties. With this in mind, this seemed right to use this annotation here, rather than the stricter ShapeDtypeStruct!

One or more :class:`~jax.Array` objects whose shapes and dtypes match
``result_shape_dtypes``.
"""
if isinstance(result_shape_dtypes, Sequence):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered always returning multiple results: result_shape_dtypes should always be a sequence, and the restult is always a sequence. The typing declarations would be simpler and the usage also.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! I did consider that, but decided that it was probably worth supporting single outputs since a large fraction of the uses of custom_call that I could find in Google code, and across GitHub only produce a single output. Furthermore, pure_callback also supports both single and multiple outputs (Note: that's not really a fair comparison since it also supports arbitrary pytree output), so I aimed to support as much of that interface as we can for an FFI. So, since the implementation is fairly straightforward, I came down on the side of supporting single or multiple outputs transparently, but I don't feel strongly if there is pressure otherwise!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
@copybara-service copybara-service bot merged commit b669ab7 into jax-ml:main Jul 2, 2024
15 checks passed
dfm added a commit to dfm/jax that referenced this pull request Jul 8, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 18, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 18, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 18, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 18, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 18, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 18, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 19, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 19, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 22, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 22, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 23, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 23, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Jul 23, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Aug 1, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Aug 1, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Aug 1, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
dfm added a commit to dfm/jax that referenced this pull request Aug 1, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
This was referenced Aug 2, 2024
dfm added a commit to dfm/jax that referenced this pull request Aug 2, 2024
Following up on jax-ml#21925, we can update the example code in
`docs/cuda_custom_call` to use `ffi_call` instead of manually
registering `core.Primitive`s. This removes quite a bit of boilerplate
and doesn't require direct use of MLIR.
nitins17 pushed a commit to google-ml-infra/jax-fork that referenced this pull request Aug 27, 2024
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
nitins17 pushed a commit to google-ml-infra/jax-fork that referenced this pull request Aug 27, 2024
Following up on jax-ml#21925, we can update the example code in
`docs/cuda_custom_call` to use `ffi_call` instead of manually
registering `core.Primitive`s. This removes quite a bit of boilerplate
and doesn't require direct use of MLIR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants