Skip to content

Commit

Permalink
Patch jsonargparse for Python >= 3.12.8 (#20479)
Browse files Browse the repository at this point in the history
* Patch argparse _parse_known_args

* Add patch to test

* Avoid importing lightning in assistant

* Fix return type
  • Loading branch information
lantiga authored Dec 9, 2024
1 parent c09fc66 commit 38971a0
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 21 deletions.
15 changes: 15 additions & 0 deletions .actions/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,21 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:


if __name__ == "__main__":
import sys

import jsonargparse
from jsonargparse import ArgumentParser

def patch_jsonargparse_python_3_12_8():
if sys.version_info < (3, 12, 8):
return

def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
return namespace, args

setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)

patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

jsonargparse.CLI(AssistantCLI, as_positional=False)
18 changes: 9 additions & 9 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
# "oldest" versions tests, only on minimum Python
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
- {
Expand Down
18 changes: 9 additions & 9 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
# "oldest" versions tests, only on minimum Python
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
- {
Expand Down
3 changes: 2 additions & 1 deletion examples/fabric/tensor_parallel/train.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import lightning as L
import torch
import torch.nn.functional as F
from data import RandomTokenDataset
from lightning.fabric.strategies import ModelParallelStrategy
from model import ModelArgs, Transformer
from parallelism import parallelize
from torch.distributed.tensor.parallel import loss_parallel
from torch.utils.data import DataLoader

from data import RandomTokenDataset


def train():
strategy = ModelParallelStrategy(
Expand Down
3 changes: 2 additions & 1 deletion examples/pytorch/tensor_parallel/train.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import lightning as L
import torch
import torch.nn.functional as F
from data import RandomTokenDataset
from lightning.pytorch.strategies import ModelParallelStrategy
from model import ModelArgs, Transformer
from parallelism import parallelize
from torch.distributed.tensor.parallel import loss_parallel
from torch.utils.data import DataLoader

from data import RandomTokenDataset


class Llama3(L.LightningModule):
def __init__(self):
Expand Down
14 changes: 14 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@

_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7")


def patch_jsonargparse_python_3_12_8() -> None:
if sys.version_info < (3, 12, 8):
return

def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
return namespace, args

setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)


if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
from jsonargparse import (
Expand All @@ -48,6 +60,8 @@
set_config_read_mode,
)

patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
set_config_read_mode(fsspec_enabled=True)
else:
Expand Down
3 changes: 3 additions & 0 deletions tests/parity_fabric/test_parity_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,8 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float

if __name__ == "__main__":
from jsonargparse import CLI
from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8

patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

CLI(run_parity_test)
2 changes: 1 addition & 1 deletion tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
import pytest
import torch
import yaml
from jsonargparse import ArgumentParser
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningArgumentParser as ArgumentParser
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down

0 comments on commit 38971a0

Please sign in to comment.