Skip to content

Commit

Permalink
ci: bump TF to 2.18, PT to 2.5 (#4228)
Browse files Browse the repository at this point in the history
This is prepared for the upcoming TF 2.18, which needs CUDNN 9. In the
future, I may move all pinnings into pyproject.toml...


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
	- Enhanced dependency management for CUDA and Python workflows.
- Introduced new jobs for better organization of test duration handling.

- **Bug Fixes**
- Updated TensorFlow and Torch versions for improved compatibility and
performance.
- Refined version requirements for TensorFlow based on detected CUDA
versions.

- **Documentation**
- Adjusted testing commands and linting configurations for clarity and
compliance.

- **Chores**
	- Streamlined caching mechanisms to optimize test duration tracking.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 24, 2024
1 parent 4d50048 commit 02580c2
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
if: false # skip as we use nvidia image
- run: python -m pip install -U uv
- run: source/install/uv_with_retry.sh pip install --system "tensorflow>=2.15.0rc0" "torch==2.3.1.*"
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0"
- run: |
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
Expand All @@ -63,7 +63,7 @@ jobs:
CUDA_VISIBLE_DEVICES: 0
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.1%2Bcu121.zip -O libtorch.zip
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip
unzip libtorch.zip
- run: |
export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- run: python -m pip install -U uv
- run: |
source/install/uv_with_retry.sh pip install --system mpich
source/install/uv_with_retry.sh pip install --system "torch==2.3.0+cpu.cxx11.abi" -i https://download.pytorch.org/whl/
source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py
env:
Expand Down
18 changes: 18 additions & 0 deletions backend/find_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
import os
import platform
import site
from functools import (
lru_cache,
Expand All @@ -22,6 +23,9 @@
Union,
)

from packaging.specifiers import (
SpecifierSet,
)
from packaging.version import (
Version,
)
Expand Down Expand Up @@ -104,6 +108,20 @@ def get_pt_requirement(pt_version: str = "") -> dict:
"""
if pt_version is None:
return {"torch": []}
if (
os.environ.get("CIBUILDWHEEL", "0") == "1"
and platform.system() == "Linux"
and platform.machine() == "x86_64"
):
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
# CUDA 12.2, cudnn 9
pt_version = "2.5.0"
elif cuda_version in SpecifierSet(">=11,<12"):
# CUDA 11.8, cudnn 8
pt_version = "2.3.1"
else:
raise RuntimeError("Unsupported CUDA version") from None
if pt_version == "":
pt_version = os.environ.get("PYTORCH_VERSION", "")

Expand Down
6 changes: 3 additions & 3 deletions backend/find_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]:
if os.environ.get("CIBUILDWHEEL", "0") == "1":
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
# CUDA 12.2
# CUDA 12.2, cudnn 9
requires.extend(
[
"tensorflow-cpu>=2.15.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'",
"tensorflow-cpu>=2.18.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'",
]
)
elif cuda_version in SpecifierSet(">=11,<12"):
# CUDA 11.8
# CUDA 11.8, cudnn 8
requires.extend(
[
"tensorflow-cpu>=2.5.0rc0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'",
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ cu12 = [
"nvidia-curand-cu12",
"nvidia-cusolver-cu12",
"nvidia-cusparse-cu12",
"nvidia-cudnn-cu12<9",
"nvidia-cudnn-cu12",
"nvidia-cuda-nvcc-cu12",
]
jax = [
Expand Down Expand Up @@ -279,8 +279,6 @@ PATH = "/usr/lib64/mpich/bin:$PATH"
UV_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cpu"
# trick to find the correction version of mpich
CMAKE_PREFIX_PATH="/opt/python/cp311-cp311/"
# PT 2.4.0 requires cudnn 9, incompatible with TF with cudnn 8
PYTORCH_VERSION = "2.3.1"

[tool.cibuildwheel.windows]
test-extras = ["cpu", "torch"]
Expand Down

0 comments on commit 02580c2

Please sign in to comment.