Skip to content

Commit

Permalink
Support PyTorch 1.10. (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 24, 2021
1 parent cae610a commit d061bc6
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 52 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-18.04, macos-10.15]
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"]
# Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.0,
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
# Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10
python-version: [3.6, 3.7, 3.8, 3.9]
exclude:
- python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
Expand Down
47 changes: 41 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,42 @@ jobs:
matrix:
os: [ubuntu-18.04]
# from https://download.pytorch.org/whl/torch_stable.html
# 1.9.0 supports: cuda10.2 (default), 11.1
# Note: There are no torch versions for CUDA 11.2
#
# 1.10 supports: cuda10.2 (default), 11.1, 11.3
# 1.9.x supports: cuda10.2 (default), 11.1
# PyTorch 1.8.x supports: cuda 10.1, 10.2 (default), 11.1
# PyTorch 1.7.x supports: cuda 10.1, 10.2 (default), 11.0
# PyTorch 1.6.0 supports: cuda 10.1, 10.2 (default)
# PyTorch 1.5.x supports: cuda 10.1, 10.2 (default)
# Other PyTorch versions are not tested
cuda: ["10.1", "10.2", "11.0", "11.1"]
# CUDA 11.3 is for torch 1.10
cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"]
gcc: ["7"]
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"]
# Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.0
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
#
# Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10
python-version: [3.6, 3.7, 3.8, 3.9]
exclude:
- cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0]
- cuda: "11.3" # exclude 11.3 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1]
torch: "1.5.0"
- cuda: "11.3"
torch: "1.5.1"
- cuda: "11.3"
torch: "1.6.0"
- cuda: "11.3"
torch: "1.7.0"
- cuda: "11.3"
torch: "1.7.1"
- cuda: "11.3"
torch: "1.8.0"
- cuda: "11.3"
torch: "1.8.1"
- cuda: "11.3"
torch: "1.9.0"
- cuda: "11.3"
torch: "1.9.1"
- cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10]
torch: "1.5.0"
- cuda: "11.0"
torch: "1.5.1"
Expand All @@ -61,6 +84,10 @@ jobs:
torch: "1.8.1"
- cuda: "11.0"
torch: "1.9.0"
- cuda: "11.0"
torch: "1.9.1"
- cuda: "11.0"
torch: "1.10"
- cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1]
torch: "1.5.0"
- cuda: "11.1"
Expand All @@ -71,8 +98,12 @@ jobs:
torch: "1.7.0"
- cuda: "11.1"
torch: "1.7.1"
- cuda: "10.1" # exclude CUDA 10.1 for [1.9.0]
- cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10]
torch: "1.9.0"
- cuda: "10.1"
torch: "1.9.1"
- cuda: "10.1"
torch: "1.10"
- python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
torch: "1.5.0"
- python-version: 3.9
Expand Down Expand Up @@ -117,6 +148,10 @@ jobs:
echo "CXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV
echo "CUDAHOSTCXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV
- name: Install git lfs
run: |
sudo apt-get install -y git-lfs
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
Expand Down
27 changes: 19 additions & 8 deletions .github/workflows/build_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,19 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-18.04]
# anaconda does not support 3.9 as of 2021.05.08
python-version: [3.6, 3.7, 3.8, 3.9]
# python-version: [3.6, 3.7, 3.8]
cuda: ["10.1", "10.2", "11.0", "11.1"]
cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"]
# from https://download.pytorch.org/whl/torch_stable.html
#
# PyTorch 1.9.0 supports: 10.2 (default), 11.1
# PyTorch 1.10 supports: 10.2 (default), 11.1, 11.3
# PyTorch 1.9.x supports: 10.2 (default), 11.1
# PyTorch 1.8.1 supports: cuda 10.1, 10.2 (default), 11.1
# PyTorch 1.8.0 supports: cuda 10.1, 10.2 (default), 11.1
# PyTorch 1.7.x supports: cuda 10.1, 10.2 (default), 11.0, 9.2 (not included in this setup)
# PyTorch 1.6.0 supports: cuda 10.1, 10.2 (default), 9.2 (not included in this setup)
# PyTorch 1.5.x supports: cuda 10.1, 10.2 (default), 9.2 (not included in this setup)
#
# PyTorch 1.8.x and 1.7.1 support 3.6, 3.7, 3.8, 3.9
# PyTorch 1.7.1, 1.8.x, 1.9.x, and 1.10 support 3.6, 3.7, 3.8, 3.9
# PyTorch 1.7.0, 1.6.0, and 1.5.x support 3.6, 3.7, 3.8
#
# Other PyTorch versions are not tested
Expand All @@ -57,9 +56,9 @@ jobs:
# https://github.com/csukuangfj/k2/runs/2533830771?check_suite_focus=true
# and
# https://github.com/NVIDIA/apex/issues/805
torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"]
torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
exclude:
# - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0]
# - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10]
# torch: "1.5.0"
# - cuda: "11.0"
# torch: "1.5.1"
Expand All @@ -71,6 +70,10 @@ jobs:
torch: "1.8.1"
- cuda: "11.0"
torch: "1.9.0"
- cuda: "11.0"
torch: "1.9.1"
- cuda: "11.0"
torch: "1.10"
# - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1]
# torch: "1.5.0"
# - cuda: "11.1"
Expand All @@ -81,8 +84,12 @@ jobs:
torch: "1.7.0"
- cuda: "11.1"
torch: "1.7.1"
- cuda: "10.1" # exclude 10.1 for [1.9.0]
- cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10]
torch: "1.9.0"
- cuda: "10.1"
torch: "1.9.1"
- cuda: "10.1"
torch: "1.10"
- python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
torch: "1.5.0"
- python-version: 3.9
Expand Down Expand Up @@ -142,6 +149,10 @@ jobs:
conda info
nproc
- name: Install git lfs
run: |
sudo apt-get install -y git-lfs
- name: Download cudnn 8.0
shell: bash -l {0}
env:
Expand Down
28 changes: 13 additions & 15 deletions .github/workflows/build_conda_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,25 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-16.04, macos-10.15]
# anaconda does not support 3.9 as of 2021.05.08
# python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.6, 3.7, 3.8]
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
# from https://download.pytorch.org/whl/torch_stable.html
#
# PyTorch 1.9.0, 1.8.x, and 1.7.1 support 3.6, 3.7, 3.8, 3.9
# PyTorch 1.10, 1.9.x, 1.8.x, and 1.7.1 support 3.6, 3.7, 3.8, 3.9
# PyTorch 1.7.0, 1.6.0, and 1.5.x support 3.6, 3.7, 3.8
#
# Other PyTorch versions are not tested
#
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"]
# exclude:
# - python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
# torch: "1.5.0"
# - python-version: 3.9
# torch: "1.5.1"
# - python-version: 3.9
# torch: "1.6.0"
# - python-version: 3.9
# torch: "1.7.0"
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
exclude:
- python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
torch: "1.5.0"
- python-version: 3.9
torch: "1.5.1"
- python-version: 3.9
torch: "1.6.0"
- python-version: 3.9
torch: "1.7.0"

steps:
# refer to https://github.com/actions/checkout
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/nightly-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-18.04, macos-10.15]
# Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.0
# Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0"]
torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
exclude:
- python-version: 3.9 # exclude Python 3.9 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0]
torch: "1.4.0"
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
name: nightly

on:
push:
branches:
- nightly
schedule:
# minute (0-59)
# hour (0-23)
Expand Down Expand Up @@ -80,6 +83,10 @@ jobs:
./scripts/github_actions/install_torch.sh
python3 -c "import torch; print('torch version:', torch.__version__)"
- name: Install git lfs
run: |
sudo apt-get install -y git-lfs
- name: Download cudnn 8.0
env:
cuda: ${{ matrix.cuda }}
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ jobs:
./scripts/github_actions/install_torch.sh
python3 -c "import torch; print('torch version:', torch.__version__)"
- name: Install git lfs
run: |
sudo apt-get install -y git-lfs
- name: Download cudnn 8.0
env:
cuda: ${{ matrix.cuda }}
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/wheel-stable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ jobs:
./scripts/github_actions/install_torch.sh
python3 -c "import torch; print('torch version:', torch.__version__)"
- name: Install git lfs
run: |
sudo apt-get install -y git-lfs
- name: Download cudnn 8.0
env:
cuda: ${{ matrix.cuda }}
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ jobs:
./scripts/github_actions/install_torch.sh
python3 -c "import torch; print('torch version:', torch.__version__)"
- name: Install git lfs
run: |
sudo apt-get install -y git-lfs
- name: Download cudnn 8.0
env:
cuda: ${{ matrix.cuda }}
Expand Down
11 changes: 9 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,17 @@ if(K2_WITH_CUDA)
cuda_select_nvcc_arch_flags(K2_COMPUTE_ARCH_FLAGS)
message(STATUS "K2_COMPUTE_ARCH_FLAGS: ${K2_COMPUTE_ARCH_FLAGS}")

# set(K2_COMPUTE_ARCHS 30 32 35 50 52 53 60 61 62 70 72)
message(WARNING "arch 62/72 are not supported for now")
# set(K2_COMPUTE_ARCHS 30 32 35 50 52 53 60 61 62 70 72)
# message(WARNING "arch 62/72 are not supported for now")

# see https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/
# https://www.myzhar.com/blog/tutorials/tutorial-nvidia-gpu-cuda-compute-capability/
set(K2_COMPUTE_ARCH_CANDIDATES 35 50 60 61 70 75)
if(CUDA_VERSION VERSION_GREATER "11.0")
list(APPEND K2_COMPUTE_ARCH_CANDIDATES 80 86)
endif()
message(STATUS "K2_COMPUTE_ARCH_CANDIDATES ${K2_COMPUTE_ARCH_CANDIDATES}")

set(K2_COMPUTE_ARCHS)

foreach(COMPUTE_ARCH IN LISTS K2_COMPUTE_ARCH_CANDIDATES)
Expand Down
4 changes: 4 additions & 0 deletions scripts/github_actions/install_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ case "$cuda" in
# url=https://developer.download.nvidia.com/compute/cuda/11.1.0/local_installers/cuda_11.1.0_455.23.05_linux.run
url=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
;;
11.3)
# url=https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda_11.3.0_465.19.01_linux.run
url=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run
;;
*)
echo "Unknown cuda version: $cuda"
exit 1
Expand Down
27 changes: 11 additions & 16 deletions scripts/github_actions/install_cudnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,37 @@
case $cuda in
10.0)
filename=cudnn-10.0-linux-x64-v7.6.5.32.tgz
url=http://www.mediafire.com/file/1037lb1vmj9qdtq/cudnn-10.0-linux-x64-v7.6.5.32.tgz/file
;;
10.1)
filename=cudnn-10.1-linux-x64-v8.0.2.39.tgz
url=http://www.mediafire.com/file/fnl2wg0h757qhd7/cudnn-10.1-linux-x64-v8.0.2.39.tgz/file
;;
10.2)
filename=cudnn-10.2-linux-x64-v8.0.2.39.tgz
url=http://www.mediafire.com/file/sc2nvbtyg0f7ien/cudnn-10.2-linux-x64-v8.0.2.39.tgz/file
;;
11.0)
filename=cudnn-11.0-linux-x64-v8.0.5.39.tgz
url=https://www.mediafire.com/file/abyhnls106ko9kp/cudnn-11.0-linux-x64-v8.0.5.39.tgz/file
;;
11.1)
filename=cudnn-11.1-linux-x64-v8.0.5.39.tgz
url=https://www.mediafire.com/file/qx55zd65773xonv/cudnn-11.1-linux-x64-v8.0.5.39.tgz/file
filename=cudnn-11.1-linux-x64-v8.0.4.30.tgz
;;
11.3)
filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz
;;
*)
echo "Unsupported cuda version: $cuda"
exit 1
;;
esac

function retry() {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
command -v git-lfs >/dev/null 2>&1 || { echo >&2 "\nPlease install 'git-lfs' first."; exit 2; }

# It is forked from https://github.com/Juvenal-Yescas/mediafire-dl
# https://github.com/Juvenal-Yescas/mediafire-dl/pull/2 changes the filename and breaks the CI.
# We use a separate fork to keep the link fixed.
retry wget https://raw.githubusercontent.com/csukuangfj/mediafire-dl/master/mediafire_dl.py
git clone https://huggingface.co/csukuangfj/cudnn
cd cudnn
git lfs pull --include="$filename"

sed -i 's/quiet=False/quiet=True/' mediafire_dl.py
retry python3 mediafire_dl.py "$url"
sudo tar xf ./$filename -C /usr/local
rm -v ./$filename

# save disk space
git lfs prune && cd .. && rm -rf cudnn

sudo sed -i '59i#define CUDNN_MAJOR 8' /usr/local/cuda/include/cudnn.h
19 changes: 18 additions & 1 deletion scripts/github_actions/install_torch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ case ${torch} in
;;
esac
;;
1.9.0)
1.9.*)
case ${cuda} in
10.2)
package="torch==${torch}"
Expand All @@ -91,6 +91,23 @@ case ${torch} in
;;
esac
;;
1.10)
case ${cuda} in
10.2)
package="torch==${torch}"
# Leave it empty to use PyPI.
url=
;;
11.1)
package="torch==${torch}+cu111"
url=https://download.pytorch.org/whl/torch_stable.html
;;
11.3)
package="torch==${torch}+cu113"
url=https://download.pytorch.org/whl/torch_stable.html
;;
esac
;;
*)
echo "Unsupported PyTorch version: ${torch}"
exit 1
Expand Down

0 comments on commit d061bc6

Please sign in to comment.