Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for Apple Metal #430

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ jobs:
name: pjrt_plugin_xla_cpu-darwin-aarch64
path: pjrt_plugin_xla_cpu.dylib
if-no-files-found: error
pjrt-plugin-apple-metal:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2
- name: Build or fetch Apple Metal PJRT plugin
run: |
prefix=jax_metal-0.1.0-py3-none-macosx_11_0_arm64
curl -fsL "https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl" \
-o "$prefix.zip"
unzip "$prefix.zip"
mv "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib" pjrt_plugin_apple_metal.dylib
- name: Upload binary
uses: actions/upload-artifact@v4
with:
name: pjrt_plugin_apple_metal
path: pjrt_plugin_apple_metal.dylib
if-no-files-found: error
pjrt-plugin-xla-cuda-linux-x86_64:
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -181,6 +200,27 @@ jobs:
name: tests-xla-cpu-darwin-aarch64
path: test/xla-cpu/tests-xla-cpu.tar.gz
if-no-files-found: error
build-tests-apple-metal:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- name: Install build dependencies
run: |
brew install chezscheme
git clone https://github.com/stefan-hoeck/idris2-pack.git
(cd idris2-pack && make micropack SCHEME=chez)
~/.pack/bin/pack switch HEAD
- name: Build tests
working-directory: test/apple-metal
run: |
SPIDR_INSTALL_SUPPORT_LIBS=false ~/.pack/bin/pack --no-prompt build apple-metal.ipkg
tar cfz tests-apple-metal.tar.gz -C build/exec .
- name: Upload tests
uses: actions/upload-artifact@v4
with:
name: tests-apple-metal
path: test/apple-metal/tests-apple-metal.tar.gz
if-no-files-found: error
build-tests-xla-cuda-linux-x86_64:
runs-on: ubuntu-latest
container: ghcr.io/stefan-hoeck/idris2-pack
Expand Down Expand Up @@ -239,6 +279,25 @@ jobs:
run: |
tar xfz tests-xla-cpu.tar.gz && rm tests-xla-cpu.tar.gz
./test
test-apple-metal:
needs:
- pjrt-darwin-aarch64
- pjrt-plugin-apple-metal
- build-tests-apple-metal
runs-on: macos-latest
steps:
- name: Download artifacts
uses: actions/download-artifact@v4
with:
pattern: "{libc_xla-darwin-aarch64,pjrt_plugin_apple_metal,tests-apple-metal}"
merge-multiple: true
- name: Install runtime dependencies
run: |
brew install chezscheme
- name: Run tests
run: |
tar xfz tests-apple-metal.tar.gz && rm tests-apple-metal.tar.gz
./test
test-xla-cuda-linux-x86_64:
needs:
- pjrt-linux-x86_64
Expand Down
5 changes: 5 additions & 0 deletions pack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ type = "local"
path = ""
ipkg = "test/runner/runner.ipkg"

[custom.all.pjrt-plugin-apple-metal]
type = "local"
path = ""
ipkg = "pjrt-plugins/apple-metal/pjrt-plugin-apple-metal.ipkg"

[custom.all.pjrt-plugin-xla-cpu]
type = "local"
path = ""
Expand Down
2 changes: 1 addition & 1 deletion pjrt-plugins/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# PJRT Plugins

A PJRT plugin provides the compiler and hardware device support required to execute spidr graphs. We provide plugins for [CPU](xla-cpu/README.md) and [CUDA-enabled GPUs](xla-cuda/README.md). You can also use third-party plugins, or make your own.
A PJRT plugin provides the compiler and hardware device support required to execute spidr graphs. We provide plugins for [CPU](xla-cpu/README.md), [Apple Metal](apple-metal/README.md), and [CUDA-enabled GPUs](xla-cuda/README.md). You can also use third-party plugins, or make your own.

## How to integrate your own plugin

Expand Down
30 changes: 30 additions & 0 deletions pjrt-plugins/apple-metal/PjrtPluginAppleMetal.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{--
Copyright 2024 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module PjrtPluginAppleMetal

import System.FFI

import public Compiler.Xla.PJRT.C.PjrtCApi
import public Device

%foreign "C:GetPjrtApi,pjrt_plugin_apple_metal"
prim__getPjrtApi : PrimIO AnyPtr

export
device : Pjrt Device
device = do
api <- MkPjrtApi <$> primIO prim__getPjrtApi
MkDevice api <$> pjrtClientCreate api
10 changes: 10 additions & 0 deletions pjrt-plugins/apple-metal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# PJRT plugin for Apple Metal

This is the PJRT plugin for Apple Metal, which provides hardware acceleration with GPU on Apple silicon (AArch64, ARM64).

## Install

Run
```
pack install pjrt-plugin-apple-metal
```
33 changes: 33 additions & 0 deletions pjrt-plugins/apple-metal/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/sh -e

script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
cd "$script_dir/../.."
. ./dev.sh
rev=$(cat XLA_VERSION)

osu="$(uname)"
case $osu in
'Linux')
os=linux
arch=x86_64
ext=so
;;
'Darwin')
os=darwin
arch=aarch64
ext=dylib
;;
*)
echo "OS $osu not handled"
exit 1
;;
esac

xla_dir=$(mktemp -d)
install_xla "$rev" "$xla_dir"
(
cd "$xla_dir"
./configure.py --backend=CPU --os=$os
bazel build //xla/pjrt/c:pjrt_c_api_cpu_plugin.so
)
mv "$xla_dir/bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so" "pjrt_plugin_xla_cpu-$os-$arch.$ext"
11 changes: 11 additions & 0 deletions pjrt-plugins/apple-metal/pjrt-plugin-apple-metal.ipkg
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package pjrt-plugin-apple-metal
version = 0.0.1

depends = spidr
modules = PjrtPluginAppleMetal

brief = "XLA PJRT plugin for Apple Metal."
readme = "README.md"
license = "Apache License, Version 2.0"

postinstall = "./postinstall.sh"
25 changes: 25 additions & 0 deletions pjrt-plugins/apple-metal/postinstall.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/sh -e

if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi

script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
cd "$script_dir/../.."

os="$(uname)"
case $os in
'Darwin')
;;
*)
echo "WARNING: OS $os not supported, unable to fetch supporting libraries."
exit 0
;;
esac

prefix=jax_metal-0.1.0-py3-none-macosx_11_0_arm64
curl -fsL "https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl" \
-o "$prefix.zip"
unzip "$prefix.zip"
libdir="$(idris2 --libdir)/pjrt-plugin-xla-cpu-0.0.1/lib"
mkdir -p libdir
mv "$prefix/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib" "$libdir/pjrt_plugin_apple_metal.dylib"
rm -rf "$prefix.zip" $prefix
25 changes: 25 additions & 0 deletions test/apple-metal/Main.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{--
Copyright 2024 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module Main

import System

import TestRunner
import PjrtPluginAppleMetal

partial
main : IO ()
main = eitherT (die . show) run device
8 changes: 8 additions & 0 deletions test/apple-metal/apple-metal.ipkg
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package apple-metal

depends =
pjrt-plugin-apple-metal,
runner

executable = test
main = Main
2 changes: 1 addition & 1 deletion test/xla-cpu/XlaCpu.idr → test/xla-cpu/Main.idr
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module XlaCpu
module Main

import System

Expand Down
2 changes: 1 addition & 1 deletion test/xla-cpu/xla-cpu.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ depends =
runner

executable = test
main = XlaCpu
main = Main
2 changes: 1 addition & 1 deletion test/xla-cuda/XlaCuda.idr → test/xla-cuda/Main.idr
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module XlaCuda
module Main

import System

Expand Down
2 changes: 1 addition & 1 deletion test/xla-cuda/xla-cuda.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ depends =
runner

executable = test
main = XlaCuda
main = Main
Loading