forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Building on jax-ml#21925, this tutorial demonstrates the use of the FFI using `ffi_call` with a simple example. I don't think this should cover all of the most advanced use cases, but it should be sufficient for the most common examples. I think it would be useful to eventually replace the existing CUDA tutorial, but I'm not sure that it'll get there in the first draft. As an added benefit, this also runs a simple test (akin to `docs/cuda_custom_call`) which actually executes using a tool chain that open source users would use in practice.
- Loading branch information
Showing
7 changed files
with
1,011 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
CMake* | ||
cmake* | ||
Makefile | ||
*.so | ||
*.dylib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
cmake_minimum_required(VERSION 3.18...3.27) | ||
project(rms_norm LANGUAGES CXX) | ||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) | ||
|
||
execute_process( | ||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir | ||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) | ||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") | ||
find_package(nanobind CONFIG REQUIRED) | ||
|
||
execute_process( | ||
COMMAND "${Python_EXECUTABLE}" | ||
"-c" "from jax.extend import ffi; print(ffi.include_dir())" | ||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) | ||
message(STATUS "XLA include directory: ${XLA_DIR}") | ||
|
||
nanobind_add_module(rms_norm NOMINSIZE "rms_norm.cc") | ||
target_include_directories(rms_norm PUBLIC ${XLA_DIR}) | ||
install(TARGETS rms_norm LIBRARY DESTINATION ${CMAKE_CURRENT_LIST_DIR}) | ||
|
Oops, something went wrong.