From bc11437f1f7fcaa62441e662ececbd6f582c413b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 25 Jun 2024 14:08:49 -0400 Subject: [PATCH] Add ffi_call tutorial Building on #21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice. --- docs/_tutorials/index.rst | 3 +- docs/conf.py | 1 + docs/ffi/.gitignore | 5 + docs/ffi/CMakeLists.txt | 30 ++ docs/ffi/ffi.ipynb | 616 ++++++++++++++++++++++++++++++++++++++ docs/ffi/ffi.md | 476 +++++++++++++++++++++++++++++ docs/ffi/rms_norm.cc | 151 ++++++++++ docs/jax.lax.rst | 1 + docs/requirements.txt | 2 + 9 files changed, 1283 insertions(+), 2 deletions(-) create mode 100644 docs/ffi/.gitignore create mode 100644 docs/ffi/CMakeLists.txt create mode 100644 docs/ffi/ffi.ipynb create mode 100644 docs/ffi/ffi.md create mode 100644 docs/ffi/rms_norm.cc diff --git a/docs/_tutorials/index.rst b/docs/_tutorials/index.rst index d261612a4cd4..2c55aeb4dd67 100644 --- a/docs/_tutorials/index.rst +++ b/docs/_tutorials/index.rst @@ -7,7 +7,6 @@ JAX tutorials draft .. note:: - This is a The tutorials below are a work in progress; for the time being, please refer to the older tutorial content, including :ref:`beginner-guide`, :ref:`user-guides`, and the now-deleted *JAX 101* tutorials. @@ -44,7 +43,7 @@ JAX 201 advanced-debugging external-callbacks profiling-and-performance - + ../ffi/ffi JAX 301 ------- diff --git a/docs/conf.py b/docs/conf.py index fdbb157ecf39..cde7f118f180 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -136,6 +136,7 @@ def _do_not_evaluate_in_jax( 'jep/9407-type-promotion.md', 'autodidax.md', 'sharded-computation.md', + 'ffi/ffi.ipynb', ] # The name of the Pygments (syntax highlighting) style to use. diff --git a/docs/ffi/.gitignore b/docs/ffi/.gitignore new file mode 100644 index 000000000000..f6608699925e --- /dev/null +++ b/docs/ffi/.gitignore @@ -0,0 +1,5 @@ +CMake* +cmake* +Makefile +*.so +*.dylib diff --git a/docs/ffi/CMakeLists.txt b/docs/ffi/CMakeLists.txt new file mode 100644 index 000000000000..64b2d32709fe --- /dev/null +++ b/docs/ffi/CMakeLists.txt @@ -0,0 +1,30 @@ +cmake_minimum_required(VERSION 3.18...3.27) +project(rms_norm LANGUAGES CXX) +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) +list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") +find_package(nanobind CONFIG REQUIRED) + +# TODO(dfm): Remove this "FetchContent" version and replace with the python +# command which is commented out below once jaxlib 0.4.31 is released. +include(FetchContent) +FetchContent_Declare( + xla + GIT_REPOSITORY https://github.com/openxla/xla.git + GIT_TAG 0b35a7fcb1c2b58d657994c588d049f9fe4ad048 +) +FetchContent_MakeAvailable(xla) +set(XLA_DIR "${xla_SOURCE_DIR}") +# execute_process( +# COMMAND "${Python_EXECUTABLE}" +# "-c" "from jax.extend import ffi; print(ffi.include_dir())" +# OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) +message(STATUS "XLA include directory: ${XLA_DIR}") + +nanobind_add_module(rms_norm NOMINSIZE "rms_norm.cc") +target_include_directories(rms_norm PUBLIC ${XLA_DIR}) +install(TARGETS rms_norm LIBRARY DESTINATION ${CMAKE_CURRENT_LIST_DIR}) + diff --git a/docs/ffi/ffi.ipynb b/docs/ffi/ffi.ipynb new file mode 100644 index 000000000000..6cfb32542f48 --- /dev/null +++ b/docs/ffi/ffi.ipynb @@ -0,0 +1,616 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# JAX's foreign function interface\n", + "\n", + "While a wide range of numerical operations can be easily and efficiently implemented using JAX's built in `jax.numpy` and `jax.lax` interfaces, it can sometimes be useful to explicitly call out to external compiled libraries via a \"foreign function interface\" (FFI).\n", + "This can be particularly useful when particular operations have been previously implemented in an optimized C or CUDA library, and it would be non-trivial to reimplement these computations directly using JAX, but it can also be useful for optimizing runtime or memory performance of JAX programs.\n", + "That being said, the FFI should typically be considered a last resort option because the XLA compiler that sits in the backend, or the Pallas kernel language, which provides lower level control, typically produce performant code with a lower development and maintenance cost.\n", + "\n", + "One point that should be taken into account when considering use of the FFI is that _JAX doesn't automatically know how to differentiate through foreign functions_.\n", + "This means that if you want to use JAX's autodifferentiation capabilities alongside a foreign function, you'll also need to provide an implementation of the relevant differentiation rules.\n", + "We will discuss some possible approaches below, but it is important to call this limitation out right from the start!\n", + "\n", + "JAX's FFI support is provided in two parts:\n", + "\n", + "1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n", + "2. A Python front end, available in the `jax.extend.ffi` submodule.\n", + "\n", + "In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n", + "We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n", + "\n", + "This tutorial comes with two supplementary files:\n", + "\n", + "* [`rms_norm.cc`](rms_norm.cc), which includes all the backend code, and\n", + "* [`CMakeLists.txt`](CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.\n", + "\n", + "## A simple example\n", + "\n", + "To demonstrate the use of the FFI interface, we will implement a simple \"root-mean-square (RMS)\" normalization function.\n", + "RMS normalization takes an array $x$ with shape $(N,)$ and returns\n", + "\n", + "$$\n", + "y_n = \\frac{x_n}{\\sqrt{\\frac{1}{N}\\sum_{n=1}^N {x_n}^2 + \\epsilon}}\n", + "$$\n", + "\n", + "where $\\epsilon$ is a tuning parameter used for numerical stability.\n", + "\n", + "This is a somewhat silly example, because it can be easily implemented using JAX as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "def rms_norm_ref(x, eps=1e-5):\n", + " scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)\n", + " return x / scale" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But, it's just non-trivial enough to be useful for demonstrating some key details of the FFI, while still being straightforward to understand.\n", + "We will use this reference implementation to test our FFI version below.\n", + "\n", + "## Backend code\n", + "\n", + "To begin with, we need an implementation of RMS normalization in C++ that we will expose using the FFI.\n", + "This isn't meant to be particularly performant, but you could imagine that if you had some new better implementation of RMS normalization in a C++ library, it might have an interface like the following.\n", + "So, here's a simple implementation of RMS normalization in C++:\n", + "\n", + "```c++\n", + "#include \n", + "#include \n", + "\n", + "float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {\n", + " float sm = 0.0f;\n", + " for (int64_t n = 0; n < size; ++n) {\n", + " sm += x[n] * x[n];\n", + " }\n", + " float scale = 1.0f / std::sqrt(sm / float(size) + eps);\n", + " for (int64_t n = 0; n < size; ++n) {\n", + " y[n] = x[n] * scale;\n", + " }\n", + " return scale;\n", + "}\n", + "```\n", + "\n", + "and, for our example, this is the function that we want to expose to JAX via the FFI.\n", + "\n", + "### C++ interface\n", + "\n", + "To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n", + "For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n", + "The full source listing can be downloaded [here](rms_norm.cc), but the key implementation details are reproduced here:\n", + "\n", + "```c++\n", + "#include \n", + "#include \n", + "#include \n", + "\n", + "#include \"xla/ffi/api/c_api.h\"\n", + "#include \"xla/ffi/api/ffi.h\"\n", + "\n", + "namespace ffi = xla::ffi;\n", + "\n", + "// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.\n", + "// In this example, we treat all leading dimensions as batch dimensions, so this\n", + "// function returns the total number of elements in the buffer, and the size of\n", + "// the last dimension.\n", + "template \n", + "std::pair GetDims(const ffi::Buffer &buffer) {\n", + " auto dims = buffer.dimensions();\n", + " if (dims.size() == 0) {\n", + " return std::make_pair(0, 0);\n", + " }\n", + " return std::make_pair(buffer.element_count(), dims.back());\n", + "}\n", + "\n", + "// A wrapper function providing the interface between the XLA FFI call and our\n", + "// library function `ComputeRmsNorm` above. This function handles the batch\n", + "// dimensions by calling `ComputeRmsNorm` within a loop.\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::Result> y) {\n", + " auto [totalSize, lastDim] = GetDims(x);\n", + " if (lastDim == 0) {\n", + " return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n", + " \"RmsNorm input must be an array\");\n", + " }\n", + " for (int64_t n = 0; n < totalSize; n += lastDim) {\n", + " ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));\n", + " }\n", + " return ffi::Error::Success();\n", + "}\n", + "\n", + "// Wrap `RmsNormImpl` and specify the interface to XLA.\n", + "XLA_FFI_DEFINE_HANDLER(\n", + " RmsNorm, RmsNormImpl,\n", + " ffi::Ffi::Bind()\n", + " .Attr(\"eps\")\n", + " .Arg>(/* x */)\n", + " .Ret>(/* y */));\n", + "```\n", + "\n", + "Starting at the bottom, we're using the XLA-provided macro `XLA_FFI_DEFINE_HANDLER` to generate some boilerplate which will expand into a function called `RmsNorm` with the appropriate signature.\n", + "But, the important stuff here is all in the call to `ffi::Ffi::Bind()`, where we define the input and output types, and the types of any parameters.\n", + "\n", + "Then, in `RmsNormImpl`, we accept `ffi::Buffer` arguments which include information about the buffer shape, and pointers to the underlying data.\n", + "In this implementation, we treat all leading dimensions of the buffer as batch dimensions, and perform RMS normalization over the last axis.\n", + "`GetDims` is a helper function providing support for this batching behavior.\n", + "\n", + "### Building and registering an FFI handler\n", + "\n", + "Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n", + "In this tutorial, we use [nanobind](https://nanobind.readthedocs.io) to define a tiny Python module, but it is also possible to compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), a pattern which we discuss below.\n", + "\n", + "For this example, our nanobind module is defined as follows:\n", + "\n", + "```c++\n", + "#include \n", + "\n", + "#include \"nanobind/nanobind.h\"\n", + "#include \"xla/ffi/api/c_api.h\"\n", + "\n", + "namespace nb = nanobind;\n", + "\n", + "template \n", + "nb::capsule EncapsulateFfiCall(T *fn) {\n", + " static_assert(std::is_invocable_r_v,\n", + " \"Encapsulated function must be and XLA FFI handler\");\n", + " return nb::capsule(reinterpret_cast(fn), \"xla._CUSTOM_CALL_TARGET\");\n", + "}\n", + "\n", + "NB_MODULE(rms_norm, m) {\n", + " m.def(\"rms_norm\", []() { return EncapsulateFfiCall(RmsNorm); });\n", + "}\n", + "```\n", + "\n", + "With this in place, we can now compile our library.\n", + "Here we use CMake, but you should be able to use your favorite build system without too much trouble.\n", + "The full `CMakeLists.txt` can be downloaded [here](CMakeLists.txt)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "!cmake -DCMAKE_BUILD_TYPE=Release -B _build .\n", + "!cmake --build _build\n", + "!cmake --install _build" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.\n", + "In our nanobind module above, we implemented a Python function called `rms_norm` which returns a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html) containing a function pointer to the C++ function `RmsNorm`.\n", + "This \"capsule\" is the object that we pass to {func}`~jax.extend.ffi.register_ffi_target`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.extend as jex\n", + "import rms_norm as rms_norm_lib\n", + "\n", + "jex.ffi.register_ffi_target(\"rms_norm\", {\"execute\": rms_norm_lib.rms_norm()}, platform=\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{tip}\n", + "If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n", + "```\n", + "\n", + "## Frontend code\n", + "\n", + "Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "\n", + "def rms_norm(x, eps=1e-5):\n", + " # We only implemented the `float32` version of this function, so we start by\n", + " # checking the dtype. This check isn't strictly necessary because type\n", + " # checking is also performed by the FFI when decoding input and output\n", + " # buffers, but it can be useful to check types in Python to raise more\n", + " # informative errors.\n", + " if x.dtype != jnp.float32:\n", + " raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n", + "\n", + " # In this case, the output of our FFI function is just a single array with the\n", + " # same shape and dtype as the input. We discuss a case with a more interesting\n", + " # output type below.\n", + " out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + "\n", + " return jex.ffi.ffi_call(\n", + " # The target name must be the same string as we used to register the target\n", + " # above in `register_custom_call_target`\n", + " \"rms_norm\",\n", + " out_type,\n", + " x,\n", + " # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n", + " # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n", + " # type (which corresponds to numpy's `float32` type), and it must be a\n", + " # static parameter (i.e. not a JAX array).\n", + " eps=np.float32(eps),\n", + " # The `vectorized` parameter controls this function's behavior under `vmap`\n", + " # as discussed below.\n", + " vectorized=True,\n", + " )\n", + "\n", + "\n", + "# Test that this gives the same result as our reference implementation\n", + "x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))\n", + "np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n", + "Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n", + "It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n", + "\n", + "Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n", + "Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n", + "\n", + "The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", + "\n", + "```{tip}\n", + "If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n", + "In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.\n", + "One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n", + "```\n", + "\n", + "### Batching with `vmap`\n", + "\n", + "All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n", + "By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", + "This default implementation is general purpose, but it doesn't parallelize very well.\n", + "But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n", + "\n", + "The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n", + "Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n", + "\n", + "```python\n", + "ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n", + "```\n", + "\n", + "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jax.make_jaxpr(jax.vmap(rms_norm))(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rms_norm_not_vectorized(x, eps=1e-5):\n", + " return jex.ffi.ffi_call(\n", + " \"rms_norm\",\n", + " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + " x,\n", + " eps=np.float32(eps),\n", + " vectorized=False, # This is the default behavior\n", + " )\n", + "\n", + "\n", + "jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Differentiation\n", + "\n", + "Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n", + "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", + "Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", + "\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "In this case, we actually define two new FFI calls:\n", + "\n", + "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", + "2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n", + "\n", + "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](rms_norm.cc) to see how these functions are implemented on the back end.\n", + "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", + "\n", + "This custom derivative rule can be wired in as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jex.ffi.register_ffi_target(\n", + " \"rms_norm_fwd\", rms_norm_lib.rms_norm_fwd(), platform=\"cpu\"\n", + ")\n", + "jex.ffi.register_ffi_target(\n", + " \"rms_norm_bwd\", rms_norm_lib.rms_norm_bwd(), platform=\"cpu\"\n", + ")\n", + "\n", + "\n", + "def rms_norm_fwd(x, eps=1e-5):\n", + " y, res = jex.ffi.ffi_call(\n", + " \"rms_norm_fwd\",\n", + " (\n", + " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + " jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n", + " ),\n", + " x,\n", + " eps=np.float32(eps),\n", + " vectorized=True,\n", + " )\n", + " return y, (res, x)\n", + "\n", + "\n", + "def rms_norm_bwd(eps, res, ct):\n", + " del eps\n", + " res, x = res\n", + " assert res.shape == ct.shape[:-1]\n", + " assert x.shape == ct.shape\n", + " return (\n", + " jex.ffi.ffi_call(\n", + " \"rms_norm_bwd\",\n", + " jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n", + " res,\n", + " x,\n", + " ct,\n", + " vectorized=True,\n", + " ),\n", + " )\n", + "\n", + "\n", + "rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))\n", + "rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)\n", + "\n", + "# Check that this gives the right answer when compared to the reference version\n", + "ct_y = jnp.ones_like(x)\n", + "np.testing.assert_allclose(\n", + " jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n", + "One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n", + "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n", + "\n", + "One other JAX feature that this example doesn't support is higher-order AD.\n", + "It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n", + "\n", + "## FFI calls on a GPU\n", + "\n", + "So far, we have been interfacing only with foreign functions running on the CPU, but JAX's FFI also supports calls to GPU code.\n", + "Since this documentation page is automatically generated on a machine without access to a GPU, we can't execute any GPU-specific examples here, but we will go over the key points.\n", + "\n", + "When defining our FFI wrapper for CPU, the function signature that we used was:\n", + "\n", + "```c++\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::Result> y)\n", + "```\n", + "\n", + "To update this to interface with a CUDA kernel, this signature becomes:\n", + "\n", + "```c++\n", + "ffi::Error RmsNormImpl(cudaStream_t stream, float eps,\n", + " ffi::Buffer x,\n", + " ffi::Result> y)\n", + "```\n", + "\n", + "And the handler definition is updated to include a `Ctx` in its binding:\n", + "\n", + "```c++\n", + "XLA_FFI_DEFINE_HANDLER(\n", + " RmsNorm, RmsNormImpl,\n", + " ffi::Ffi::Bind()\n", + " .Ctx>()\n", + " .Attr(\"eps\")\n", + " .Arg>(/* x */)\n", + " .Ret>(/* y */));\n", + "```\n", + "\n", + "Then, the `RmsNormImpl` can use the CUDA stream to launch CUDA kernels.\n", + "\n", + "On the front end, the registration code would be updated to specify the appropriate platform:\n", + "\n", + "```python\n", + "jex.ffi.register_ffi_target(\n", + " \"rms_norm_cuda\", rms_norm_lib_cuda.rms_norm(), platform=\"CUDA\"\n", + ")\n", + "```\n", + "\n", + "### Supporting multiple platforms\n", + "\n", + "To support running our `rms_norm` function on both GPU and CPU, we can combine our implementation above with the {func}`jax.lax.platform_dependent` function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rms_norm_cross_platform(x, eps=1e-5):\n", + " assert x.dtype == jnp.float32\n", + " out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + "\n", + " def impl(target_name):\n", + " return lambda x: jex.ffi.ffi_call(\n", + " target_name,\n", + " out_type,\n", + " x,\n", + " eps=np.float32(eps),\n", + " vectorized=True,\n", + " )\n", + "\n", + " return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n", + "\n", + "\n", + "np.testing.assert_allclose(rms_norm_cross_platform(x), rms_norm_ref(x), rtol=1e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This version of the function will call the appropriate FFI target depending on the runtime platform.\n", + "\n", + "As an aside, it may be interesting to note that while the jaxpr and lowered HLO both contain a reference to both FFI targets:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jax.make_jaxpr(rms_norm_cross_platform)(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(jax.jit(rms_norm_cross_platform).lower(x).as_text().strip())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "by the time the function is compiled, the appropriate FFI has been selected:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect=\"hlo\").strip())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "and there will be no runtime overhead to using {func}`jax.lax.platform_dependent`, and the compiled program won't include any references to unavailable FFI targets.\n", + "\n", + "## Advanced topics\n", + "\n", + "This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.\n", + "We will leave these topics to future tutorials, but here are some possibly useful references:\n", + "\n", + "* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n", + "\n", + "* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/ffi/ffi.md b/docs/ffi/ffi.md new file mode 100644 index 000000000000..6c6dbf9cf128 --- /dev/null +++ b/docs/ffi/ffi.md @@ -0,0 +1,476 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JAX's foreign function interface + +While a wide range of numerical operations can be easily and efficiently implemented using JAX's built in `jax.numpy` and `jax.lax` interfaces, it can sometimes be useful to explicitly call out to external compiled libraries via a "foreign function interface" (FFI). +This can be particularly useful when particular operations have been previously implemented in an optimized C or CUDA library, and it would be non-trivial to reimplement these computations directly using JAX, but it can also be useful for optimizing runtime or memory performance of JAX programs. +That being said, the FFI should typically be considered a last resort option because the XLA compiler that sits in the backend, or the Pallas kernel language, which provides lower level control, typically produce performant code with a lower development and maintenance cost. + +One point that should be taken into account when considering use of the FFI is that _JAX doesn't automatically know how to differentiate through foreign functions_. +This means that if you want to use JAX's autodifferentiation capabilities alongside a foreign function, you'll also need to provide an implementation of the relevant differentiation rules. +We will discuss some possible approaches below, but it is important to call this limitation out right from the start! + +JAX's FFI support is provided in two parts: + +1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and +2. A Python front end, available in the `jax.extend.ffi` submodule. + +In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. +We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below. + +This tutorial comes with two supplementary files: + +* [`rms_norm.cc`](rms_norm.cc), which includes all the backend code, and +* [`CMakeLists.txt`](CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code. + +## A simple example + +To demonstrate the use of the FFI interface, we will implement a simple "root-mean-square (RMS)" normalization function. +RMS normalization takes an array $x$ with shape $(N,)$ and returns + +$$ +y_n = \frac{x_n}{\sqrt{\frac{1}{N}\sum_{n=1}^N {x_n}^2 + \epsilon}} +$$ + +where $\epsilon$ is a tuning parameter used for numerical stability. + +This is a somewhat silly example, because it can be easily implemented using JAX as follows: + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp + + +def rms_norm_ref(x, eps=1e-5): + scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps) + return x / scale +``` + +But, it's just non-trivial enough to be useful for demonstrating some key details of the FFI, while still being straightforward to understand. +We will use this reference implementation to test our FFI version below. + +## Backend code + +To begin with, we need an implementation of RMS normalization in C++ that we will expose using the FFI. +This isn't meant to be particularly performant, but you could imagine that if you had some new better implementation of RMS normalization in a C++ library, it might have an interface like the following. +So, here's a simple implementation of RMS normalization in C++: + +```c++ +#include +#include + +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} +``` + +and, for our example, this is the function that we want to expose to JAX via the FFI. + +### C++ interface + +To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla). +For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call). +The full source listing can be downloaded [here](rms_norm.cc), but the key implementation details are reproduced here: + +```c++ +#include +#include +#include + +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +// A helper function for extracting the relevant dimensions from `ffi::Buffer`s. +// In this example, we treat all leading dimensions as batch dimensions, so this +// function returns the total number of elements in the buffer, and the size of +// the last dimension. +template +std::pair GetDims(const ffi::Buffer &buffer) { + auto dims = buffer.dimensions(); + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + return std::make_pair(buffer.element_count(), dims.back()); +} + +// A wrapper function providing the interface between the XLA FFI call and our +// library function `ComputeRmsNorm` above. This function handles the batch +// dimensions by calling `ComputeRmsNorm` within a loop. +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +// Wrap `RmsNormImpl` and specify the interface to XLA. +XLA_FFI_DEFINE_HANDLER( + RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */)); +``` + +Starting at the bottom, we're using the XLA-provided macro `XLA_FFI_DEFINE_HANDLER` to generate some boilerplate which will expand into a function called `RmsNorm` with the appropriate signature. +But, the important stuff here is all in the call to `ffi::Ffi::Bind()`, where we define the input and output types, and the types of any parameters. + +Then, in `RmsNormImpl`, we accept `ffi::Buffer` arguments which include information about the buffer shape, and pointers to the underlying data. +In this implementation, we treat all leading dimensions of the buffer as batch dimensions, and perform RMS normalization over the last axis. +`GetDims` is a helper function providing support for this batching behavior. + +### Building and registering an FFI handler + +Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python. +In this tutorial, we use [nanobind](https://nanobind.readthedocs.io) to define a tiny Python module, but it is also possible to compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), a pattern which we discuss below. + +For this example, our nanobind module is defined as follows: + +```c++ +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" + +namespace nb = nanobind; + +template +nb::capsule EncapsulateFfiCall(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +NB_MODULE(rms_norm, m) { + m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); }); +} +``` + +With this in place, we can now compile our library. +Here we use CMake, but you should be able to use your favorite build system without too much trouble. +The full `CMakeLists.txt` can be downloaded [here](CMakeLists.txt). + +```{code-cell} ipython3 +:tags: [hide-output] + +!cmake -DCMAKE_BUILD_TYPE=Release -B _build . +!cmake --build _build +!cmake --install _build +``` + +With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function. +In our nanobind module above, we implemented a Python function called `rms_norm` which returns a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html) containing a function pointer to the C++ function `RmsNorm`. +This "capsule" is the object that we pass to {func}`~jax.extend.ffi.register_ffi_target`: + +```{code-cell} ipython3 +import jax.extend as jex +import rms_norm as rms_norm_lib + +jex.ffi.register_ffi_target("rms_norm", {"execute": rms_norm_lib.rms_norm()}, platform="cpu") +``` + +```{tip} +If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here. +``` + +## Frontend code + +Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function: + +```{code-cell} ipython3 +import numpy as np + + +def rms_norm(x, eps=1e-5): + # We only implemented the `float32` version of this function, so we start by + # checking the dtype. This check isn't strictly necessary because type + # checking is also performed by the FFI when decoding input and output + # buffers, but it can be useful to check types in Python to raise more + # informative errors. + if x.dtype != jnp.float32: + raise ValueError("Only the float32 dtype is implemented by rms_norm") + + # In this case, the output of our FFI function is just a single array with the + # same shape and dtype as the input. We discuss a case with a more interesting + # output type below. + out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + + return jex.ffi.ffi_call( + # The target name must be the same string as we used to register the target + # above in `register_custom_call_target` + "rms_norm", + out_type, + x, + # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for + # the attribute `eps`. Our FFI function expects this to have the C++ `float` + # type (which corresponds to numpy's `float32` type), and it must be a + # static parameter (i.e. not a JAX array). + eps=np.float32(eps), + # The `vectorized` parameter controls this function's behavior under `vmap` + # as discussed below. + vectorized=True, + ) + + +# Test that this gives the same result as our reference implementation +x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) +np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5) +``` + +This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting. +Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs. +It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above. + +Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`. +Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments. + +The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. + +```{tip} +If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`. +In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering. +One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below. +``` + +### Batching with `vmap` + +All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient. +By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body. +This default implementation is general purpose, but it doesn't parallelize very well. +But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation. + +The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes. +Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly: + +```python +ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs]) +``` + +Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box: + +```{code-cell} ipython3 +np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5) +``` + +We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: + +```{code-cell} ipython3 +jax.make_jaxpr(jax.vmap(rms_norm))(x) +``` + +If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body: + +```{code-cell} ipython3 +def rms_norm_not_vectorized(x, eps=1e-5): + return jex.ffi.ffi_call( + "rms_norm", + jax.ShapeDtypeStruct(x.shape, x.dtype), + x, + eps=np.float32(eps), + vectorized=False, # This is the default behavior + ) + + +jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) +``` + +If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues). + ++++ + +### Differentiation + +Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions. +As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. +Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule. + +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +In this case, we actually define two new FFI calls: + +1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. +2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents. + +We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](rms_norm.cc) to see how these functions are implemented on the back end. +The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. + +This custom derivative rule can be wired in as follows: + +```{code-cell} ipython3 +jex.ffi.register_ffi_target( + "rms_norm_fwd", rms_norm_lib.rms_norm_fwd(), platform="cpu" +) +jex.ffi.register_ffi_target( + "rms_norm_bwd", rms_norm_lib.rms_norm_bwd(), platform="cpu" +) + + +def rms_norm_fwd(x, eps=1e-5): + y, res = jex.ffi.ffi_call( + "rms_norm_fwd", + ( + jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), + ), + x, + eps=np.float32(eps), + vectorized=True, + ) + return y, (res, x) + + +def rms_norm_bwd(eps, res, ct): + del eps + res, x = res + assert res.shape == ct.shape[:-1] + assert x.shape == ct.shape + return ( + jex.ffi.ffi_call( + "rms_norm_bwd", + jax.ShapeDtypeStruct(ct.shape, ct.dtype), + res, + x, + ct, + vectorized=True, + ), + ) + + +rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,)) +rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) + +# Check that this gives the right answer when compared to the reference version +ct_y = jnp.ones_like(x) +np.testing.assert_allclose( + jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5 +) +``` + +At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`. +One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode. +JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice. + +One other JAX feature that this example doesn't support is higher-order AD. +It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here. + +## FFI calls on a GPU + +So far, we have been interfacing only with foreign functions running on the CPU, but JAX's FFI also supports calls to GPU code. +Since this documentation page is automatically generated on a machine without access to a GPU, we can't execute any GPU-specific examples here, but we will go over the key points. + +When defining our FFI wrapper for CPU, the function signature that we used was: + +```c++ +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) +``` + +To update this to interface with a CUDA kernel, this signature becomes: + +```c++ +ffi::Error RmsNormImpl(cudaStream_t stream, float eps, + ffi::Buffer x, + ffi::Result> y) +``` + +And the handler definition is updated to include a `Ctx` in its binding: + +```c++ +XLA_FFI_DEFINE_HANDLER( + RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Ctx>() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */)); +``` + +Then, the `RmsNormImpl` can use the CUDA stream to launch CUDA kernels. + +On the front end, the registration code would be updated to specify the appropriate platform: + +```python +jex.ffi.register_ffi_target( + "rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA" +) +``` + +### Supporting multiple platforms + +To support running our `rms_norm` function on both GPU and CPU, we can combine our implementation above with the {func}`jax.lax.platform_dependent` function: + +```{code-cell} ipython3 +def rms_norm_cross_platform(x, eps=1e-5): + assert x.dtype == jnp.float32 + out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + + def impl(target_name): + return lambda x: jex.ffi.ffi_call( + target_name, + out_type, + x, + eps=np.float32(eps), + vectorized=True, + ) + + return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda")) + + +np.testing.assert_allclose(rms_norm_cross_platform(x), rms_norm_ref(x), rtol=1e-5) +``` + +This version of the function will call the appropriate FFI target depending on the runtime platform. + +As an aside, it may be interesting to note that while the jaxpr and lowered HLO both contain a reference to both FFI targets: + +```{code-cell} ipython3 +jax.make_jaxpr(rms_norm_cross_platform)(x) +``` + +```{code-cell} ipython3 +print(jax.jit(rms_norm_cross_platform).lower(x).as_text().strip()) +``` + +by the time the function is compiled, the appropriate FFI has been selected: + +```{code-cell} ipython3 +print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect="hlo").strip()) +``` + +and there will be no runtime overhead to using {func}`jax.lax.platform_dependent`, and the compiled program won't include any references to unavailable FFI targets. + +## Advanced topics + +This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features. +We will leave these topics to future tutorials, but here are some possibly useful references: + +* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend. + +* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`. diff --git a/docs/ffi/rms_norm.cc b/docs/ffi/rms_norm.cc new file mode 100644 index 000000000000..a49726d35a4e --- /dev/null +++ b/docs/ffi/rms_norm.cc @@ -0,0 +1,151 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +// This is the example "library function" that we want to expose to JAX. This +// isn't meant to be a particularly good implementation, it's just here as a +// placeholder for the purposes of this tutorial. +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} + +// A helper function for extracting the relevant dimensions from `ffi::Buffer`s. +// In this example, we treat all leading dimensions as batch dimensions, so this +// function returns the total number of elements in the buffer, and the size of +// the last dimension. +template +std::pair GetDims(const ffi::Buffer &buffer) { + auto dims = buffer.dimensions(); + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + return std::make_pair(buffer.element_count(), dims.back()); +} + +// A wrapper function providing the interface between the XLA FFI call and our +// library function `ComputeRmsNorm` above. This function handles the batch +// dimensions by calling `ComputeRmsNorm` within a loop. +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +// Wrap `RmsNormImpl` and specify the interface to XLA. +XLA_FFI_DEFINE_HANDLER(RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */)); + +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::Result> y, + ffi::Result> res) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormFwd input must be an array"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), + &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */) + .Ret>(/* res */)); + +void ComputeRmsNormBwd(int64_t size, float res, const float *x, + const float *ct_y, float *ct_x) { + float ct_res = 0.0f; + for (int64_t n = 0; n < size; ++n) { + ct_res += x[n] * ct_y[n]; + } + float factor = ct_res * res * res * res / float(size); + for (int64_t n = 0; n < size; ++n) { + ct_x[n] = res * ct_y[n] - factor * x[n]; + } +} + +ffi::Error RmsNormBwdImpl(ffi::Buffer res, + ffi::Buffer x, + ffi::Buffer ct_y, + ffi::Result> ct_x) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormBwd inputs must be arrays"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), + &(ct_y.typed_data()[n]), &(ct_x->typed_data()[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>(/* res */) + .Arg>(/* x */) + .Arg>(/* ct_y */) + .Ret>(/* ct_x */)); + +template +nb::capsule EncapsulateFfiCall(T *fn) { + // This check is optional, but it can be useful to catch invalid function + // pointers at compile time. + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn)); +} + +NB_MODULE(rms_norm, m) { + m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); }); + m.def("rms_norm_fwd", []() { return EncapsulateFfiCall(RmsNormFwd); }); + m.def("rms_norm_bwd", []() { return EncapsulateFfiCall(RmsNormBwd); }); +} diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 32db1ba77dea..877ca231567b 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -120,6 +120,7 @@ Operators neg nextafter pad + platform_dependent polygamma population_count pow diff --git a/docs/requirements.txt b/docs/requirements.txt index 643d4086d8be..7bf7a5350a33 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -18,4 +18,6 @@ matplotlib scikit-learn numpy rich[jupyter] +cmake +nanobind .[ci] # Install jax from the current directory; jaxlib from pypi.