diff --git a/docs/_tutorials/index.rst b/docs/_tutorials/index.rst index d261612a4cd4..5fae0d5838d0 100644 --- a/docs/_tutorials/index.rst +++ b/docs/_tutorials/index.rst @@ -7,7 +7,7 @@ JAX tutorials draft .. note:: - This is a + This is a The tutorials below are a work in progress; for the time being, please refer to the older tutorial content, including :ref:`beginner-guide`, :ref:`user-guides`, and the now-deleted *JAX 101* tutorials. @@ -44,7 +44,7 @@ JAX 201 advanced-debugging external-callbacks profiling-and-performance - + ../ffi/ffi.ipynb JAX 301 ------- diff --git a/docs/ffi/.gitignore b/docs/ffi/.gitignore new file mode 100644 index 000000000000..f6608699925e --- /dev/null +++ b/docs/ffi/.gitignore @@ -0,0 +1,5 @@ +CMake* +cmake* +Makefile +*.so +*.dylib diff --git a/docs/ffi/CMakeLists.txt b/docs/ffi/CMakeLists.txt new file mode 100644 index 000000000000..c219df46876f --- /dev/null +++ b/docs/ffi/CMakeLists.txt @@ -0,0 +1,20 @@ +cmake_minimum_required(VERSION 3.18...3.27) +project(rms_norm LANGUAGES CXX) +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) +list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") +find_package(nanobind CONFIG REQUIRED) + +execute_process( + COMMAND "${Python_EXECUTABLE}" + "-c" "from jax.extend import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) +message(STATUS "XLA include directory: ${XLA_DIR}") + +nanobind_add_module(rms_norm NOMINSIZE "rms_norm.cc") +target_include_directories(rms_norm PUBLIC ${XLA_DIR}) +install(TARGETS rms_norm LIBRARY DESTINATION ${CMAKE_CURRENT_LIST_DIR}) + diff --git a/docs/ffi/ffi.ipynb b/docs/ffi/ffi.ipynb new file mode 100644 index 000000000000..4d09b71a7a4d --- /dev/null +++ b/docs/ffi/ffi.ipynb @@ -0,0 +1,468 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# JAX's foreign function interface\n", + "\n", + "While a wide range of numerical operations can be easily and efficiently implemented using JAX's built in `jax.numpy` and `jax.lax` interfaces, it can sometimes be useful to explicitly call out to external compiled libraries via a \"foreign function interface\" (FFI).\n", + "This can be particularly useful when particular operations have been previously implemented in an optimized C or CUDA library, and it would be non-trivial to reimplement these computations directly using JAX, but it can also be useful for optimizing runtime or memory performance of JAX programs.\n", + "That being said, the FFI should typically be considered a last resort option because the XLA compiler that sits in the backend, or the Pallas kernel language, which provides lower level control, typically produce performant code with a lower development and maintenance cost.\n", + "\n", + "One point that should be taken into account when considering use of the FFI is that _JAX doesn't automatically know how to differentiate through foreign functions_.\n", + "This means that if you want to use JAX's autodifferentiation capabilities alongside a foreign function, you'll also need to provide an implementation of the relevant differentiation rules.\n", + "We will discuss some possible approaches below, but it is important to call this limitation out right from the start!\n", + "\n", + "JAX's FFI support is provided in two parts:\n", + "\n", + "1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n", + "2. A Python front end, available in the `jax.extend.ffi` submodule.\n", + "\n", + "In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n", + "We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n", + "\n", + "## A simple example\n", + "\n", + "To demonstrate the use of the FFI interface, we will implement a simple \"root-mean-square (RMS)\" normalization function.\n", + "RMS normalization takes an array $x$ with shape $(N,)$ and returns\n", + "\n", + "$$\n", + "y_n = \\frac{x_n}{\\sqrt{\\frac{1}{N}\\sum_{n=1}^N {x_n}^2 + \\epsilon}}\n", + "$$\n", + "\n", + "where $\\epsilon$ is a tuning parameter used for numerical stability.\n", + "\n", + "This is a somewhat silly example, because it can be easily implemented using JAX as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "def rms_norm_ref(x, eps=1e-5):\n", + " scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)\n", + " return x / scale" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But, it's just non-trivial enough to be useful for demonstrating some key details of the FFI, while still being straightforward to understand.\n", + "We will use this reference implementation to test our FFI version below.\n", + "\n", + "## Backend code\n", + "\n", + "To begin with, we need an implementation of RMS normalization in C++ that we will expose using the FFI.\n", + "This isn't meant to be particularly performant, but you could imagine that if you had some new better implementation of RMS normalization in a C++ library, it might have an interface like the following.\n", + "So, here's a simple implementation of RMS normalization in C++:\n", + "\n", + "```c++\n", + "#include \n", + "#include \n", + "\n", + "float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {\n", + " float sm = 0.0f;\n", + " for (int64_t n = 0; n < size; ++n) {\n", + " sm += x[n] * x[n];\n", + " }\n", + " float scale = 1.0f / std::sqrt(sm / float(size) + eps);\n", + " for (int64_t n = 0; n < size; ++n) {\n", + " y[n] = x[n] * scale;\n", + " }\n", + " return scale;\n", + "}\n", + "```\n", + "\n", + "and, for our example, this is the function that we want to expose to JAX via the FFI.\n", + "\n", + "### C++ interface\n", + "\n", + "To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n", + "For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call), but for our purposes, it's sufficient to know that we can define our implementation as follows:\n", + "\n", + "```c++\n", + "#include \n", + "#include \n", + "#include \n", + "\n", + "#include \"xla/ffi/api/c_api.h\"\n", + "#include \"xla/ffi/api/ffi.h\"\n", + "\n", + "namespace ffi = xla::ffi;\n", + "\n", + "std::pair GetDims(ffi::Span dims) {\n", + " if (dims.size() == 0) {\n", + " return std::make_pair(0, 0);\n", + " }\n", + " int64_t totalSize =\n", + " std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());\n", + " int64_t lastDim = dims.back();\n", + " return std::make_pair(totalSize, lastDim);\n", + "}\n", + "\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::Result> y) {\n", + " auto [totalSize, lastDim] = GetDims(x.dimensions);\n", + " if (lastDim == 0) {\n", + " return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n", + " \"RmsNorm input must be an array\");\n", + " }\n", + " for (int64_t n = 0; n < totalSize; n += lastDim) {\n", + " ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n]));\n", + " }\n", + " return ffi::Error::Success();\n", + "}\n", + "\n", + "XLA_FFI_DEFINE_HANDLER(\n", + " RmsNorm, RmsNormImpl,\n", + " ffi::Ffi::Bind()\n", + " .Attr(\"eps\")\n", + " .Arg>(/* x */)\n", + " .Ret>(/* y */));\n", + "```\n", + "\n", + "Starting at the bottom, we're using the XLA-provided macro `XLA_FFI_DEFINE_HANDLER` to generate some boilerplate which will expand into a function called `RmsNorm` with the appropriate signature.\n", + "But, the important stuff here is all in the call to `ffi::Ffi::Bind()`, where we define the input and output types, and the types of any parameters.\n", + "\n", + "Then, in `RmsNormImpl`, we accept `ffi::Buffer` arguments which include information about the buffer shape, and pointers to the underlying data.\n", + "In this implementation, we treat all leading dimensions of the buffer as batch dimensions, and perform RMS normalization over the last axis.\n", + "`GetDims` is a helper function providing support for this batching behavior.\n", + "\n", + "### Building and registering an FFI handler\n", + "\n", + "Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n", + "In this tutorial, we use [nanobind](https://nanobind.readthedocs.io) to define a tiny Python module, but it is also possible to compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), a pattern which we discuss below.\n", + "\n", + "For this example, our nanobind module is defined as follows:\n", + "\n", + "```c++\n", + "#include \n", + "\n", + "#include \"nanobind/nanobind.h\"\n", + "#include \"xla/ffi/api/c_api.h\"\n", + "\n", + "namespace nb = nanobind;\n", + "\n", + "template \n", + "nb::capsule EncapsulateFfiCall(T *fn) {\n", + " static_assert(std::is_invocable_r_v,\n", + " \"Encapsulated function must be and XLA FFI handler\");\n", + " return nb::capsule(reinterpret_cast(fn), \"xla._CUSTOM_CALL_TARGET\");\n", + "}\n", + "\n", + "NB_MODULE(rms_norm, m) {\n", + " m.def(\"rms_norm\", []() { return EncapsulateFfiCall(RmsNorm); });\n", + "}\n", + "```\n", + "\n", + "With this in place, we can now compile our library.\n", + "Here we use CMake, but you should be able to use your favorite build system without too much trouble." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "!cmake -DCMAKE_BUILD_TYPE=Release -B _build .\n", + "!cmake --build _build\n", + "!cmake --install _build" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.lib.xla_client.register_custom_call_target` function.\n", + "In our nanobind module above, we implemented a Python function called `rms_norm` which returns a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html) containing a function pointer to the C++ function `RmsNorm`.\n", + "This \"capsule\" is the object that we pass to `register_custom_call_target`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax.lib import xla_client\n", + "import rms_norm as rms_norm_lib\n", + "\n", + "xla_client.register_custom_call_target(\n", + " \"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\", api_version=1\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We note that the `api_version=1` keyword in the call to `register_custom_call_target` is important because it indicates that our function pointer implements this FFI interface.\n", + "\n", + "## Frontend code\n", + "\n", + "Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import jax.extend as jex\n", + "\n", + "\n", + "def rms_norm(x, eps=1e-5):\n", + " # We only implemented the `float32` version of this function, so we start by\n", + " # checking the dtype. This check isn't strictly necessary because type\n", + " # checking is also performed by the FFI when decoding input and output\n", + " # buffers, but it can be useful to check types in Python to raise more\n", + " # informative errors.\n", + " if x.dtype != jnp.float32:\n", + " raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n", + "\n", + " # In this case, the output of our FFI function is just a single array with the\n", + " # same shape and dtype as the input. We discuss a case with a more interesting\n", + " # output type below.\n", + " out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + "\n", + " return jex.ffi.ffi_call(\n", + " # The target name must be the same string as we used to register the target\n", + " # above in `register_custom_call_target`\n", + " \"rms_norm\",\n", + " out_type,\n", + " x,\n", + " # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n", + " # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n", + " # type (which corresponds to numpy's `float32` type), and it must be a\n", + " # static parameter (i.e. not a JAX array).\n", + " eps=np.float32(eps),\n", + " # The `vectorized` parameter controls this function's behavior under `vmap`\n", + " # as discussed below.\n", + " vectorized=True,\n", + " )\n", + "\n", + "\n", + "# Test that this gives the same result as our reference implementation\n", + "x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))\n", + "np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n", + "Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n", + "It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n", + "\n", + "Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n", + "Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n", + "\n", + "The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", + "\n", + "### Batching with `vmap`\n", + "\n", + "All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n", + "By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", + "This default implementation is general purpose, but it doesn't parallelize very well.\n", + "But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n", + "\n", + "The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n", + "Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n", + "\n", + "```python\n", + "ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n", + "```\n", + "\n", + "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jax.make_jaxpr(jax.vmap(rms_norm))(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rms_norm_not_vectorized(x, eps=1e-5):\n", + " return jex.ffi.ffi_call(\n", + " \"rms_norm\",\n", + " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + " x,\n", + " eps=np.float32(eps),\n", + " vectorized=False, # This is the default behavior\n", + " )\n", + "\n", + "\n", + "jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it is possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Differentiation\n", + "\n", + "Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for differentiating foreign functions.\n", + "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", + "Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", + "\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "In this case, we actually define two new FFI calls:\n", + "\n", + "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", + "2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n", + "\n", + "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](rms_norm.cc) to see how these functions are implemented on the back end.\n", + "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", + "\n", + "This custom derivative rule can be wired in as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "xla_client.register_custom_call_target(\n", + " \"rms_norm_fwd\", rms_norm_lib.rms_norm_fwd(), platform=\"cpu\", api_version=1\n", + ")\n", + "xla_client.register_custom_call_target(\n", + " \"rms_norm_bwd\", rms_norm_lib.rms_norm_bwd(), platform=\"cpu\", api_version=1\n", + ")\n", + "\n", + "\n", + "def rms_norm_fwd(x, eps=1e-5):\n", + " y, res = jex.ffi.ffi_call(\n", + " \"rms_norm_fwd\",\n", + " (\n", + " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + " jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n", + " ),\n", + " x,\n", + " eps=np.float32(eps),\n", + " vectorized=True,\n", + " )\n", + " return y, (res, x)\n", + "\n", + "\n", + "def rms_norm_bwd(eps, res, ct):\n", + " del eps\n", + " res, x = res\n", + " assert res.shape == ct.shape[:-1]\n", + " assert x.shape == ct.shape\n", + " return (\n", + " jex.ffi.ffi_call(\n", + " \"rms_norm_bwd\",\n", + " jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n", + " res,\n", + " x,\n", + " ct,\n", + " vectorized=True,\n", + " ),\n", + " )\n", + "\n", + "\n", + "rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))\n", + "rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)\n", + "\n", + "# Check that this gives the right answer when compared to the reference version\n", + "ct_y = jnp.ones_like(x)\n", + "np.testing.assert_allclose(\n", + " jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TODO(dfm)\n", + "\n", + "- ctypes + ffi.pycapsule interface\n", + "- dtype dispatching\n", + "- CUDA\n", + "- partitioning\n", + "-" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/ffi/ffi.md b/docs/ffi/ffi.md new file mode 100644 index 000000000000..4d2212dfe7c4 --- /dev/null +++ b/docs/ffi/ffi.md @@ -0,0 +1,363 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JAX's foreign function interface + +While a wide range of numerical operations can be easily and efficiently implemented using JAX's built in `jax.numpy` and `jax.lax` interfaces, it can sometimes be useful to explicitly call out to external compiled libraries via a "foreign function interface" (FFI). +This can be particularly useful when particular operations have been previously implemented in an optimized C or CUDA library, and it would be non-trivial to reimplement these computations directly using JAX, but it can also be useful for optimizing runtime or memory performance of JAX programs. +That being said, the FFI should typically be considered a last resort option because the XLA compiler that sits in the backend, or the Pallas kernel language, which provides lower level control, typically produce performant code with a lower development and maintenance cost. + +One point that should be taken into account when considering use of the FFI is that _JAX doesn't automatically know how to differentiate through foreign functions_. +This means that if you want to use JAX's autodifferentiation capabilities alongside a foreign function, you'll also need to provide an implementation of the relevant differentiation rules. +We will discuss some possible approaches below, but it is important to call this limitation out right from the start! + +JAX's FFI support is provided in two parts: + +1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and +2. A Python front end, available in the `jax.extend.ffi` submodule. + +In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. +We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below. + +## A simple example + +To demonstrate the use of the FFI interface, we will implement a simple "root-mean-square (RMS)" normalization function. +RMS normalization takes an array $x$ with shape $(N,)$ and returns + +$$ +y_n = \frac{x_n}{\sqrt{\frac{1}{N}\sum_{n=1}^N {x_n}^2 + \epsilon}} +$$ + +where $\epsilon$ is a tuning parameter used for numerical stability. + +This is a somewhat silly example, because it can be easily implemented using JAX as follows: + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp + + +def rms_norm_ref(x, eps=1e-5): + scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps) + return x / scale +``` + +But, it's just non-trivial enough to be useful for demonstrating some key details of the FFI, while still being straightforward to understand. +We will use this reference implementation to test our FFI version below. + +## Backend code + +To begin with, we need an implementation of RMS normalization in C++ that we will expose using the FFI. +This isn't meant to be particularly performant, but you could imagine that if you had some new better implementation of RMS normalization in a C++ library, it might have an interface like the following. +So, here's a simple implementation of RMS normalization in C++: + +```c++ +#include +#include + +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} +``` + +and, for our example, this is the function that we want to expose to JAX via the FFI. + +### C++ interface + +To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla). +For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call), but for our purposes, it's sufficient to know that we can define our implementation as follows: + +```c++ +#include +#include +#include + +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +std::pair GetDims(ffi::Span dims) { + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + int64_t totalSize = + std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + int64_t lastDim = dims.back(); + return std::make_pair(totalSize, lastDim); +} + +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x.dimensions); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER( + RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */)); +``` + +Starting at the bottom, we're using the XLA-provided macro `XLA_FFI_DEFINE_HANDLER` to generate some boilerplate which will expand into a function called `RmsNorm` with the appropriate signature. +But, the important stuff here is all in the call to `ffi::Ffi::Bind()`, where we define the input and output types, and the types of any parameters. + +Then, in `RmsNormImpl`, we accept `ffi::Buffer` arguments which include information about the buffer shape, and pointers to the underlying data. +In this implementation, we treat all leading dimensions of the buffer as batch dimensions, and perform RMS normalization over the last axis. +`GetDims` is a helper function providing support for this batching behavior. + +### Building and registering an FFI handler + +Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python. +In this tutorial, we use [nanobind](https://nanobind.readthedocs.io) to define a tiny Python module, but it is also possible to compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), a pattern which we discuss below. + +For this example, our nanobind module is defined as follows: + +```c++ +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" + +namespace nb = nanobind; + +template +nb::capsule EncapsulateFfiCall(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +NB_MODULE(rms_norm, m) { + m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); }); +} +``` + +With this in place, we can now compile our library. +Here we use CMake, but you should be able to use your favorite build system without too much trouble. + +```{code-cell} ipython3 +:tags: [hide-output] + +!cmake -DCMAKE_BUILD_TYPE=Release -B _build . +!cmake --build _build +!cmake --install _build +``` + +With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.lib.xla_client.register_custom_call_target` function. +In our nanobind module above, we implemented a Python function called `rms_norm` which returns a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html) containing a function pointer to the C++ function `RmsNorm`. +This "capsule" is the object that we pass to `register_custom_call_target`: + +```{code-cell} ipython3 +from jax.lib import xla_client +import rms_norm as rms_norm_lib + +xla_client.register_custom_call_target( + "rms_norm", rms_norm_lib.rms_norm(), platform="cpu", api_version=1 +) +``` + +We note that the `api_version=1` keyword in the call to `register_custom_call_target` is important because it indicates that our function pointer implements this FFI interface. + +## Frontend code + +Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function: + +```{code-cell} ipython3 +import numpy as np +import jax.extend as jex + + +def rms_norm(x, eps=1e-5): + # We only implemented the `float32` version of this function, so we start by + # checking the dtype. This check isn't strictly necessary because type + # checking is also performed by the FFI when decoding input and output + # buffers, but it can be useful to check types in Python to raise more + # informative errors. + if x.dtype != jnp.float32: + raise ValueError("Only the float32 dtype is implemented by rms_norm") + + # In this case, the output of our FFI function is just a single array with the + # same shape and dtype as the input. We discuss a case with a more interesting + # output type below. + out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + + return jex.ffi.ffi_call( + # The target name must be the same string as we used to register the target + # above in `register_custom_call_target` + "rms_norm", + out_type, + x, + # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for + # the attribute `eps`. Our FFI function expects this to have the C++ `float` + # type (which corresponds to numpy's `float32` type), and it must be a + # static parameter (i.e. not a JAX array). + eps=np.float32(eps), + # The `vectorized` parameter controls this function's behavior under `vmap` + # as discussed below. + vectorized=True, + ) + + +# Test that this gives the same result as our reference implementation +x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) +np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5) +``` + +This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting. +Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs. +It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above. + +Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`. +Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments. + +The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. + +### Batching with `vmap` + +All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient. +By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body. +This default implementation is general purpose, but it doesn't parallelize very well. +But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation. + +The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes. +Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly: + +```python +ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs]) +``` + +Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box: + +```{code-cell} ipython3 +np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5) +``` + +We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: + +```{code-cell} ipython3 +jax.make_jaxpr(jax.vmap(rms_norm))(x) +``` + +If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body: + +```{code-cell} ipython3 +def rms_norm_not_vectorized(x, eps=1e-5): + return jex.ffi.ffi_call( + "rms_norm", + jax.ShapeDtypeStruct(x.shape, x.dtype), + x, + eps=np.float32(eps), + vectorized=False, # This is the default behavior + ) + + +jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) +``` + +If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it is possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues). + ++++ + +### Differentiation + +Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for differentiating foreign functions. +As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. +Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule. + +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +In this case, we actually define two new FFI calls: + +1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. +2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents. + +We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](rms_norm.cc) to see how these functions are implemented on the back end. +The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. + +This custom derivative rule can be wired in as follows: + +```{code-cell} ipython3 +xla_client.register_custom_call_target( + "rms_norm_fwd", rms_norm_lib.rms_norm_fwd(), platform="cpu", api_version=1 +) +xla_client.register_custom_call_target( + "rms_norm_bwd", rms_norm_lib.rms_norm_bwd(), platform="cpu", api_version=1 +) + + +def rms_norm_fwd(x, eps=1e-5): + y, res = jex.ffi.ffi_call( + "rms_norm_fwd", + ( + jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), + ), + x, + eps=np.float32(eps), + vectorized=True, + ) + return y, (res, x) + + +def rms_norm_bwd(eps, res, ct): + del eps + res, x = res + assert res.shape == ct.shape[:-1] + assert x.shape == ct.shape + return ( + jex.ffi.ffi_call( + "rms_norm_bwd", + jax.ShapeDtypeStruct(ct.shape, ct.dtype), + res, + x, + ct, + vectorized=True, + ), + ) + + +rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,)) +rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) + +# Check that this gives the right answer when compared to the reference version +ct_y = jnp.ones_like(x) +np.testing.assert_allclose( + jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5 +) +``` + +## TODO(dfm) + +- ctypes + ffi.pycapsule interface +- dtype dispatching +- CUDA +- partitioning +- diff --git a/docs/ffi/rms_norm.cc b/docs/ffi/rms_norm.cc new file mode 100644 index 000000000000..9721e80b1167 --- /dev/null +++ b/docs/ffi/rms_norm.cc @@ -0,0 +1,151 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +// This is the example "library function" that we want to expose to JAX. This +// isn't meant to be a particularly good implementation, it's just here as a +// placeholder for the purposes of this tutorial. +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} + +// A helper function for extracting the relevant dimensions from `ffi::Buffer`s. +// In this example, we treat all leading dimensions as batch dimensions, so this +// function returns the total number of elements in the buffer, and the size of +// the last dimension. +std::pair GetDims(ffi::Span dims) { + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + int64_t totalSize = + std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + int64_t lastDim = dims.back(); + return std::make_pair(totalSize, lastDim); +} + +// A wrapper function providing the interface between the XLA FFI call and our +// library function `ComputeRmsNorm` above. This function handles the batch +// dimensions by calling `ComputeRmsNorm` within a loop. +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x.dimensions); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n])); + } + return ffi::Error::Success(); +} + +// Wrap `RmsNormImpl` and specify the interface to XLA. +XLA_FFI_DEFINE_HANDLER(RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */)); + +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::Result> y, + ffi::Result> res) { + auto [totalSize, lastDim] = GetDims(x.dimensions); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormFwd input must be an array"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + res->data[idx] = ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */) + .Ret>(/* res */)); + +void ComputeRmsNormBwd(int64_t size, float res, const float *x, + const float *ct_y, float *ct_x) { + float ct_res = 0.0f; + for (int64_t n = 0; n < size; ++n) { + ct_res += x[n] * ct_y[n]; + } + float factor = ct_res * res * res * res / float(size); + for (int64_t n = 0; n < size; ++n) { + ct_x[n] = res * ct_y[n] - factor * x[n]; + } +} + +ffi::Error RmsNormBwdImpl(ffi::Buffer res, + ffi::Buffer x, + ffi::Buffer ct_y, + ffi::Result> ct_x) { + auto [totalSize, lastDim] = GetDims(x.dimensions); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormBwd inputs must be arrays"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + ComputeRmsNormBwd(lastDim, res.data[idx], &(x.data[n]), &(ct_y.data[n]), + &(ct_x->data[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>(/* res */) + .Arg>(/* x */) + .Arg>(/* ct_y */) + .Ret>(/* ct_x */)); + +template +nb::capsule EncapsulateFfiCall(T *fn) { + // This check is optional, but it can be useful to catch invalid function + // pointers at compile time. + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn)); +} + +NB_MODULE(rms_norm, m) { + m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); }); + m.def("rms_norm_fwd", []() { return EncapsulateFfiCall(RmsNormFwd); }); + m.def("rms_norm_bwd", []() { return EncapsulateFfiCall(RmsNormBwd); }); +} diff --git a/docs/requirements.txt b/docs/requirements.txt index 643d4086d8be..7bf7a5350a33 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -18,4 +18,6 @@ matplotlib scikit-learn numpy rich[jupyter] +cmake +nanobind .[ci] # Install jax from the current directory; jaxlib from pypi.