Skip to content

Commit

Permalink
[ci] Adds cuda version as github actions parameter for Pytorch JNI bu…
Browse files Browse the repository at this point in the history
…ild (#3185)
  • Loading branch information
frankfliu authored May 14, 2024
1 parent d7ce12d commit 102f635
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
12 changes: 9 additions & 3 deletions .github/workflows/native_jni_s3_pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ on:
pt_version:
description: 'pytorch version'
required: false
cuda:
description: 'CUDA version'
required: true
default: 'cu121'
schedule:
- cron: '0 5 * * *'

Expand Down Expand Up @@ -83,7 +87,8 @@ jobs:
./gradlew -Pjni -Ppt_version=$PYTORCH_VERSION :integration:test "-Dai.djl.default_engine=PyTorch"
./gradlew :engines:pytorch:pytorch-native:cleanJNI
export TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0"
./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcuda=cu121 -Ppt_version=$PYTORCH_VERSION
CUDA_VERSION=${{ github.event.inputs.cuda }}
./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcuda=$CUDA_VERSION -Ppt_version=$PYTORCH_VERSION
./gradlew :engines:pytorch:pytorch-native:cleanJNI
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v3
Expand Down Expand Up @@ -134,7 +139,8 @@ jobs:
./gradlew :engines:pytorch:pytorch-native:cleanJNI
rm -rf ~/.djl.ai
export TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0"
./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcuda=cu121 -Pprecxx11 -Ppt_version=$PYTORCH_VERSION
CUDA_VERSION=${{ github.event.inputs.cuda }}
./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcuda=$CUDA_VERSION -Pprecxx11 -Ppt_version=$PYTORCH_VERSION
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v2
with:
Expand Down Expand Up @@ -185,7 +191,7 @@ jobs:
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" amd64
set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v11.7"
set "PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH%"
gradlew :engines:pytorch:pytorch-native:cleanJNI :engines:pytorch:pytorch-native:compileJNI -Pcuda=cu121 -Ppt_version=${{ github.event.inputs.pt_version }}
gradlew :engines:pytorch:pytorch-native:cleanJNI :engines:pytorch:pytorch-native:compileJNI -Pcuda=${{ github.event.inputs.cuda }} -Ppt_version=${{ github.event.inputs.pt_version }}
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v4
with:
Expand Down
13 changes: 11 additions & 2 deletions engines/pytorch/pytorch-native/build.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,17 @@ if exist %FILEPATH% (
echo Finished downloading libtorch
)

if "%VERSION%" == "1.11.0" (
set PT_VERSION=V1_11_X
if "%VERSION%" == "1.13.1" (
set PT_VERSION=V1_13_X
)
if "%VERSION%" == "2.0.1" (
set PT_VERSION=V1_13_X
)
if "%VERSION%" == "2.1.1" (
set PT_VERSION=V1_13_X
)
if "%VERSION%" == "2.1.2" (
set PT_VERSION=V1_13_X
)

if /i "%2:~0,2%" == "cu" (
Expand Down
2 changes: 1 addition & 1 deletion engines/pytorch/pytorch-native/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ if [[ ! -d "libtorch" ]]; then
fi
fi

if [[ "$VERSION" == "1.13.1" ]]; then
if [[ "$VERSION" == "1.13.1" || "$VERSION" == "2.0.1" || "$VERSION" =~ ^(2.1.*)$ ]]; then
PT_VERSION=V1_13_X
fi

Expand Down

0 comments on commit 102f635

Please sign in to comment.