Skip to content

Commit

Permalink
Add ffi_call tutorial
Browse files Browse the repository at this point in the history
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
  • Loading branch information
dfm committed Jun 27, 2024
1 parent ed56df0 commit 7185032
Show file tree
Hide file tree
Showing 6 changed files with 722 additions and 2 deletions.
4 changes: 2 additions & 2 deletions docs/_tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ JAX tutorials draft

.. note::

This is a
This is a
The tutorials below are a work in progress; for the time being, please refer
to the older tutorial content, including :ref:`beginner-guide`,
:ref:`user-guides`, and the now-deleted *JAX 101* tutorials.
Expand Down Expand Up @@ -44,7 +44,7 @@ JAX 201
advanced-debugging
external-callbacks
profiling-and-performance

../ffi/ffi.ipynb

JAX 301
-------
Expand Down
5 changes: 5 additions & 0 deletions docs/ffi/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CMake*
cmake*
Makefile
*.so
*.dylib
20 changes: 20 additions & 0 deletions docs/ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
cmake_minimum_required(VERSION 3.18...3.27)
project(rms_norm LANGUAGES CXX)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)

execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)

execute_process(
COMMAND "${Python_EXECUTABLE}"
"-c" "from jax.extend import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")

nanobind_add_module(rms_norm NOMINSIZE "rms_norm.cc")
target_include_directories(rms_norm PUBLIC ${XLA_DIR})
install(TARGETS rms_norm LIBRARY DESTINATION ${CMAKE_CURRENT_LIST_DIR})

Loading

0 comments on commit 7185032

Please sign in to comment.