From 38971a0cae3104affea29ffd458a5f68b5a42692 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 9 Dec 2024 05:35:44 -0800 Subject: [PATCH] Patch jsonargparse for Python >= 3.12.8 (#20479) * Patch argparse _parse_known_args * Add patch to test * Avoid importing lightning in assistant * Fix return type --- .actions/assistant.py | 15 +++++++++++++++ .github/workflows/ci-tests-fabric.yml | 18 +++++++++--------- .github/workflows/ci-tests-pytorch.yml | 18 +++++++++--------- examples/fabric/tensor_parallel/train.py | 3 ++- examples/pytorch/tensor_parallel/train.py | 3 ++- src/lightning/pytorch/cli.py | 14 ++++++++++++++ tests/parity_fabric/test_parity_ddp.py | 3 +++ .../checkpointing/test_model_checkpoint.py | 2 +- 8 files changed, 55 insertions(+), 21 deletions(-) diff --git a/.actions/assistant.py b/.actions/assistant.py index dc1fa055d5e62..6bb0bc201ed05 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -483,6 +483,21 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None: if __name__ == "__main__": + import sys + import jsonargparse + from jsonargparse import ArgumentParser + + def patch_jsonargparse_python_3_12_8(): + if sys.version_info < (3, 12, 8): + return + + def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]: + namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore + return namespace, args + + setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch) + + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 jsonargparse.CLI(AssistantCLI, as_positional=False) diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index c6d1dbf5b5ff2..c2fda73dcf1f4 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -49,16 +49,16 @@ jobs: - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } # only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" } + - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" } # "oldest" versions tests, only on minimum Python - { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" } - { diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index 0112d21d33a5c..ba8519ea8ed8a 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -53,16 +53,16 @@ jobs: - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } # only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" } + - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" } # "oldest" versions tests, only on minimum Python - { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" } - { diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py index 4a98f12cf6168..1435e5c2003e1 100644 --- a/examples/fabric/tensor_parallel/train.py +++ b/examples/fabric/tensor_parallel/train.py @@ -1,13 +1,14 @@ import lightning as L import torch import torch.nn.functional as F -from data import RandomTokenDataset from lightning.fabric.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader +from data import RandomTokenDataset + def train(): strategy = ModelParallelStrategy( diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py index 6a91e1242e4af..37c620f4582f0 100644 --- a/examples/pytorch/tensor_parallel/train.py +++ b/examples/pytorch/tensor_parallel/train.py @@ -1,13 +1,14 @@ import lightning as L import torch import torch.nn.functional as F -from data import RandomTokenDataset from lightning.pytorch.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader +from data import RandomTokenDataset + class Llama3(L.LightningModule): def __init__(self): diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index c79f2481c8af4..5e70a9bdd8b9e 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -37,6 +37,18 @@ _JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7") + +def patch_jsonargparse_python_3_12_8() -> None: + if sys.version_info < (3, 12, 8): + return + + def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]: + namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore + return namespace, args + + setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch) + + if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser from jsonargparse import ( @@ -48,6 +60,8 @@ set_config_read_mode, ) + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 + register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483 set_config_read_mode(fsspec_enabled=True) else: diff --git a/tests/parity_fabric/test_parity_ddp.py b/tests/parity_fabric/test_parity_ddp.py index 217d401ad6fba..aebd9064b31fd 100644 --- a/tests/parity_fabric/test_parity_ddp.py +++ b/tests/parity_fabric/test_parity_ddp.py @@ -162,5 +162,8 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float if __name__ == "__main__": from jsonargparse import CLI + from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8 + + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 CLI(run_parity_test) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index d43f07179e7bb..ef0abc7c463a8 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -29,10 +29,10 @@ import pytest import torch import yaml -from jsonargparse import ArgumentParser from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.cli import LightningArgumentParser as ArgumentParser from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException