From f0122b7535fa479d1cedd953e6bcf9717a091e10 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 25 Jun 2024 14:08:49 -0400 Subject: [PATCH] Add ffi_call tutorial Building on #21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice. --- docs/_tutorials/index.rst | 4 +- docs/conf.py | 1 + docs/ffi/.gitignore | 5 + docs/ffi/CMakeLists.txt | 20 ++ docs/ffi/ffi.ipynb | 468 ++++++++++++++++++++++++++++++++++++++ docs/ffi/ffi.md | 363 +++++++++++++++++++++++++++++ docs/ffi/rms_norm.cc | 151 ++++++++++++ docs/requirements.txt | 2 + 8 files changed, 1012 insertions(+), 2 deletions(-) create mode 100644 docs/ffi/.gitignore create mode 100644 docs/ffi/CMakeLists.txt create mode 100644 docs/ffi/ffi.ipynb create mode 100644 docs/ffi/ffi.md create mode 100644 docs/ffi/rms_norm.cc diff --git a/docs/_tutorials/index.rst b/docs/_tutorials/index.rst index d261612a4cd4..3e2da4bc141e 100644 --- a/docs/_tutorials/index.rst +++ b/docs/_tutorials/index.rst @@ -7,7 +7,7 @@ JAX tutorials draft .. note:: - This is a + This is a The tutorials below are a work in progress; for the time being, please refer to the older tutorial content, including :ref:`beginner-guide`, :ref:`user-guides`, and the now-deleted *JAX 101* tutorials. @@ -44,7 +44,7 @@ JAX 201 advanced-debugging external-callbacks profiling-and-performance - + ../ffi/ffi JAX 301 ------- diff --git a/docs/conf.py b/docs/conf.py index fdbb157ecf39..cde7f118f180 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -136,6 +136,7 @@ def _do_not_evaluate_in_jax( 'jep/9407-type-promotion.md', 'autodidax.md', 'sharded-computation.md', + 'ffi/ffi.ipynb', ] # The name of the Pygments (syntax highlighting) style to use. diff --git a/docs/ffi/.gitignore b/docs/ffi/.gitignore new file mode 100644 index 000000000000..f6608699925e --- /dev/null +++ b/docs/ffi/.gitignore @@ -0,0 +1,5 @@ +CMake* +cmake* +Makefile +*.so +*.dylib diff --git a/docs/ffi/CMakeLists.txt b/docs/ffi/CMakeLists.txt new file mode 100644 index 000000000000..c219df46876f --- /dev/null +++ b/docs/ffi/CMakeLists.txt @@ -0,0 +1,20 @@ +cmake_minimum_required(VERSION 3.18...3.27) +project(rms_norm LANGUAGES CXX) +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) +list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") +find_package(nanobind CONFIG REQUIRED) + +execute_process( + COMMAND "${Python_EXECUTABLE}" + "-c" "from jax.extend import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) +message(STATUS "XLA include directory: ${XLA_DIR}") + +nanobind_add_module(rms_norm NOMINSIZE "rms_norm.cc") +target_include_directories(rms_norm PUBLIC ${XLA_DIR}) +install(TARGETS rms_norm LIBRARY DESTINATION ${CMAKE_CURRENT_LIST_DIR}) + diff --git a/docs/ffi/ffi.ipynb b/docs/ffi/ffi.ipynb new file mode 100644 index 000000000000..4d09b71a7a4d --- /dev/null +++ b/docs/ffi/ffi.ipynb @@ -0,0 +1,468 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# JAX's foreign function interface\n", + "\n", + "While a wide range of numerical operations can be easily and efficiently implemented using JAX's built in `jax.numpy` and `jax.lax` interfaces, it can sometimes be useful to explicitly call out to external compiled libraries via a \"foreign function interface\" (FFI).\n", + "This can be particularly useful when particular operations have been previously implemented in an optimized C or CUDA library, and it would be non-trivial to reimplement these computations directly using JAX, but it can also be useful for optimizing runtime or memory performance of JAX programs.\n", + "That being said, the FFI should typically be considered a last resort option because the XLA compiler that sits in the backend, or the Pallas kernel language, which provides lower level control, typically produce performant code with a lower development and maintenance cost.\n", + "\n", + "One point that should be taken into account when considering use of the FFI is that _JAX doesn't automatically know how to differentiate through foreign functions_.\n", + "This means that if you want to use JAX's autodifferentiation capabilities alongside a foreign function, you'll also need to provide an implementation of the relevant differentiation rules.\n", + "We will discuss some possible approaches below, but it is important to call this limitation out right from the start!\n", + "\n", + "JAX's FFI support is provided in two parts:\n", + "\n", + "1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n", + "2. A Python front end, available in the `jax.extend.ffi` submodule.\n", + "\n", + "In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n", + "We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n", + "\n", + "## A simple example\n", + "\n", + "To demonstrate the use of the FFI interface, we will implement a simple \"root-mean-square (RMS)\" normalization function.\n", + "RMS normalization takes an array $x$ with shape $(N,)$ and returns\n", + "\n", + "$$\n", + "y_n = \\frac{x_n}{\\sqrt{\\frac{1}{N}\\sum_{n=1}^N {x_n}^2 + \\epsilon}}\n", + "$$\n", + "\n", + "where $\\epsilon$ is a tuning parameter used for numerical stability.\n", + "\n", + "This is a somewhat silly example, because it can be easily implemented using JAX as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "def rms_norm_ref(x, eps=1e-5):\n", + " scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)\n", + " return x / scale" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But, it's just non-trivial enough to be useful for demonstrating some key details of the FFI, while still being straightforward to understand.\n", + "We will use this reference implementation to test our FFI version below.\n", + "\n", + "## Backend code\n", + "\n", + "To begin with, we need an implementation of RMS normalization in C++ that we will expose using the FFI.\n", + "This isn't meant to be particularly performant, but you could imagine that if you had some new better implementation of RMS normalization in a C++ library, it might have an interface like the following.\n", + "So, here's a simple implementation of RMS normalization in C++:\n", + "\n", + "```c++\n", + "#include \n", + "#include \n", + "\n", + "float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {\n", + " float sm = 0.0f;\n", + " for (int64_t n = 0; n < size; ++n) {\n", + " sm += x[n] * x[n];\n", + " }\n", + " float scale = 1.0f / std::sqrt(sm / float(size) + eps);\n", + " for (int64_t n = 0; n < size; ++n) {\n", + " y[n] = x[n] * scale;\n", + " }\n", + " return scale;\n", + "}\n", + "```\n", + "\n", + "and, for our example, this is the function that we want to expose to JAX via the FFI.\n", + "\n", + "### C++ interface\n", + "\n", + "To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n", + "For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call), but for our purposes, it's sufficient to know that we can define our implementation as follows:\n", + "\n", + "```c++\n", + "#include \n", + "#include \n", + "#include \n", + "\n", + "#include \"xla/ffi/api/c_api.h\"\n", + "#include \"xla/ffi/api/ffi.h\"\n", + "\n", + "namespace ffi = xla::ffi;\n", + "\n", + "std::pair GetDims(ffi::Span dims) {\n", + " if (dims.size() == 0) {\n", + " return std::make_pair(0, 0);\n", + " }\n", + " int64_t totalSize =\n", + " std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());\n", + " int64_t lastDim = dims.back();\n", + " return std::make_pair(totalSize, lastDim);\n", + "}\n", + "\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::Result> y) {\n", + " auto [totalSize, lastDim] = GetDims(x.dimensions);\n", + " if (lastDim == 0) {\n", + " return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n", + " \"RmsNorm input must be an array\");\n", + " }\n", + " for (int64_t n = 0; n < totalSize; n += lastDim) {\n", + " ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n]));\n", + " }\n", + " return ffi::Error::Success();\n", + "}\n", + "\n", + "XLA_FFI_DEFINE_HANDLER(\n", + " RmsNorm, RmsNormImpl,\n", + " ffi::Ffi::Bind()\n", + " .Attr(\"eps\")\n", + " .Arg>(/* x */)\n", + " .Ret>(/* y */));\n", + "```\n", + "\n", + "Starting at the bottom, we're using the XLA-provided macro `XLA_FFI_DEFINE_HANDLER` to generate some boilerplate which will expand into a function called `RmsNorm` with the appropriate signature.\n", + "But, the important stuff here is all in the call to `ffi::Ffi::Bind()`, where we define the input and output types, and the types of any parameters.\n", + "\n", + "Then, in `RmsNormImpl`, we accept `ffi::Buffer` arguments which include information about the buffer shape, and pointers to the underlying data.\n", + "In this implementation, we treat all leading dimensions of the buffer as batch dimensions, and perform RMS normalization over the last axis.\n", + "`GetDims` is a helper function providing support for this batching behavior.\n", + "\n", + "### Building and registering an FFI handler\n", + "\n", + "Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n", + "In this tutorial, we use [nanobind](https://nanobind.readthedocs.io) to define a tiny Python module, but it is also possible to compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), a pattern which we discuss below.\n", + "\n", + "For this example, our nanobind module is defined as follows:\n", + "\n", + "```c++\n", + "#include \n", + "\n", + "#include \"nanobind/nanobind.h\"\n", + "#include \"xla/ffi/api/c_api.h\"\n", + "\n", + "namespace nb = nanobind;\n", + "\n", + "template \n", + "nb::capsule EncapsulateFfiCall(T *fn) {\n", + " static_assert(std::is_invocable_r_v,\n", + " \"Encapsulated function must be and XLA FFI handler\");\n", + " return nb::capsule(reinterpret_cast(fn), \"xla._CUSTOM_CALL_TARGET\");\n", + "}\n", + "\n", + "NB_MODULE(rms_norm, m) {\n", + " m.def(\"rms_norm\", []() { return EncapsulateFfiCall(RmsNorm); });\n", + "}\n", + "```\n", + "\n", + "With this in place, we can now compile our library.\n", + "Here we use CMake, but you should be able to use your favorite build system without too much trouble." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "!cmake -DCMAKE_BUILD_TYPE=Release -B _build .\n", + "!cmake --build _build\n", + "!cmake --install _build" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.lib.xla_client.register_custom_call_target` function.\n", + "In our nanobind module above, we implemented a Python function called `rms_norm` which returns a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html) containing a function pointer to the C++ function `RmsNorm`.\n", + "This \"capsule\" is the object that we pass to `register_custom_call_target`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax.lib import xla_client\n", + "import rms_norm as rms_norm_lib\n", + "\n", + "xla_client.register_custom_call_target(\n", + " \"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\", api_version=1\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We note that the `api_version=1` keyword in the call to `register_custom_call_target` is important because it indicates that our function pointer implements this FFI interface.\n", + "\n", + "## Frontend code\n", + "\n", + "Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import jax.extend as jex\n", + "\n", + "\n", + "def rms_norm(x, eps=1e-5):\n", + " # We only implemented the `float32` version of this function, so we start by\n", + " # checking the dtype. This check isn't strictly necessary because type\n", + " # checking is also performed by the FFI when decoding input and output\n", + " # buffers, but it can be useful to check types in Python to raise more\n", + " # informative errors.\n", + " if x.dtype != jnp.float32:\n", + " raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n", + "\n", + " # In this case, the output of our FFI function is just a single array with the\n", + " # same shape and dtype as the input. We discuss a case with a more interesting\n", + " # output type below.\n", + " out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + "\n", + " return jex.ffi.ffi_call(\n", + " # The target name must be the same string as we used to register the target\n", + " # above in `register_custom_call_target`\n", + " \"rms_norm\",\n", + " out_type,\n", + " x,\n", + " # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n", + " # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n", + " # type (which corresponds to numpy's `float32` type), and it must be a\n", + " # static parameter (i.e. not a JAX array).\n", + " eps=np.float32(eps),\n", + " # The `vectorized` parameter controls this function's behavior under `vmap`\n", + " # as discussed below.\n", + " vectorized=True,\n", + " )\n", + "\n", + "\n", + "# Test that this gives the same result as our reference implementation\n", + "x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))\n", + "np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n", + "Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n", + "It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n", + "\n", + "Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n", + "Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n", + "\n", + "The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", + "\n", + "### Batching with `vmap`\n", + "\n", + "All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n", + "By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", + "This default implementation is general purpose, but it doesn't parallelize very well.\n", + "But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n", + "\n", + "The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n", + "Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n", + "\n", + "```python\n", + "ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n", + "```\n", + "\n", + "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jax.make_jaxpr(jax.vmap(rms_norm))(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rms_norm_not_vectorized(x, eps=1e-5):\n", + " return jex.ffi.ffi_call(\n", + " \"rms_norm\",\n", + " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + " x,\n", + " eps=np.float32(eps),\n", + " vectorized=False, # This is the default behavior\n", + " )\n", + "\n", + "\n", + "jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it is possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Differentiation\n", + "\n", + "Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for differentiating foreign functions.\n", + "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", + "Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", + "\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "In this case, we actually define two new FFI calls:\n", + "\n", + "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", + "2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n", + "\n", + "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](rms_norm.cc) to see how these functions are implemented on the back end.\n", + "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", + "\n", + "This custom derivative rule can be wired in as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "xla_client.register_custom_call_target(\n", + " \"rms_norm_fwd\", rms_norm_lib.rms_norm_fwd(), platform=\"cpu\", api_version=1\n", + ")\n", + "xla_client.register_custom_call_target(\n", + " \"rms_norm_bwd\", rms_norm_lib.rms_norm_bwd(), platform=\"cpu\", api_version=1\n", + ")\n", + "\n", + "\n", + "def rms_norm_fwd(x, eps=1e-5):\n", + " y, res = jex.ffi.ffi_call(\n", + " \"rms_norm_fwd\",\n", + " (\n", + " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + " jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n", + " ),\n", + " x,\n", + " eps=np.float32(eps),\n", + " vectorized=True,\n", + " )\n", + " return y, (res, x)\n", + "\n", + "\n", + "def rms_norm_bwd(eps, res, ct):\n", + " del eps\n", + " res, x = res\n", + " assert res.shape == ct.shape[:-1]\n", + " assert x.shape == ct.shape\n", + " return (\n", + " jex.ffi.ffi_call(\n", + " \"rms_norm_bwd\",\n", + " jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n", + " res,\n", + " x,\n", + " ct,\n", + " vectorized=True,\n", + " ),\n", + " )\n", + "\n", + "\n", + "rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))\n", + "rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)\n", + "\n", + "# Check that this gives the right answer when compared to the reference version\n", + "ct_y = jnp.ones_like(x)\n", + "np.testing.assert_allclose(\n", + " jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TODO(dfm)\n", + "\n", + "- ctypes + ffi.pycapsule interface\n", + "- dtype dispatching\n", + "- CUDA\n", + "- partitioning\n", + "-" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/ffi/ffi.md b/docs/ffi/ffi.md new file mode 100644 index 000000000000..4d2212dfe7c4 --- /dev/null +++ b/docs/ffi/ffi.md @@ -0,0 +1,363 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JAX's foreign function interface + +While a wide range of numerical operations can be easily and efficiently implemented using JAX's built in `jax.numpy` and `jax.lax` interfaces, it can sometimes be useful to explicitly call out to external compiled libraries via a "foreign function interface" (FFI). +This can be particularly useful when particular operations have been previously implemented in an optimized C or CUDA library, and it would be non-trivial to reimplement these computations directly using JAX, but it can also be useful for optimizing runtime or memory performance of JAX programs. +That being said, the FFI should typically be considered a last resort option because the XLA compiler that sits in the backend, or the Pallas kernel language, which provides lower level control, typically produce performant code with a lower development and maintenance cost. + +One point that should be taken into account when considering use of the FFI is that _JAX doesn't automatically know how to differentiate through foreign functions_. +This means that if you want to use JAX's autodifferentiation capabilities alongside a foreign function, you'll also need to provide an implementation of the relevant differentiation rules. +We will discuss some possible approaches below, but it is important to call this limitation out right from the start! + +JAX's FFI support is provided in two parts: + +1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and +2. A Python front end, available in the `jax.extend.ffi` submodule. + +In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. +We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below. + +## A simple example + +To demonstrate the use of the FFI interface, we will implement a simple "root-mean-square (RMS)" normalization function. +RMS normalization takes an array $x$ with shape $(N,)$ and returns + +$$ +y_n = \frac{x_n}{\sqrt{\frac{1}{N}\sum_{n=1}^N {x_n}^2 + \epsilon}} +$$ + +where $\epsilon$ is a tuning parameter used for numerical stability. + +This is a somewhat silly example, because it can be easily implemented using JAX as follows: + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp + + +def rms_norm_ref(x, eps=1e-5): + scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps) + return x / scale +``` + +But, it's just non-trivial enough to be useful for demonstrating some key details of the FFI, while still being straightforward to understand. +We will use this reference implementation to test our FFI version below. + +## Backend code + +To begin with, we need an implementation of RMS normalization in C++ that we will expose using the FFI. +This isn't meant to be particularly performant, but you could imagine that if you had some new better implementation of RMS normalization in a C++ library, it might have an interface like the following. +So, here's a simple implementation of RMS normalization in C++: + +```c++ +#include +#include + +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} +``` + +and, for our example, this is the function that we want to expose to JAX via the FFI. + +### C++ interface + +To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla). +For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call), but for our purposes, it's sufficient to know that we can define our implementation as follows: + +```c++ +#include +#include +#include + +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +std::pair GetDims(ffi::Span dims) { + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + int64_t totalSize = + std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + int64_t lastDim = dims.back(); + return std::make_pair(totalSize, lastDim); +} + +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x.dimensions); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER( + RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */)); +``` + +Starting at the bottom, we're using the XLA-provided macro `XLA_FFI_DEFINE_HANDLER` to generate some boilerplate which will expand into a function called `RmsNorm` with the appropriate signature. +But, the important stuff here is all in the call to `ffi::Ffi::Bind()`, where we define the input and output types, and the types of any parameters. + +Then, in `RmsNormImpl`, we accept `ffi::Buffer` arguments which include information about the buffer shape, and pointers to the underlying data. +In this implementation, we treat all leading dimensions of the buffer as batch dimensions, and perform RMS normalization over the last axis. +`GetDims` is a helper function providing support for this batching behavior. + +### Building and registering an FFI handler + +Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python. +In this tutorial, we use [nanobind](https://nanobind.readthedocs.io) to define a tiny Python module, but it is also possible to compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), a pattern which we discuss below. + +For this example, our nanobind module is defined as follows: + +```c++ +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" + +namespace nb = nanobind; + +template +nb::capsule EncapsulateFfiCall(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +NB_MODULE(rms_norm, m) { + m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); }); +} +``` + +With this in place, we can now compile our library. +Here we use CMake, but you should be able to use your favorite build system without too much trouble. + +```{code-cell} ipython3 +:tags: [hide-output] + +!cmake -DCMAKE_BUILD_TYPE=Release -B _build . +!cmake --build _build +!cmake --install _build +``` + +With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.lib.xla_client.register_custom_call_target` function. +In our nanobind module above, we implemented a Python function called `rms_norm` which returns a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html) containing a function pointer to the C++ function `RmsNorm`. +This "capsule" is the object that we pass to `register_custom_call_target`: + +```{code-cell} ipython3 +from jax.lib import xla_client +import rms_norm as rms_norm_lib + +xla_client.register_custom_call_target( + "rms_norm", rms_norm_lib.rms_norm(), platform="cpu", api_version=1 +) +``` + +We note that the `api_version=1` keyword in the call to `register_custom_call_target` is important because it indicates that our function pointer implements this FFI interface. + +## Frontend code + +Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function: + +```{code-cell} ipython3 +import numpy as np +import jax.extend as jex + + +def rms_norm(x, eps=1e-5): + # We only implemented the `float32` version of this function, so we start by + # checking the dtype. This check isn't strictly necessary because type + # checking is also performed by the FFI when decoding input and output + # buffers, but it can be useful to check types in Python to raise more + # informative errors. + if x.dtype != jnp.float32: + raise ValueError("Only the float32 dtype is implemented by rms_norm") + + # In this case, the output of our FFI function is just a single array with the + # same shape and dtype as the input. We discuss a case with a more interesting + # output type below. + out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + + return jex.ffi.ffi_call( + # The target name must be the same string as we used to register the target + # above in `register_custom_call_target` + "rms_norm", + out_type, + x, + # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for + # the attribute `eps`. Our FFI function expects this to have the C++ `float` + # type (which corresponds to numpy's `float32` type), and it must be a + # static parameter (i.e. not a JAX array). + eps=np.float32(eps), + # The `vectorized` parameter controls this function's behavior under `vmap` + # as discussed below. + vectorized=True, + ) + + +# Test that this gives the same result as our reference implementation +x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) +np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5) +``` + +This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting. +Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs. +It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above. + +Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`. +Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments. + +The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. + +### Batching with `vmap` + +All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient. +By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body. +This default implementation is general purpose, but it doesn't parallelize very well. +But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation. + +The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes. +Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly: + +```python +ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs]) +``` + +Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box: + +```{code-cell} ipython3 +np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5) +``` + +We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: + +```{code-cell} ipython3 +jax.make_jaxpr(jax.vmap(rms_norm))(x) +``` + +If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body: + +```{code-cell} ipython3 +def rms_norm_not_vectorized(x, eps=1e-5): + return jex.ffi.ffi_call( + "rms_norm", + jax.ShapeDtypeStruct(x.shape, x.dtype), + x, + eps=np.float32(eps), + vectorized=False, # This is the default behavior + ) + + +jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) +``` + +If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it is possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues). + ++++ + +### Differentiation + +Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for differentiating foreign functions. +As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. +Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule. + +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +In this case, we actually define two new FFI calls: + +1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. +2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents. + +We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](rms_norm.cc) to see how these functions are implemented on the back end. +The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. + +This custom derivative rule can be wired in as follows: + +```{code-cell} ipython3 +xla_client.register_custom_call_target( + "rms_norm_fwd", rms_norm_lib.rms_norm_fwd(), platform="cpu", api_version=1 +) +xla_client.register_custom_call_target( + "rms_norm_bwd", rms_norm_lib.rms_norm_bwd(), platform="cpu", api_version=1 +) + + +def rms_norm_fwd(x, eps=1e-5): + y, res = jex.ffi.ffi_call( + "rms_norm_fwd", + ( + jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), + ), + x, + eps=np.float32(eps), + vectorized=True, + ) + return y, (res, x) + + +def rms_norm_bwd(eps, res, ct): + del eps + res, x = res + assert res.shape == ct.shape[:-1] + assert x.shape == ct.shape + return ( + jex.ffi.ffi_call( + "rms_norm_bwd", + jax.ShapeDtypeStruct(ct.shape, ct.dtype), + res, + x, + ct, + vectorized=True, + ), + ) + + +rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,)) +rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) + +# Check that this gives the right answer when compared to the reference version +ct_y = jnp.ones_like(x) +np.testing.assert_allclose( + jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5 +) +``` + +## TODO(dfm) + +- ctypes + ffi.pycapsule interface +- dtype dispatching +- CUDA +- partitioning +- diff --git a/docs/ffi/rms_norm.cc b/docs/ffi/rms_norm.cc new file mode 100644 index 000000000000..9721e80b1167 --- /dev/null +++ b/docs/ffi/rms_norm.cc @@ -0,0 +1,151 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +// This is the example "library function" that we want to expose to JAX. This +// isn't meant to be a particularly good implementation, it's just here as a +// placeholder for the purposes of this tutorial. +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} + +// A helper function for extracting the relevant dimensions from `ffi::Buffer`s. +// In this example, we treat all leading dimensions as batch dimensions, so this +// function returns the total number of elements in the buffer, and the size of +// the last dimension. +std::pair GetDims(ffi::Span dims) { + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + int64_t totalSize = + std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + int64_t lastDim = dims.back(); + return std::make_pair(totalSize, lastDim); +} + +// A wrapper function providing the interface between the XLA FFI call and our +// library function `ComputeRmsNorm` above. This function handles the batch +// dimensions by calling `ComputeRmsNorm` within a loop. +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x.dimensions); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n])); + } + return ffi::Error::Success(); +} + +// Wrap `RmsNormImpl` and specify the interface to XLA. +XLA_FFI_DEFINE_HANDLER(RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */)); + +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::Result> y, + ffi::Result> res) { + auto [totalSize, lastDim] = GetDims(x.dimensions); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormFwd input must be an array"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + res->data[idx] = ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>(/* x */) + .Ret>(/* y */) + .Ret>(/* res */)); + +void ComputeRmsNormBwd(int64_t size, float res, const float *x, + const float *ct_y, float *ct_x) { + float ct_res = 0.0f; + for (int64_t n = 0; n < size; ++n) { + ct_res += x[n] * ct_y[n]; + } + float factor = ct_res * res * res * res / float(size); + for (int64_t n = 0; n < size; ++n) { + ct_x[n] = res * ct_y[n] - factor * x[n]; + } +} + +ffi::Error RmsNormBwdImpl(ffi::Buffer res, + ffi::Buffer x, + ffi::Buffer ct_y, + ffi::Result> ct_x) { + auto [totalSize, lastDim] = GetDims(x.dimensions); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormBwd inputs must be arrays"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + ComputeRmsNormBwd(lastDim, res.data[idx], &(x.data[n]), &(ct_y.data[n]), + &(ct_x->data[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>(/* res */) + .Arg>(/* x */) + .Arg>(/* ct_y */) + .Ret>(/* ct_x */)); + +template +nb::capsule EncapsulateFfiCall(T *fn) { + // This check is optional, but it can be useful to catch invalid function + // pointers at compile time. + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn)); +} + +NB_MODULE(rms_norm, m) { + m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); }); + m.def("rms_norm_fwd", []() { return EncapsulateFfiCall(RmsNormFwd); }); + m.def("rms_norm_bwd", []() { return EncapsulateFfiCall(RmsNormBwd); }); +} diff --git a/docs/requirements.txt b/docs/requirements.txt index 643d4086d8be..7bf7a5350a33 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -18,4 +18,6 @@ matplotlib scikit-learn numpy rich[jupyter] +cmake +nanobind .[ci] # Install jax from the current directory; jaxlib from pypi.