diff --git a/.github/workflows/native_jni_s3_pytorch.yml b/.github/workflows/native_jni_s3_pytorch.yml index b31bf97ea0f..18aab758461 100644 --- a/.github/workflows/native_jni_s3_pytorch.yml +++ b/.github/workflows/native_jni_s3_pytorch.yml @@ -6,6 +6,10 @@ on: pt_version: description: 'pytorch version' required: false + cuda: + description: 'CUDA version' + required: true + default: 'cu121' schedule: - cron: '0 5 * * *' @@ -83,7 +87,8 @@ jobs: ./gradlew -Pjni -Ppt_version=$PYTORCH_VERSION :integration:test "-Dai.djl.default_engine=PyTorch" ./gradlew :engines:pytorch:pytorch-native:cleanJNI export TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0" - ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcuda=cu121 -Ppt_version=$PYTORCH_VERSION + CUDA_VERSION=${{ github.event.inputs.cuda }} + ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcuda=$CUDA_VERSION -Ppt_version=$PYTORCH_VERSION ./gradlew :engines:pytorch:pytorch-native:cleanJNI - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v3 @@ -134,7 +139,8 @@ jobs: ./gradlew :engines:pytorch:pytorch-native:cleanJNI rm -rf ~/.djl.ai export TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0" - ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcuda=cu121 -Pprecxx11 -Ppt_version=$PYTORCH_VERSION + CUDA_VERSION=${{ github.event.inputs.cuda }} + ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcuda=$CUDA_VERSION -Pprecxx11 -Ppt_version=$PYTORCH_VERSION - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v2 with: @@ -185,7 +191,7 @@ jobs: call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" amd64 set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v11.7" set "PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH%" - gradlew :engines:pytorch:pytorch-native:cleanJNI :engines:pytorch:pytorch-native:compileJNI -Pcuda=cu121 -Ppt_version=${{ github.event.inputs.pt_version }} + gradlew :engines:pytorch:pytorch-native:cleanJNI :engines:pytorch:pytorch-native:compileJNI -Pcuda=${{ github.event.inputs.cuda }} -Ppt_version=${{ github.event.inputs.pt_version }} - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v4 with: diff --git a/engines/pytorch/pytorch-native/build.cmd b/engines/pytorch/pytorch-native/build.cmd index 256636c611a..adb3e63741d 100644 --- a/engines/pytorch/pytorch-native/build.cmd +++ b/engines/pytorch/pytorch-native/build.cmd @@ -19,8 +19,17 @@ if exist %FILEPATH% ( echo Finished downloading libtorch ) -if "%VERSION%" == "1.11.0" ( - set PT_VERSION=V1_11_X +if "%VERSION%" == "1.13.1" ( + set PT_VERSION=V1_13_X +) +if "%VERSION%" == "2.0.1" ( + set PT_VERSION=V1_13_X +) +if "%VERSION%" == "2.1.1" ( + set PT_VERSION=V1_13_X +) +if "%VERSION%" == "2.1.2" ( + set PT_VERSION=V1_13_X ) if /i "%2:~0,2%" == "cu" ( diff --git a/engines/pytorch/pytorch-native/build.sh b/engines/pytorch/pytorch-native/build.sh index b020368d3e6..3e4f03d48a9 100755 --- a/engines/pytorch/pytorch-native/build.sh +++ b/engines/pytorch/pytorch-native/build.sh @@ -53,7 +53,7 @@ if [[ ! -d "libtorch" ]]; then fi fi -if [[ "$VERSION" == "1.13.1" ]]; then +if [[ "$VERSION" == "1.13.1" || "$VERSION" == "2.0.1" || "$VERSION" =~ ^(2.1.*)$ ]]; then PT_VERSION=V1_13_X fi