Skip to content

Commit

Permalink
[rust] Build .so file for each cuda arch (#3410)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Aug 13, 2024
1 parent a22b905 commit a30355e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 40 deletions.
16 changes: 4 additions & 12 deletions .github/workflows/native_s3_huggingface.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,18 @@ jobs:
build-tokenizers-jni-cu122:
if: github.repository == 'deepjavalibrary/djl'
runs-on: [ self-hosted, g5 ]
timeout-minutes: 30
timeout-minutes: 60
needs: create-runners
container:
image: nvidia/cuda:12.2.2-cudnn8-devel-ubuntu20.04
image: nvidia/cuda:12.2.2-devel-ubuntu20.04
options: --gpus all --runtime=nvidia
env:
CUDA_VERSION: cu122
steps:
- name: Install Environment
run: |
apt-get -y update
apt-get -y install curl git
- name: Set up Python3
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install Python Dependencies
run: |
python -m pip install --upgrade pip
apt-get -y install curl git python3-pip
pip install awscli wheel setuptools --upgrade
- uses: actions-rs/toolchain@v1
with:
Expand All @@ -223,7 +216,6 @@ jobs:
run: |
. "$HOME/.cargo/env"
./gradlew :extensions:tokenizers:compileJNI -Pcuda=${{ env.CUDA_VERSION }}
./gradlew -Pjni :extensions:tokenizers:test
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v2
with:
Expand All @@ -234,7 +226,7 @@ jobs:
run: |
DJL_VERSION=$(awk -F '=' '/djl / {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)
TOKENIZERS_VERSION="$(awk -F '=' '/tokenizers/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)"
aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-x86_64/${{ env.CUDA_VERSION }} s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/${DJL_VERSION}/linux-x86_64/${{ env.CUDA_VERSION }}/
aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-x86_64/ s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/${DJL_VERSION}/linux-x86_64/
aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*"
stop-runners:
Expand Down
49 changes: 31 additions & 18 deletions extensions/tokenizers/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@

set -e
WORK_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
NUM_PROC=1
if [[ -n $(command -v nproc) ]]; then
NUM_PROC=$(nproc)
elif [[ -n $(command -v sysctl) ]]; then
NUM_PROC=$(sysctl -n hw.ncpu)
fi
PLATFORM=$(uname | tr '[:upper:]' '[:lower:]')

VERSION=v$1
ARCH=$2
FLAVOR=$3

pushd $WORK_DIR
pushd "$WORK_DIR"
if [ ! -d "tokenizers" ]; then
git clone https://github.com/huggingface/tokenizers -b $VERSION
git clone https://github.com/huggingface/tokenizers -b "$VERSION"
fi

if [ ! -d "build" ]; then
Expand All @@ -28,20 +22,39 @@ mkdir build/classes
javac -sourcepath src/main/java/ src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java -h build/include -d build/classes
javac -sourcepath src/main/java/ src/main/java/ai/djl/engine/rust/RustLibrary.java -h build/include -d build/classes

function copy_files() {
# for nightly ci
arch="$1"
flavor="$2"
if [[ $PLATFORM == 'darwin' ]]; then
mkdir -p "build/jnilib/osx-$arch/$flavor"
cp -f rust/target/release/libdjl.dylib "build/jnilib/osx-$arch/$flavor/libtokenizers.dylib"
elif [[ $PLATFORM == 'linux' ]]; then
mkdir -p "build/jnilib/linux-$arch/$flavor"
cp -f rust/target/release/libdjl.so "build/jnilib/linux-$arch/$flavor/libtokenizers.so"
fi
}

RUST_MANIFEST=rust/Cargo.toml
if [[ "$FLAVOR" = "cpu"* ]]; then
cargo build --manifest-path $RUST_MANIFEST --release
copy_files "$ARCH" "$FLAVOR"
elif [[ "$FLAVOR" = "cu"* && "$FLAVOR" > "cu121" ]]; then
cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn
CUDA_COMPUTE_CAP=80 cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn
copy_files "$ARCH" "${FLAVOR}-80"

cargo clean --manifest-path $RUST_MANIFEST
CUDA_COMPUTE_CAP=86 cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn
copy_files "$ARCH" "${FLAVOR}-86"

cargo clean --manifest-path $RUST_MANIFEST
CUDA_COMPUTE_CAP=89 cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn
copy_files "$ARCH" "${FLAVOR}-89"

cargo clean --manifest-path $RUST_MANIFEST
CUDA_COMPUTE_CAP=90 cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn
copy_files "$ARCH" "${FLAVOR}-90"
else
cargo build --manifest-path $RUST_MANIFEST --release
fi

# for nightly ci
if [[ $PLATFORM == 'darwin' ]]; then
mkdir -p build/jnilib/osx-$ARCH/$FLAVOR
cp -f rust/target/release/libdjl.dylib build/jnilib/osx-$ARCH/$FLAVOR/libtokenizers.dylib
elif [[ $PLATFORM == 'linux' ]]; then
mkdir -p build/jnilib/linux-$ARCH/$FLAVOR
cp -f rust/target/release/libdjl.so build/jnilib/linux-$ARCH/$FLAVOR/libtokenizers.so
copy_files "$ARCH" "$FLAVOR"
fi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -39,6 +42,8 @@ public final class LibUtils {
Pattern.compile(
"(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)-(\\d+\\.\\d+\\.\\d+)(-SNAPSHOT)?(-\\d+)?");
private static final int[] SUPPORTED_CUDA_VERSIONS = {122};
private static final Set<String> SUPPORTED_CUDA_ARCH =
new HashSet<>(Arrays.asList("80", "86", "89", "90"));

private static EngineException exception;

Expand Down Expand Up @@ -90,6 +95,10 @@ private static Path copyJniLibrary(String[] libs) {
String os = platform.getOsPrefix();
String classifier = platform.getClassifier();
String version = platform.getVersion();
String cudaArch = platform.getCudaArch();
if (cudaArch == null) {
cudaArch = "";
}
String flavor = Utils.getEnvOrSystemProperty("TOKENIZERS_FLAVOR");
boolean override = flavor != null && !flavor.isEmpty();
if (override) {
Expand All @@ -104,20 +113,26 @@ private static Path copyJniLibrary(String[] libs) {

// Find the highest matching CUDA version
if (flavor.startsWith("cu")) {
int cudaVersion = Integer.parseInt(flavor.substring(2, 5));
boolean match = false;
for (int v : SUPPORTED_CUDA_VERSIONS) {
if (override && cudaVersion == v) {
match = true;
break;
} else if (cudaVersion >= v) {
flavor = "cu" + v;
match = true;
break;
if (SUPPORTED_CUDA_ARCH.contains(cudaArch)) {
int cudaVersion = Integer.parseInt(flavor.substring(2, 5));
for (int v : SUPPORTED_CUDA_VERSIONS) {
if (override && cudaVersion == v) {
match = true;
break;
} else if (cudaVersion >= v) {
flavor = "cu" + v + "-" + cudaArch;
match = true;
break;
}
}
}
if (!match) {
logger.warn("No matching cuda flavor for {} found: {}.", classifier, flavor);
logger.warn(
"No matching cuda flavor for {} found: {}/sm_{}.",
classifier,
flavor,
cudaArch);
flavor = "cpu"; // Fallback to CPU
}
}
Expand Down

0 comments on commit a30355e

Please sign in to comment.