Skip to content

Commit

Permalink
Merge branch 'master' into cli-jsonargparse-change
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Nov 12, 2024
2 parents ffc3c21 + b0aa504 commit fcf9805
Show file tree
Hide file tree
Showing 30 changed files with 132 additions and 103 deletions.
2 changes: 1 addition & 1 deletion .azure/gpu-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
variables:
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
container:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
options: "--gpus=all --shm-size=32g"
strategy:
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .azure/gpu-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "fabric"
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
PACKAGE_NAME: "lightning"
workspace:
clean: all
Expand Down
2 changes: 1 addition & 1 deletion .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "pytorch"
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
PACKAGE_NAME: "lightning"
pool: lit-rtx-3090
variables:
Expand Down
64 changes: 36 additions & 28 deletions .github/checkgroup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,26 @@ subprojects:
checks:
- "pl-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
- "pl-cpu (macOS-14, lightning, 3.10, 2.1)"
- "pl-cpu (macOS-14, lightning, 3.11, 2.2)"
- "pl-cpu (macOS-14, lightning, 3.11, 2.2.2)"
- "pl-cpu (macOS-14, lightning, 3.11, 2.3)"
- "pl-cpu (macOS-14, lightning, 3.12, 2.4)"
- "pl-cpu (macOS-14, lightning, 3.12, 2.4.1)"
- "pl-cpu (macOS-14, lightning, 3.12, 2.5.1)"
- "pl-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)"
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)"
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
- "pl-cpu (ubuntu-20.04, lightning, 3.12, 2.4)"
- "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)"
- "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)"
- "pl-cpu (windows-2022, lightning, 3.9, 2.1, oldest)"
- "pl-cpu (windows-2022, lightning, 3.10, 2.1)"
- "pl-cpu (windows-2022, lightning, 3.11, 2.2)"
- "pl-cpu (windows-2022, lightning, 3.11, 2.2.2)"
- "pl-cpu (windows-2022, lightning, 3.11, 2.3)"
- "pl-cpu (windows-2022, lightning, 3.12, 2.4)"
- "pl-cpu (windows-2022, lightning, 3.12, 2.4.1)"
- "pl-cpu (windows-2022, lightning, 3.12, 2.5.1)"
- "pl-cpu (macOS-14, pytorch, 3.9, 2.1)"
- "pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1)"
- "pl-cpu (windows-2022, pytorch, 3.9, 2.1)"
- "pl-cpu (macOS-12, pytorch, 3.10, 2.1)"
- "pl-cpu (macOS-13, pytorch, 3.10, 2.1)"
- "pl-cpu (ubuntu-22.04, pytorch, 3.10, 2.1)"
- "pl-cpu (windows-2022, pytorch, 3.10, 2.1)"

Expand Down Expand Up @@ -141,15 +144,17 @@ subprojects:
- "!*.md"
- "!**/*.md"
checks:
- "build-cuda (3.11, 2.1, 12.1.0)"
- "build-cuda (3.11, 2.2, 12.1.0)"
- "build-cuda (3.11, 2.3, 12.1.0)"
- "build-cuda (3.12, 2.4, 12.1.0)"
- "build-cuda (3.10, 2.1.2, 12.1.0)"
- "build-cuda (3.11, 2.2.2, 12.1.0)"
- "build-cuda (3.11, 2.3.1, 12.1.0)"
- "build-cuda (3.11, 2.4.1, 12.1.0)"
- "build-cuda (3.12, 2.5.1, 12.1.0)"
#- "build-NGC"
- "build-pl (3.11, 2.1, 12.1.0)"
- "build-pl (3.10, 2.1, 12.1.0)"
- "build-pl (3.11, 2.2, 12.1.0)"
- "build-pl (3.11, 2.3, 12.1.0)"
- "build-pl (3.12, 2.4, 12.1.0)"
- "build-pl (3.11, 2.4, 12.1.0)"
- "build-pl (3.12, 2.5, 12.1.0)"

# SECTION: lightning_fabric

Expand All @@ -168,23 +173,26 @@ subprojects:
checks:
- "fabric-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
- "fabric-cpu (macOS-14, lightning, 3.10, 2.1)"
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2)"
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2.2)"
- "fabric-cpu (macOS-14, lightning, 3.11, 2.3)"
- "fabric-cpu (macOS-14, lightning, 3.12, 2.4)"
- "fabric-cpu (macOS-14, lightning, 3.12, 2.4.1)"
- "fabric-cpu (macOS-14, lightning, 3.12, 2.5.1)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.12, 2.4)"
- "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)"
- "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)"
- "fabric-cpu (windows-2022, lightning, 3.9, 2.1, oldest)"
- "fabric-cpu (windows-2022, lightning, 3.10, 2.1)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2.2)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.3)"
- "fabric-cpu (windows-2022, lightning, 3.12, 2.4)"
- "fabric-cpu (windows-2022, lightning, 3.12, 2.4.1)"
- "fabric-cpu (windows-2022, lightning, 3.12, 2.5.1)"
- "fabric-cpu (macOS-14, fabric, 3.9, 2.1)"
- "fabric-cpu (ubuntu-20.04, fabric, 3.9, 2.1)"
- "fabric-cpu (windows-2022, fabric, 3.9, 2.1)"
- "fabric-cpu (macOS-12, fabric, 3.10, 2.1)"
- "fabric-cpu (macOS-13, fabric, 3.10, 2.1)"
- "fabric-cpu (ubuntu-22.04, fabric, 3.10, 2.1)"
- "fabric-cpu (windows-2022, fabric, 3.10, 2.1)"

Expand Down Expand Up @@ -258,14 +266,14 @@ subprojects:
- "install-pkg (ubuntu-22.04, lightning, 3.11)"
- "install-pkg (ubuntu-22.04, notset, 3.9)"
- "install-pkg (ubuntu-22.04, notset, 3.11)"
- "install-pkg (macOS-12, fabric, 3.9)"
- "install-pkg (macOS-12, fabric, 3.11)"
- "install-pkg (macOS-12, pytorch, 3.9)"
- "install-pkg (macOS-12, pytorch, 3.11)"
- "install-pkg (macOS-12, lightning, 3.9)"
- "install-pkg (macOS-12, lightning, 3.11)"
- "install-pkg (macOS-12, notset, 3.9)"
- "install-pkg (macOS-12, notset, 3.11)"
- "install-pkg (macOS-13, fabric, 3.9)"
- "install-pkg (macOS-13, fabric, 3.11)"
- "install-pkg (macOS-13, pytorch, 3.9)"
- "install-pkg (macOS-13, pytorch, 3.11)"
- "install-pkg (macOS-13, lightning, 3.9)"
- "install-pkg (macOS-13, lightning, 3.11)"
- "install-pkg (macOS-13, notset, 3.9)"
- "install-pkg (macOS-13, notset, 3.11)"
- "install-pkg (windows-2022, fabric, 3.9)"
- "install-pkg (windows-2022, fabric, 3.11)"
- "install-pkg (windows-2022, pytorch, 3.9)"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-pkg-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-22.04", "macOS-12", "windows-2022"]
os: ["ubuntu-22.04", "macOS-13", "windows-2022"]
pkg-name: ["fabric", "pytorch", "lightning", "notset"]
python-version: ["3.9", "3.11"]
steps:
Expand Down
17 changes: 10 additions & 7 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,20 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
- { os: "macOS-13", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
# "oldest" versions tests, only on minimum Python
Expand Down
17 changes: 10 additions & 7 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,20 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
- { os: "macOS-13", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
# "oldest" versions tests, only on minimum Python
Expand Down
22 changes: 15 additions & 7 deletions .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ jobs:
include:
# We only release one docker image per PyTorch version.
# Make sure the matrix here matches the one below.
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
- { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
- { python_version: "3.12", pytorch_version: "2.5", cuda_version: "12.1.0" }
steps:
- uses: actions/checkout@v4
with:
Expand Down Expand Up @@ -103,10 +104,11 @@ jobs:
include:
# These are the base images for PL release docker images.
# Make sure the matrix here matches the one above.
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
- { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" }
- { python_version: "3.10", pytorch_version: "2.1.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.0" }
- { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.0" }
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
Expand All @@ -115,6 +117,12 @@ jobs:
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}

- name: shorten Torch version
run: |
# convert 1.10.2 to 1.10
pt_version=$(echo ${{ matrix.pytorch_version }} | cut -d. -f1,2)
echo "PT_VERSION=$pt_version" >> $GITHUB_ENV
- uses: docker/build-push-action@v6
with:
build-args: |
Expand All @@ -123,7 +131,7 @@ jobs:
CUDA_VERSION=${{ matrix.cuda_version }}
file: dockers/base-cuda/Dockerfile
push: ${{ env.PUSH_NIGHTLY }}
tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }}"
tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ env.PT_VERSION }}-cuda${{ matrix.cuda_version }}"
timeout-minutes: 95
- uses: ravsamhq/notify-slack-action@v2
if: failure() && env.PUSH_NIGHTLY == 'true'
Expand Down
2 changes: 1 addition & 1 deletion _notebooks
2 changes: 1 addition & 1 deletion docs/source-pytorch/levels/expert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Customize and extend Lightning for things like custom hardware or distributed st
:header: Level 24: Add a new accelerator or Strategy
:description: Integrate a new accelerator or distributed strategy.
:col_css: col-md-6
:button_link: expert_level_27.html
:button_link: expert_level_24.html
:height: 150
:tag: expert

Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <2.5.0
torch >=2.1.0, <2.6.0
fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
Expand Down
4 changes: 2 additions & 2 deletions requirements/fabric/examples.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torchvision >=0.16.0, <0.20.0
torchmetrics >=0.10.0, <1.3.0
torchvision >=0.16.0, <0.21.0
torchmetrics >=0.10.0, <1.5.0
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/fabric/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ pytest-rerunfailures ==12.0
pytest-random-order ==1.1.0
click ==8.1.7
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version
4 changes: 2 additions & 2 deletions requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <2.5.0
torch >=2.1.0, <2.6.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2024.4.0
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
lightning-utilities >=0.10.0, <0.12.0
4 changes: 2 additions & 2 deletions requirements/pytorch/examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

requests <2.32.0
torchvision >=0.16.0, <0.20.0
torchvision >=0.16.0, <0.21.0
ipython[all] <8.15.0
torchmetrics >=0.10.0, <1.3.0
torchmetrics >=0.10.0, <1.5.0
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mypy==1.11.0
torch==2.4.1
torch==2.5.1

types-Markdown
types-PyYAML
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
import sys

from lightning_utilities.core.imports import package_available

Expand All @@ -26,6 +27,10 @@
# https://github.com/pytorch/pytorch/issues/83973
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"

# see https://github.com/pytorch/pytorch/issues/139990
if sys.platform == "win32":
os.environ["USE_LIBUV"] = "0"


from lightning.fabric.fabric import Fabric # noqa: E402
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def log(
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx, # type: ignore[arg-type]
reduce_fx=reduce_fx,
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
Expand Down Expand Up @@ -1405,7 +1405,9 @@ def forward(self, x):
input_sample = self._apply_batch_transfer_handler(input_sample)

file_path = str(file_path) if isinstance(file_path, Path) else file_path
torch.onnx.export(self, input_sample, file_path, **kwargs)
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
# BytesIO does work, too.
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
self.train(mode)

@torch.no_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], m

return batch_size

@torch.compiler.disable
def log(
self,
fx: str,
Expand Down Expand Up @@ -413,6 +414,7 @@ def log(
batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

@torch.compiler.disable
def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
Expand Down
Loading

0 comments on commit fcf9805

Please sign in to comment.