Skip to content

Commit

Permalink
Add ffi_call tutorial
Browse files Browse the repository at this point in the history
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
  • Loading branch information
dfm committed Jul 18, 2024
1 parent a07b9ad commit 5a766f9
Show file tree
Hide file tree
Showing 9 changed files with 1,241 additions and 2 deletions.
3 changes: 1 addition & 2 deletions docs/_tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ JAX tutorials draft

.. note::

This is a
The tutorials below are a work in progress; for the time being, please refer
to the older tutorial content, including :ref:`beginner-guide`,
:ref:`user-guides`, and the now-deleted *JAX 101* tutorials.
Expand Down Expand Up @@ -44,7 +43,7 @@ JAX 201
advanced-debugging
external-callbacks
profiling-and-performance

../ffi/ffi

JAX 301
-------
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _do_not_evaluate_in_jax(
'jep/9407-type-promotion.md',
'autodidax.md',
'sharded-computation.md',
'ffi/ffi.ipynb',
]

# The name of the Pygments (syntax highlighting) style to use.
Expand Down
5 changes: 5 additions & 0 deletions docs/ffi/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CMake*
cmake*
Makefile
*.so
*.dylib
20 changes: 20 additions & 0 deletions docs/ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
cmake_minimum_required(VERSION 3.18...3.27)
project(rms_norm LANGUAGES CXX)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)

execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)

execute_process(
COMMAND "${Python_EXECUTABLE}"
"-c" "from jax.extend import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")

nanobind_add_module(rms_norm NOMINSIZE "rms_norm.cc")
target_include_directories(rms_norm PUBLIC ${XLA_DIR})
install(TARGETS rms_norm LIBRARY DESTINATION ${CMAKE_CURRENT_LIST_DIR})

600 changes: 600 additions & 0 deletions docs/ffi/ffi.ipynb

Large diffs are not rendered by default.

460 changes: 460 additions & 0 deletions docs/ffi/ffi.md

Large diffs are not rendered by default.

151 changes: 151 additions & 0 deletions docs/ffi/rms_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cmath>
#include <cstdint>
#include <functional>
#include <numeric>
#include <utility>
#include <type_traits>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

namespace ffi = xla::ffi;
namespace nb = nanobind;

// This is the example "library function" that we want to expose to JAX. This
// isn't meant to be a particularly good implementation, it's just here as a
// placeholder for the purposes of this tutorial.
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
float sm = 0.0f;
for (int64_t n = 0; n < size; ++n) {
sm += x[n] * x[n];
}
float scale = 1.0f / std::sqrt(sm / float(size) + eps);
for (int64_t n = 0; n < size; ++n) {
y[n] = x[n] * scale;
}
return scale;
}

// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
std::pair<int64_t, int64_t> GetDims(ffi::Span<const int64_t> dims) {
if (dims.size() == 0) {
return std::make_pair(0, 0);
}
int64_t totalSize =
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
int64_t lastDim = dims.back();
return std::make_pair(totalSize, lastDim);
}

// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
auto [totalSize, lastDim] = GetDims(x.dimensions);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n]));
}
return ffi::Error::Success();
}

// Wrap `RmsNormImpl` and specify the interface to XLA.
XLA_FFI_DEFINE_HANDLER(RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>(/* x */)
.Ret<ffi::Buffer<ffi::DataType::F32>>(/* y */));

ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> res) {
auto [totalSize, lastDim] = GetDims(x.dimensions);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormFwd input must be an array");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
res->data[idx] = ComputeRmsNorm(eps, lastDim, &(x.data[n]), &(y->data[n]));
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER(RmsNormFwd, RmsNormFwdImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>(/* x */)
.Ret<ffi::Buffer<ffi::DataType::F32>>(/* y */)
.Ret<ffi::Buffer<ffi::DataType::F32>>(/* res */));

void ComputeRmsNormBwd(int64_t size, float res, const float *x,
const float *ct_y, float *ct_x) {
float ct_res = 0.0f;
for (int64_t n = 0; n < size; ++n) {
ct_res += x[n] * ct_y[n];
}
float factor = ct_res * res * res * res / float(size);
for (int64_t n = 0; n < size; ++n) {
ct_x[n] = res * ct_y[n] - factor * x[n];
}
}

ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res,
ffi::Buffer<ffi::DataType::F32> x,
ffi::Buffer<ffi::DataType::F32> ct_y,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> ct_x) {
auto [totalSize, lastDim] = GetDims(x.dimensions);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormBwd inputs must be arrays");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
ComputeRmsNormBwd(lastDim, res.data[idx], &(x.data[n]), &(ct_y.data[n]),
&(ct_x->data[n]));
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER(RmsNormBwd, RmsNormBwdImpl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::DataType::F32>>(/* res */)
.Arg<ffi::Buffer<ffi::DataType::F32>>(/* x */)
.Arg<ffi::Buffer<ffi::DataType::F32>>(/* ct_y */)
.Ret<ffi::Buffer<ffi::DataType::F32>>(/* ct_x */));

template <typename T>
nb::capsule EncapsulateFfiCall(T *fn) {
// This check is optional, but it can be useful to catch invalid function
// pointers at compile time.
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be and XLA FFI handler");
return nb::capsule(reinterpret_cast<void *>(fn));
}

NB_MODULE(rms_norm, m) {
m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); });
m.def("rms_norm_fwd", []() { return EncapsulateFfiCall(RmsNormFwd); });
m.def("rms_norm_bwd", []() { return EncapsulateFfiCall(RmsNormBwd); });
}
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Operators
neg
nextafter
pad
platform_dependent
polygamma
population_count
pow
Expand Down
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ matplotlib
scikit-learn
numpy
rich[jupyter]
cmake
nanobind
.[ci] # Install jax from the current directory; jaxlib from pypi.

0 comments on commit 5a766f9

Please sign in to comment.