From d7f00be585846b3944d7d9bdd26f9d115bd5d351 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 9 Aug 2021 20:34:12 +0200 Subject: [PATCH 01/77] add registries --- pytorch_lightning/utilities/cli.py | 116 +++++++++++++++++- pytorch_lightning/utilities/cli_registries.py | 92 ++++++++++++++ tests/utilities/test_cli.py | 57 +++++++++ 3 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 pytorch_lightning/utilities/cli_registries.py diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 752d0092d1baf..6b215ff05fe5f 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import json import os +import sys from argparse import Namespace +from contextlib import contextmanager from types import MethodType from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union +from unittest import mock from torch.optim import Optimizer from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, warnings +from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES, OPTIMIZER_REGISTRIES, SCHEDULER_REGISTRIES from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -332,6 +337,16 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def link_optimizers_and_lr_schedulers(self) -> None: """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" + if any( + True for v in sys.argv for optim_name in OPTIMIZER_REGISTRIES.keys() if f"--optimizer={optim_name}" in v + ): + optimizers = tuple(v for v in OPTIMIZER_REGISTRIES.values()) + self.parser.add_optimizer_args(optimizers) + + if any(True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES.keys() if f"-lr_scheduler={sch_name}" in v): + lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values()) + self.parser.add_lr_scheduler_args(lr_schdulers) + for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items(): if link_to == "AUTOMATIC": continue @@ -341,9 +356,108 @@ def link_optimizers_and_lr_schedulers(self) -> None: add_class_path = _add_class_path_generator(class_type) self.parser.link_arguments(key, link_to, compute_fn=add_class_path) + @contextmanager + def prepare_optimizer(self): + """ + This context manager is used to simplify optimizer instantiation for Lightning users. + """ + optimizer_args = [v for v in sys.argv if v.startswith("--optimizer")] + should_replace = len(optimizer_args) > 0 and not any(v for v in optimizer_args if "class_path" in v) + if should_replace: + optimizer_arg = {} + init_args = {} + for v in optimizer_args: + if "optimizer." in v: + arg_path, value = v.split("=") + init_args[arg_path.split(".")[-1]] = value + else: + class_name = v.split("=")[-1] + optim_cls = OPTIMIZER_REGISTRIES.get(class_name) + optimizer_arg["class_path"] = optim_cls.__module__ + "." + class_name + optimizer_arg["init_args"] = init_args + argv = [v for v in sys.argv if not v.startswith("--optimizer")] + [ + f"--optimizer={json.dumps(optimizer_arg)}" + ] + with mock.patch("sys.argv", argv): + yield + else: + yield + + @contextmanager + def prepare_callbacks(self): + """ + This context manager is used to simplify callbacks instantiation for Lightning users. + """ + all_callbacks_args = [v for v in sys.argv if v.startswith("--trainer.callbacks")] + callbacks_args = [v for v in sys.argv if v.startswith("--trainer.callbacks=")] + num_callbacks = len(callbacks_args) + should_replace = len(all_callbacks_args) > 0 and not any(v for v in all_callbacks_args if "class_path" in v) + if should_replace: + # FIXME: Add support for combining callbacks. + callbacks_argv = {} + init_args = {} + map_callback_args = {idx: [] for idx in range(num_callbacks)} + counter = -1 + callback_out = [] + for v in all_callbacks_args: + if "--trainer.callbacks=" in v: + counter += 1 + map_callback_args[counter].append(v) + callback_out = [] + for callback_idx in range(num_callbacks): + callback_args = map_callback_args[callback_idx] + callbacks_argv = {} + init_args = {} + for callback_arg in callback_args: + if "--trainer.callbacks=" in callback_arg: + class_name = callback_arg.split("=")[-1] + callback_cls = CALLBACK_REGISTRIES.get(class_name) + callbacks_argv["class_path"] = callback_cls.__module__ + "." + class_name + else: + arg_path, value = callback_arg.split("=") + init_args[arg_path.split(".")[-1]] = value + callbacks_argv["init_args"] = init_args + callback_out.append(callbacks_argv) + argv = [v for v in sys.argv if not v.startswith("--trainer.callbacks")] + [ + f"--trainer.callbacks={json.dumps(callback_out)}" + ] + with mock.patch("sys.argv", argv): + yield + else: + yield + + @contextmanager + def prepare_schedulers(self): + """ + This context manager is used to simplify schedulers instantiation for Lightning users. + """ + lr_scheduler_args = [v for v in sys.argv if v.startswith("--lr_scheduler")] + should_replace = len(lr_scheduler_args) > 0 and not any(v for v in lr_scheduler_args if "class_path" in v) + if should_replace: + lr_scheduler_arg = {} + init_args = {} + for v in lr_scheduler_args: + if "lr_scheduler." in v: + arg_path, value = v.split("=") + init_args[arg_path.split(".")[-1]] = value + else: + class_name = v.split("=")[-1] + optim_cls = SCHEDULER_REGISTRIES.get(class_name) + lr_scheduler_arg["class_path"] = optim_cls.__module__ + "." + class_name + lr_scheduler_arg["init_args"] = init_args + argv = [v for v in sys.argv if not v.startswith("--lr_scheduler")] + [ + f"--lr_scheduler={json.dumps(lr_scheduler_arg)}" + ] + breakpoint() + with mock.patch("sys.argv", argv): + yield + else: + yield + def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" - self.config = parser.parse_args() + with self.prepare_optimizer(), self.prepare_callbacks(), self.prepare_schedulers(): + self.config = parser.parse_args() def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes.""" diff --git a/pytorch_lightning/utilities/cli_registries.py b/pytorch_lightning/utilities/cli_registries.py new file mode 100644 index 0000000000000..f778d45973910 --- /dev/null +++ b/pytorch_lightning/utilities/cli_registries.py @@ -0,0 +1,92 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from collections import UserDict +from typing import Any, Callable, List, Optional, Type + +import torch + +import pytorch_lightning as pl +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class Registry(UserDict): + def __call__( + self, + cls: Optional[Type] = None, + key: Optional[str] = None, + override: bool = False, + ) -> Callable: + """ + Registers a plugin mapped to a name and with required metadata. + + Args: + key : the name that identifies a plugin, e.g. "deepspeed_stage_3" + value : plugin class + """ + if key is None: + key = cls.__name__ + elif not isinstance(key, str): + raise TypeError(f"`key` must be a str, found {key}") + + if key in self and not override: + raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") + + def do_register(key, cls) -> Callable: + self[key] = cls + return cls + + do_register(key, cls) + + return do_register + + def register_package(self, module, base_cls: Type) -> None: + for obj_name in dir(module): + obj_cls = getattr(module, obj_name) + if inspect.isclass(obj_cls) and issubclass(obj_cls, base_cls): + self(cls=obj_cls) + + def get(self, name: Optional[str], default: Optional[Any] = None) -> Any: + """ + Calls the registered plugin with the required parameters + and returns the plugin object + + Args: + name (str): the name + """ + if name in self: + return self[name] + else: + raise KeyError + + def remove(self, name: str) -> None: + """Removes the registered plugin by name""" + self.pop(name) + + def available_objects(self) -> List: + """Returns a list of registered plugins""" + return list(self.keys()) + + def __str__(self) -> str: + return "Registered Plugins: {}".format(", ".join(self.keys())) + + +CALLBACK_REGISTRIES = Registry() +CALLBACK_REGISTRIES.register_package(pl.callbacks, pl.callbacks.Callback) + +OPTIMIZER_REGISTRIES = Registry() +OPTIMIZER_REGISTRIES.register_package(torch.optim, torch.optim.Optimizer) + +SCHEDULER_REGISTRIES = Registry() +SCHEDULER_REGISTRIES.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index cc636aa9a17ed..4f940dedbc059 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -687,3 +687,60 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): assert isinstance(cli.model.optim1, torch.optim.Adam) assert isinstance(cli.model.optim2, torch.optim.SGD) assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) + + +def test_registries(tmpdir): + + from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES + + @CALLBACK_REGISTRIES + class CustomCallback(Callback): + pass + + assert CALLBACK_REGISTRIES.available_objects() == [ + "BackboneFinetuning", + "BaseFinetuning", + "BasePredictionWriter", + "Callback", + "EarlyStopping", + "GPUStatsMonitor", + "GradientAccumulationScheduler", + "LambdaCallback", + "LearningRateMonitor", + "ModelCheckpoint", + "ModelPruning", + "ProgressBar", + "ProgressBarBase", + "QuantizationAwareTraining", + "StochasticWeightAveraging", + "Timer", + "XLAStatsMonitor", + "CustomCallback", + ] + + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + pass + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--optimizer=Adam", + "--optimizer.lr=0.0001", + "--trainer.callbacks=LearningRateMonitor", + "--trainer.callbacks.logging_interval=epoch", + "--trainer.callbacks.log_momentum=True", + "--trainer.callbacks=ModelCheckpoint", + "--trainer.callbacks.monitor=loss", + "--lr_scheduler=StepLR", + "--lr_scheduler.step_size=50", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(TestModel) + + assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) From a07d3050d38344e5147faee179034e8bdca92fbf Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 9 Aug 2021 20:35:47 +0200 Subject: [PATCH 02/77] simplify LightningCLI with defaults --- pytorch_lightning/utilities/cli.py | 7 +++---- pytorch_lightning/utilities/cli_registries.py | 15 +-------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6b215ff05fe5f..92007e519f6cb 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -372,7 +372,7 @@ def prepare_optimizer(self): init_args[arg_path.split(".")[-1]] = value else: class_name = v.split("=")[-1] - optim_cls = OPTIMIZER_REGISTRIES.get(class_name) + optim_cls = OPTIMIZER_REGISTRIES[class_name] optimizer_arg["class_path"] = optim_cls.__module__ + "." + class_name optimizer_arg["init_args"] = init_args argv = [v for v in sys.argv if not v.startswith("--optimizer")] + [ @@ -411,7 +411,7 @@ def prepare_callbacks(self): for callback_arg in callback_args: if "--trainer.callbacks=" in callback_arg: class_name = callback_arg.split("=")[-1] - callback_cls = CALLBACK_REGISTRIES.get(class_name) + callback_cls = CALLBACK_REGISTRIES[class_name] callbacks_argv["class_path"] = callback_cls.__module__ + "." + class_name else: arg_path, value = callback_arg.split("=") @@ -442,13 +442,12 @@ def prepare_schedulers(self): init_args[arg_path.split(".")[-1]] = value else: class_name = v.split("=")[-1] - optim_cls = SCHEDULER_REGISTRIES.get(class_name) + optim_cls = SCHEDULER_REGISTRIES[class_name] lr_scheduler_arg["class_path"] = optim_cls.__module__ + "." + class_name lr_scheduler_arg["init_args"] = init_args argv = [v for v in sys.argv if not v.startswith("--lr_scheduler")] + [ f"--lr_scheduler={json.dumps(lr_scheduler_arg)}" ] - breakpoint() with mock.patch("sys.argv", argv): yield else: diff --git a/pytorch_lightning/utilities/cli_registries.py b/pytorch_lightning/utilities/cli_registries.py index f778d45973910..94b513f75957e 100644 --- a/pytorch_lightning/utilities/cli_registries.py +++ b/pytorch_lightning/utilities/cli_registries.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect from collections import UserDict -from typing import Any, Callable, List, Optional, Type +from typing import Callable, List, Optional, Type import torch @@ -57,19 +57,6 @@ def register_package(self, module, base_cls: Type) -> None: if inspect.isclass(obj_cls) and issubclass(obj_cls, base_cls): self(cls=obj_cls) - def get(self, name: Optional[str], default: Optional[Any] = None) -> Any: - """ - Calls the registered plugin with the required parameters - and returns the plugin object - - Args: - name (str): the name - """ - if name in self: - return self[name] - else: - raise KeyError - def remove(self, name: str) -> None: """Removes the registered plugin by name""" self.pop(name) From ce39c47b17bcdf4b6bd14b5dfd9496f79c0ccedd Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 9 Aug 2021 20:46:48 +0200 Subject: [PATCH 03/77] cleanup --- tests/utilities/test_cli.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 4f940dedbc059..6cc47b95e8486 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -33,6 +33,7 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback +from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -689,13 +690,12 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) -def test_registries(tmpdir): +@CALLBACK_REGISTRIES +class CustomCallback(Callback): + pass - from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES - @CALLBACK_REGISTRIES - class CustomCallback(Callback): - pass +def test_registries(tmpdir): assert CALLBACK_REGISTRIES.available_objects() == [ "BackboneFinetuning", From 3081475c44370df6bf3da4767432b158483b8c38 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 9 Aug 2021 20:52:37 +0200 Subject: [PATCH 04/77] update --- tests/utilities/test_cli.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 6cc47b95e8486..d78ea919d4c75 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -33,7 +33,7 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback -from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES +from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES, OPTIMIZER_REGISTRIES from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -695,6 +695,11 @@ class CustomCallback(Callback): pass +@OPTIMIZER_REGISTRIES +class MyAdamVariant(torch.optim.Adam): + pass + + def test_registries(tmpdir): assert CALLBACK_REGISTRIES.available_objects() == [ @@ -729,7 +734,7 @@ def __init__(self): cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", - "--optimizer=Adam", + "--optimizer=MyAdamVariant", "--optimizer.lr=0.0001", "--trainer.callbacks=LearningRateMonitor", "--trainer.callbacks.logging_interval=epoch", From 51f82d585dfc6faf64284bb67413f203fda27e94 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 12:01:41 +0200 Subject: [PATCH 05/77] updates --- pytorch_lightning/utilities/cli.py | 184 ++++++++++++++++++++--------- tests/utilities/test_cli.py | 87 +++++++++++--- 2 files changed, 194 insertions(+), 77 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 92007e519f6cb..501b59baac56e 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -17,10 +17,13 @@ import sys from argparse import Namespace from contextlib import contextmanager +from functools import partial from types import MethodType from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union from unittest import mock +import torch +from attr import dataclass from torch.optim import Optimizer from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer @@ -28,7 +31,6 @@ from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES, OPTIMIZER_REGISTRIES, SCHEDULER_REGISTRIES from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple if _JSONARGPARSE_AVAILABLE: @@ -294,6 +296,14 @@ def __init__( self.fit() self.after_fit() + @property + def optimizer_registered(self) -> Tuple[Type[Optimizer]]: + return tuple(o for o in OPTIMIZER_REGISTRIES.values()) + + @property + def lr_scheduler_registered(self) -> Tuple[Type[torch.optim.lr_scheduler._LRScheduler]]: + return tuple(o for o in SCHEDULER_REGISTRIES.values()) + def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" return LightningArgumentParser(**kwargs) @@ -340,8 +350,7 @@ def link_optimizers_and_lr_schedulers(self) -> None: if any( True for v in sys.argv for optim_name in OPTIMIZER_REGISTRIES.keys() if f"--optimizer={optim_name}" in v ): - optimizers = tuple(v for v in OPTIMIZER_REGISTRIES.values()) - self.parser.add_optimizer_args(optimizers) + self.parser.add_optimizer_args(self.optimizer_registered) if any(True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES.keys() if f"-lr_scheduler={sch_name}" in v): lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values()) @@ -357,26 +366,52 @@ def link_optimizers_and_lr_schedulers(self) -> None: self.parser.link_arguments(key, link_to, compute_fn=add_class_path) @contextmanager - def prepare_optimizer(self): + def prepare_optimizers(self): """ This context manager is used to simplify optimizer instantiation for Lightning users. """ - optimizer_args = [v for v in sys.argv if v.startswith("--optimizer")] - should_replace = len(optimizer_args) > 0 and not any(v for v in optimizer_args if "class_path" in v) - if should_replace: - optimizer_arg = {} - init_args = {} - for v in optimizer_args: - if "optimizer." in v: - arg_path, value = v.split("=") + + @dataclass + class OptimizerInfo: + optimizer_cls_arg: str + optim_cls: str + optimizer_init_args: List = [] + + def add_optimizer_init_args(self, args: Dict[str, str]) -> None: + if args != self.optimizer_cls_arg: + self.optimizer_init_args.append(args) + + @property + def optimizer_init(self) -> Dict[str, str]: + optimizer_init = {} + optimizer_init["class_path"] = self.optim_cls.__module__ + "." + self.optim_cls.__name__ + init_args = {} + for init_arg in self.optimizer_init_args: + arg_path, value = init_arg.split("=") init_args[arg_path.split(".")[-1]] = value - else: - class_name = v.split("=")[-1] - optim_cls = OPTIMIZER_REGISTRIES[class_name] - optimizer_arg["class_path"] = optim_cls.__module__ + "." + class_name - optimizer_arg["init_args"] = init_args - argv = [v for v in sys.argv if not v.startswith("--optimizer")] + [ - f"--optimizer={json.dumps(optimizer_arg)}" + optimizer_init["init_args"] = init_args + return optimizer_init + + map_arg_path = {} + for optim_name, optim_cls in OPTIMIZER_REGISTRIES.items(): + for v in sys.argv: + if f"={optim_name}" in v: + key = v.split("=")[0] + map_arg_path[key] = OptimizerInfo(optimizer_cls_arg=v, optim_cls=optim_cls) + argv = [] + for v in sys.argv: + skip = False + for key in map_arg_path: + if key in v: + skip = True + map_arg_path[key].add_optimizer_init_args(v) + if not skip: + argv.append(v) + + if len(map_arg_path) > 0: + argv += [ + f"{optimizer_key}={optimizer_args.optimizer_init}" + for optimizer_key, optimizer_args in map_arg_path.items() ] with mock.patch("sys.argv", argv): yield @@ -388,21 +423,28 @@ def prepare_callbacks(self): """ This context manager is used to simplify callbacks instantiation for Lightning users. """ - all_callbacks_args = [v for v in sys.argv if v.startswith("--trainer.callbacks")] - callbacks_args = [v for v in sys.argv if v.startswith("--trainer.callbacks=")] - num_callbacks = len(callbacks_args) - should_replace = len(all_callbacks_args) > 0 and not any(v for v in all_callbacks_args if "class_path" in v) + all_callbacks_args = [ + v for v in sys.argv if v.startswith("--trainer.callbacks") and not v.startswith("--trainer.callbacks=[") + ] + simple_callbacks_args = [ + v for v in sys.argv if v.startswith("--trainer.callbacks=") and not v.startswith("--trainer.callbacks=[") + ] + class_path_callbacks = [ + v for v in sys.argv if v.startswith("--trainer.callbacks=") and v.startswith("--trainer.callbacks=[") + ] + num_callbacks = len(simple_callbacks_args) + should_replace = len(all_callbacks_args) > 0 and not all("class_path" in v for v in all_callbacks_args) if should_replace: - # FIXME: Add support for combining callbacks. - callbacks_argv = {} - init_args = {} + if len(class_path_callbacks) > 1: + raise MisconfigurationException("When provided callbacks as list, please group them under 1 argument.") + # group arguments per callbacks map_callback_args = {idx: [] for idx in range(num_callbacks)} counter = -1 - callback_out = [] for v in all_callbacks_args: if "--trainer.callbacks=" in v: counter += 1 map_callback_args[counter].append(v) + # re-compose the grouped command line callback_out = [] for callback_idx in range(num_callbacks): callback_args = map_callback_args[callback_idx] @@ -418,6 +460,11 @@ def prepare_callbacks(self): init_args[arg_path.split(".")[-1]] = value callbacks_argv["init_args"] = init_args callback_out.append(callbacks_argv) + + # add other callback arguments. + callback_out.extend(eval(class_path_callbacks[0].split("=")[-1])) + + # compose the command line argv = [v for v in sys.argv if not v.startswith("--trainer.callbacks")] + [ f"--trainer.callbacks={json.dumps(callback_out)}" ] @@ -432,7 +479,10 @@ def prepare_schedulers(self): This context manager is used to simplify schedulers instantiation for Lightning users. """ lr_scheduler_args = [v for v in sys.argv if v.startswith("--lr_scheduler")] - should_replace = len(lr_scheduler_args) > 0 and not any(v for v in lr_scheduler_args if "class_path" in v) + lr_scheduler_class_args = [v for v in sys.argv if v.startswith("--lr_scheduler=")] + should_replace = len(lr_scheduler_class_args) > 0 and not any( + "class_path" in v for v in lr_scheduler_class_args + ) if should_replace: lr_scheduler_arg = {} init_args = {} @@ -455,7 +505,7 @@ def prepare_schedulers(self): def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" - with self.prepare_optimizer(), self.prepare_callbacks(), self.prepare_schedulers(): + with self.prepare_optimizers(), self.prepare_callbacks(), self.prepare_schedulers(): self.config = parser.parse_args() def before_instantiate_classes(self) -> None: @@ -508,42 +558,60 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: if len(optimizers) == 0: return + optimizer_inits = {} + for optimizer in optimizers: + optimizer_class = self.parser.optimizers_and_lr_schedulers[optimizer][0] + optimizer_init = self.config_init.get(optimizer, {}) + if not isinstance(optimizer_class, tuple): + optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) + optimizer_inits[optimizer] = optimizer_init + + lr_scheduler_inits = {} + lr_scheduler_init = None + if lr_schedulers: + for scheduler in lr_schedulers: + lr_scheduler_class = self.parser.optimizers_and_lr_schedulers[scheduler][0] + lr_scheduler_init = self.config_init.get(scheduler, {}) + if not isinstance(lr_scheduler_class, tuple): + lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) + lr_scheduler_inits[scheduler] = lr_scheduler_init + if len(optimizers) > 1 or len(lr_schedulers) > 1: - raise MisconfigurationException( - f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " - f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " - "is expected to link the argument groups and implement `configure_optimizers`, see " - "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" - "#optimizers-and-learning-rate-schedulers" - ) + configure_optimizers_params = inspect.signature(self.model.configure_optimizers).parameters + if len(configure_optimizers_params) > 1: + expected_params = set(optimizers + lr_schedulers) + if expected_params.difference(configure_optimizers_params): + raise MisconfigurationException( + f"The model ``configure_optimizers`` should expose optional arguments: {expected_params}" + ) + configure_optimizers = partial(self.model.configure_optimizers, **optimizer_inits, **lr_scheduler_inits) + configure_optimizers.__code__ = self.model.configure_optimizers.__code__ + self.model.configure_optimizers = configure_optimizers - if is_overridden("configure_optimizers", self.model): + else: warnings._warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." ) - optimizer_class = self.parser.optimizers_and_lr_schedulers[optimizers[0]][0] - optimizer_init = self.config_init.get(optimizers[0], {}) - if not isinstance(optimizer_class, tuple): - optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) - lr_scheduler_init = None - if lr_schedulers: - lr_scheduler_class = self.parser.optimizers_and_lr_schedulers[lr_schedulers[0]][0] - lr_scheduler_init = self.config_init.get(lr_schedulers[0], {}) - if not isinstance(lr_scheduler_class, tuple): - lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) - - def configure_optimizers( - self: LightningModule, - ) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]: - optimizer = instantiate_class(self.parameters(), optimizer_init) - if not lr_scheduler_init: - return optimizer - lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) - return [optimizer], [lr_scheduler] - - self.model.configure_optimizers = MethodType(configure_optimizers, self.model) + configure_optimizers = partial( + self.configure_optimizers, optimizer_init=optimizer_init, lr_scheduler_init=lr_scheduler_init + ) + configure_optimizers.__code__ = self.model.configure_optimizers.__code__ + + self.model.configure_optimizers = MethodType(configure_optimizers, self.model) + + @staticmethod + def configure_optimizers( + pl_module: LightningModule, + optimizer_init: Union[str, List[str]], + lr_scheduler_init: Optional[Union[str, List[str]]] = None, + ) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]: + optimizer = instantiate_class(pl_module.parameters(), optimizer_init) + if not lr_scheduler_init: + return optimizer + lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) + return [optimizer], [lr_scheduler] def prepare_fit_kwargs(self) -> None: """Prepares fit_kwargs including datamodule using self.config_init['data'] if given""" diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index d78ea919d4c75..efb1d9700473c 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -27,13 +27,14 @@ import torch import yaml from packaging import version +from torch import nn from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback -from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES, OPTIMIZER_REGISTRIES +from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -690,16 +691,6 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) -@CALLBACK_REGISTRIES -class CustomCallback(Callback): - pass - - -@OPTIMIZER_REGISTRIES -class MyAdamVariant(torch.optim.Adam): - pass - - def test_registries(tmpdir): assert CALLBACK_REGISTRIES.available_objects() == [ @@ -720,32 +711,90 @@ def test_registries(tmpdir): "StochasticWeightAveraging", "Timer", "XLAStatsMonitor", - "CustomCallback", ] - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - pass - class TestModel(BoringModel): def __init__(self): super().__init__() + callbacks = [ + dict( + class_path="pytorch_lightning.callbacks.Callback", + init_args=dict(), + ), + dict( + class_path="pytorch_lightning.callbacks.Callback", + init_args=dict(), + ), + ] + cli_args = [ f"--trainer.default_root_dir={tmpdir}", - "--trainer.max_epochs=1", - "--optimizer=MyAdamVariant", + "--trainer.fast_dev_run=True", + "--trainer.progress_bar_refresh_rate=0", + "--optimizer=Adam", "--optimizer.lr=0.0001", "--trainer.callbacks=LearningRateMonitor", "--trainer.callbacks.logging_interval=epoch", "--trainer.callbacks.log_momentum=True", "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=loss", + f"--trainer.callbacks={json.dumps(callbacks)}", "--lr_scheduler=StepLR", "--lr_scheduler.step_size=50", ] with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = MyLightningCLI(TestModel) + cli = LightningCLI(TestModel) assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) + assert len(cli.trainer.callbacks) == 4 + + +@pytest.mark.parametrize("use_scheduler", [False, True]) +def test_configure_optimizers(use_scheduler, tmpdir): + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args(self.optimizer_registered, nested_key="optim1") + parser.add_optimizer_args(self.optimizer_registered, nested_key="optim2") + if use_scheduler: + parser.add_lr_scheduler_args(self.lr_scheduler_registered) + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 2)) + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def configure_optimizers(self, optim1: dict = None, optim2: dict = None, lr_scheduler: dict = None): + optim1 = instantiate_class(self.layer[:2].parameters(), optim1) + optim2 = instantiate_class(self.layer[2:].parameters(), optim2) + if lr_scheduler: + scheduler = instantiate_class(optim1, lr_scheduler) + return [optim1, optim2], [scheduler] + return [optim1, optim2] + + training_epoch_end = None + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--optim1.lr=0.01", + "--optim1=Adam", + "--optim2={'class_path': 'torch.optim.SGD', 'init_args': {'lr': '0.01'}}", + ] + if use_scheduler: + lr_scheduler_arg = dict(class_path="torch.optim.lr_scheduler.StepLR", init_args=dict(step_size=50)) + cli_args += [ + f"--lr_scheduler={json.dumps(lr_scheduler_arg)}", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(TestModel) + + assert isinstance(cli.model.optimizers(use_pl_optimizer=False)[0], torch.optim.Adam) + assert isinstance(cli.model.optimizers(use_pl_optimizer=False)[1], torch.optim.SGD) + if use_scheduler: + assert isinstance(cli.model.lr_schedulers(), torch.optim.lr_scheduler.StepLR) From 7197d6ee55ee07b47388a6b763a74a67a9d2bca9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 13:29:48 +0200 Subject: [PATCH 06/77] cleanup --- pytorch_lightning/utilities/cli.py | 265 ++++++++++++----------------- tests/utilities/test_cli.py | 50 ------ 2 files changed, 111 insertions(+), 204 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 501b59baac56e..97d7677aa4ef1 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -12,25 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import json import os import sys from argparse import Namespace from contextlib import contextmanager +from dataclasses import dataclass, field from functools import partial from types import MethodType from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union from unittest import mock import torch -from attr import dataclass from torch.optim import Optimizer from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, warnings -from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES, OPTIMIZER_REGISTRIES, SCHEDULER_REGISTRIES +from pytorch_lightning.utilities.cli_registries import ( + CALLBACK_REGISTRIES, + OPTIMIZER_REGISTRIES, + Registry, + SCHEDULER_REGISTRIES, +) from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple if _JSONARGPARSE_AVAILABLE: @@ -41,6 +46,30 @@ ArgumentParser = object +@dataclass +class ClassInfo: + """This class is an helper to easily build the mocked command line""" + + class_arg: str + cls: str + class_init_args: List[str] = field(default_factory=lambda: []) + + def add_class_init_args(self, args: Dict[str, str]) -> None: + if args != self.class_arg: + self.class_init_args.append(args) + + @property + def class_init(self) -> Dict[str, str]: + class_init = {} + class_init["class_path"] = self.cls.__module__ + "." + self.cls.__name__ + init_args = {} + for init_arg in self.class_init_args: + arg_path, value = init_arg.split("=") + init_args[arg_path.split(".")[-1]] = value + class_init["init_args"] = init_args + return class_init + + class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" @@ -366,138 +395,75 @@ def link_optimizers_and_lr_schedulers(self) -> None: self.parser.link_arguments(key, link_to, compute_fn=add_class_path) @contextmanager - def prepare_optimizers(self): + def prepare_from_registry(self, registry: Registry): """ - This context manager is used to simplify optimizer instantiation for Lightning users. + This context manager is used to simplify unique class instantiation. """ - @dataclass - class OptimizerInfo: - optimizer_cls_arg: str - optim_cls: str - optimizer_init_args: List = [] - - def add_optimizer_init_args(self, args: Dict[str, str]) -> None: - if args != self.optimizer_cls_arg: - self.optimizer_init_args.append(args) - - @property - def optimizer_init(self) -> Dict[str, str]: - optimizer_init = {} - optimizer_init["class_path"] = self.optim_cls.__module__ + "." + self.optim_cls.__name__ - init_args = {} - for init_arg in self.optimizer_init_args: - arg_path, value = init_arg.split("=") - init_args[arg_path.split(".")[-1]] = value - optimizer_init["init_args"] = init_args - return optimizer_init - - map_arg_path = {} - for optim_name, optim_cls in OPTIMIZER_REGISTRIES.items(): + # find if the users is using shortcut command line. + map_user_key_to_info = {} + for registered_name, registered_cls in registry.items(): for v in sys.argv: - if f"={optim_name}" in v: + if f"={registered_name}" in v: key = v.split("=")[0] - map_arg_path[key] = OptimizerInfo(optimizer_cls_arg=v, optim_cls=optim_cls) - argv = [] - for v in sys.argv: - skip = False - for key in map_arg_path: - if key in v: - skip = True - map_arg_path[key].add_optimizer_init_args(v) - if not skip: - argv.append(v) - - if len(map_arg_path) > 0: - argv += [ - f"{optimizer_key}={optimizer_args.optimizer_init}" - for optimizer_key, optimizer_args in map_arg_path.items() - ] + map_user_key_to_info[key] = ClassInfo(class_arg=v, cls=registered_cls) + + if len(map_user_key_to_info) > 0: + # for each shortcut command line, add its init arguments and skip them from `sys.argv`. + argv = [] + for v in sys.argv: + skip = False + for key in map_user_key_to_info: + if key in v: + skip = True + map_user_key_to_info[key].add_class_init_args(v) + if not skip: + argv.append(v) + + # re-create the global command line and mock `sys.argv`. + argv += [f"{user_key}={info.class_init}" for user_key, info in map_user_key_to_info.items()] with mock.patch("sys.argv", argv): yield else: yield @contextmanager - def prepare_callbacks(self): + def prepare_class_list_from_registry(self, pattern: str, registry: Registry): """ - This context manager is used to simplify callbacks instantiation for Lightning users. + This context manager is used to simplify instantiation of a list of class. """ - all_callbacks_args = [ - v for v in sys.argv if v.startswith("--trainer.callbacks") and not v.startswith("--trainer.callbacks=[") - ] - simple_callbacks_args = [ - v for v in sys.argv if v.startswith("--trainer.callbacks=") and not v.startswith("--trainer.callbacks=[") - ] - class_path_callbacks = [ - v for v in sys.argv if v.startswith("--trainer.callbacks=") and v.startswith("--trainer.callbacks=[") - ] - num_callbacks = len(simple_callbacks_args) - should_replace = len(all_callbacks_args) > 0 and not all("class_path" in v for v in all_callbacks_args) + argv = [v for v in sys.argv if pattern not in v] + all_matched_args = [v for v in sys.argv if pattern in v] + all_simplified_args = [v for v in all_matched_args if f"{pattern}" in v and f"{pattern}=[" not in v] + all_cls_simplified_args = [v for v in all_simplified_args if f"{pattern}=" in v] + all_non_simplified_args = [v for v in all_matched_args if f"{pattern}=" in v and f"{pattern}=[" in v] + + num_simplified_cls = len(all_simplified_args) + should_replace = num_simplified_cls > 0 and not all("class_path" in v for v in all_matched_args) + if should_replace: - if len(class_path_callbacks) > 1: - raise MisconfigurationException("When provided callbacks as list, please group them under 1 argument.") - # group arguments per callbacks - map_callback_args = {idx: [] for idx in range(num_callbacks)} - counter = -1 - for v in all_callbacks_args: - if "--trainer.callbacks=" in v: - counter += 1 - map_callback_args[counter].append(v) - # re-compose the grouped command line - callback_out = [] - for callback_idx in range(num_callbacks): - callback_args = map_callback_args[callback_idx] - callbacks_argv = {} - init_args = {} - for callback_arg in callback_args: - if "--trainer.callbacks=" in callback_arg: - class_name = callback_arg.split("=")[-1] - callback_cls = CALLBACK_REGISTRIES[class_name] - callbacks_argv["class_path"] = callback_cls.__module__ + "." + class_name - else: - arg_path, value = callback_arg.split("=") - init_args[arg_path.split(".")[-1]] = value - callbacks_argv["init_args"] = init_args - callback_out.append(callbacks_argv) + # verify the user is properly ordering arguments. + assert all_cls_simplified_args[0] == all_simplified_args[0] + if len(all_non_simplified_args) > 1: + raise MisconfigurationException(f"When provided {pattern} as list, please group them under 1 argument.") + # group arguments per callbacks + infos = [] + for class_arg in all_cls_simplified_args: + class_name = class_arg.split("=")[1] + registered_cls = registry[class_name] + infos.append(ClassInfo(class_arg=class_arg, cls=registered_cls)) + + for v in all_simplified_args: + if v in all_cls_simplified_args: + current_info = infos[all_cls_simplified_args.index(v)] + current_info.add_class_init_args(v) + + class_args = [info.class_init for info in infos] # add other callback arguments. - callback_out.extend(eval(class_path_callbacks[0].split("=")[-1])) - - # compose the command line - argv = [v for v in sys.argv if not v.startswith("--trainer.callbacks")] + [ - f"--trainer.callbacks={json.dumps(callback_out)}" - ] - with mock.patch("sys.argv", argv): - yield - else: - yield + class_args.extend(eval(all_non_simplified_args[0].split("=")[-1])) - @contextmanager - def prepare_schedulers(self): - """ - This context manager is used to simplify schedulers instantiation for Lightning users. - """ - lr_scheduler_args = [v for v in sys.argv if v.startswith("--lr_scheduler")] - lr_scheduler_class_args = [v for v in sys.argv if v.startswith("--lr_scheduler=")] - should_replace = len(lr_scheduler_class_args) > 0 and not any( - "class_path" in v for v in lr_scheduler_class_args - ) - if should_replace: - lr_scheduler_arg = {} - init_args = {} - for v in lr_scheduler_args: - if "lr_scheduler." in v: - arg_path, value = v.split("=") - init_args[arg_path.split(".")[-1]] = value - else: - class_name = v.split("=")[-1] - optim_cls = SCHEDULER_REGISTRIES[class_name] - lr_scheduler_arg["class_path"] = optim_cls.__module__ + "." + class_name - lr_scheduler_arg["init_args"] = init_args - argv = [v for v in sys.argv if not v.startswith("--lr_scheduler")] + [ - f"--lr_scheduler={json.dumps(lr_scheduler_arg)}" - ] + argv += [f"{pattern}={class_args}"] with mock.patch("sys.argv", argv): yield else: @@ -505,7 +471,9 @@ def prepare_schedulers(self): def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" - with self.prepare_optimizers(), self.prepare_callbacks(), self.prepare_schedulers(): + with self.prepare_from_registry(OPTIMIZER_REGISTRIES), self.prepare_from_registry( + SCHEDULER_REGISTRIES + ), self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRIES): self.config = parser.parse_args() def before_instantiate_classes(self) -> None: @@ -538,7 +506,6 @@ def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) def add_configure_optimizers_method_to_model(self) -> None: """ Adds to the model an automatically generated ``configure_optimizers`` method. - If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a `configure_optimizers` method is automatically implemented in the model class. """ @@ -558,48 +525,38 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: if len(optimizers) == 0: return - optimizer_inits = {} - for optimizer in optimizers: - optimizer_class = self.parser.optimizers_and_lr_schedulers[optimizer][0] - optimizer_init = self.config_init.get(optimizer, {}) - if not isinstance(optimizer_class, tuple): - optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) - optimizer_inits[optimizer] = optimizer_init - - lr_scheduler_inits = {} - lr_scheduler_init = None - if lr_schedulers: - for scheduler in lr_schedulers: - lr_scheduler_class = self.parser.optimizers_and_lr_schedulers[scheduler][0] - lr_scheduler_init = self.config_init.get(scheduler, {}) - if not isinstance(lr_scheduler_class, tuple): - lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) - lr_scheduler_inits[scheduler] = lr_scheduler_init - if len(optimizers) > 1 or len(lr_schedulers) > 1: - configure_optimizers_params = inspect.signature(self.model.configure_optimizers).parameters - if len(configure_optimizers_params) > 1: - expected_params = set(optimizers + lr_schedulers) - if expected_params.difference(configure_optimizers_params): - raise MisconfigurationException( - f"The model ``configure_optimizers`` should expose optional arguments: {expected_params}" - ) - configure_optimizers = partial(self.model.configure_optimizers, **optimizer_inits, **lr_scheduler_inits) - configure_optimizers.__code__ = self.model.configure_optimizers.__code__ - self.model.configure_optimizers = configure_optimizers + raise MisconfigurationException( + f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " + f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " + "is expected to link the argument groups and implement `configure_optimizers`, see " + "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" + "#optimizers-and-learning-rate-schedulers" + ) - else: + if is_overridden("configure_optimizers", self.model): warnings._warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." ) - configure_optimizers = partial( - self.configure_optimizers, optimizer_init=optimizer_init, lr_scheduler_init=lr_scheduler_init - ) - configure_optimizers.__code__ = self.model.configure_optimizers.__code__ + optimizer_class = self.parser.optimizers_and_lr_schedulers[optimizers[0]][0] + optimizer_init = self.config_init.get(optimizers[0], {}) + if not isinstance(optimizer_class, tuple): + optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) + lr_scheduler_init = None + if lr_schedulers: + lr_scheduler_class = self.parser.optimizers_and_lr_schedulers[lr_schedulers[0]][0] + lr_scheduler_init = self.config_init.get(lr_schedulers[0], {}) + if not isinstance(lr_scheduler_class, tuple): + lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) + + configure_optimizers = partial( + self.configure_optimizers, optimizer_init=optimizer_init, lr_scheduler_init=lr_scheduler_init + ) + configure_optimizers.__code__ = self.model.configure_optimizers.__code__ - self.model.configure_optimizers = MethodType(configure_optimizers, self.model) + self.model.configure_optimizers = MethodType(configure_optimizers, self.model) @staticmethod def configure_optimizers( diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index efb1d9700473c..92120740989aa 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -27,7 +27,6 @@ import torch import yaml from packaging import version -from torch import nn from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint @@ -749,52 +748,3 @@ def __init__(self): assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) assert len(cli.trainer.callbacks) == 4 - - -@pytest.mark.parametrize("use_scheduler", [False, True]) -def test_configure_optimizers(use_scheduler, tmpdir): - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(self.optimizer_registered, nested_key="optim1") - parser.add_optimizer_args(self.optimizer_registered, nested_key="optim2") - if use_scheduler: - parser.add_lr_scheduler_args(self.lr_scheduler_registered) - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.layer = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 2)) - - def training_step(self, batch, batch_idx, optimizer_idx): - return super().training_step(batch, batch_idx) - - def configure_optimizers(self, optim1: dict = None, optim2: dict = None, lr_scheduler: dict = None): - optim1 = instantiate_class(self.layer[:2].parameters(), optim1) - optim2 = instantiate_class(self.layer[2:].parameters(), optim2) - if lr_scheduler: - scheduler = instantiate_class(optim1, lr_scheduler) - return [optim1, optim2], [scheduler] - return [optim1, optim2] - - training_epoch_end = None - - cli_args = [ - f"--trainer.default_root_dir={tmpdir}", - "--trainer.max_epochs=1", - "--optim1.lr=0.01", - "--optim1=Adam", - "--optim2={'class_path': 'torch.optim.SGD', 'init_args': {'lr': '0.01'}}", - ] - if use_scheduler: - lr_scheduler_arg = dict(class_path="torch.optim.lr_scheduler.StepLR", init_args=dict(step_size=50)) - cli_args += [ - f"--lr_scheduler={json.dumps(lr_scheduler_arg)}", - ] - - with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = MyLightningCLI(TestModel) - - assert isinstance(cli.model.optimizers(use_pl_optimizer=False)[0], torch.optim.Adam) - assert isinstance(cli.model.optimizers(use_pl_optimizer=False)[1], torch.optim.SGD) - if use_scheduler: - assert isinstance(cli.model.lr_schedulers(), torch.optim.lr_scheduler.StepLR) From 9a6e81e56003c14866ab78ccc16c9edf644f6ba5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 14:01:28 +0200 Subject: [PATCH 07/77] update on comments --- pytorch_lightning/utilities/cli.py | 16 ++++++++++------ tests/utilities/test_cli.py | 14 ++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 97d7677aa4ef1..7ce108523663e 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -379,11 +379,13 @@ def link_optimizers_and_lr_schedulers(self) -> None: if any( True for v in sys.argv for optim_name in OPTIMIZER_REGISTRIES.keys() if f"--optimizer={optim_name}" in v ): - self.parser.add_optimizer_args(self.optimizer_registered) + if "optimizer" not in self.parser.groups: + self.parser.add_optimizer_args(self.optimizer_registered) - if any(True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES.keys() if f"-lr_scheduler={sch_name}" in v): + if any(True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES.keys() if f"--lr_scheduler={sch_name}" in v): lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values()) - self.parser.add_lr_scheduler_args(lr_schdulers) + if "lr_scheduler" not in self.parser.groups: + self.parser.add_lr_scheduler_args(lr_schdulers) for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items(): if link_to == "AUTOMATIC": @@ -471,10 +473,12 @@ def prepare_class_list_from_registry(self, pattern: str, registry: Registry): def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" - with self.prepare_from_registry(OPTIMIZER_REGISTRIES), self.prepare_from_registry( - SCHEDULER_REGISTRIES - ), self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRIES): + # fmt: off + with self.prepare_from_registry(OPTIMIZER_REGISTRIES), \ + self.prepare_from_registry(SCHEDULER_REGISTRIES), \ + self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRIES): self.config = parser.parse_args() + # fmt: on def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes.""" diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 92120740989aa..4fe2f567fba4e 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -663,9 +663,9 @@ def add_arguments_to_parser(self, parser): def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1") - parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") - parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler") + parser.add_optimizer_args(self.optimizer_registered, nested_key="optim1", link_to="model.optim1") + parser.add_optimizer_args(torch.optim.SGD, nested_key="optim2", link_to="model.optim2") + parser.add_lr_scheduler_args(self.lr_scheduler_registered, link_to="model.scheduler") class TestModel(BoringModel): def __init__(self, optim1: dict, optim2: dict, scheduler: dict): @@ -677,9 +677,11 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", - "--optim2.class_path=torch.optim.SGD", - "--optim2.init_args.lr=0.01", - "--lr_scheduler.gamma=0.2", + "--optim1=Adam", + "--optim1.weight_decay=0.001", + "--optim2.lr=0.005", + "--lr_scheduler=ExponentialLR", + "--lr_scheduler.gamma=0.1", ] with mock.patch("sys.argv", ["any.py"] + cli_args): From 41f5d78a7e2056a1d5e7a068e0965772e89db4c8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 14:12:54 +0200 Subject: [PATCH 08/77] update --- pytorch_lightning/utilities/cli.py | 27 ++++++++++++++++++--------- tests/utilities/test_cli.py | 8 ++++---- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 7ce108523663e..9d7cb32f1df05 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -22,7 +22,6 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union from unittest import mock -import torch from torch.optim import Optimizer from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer @@ -64,7 +63,8 @@ def class_init(self) -> Dict[str, str]: class_init["class_path"] = self.cls.__module__ + "." + self.cls.__name__ init_args = {} for init_arg in self.class_init_args: - arg_path, value = init_arg.split("=") + separator = "=" if "=" in init_arg else " " + arg_path, value = init_arg.split(separator) init_args[arg_path.split(".")[-1]] = value class_init["init_args"] = init_args return class_init @@ -327,11 +327,11 @@ def __init__( @property def optimizer_registered(self) -> Tuple[Type[Optimizer]]: - return tuple(o for o in OPTIMIZER_REGISTRIES.values()) + return tuple(OPTIMIZER_REGISTRIES.values()) @property - def lr_scheduler_registered(self) -> Tuple[Type[torch.optim.lr_scheduler._LRScheduler]]: - return tuple(o for o in SCHEDULER_REGISTRIES.values()) + def lr_scheduler_registered(self) -> Tuple[LRSchedulerType]: + return tuple(SCHEDULER_REGISTRIES.values()) def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" @@ -377,12 +377,20 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def link_optimizers_and_lr_schedulers(self) -> None: """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" if any( - True for v in sys.argv for optim_name in OPTIMIZER_REGISTRIES.keys() if f"--optimizer={optim_name}" in v + True + for v in sys.argv + for optim_name in OPTIMIZER_REGISTRIES + if f"--optimizer={optim_name}" in v or f"--optimizer {optim_name}" in v ): if "optimizer" not in self.parser.groups: self.parser.add_optimizer_args(self.optimizer_registered) - if any(True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES.keys() if f"--lr_scheduler={sch_name}" in v): + if any( + True + for v in sys.argv + for sch_name in SCHEDULER_REGISTRIES + if f"--lr_scheduler={sch_name}" in v or f"--lr_scheduler {sch_name}" in v + ): lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values()) if "lr_scheduler" not in self.parser.groups: self.parser.add_lr_scheduler_args(lr_schdulers) @@ -406,8 +414,9 @@ def prepare_from_registry(self, registry: Registry): map_user_key_to_info = {} for registered_name, registered_cls in registry.items(): for v in sys.argv: - if f"={registered_name}" in v: - key = v.split("=")[0] + separator = "=" if "=" in v else " " + if f"{separator}{registered_name}" in v: + key = v.split(separator)[0] map_user_key_to_info[key] = ClassInfo(class_arg=v, cls=registered_cls) if len(map_user_key_to_info) > 0: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 4fe2f567fba4e..f8bb1bc77d966 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -677,8 +677,8 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", - "--optim1=Adam", - "--optim1.weight_decay=0.001", + "--optim1 Adam", + "--optim1.weight_decay 0.001", "--optim2.lr=0.005", "--lr_scheduler=ExponentialLR", "--lr_scheduler.gamma=0.1", @@ -733,8 +733,8 @@ def __init__(self): f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=True", "--trainer.progress_bar_refresh_rate=0", - "--optimizer=Adam", - "--optimizer.lr=0.0001", + "--optimizer Adam", + "--optimizer.lr 0.0001", "--trainer.callbacks=LearningRateMonitor", "--trainer.callbacks.logging_interval=epoch", "--trainer.callbacks.log_momentum=True", From 06e49999fca341db761644d73ef2f80a14a0321b Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 14:17:22 +0200 Subject: [PATCH 09/77] cleanup --- pytorch_lightning/utilities/cli.py | 5 +++-- tests/utilities/test_cli.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 9d7cb32f1df05..6efeed044e3e2 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect import os +import re import sys from argparse import Namespace from contextlib import contextmanager @@ -380,7 +381,7 @@ def link_optimizers_and_lr_schedulers(self) -> None: True for v in sys.argv for optim_name in OPTIMIZER_REGISTRIES - if f"--optimizer={optim_name}" in v or f"--optimizer {optim_name}" in v + if re.match(fr"^--optimizer[^\S+=]*?{optim_name}?", v) ): if "optimizer" not in self.parser.groups: self.parser.add_optimizer_args(self.optimizer_registered) @@ -389,7 +390,7 @@ def link_optimizers_and_lr_schedulers(self) -> None: True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES - if f"--lr_scheduler={sch_name}" in v or f"--lr_scheduler {sch_name}" in v + if re.match(fr"^--lr_scheduler[^\S+=]*{sch_name}?", v) ): lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values()) if "lr_scheduler" not in self.parser.groups: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index f8bb1bc77d966..11b870cd66d28 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -741,7 +741,7 @@ def __init__(self): "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=loss", f"--trainer.callbacks={json.dumps(callbacks)}", - "--lr_scheduler=StepLR", + "--lr_scheduler StepLR", "--lr_scheduler.step_size=50", ] From e91ea47934cbf9bc1247f690a7edf3c6fe9aa17f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 15:48:14 +0200 Subject: [PATCH 10/77] update on comments --- pytorch_lightning/utilities/cli.py | 66 +++++++++-- pytorch_lightning/utilities/cli_registries.py | 79 ------------- tests/utilities/test_cli.py | 105 ++++++++++++++++-- 3 files changed, 152 insertions(+), 98 deletions(-) delete mode 100644 pytorch_lightning/utilities/cli_registries.py diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6efeed044e3e2..27accc5277cbc 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -16,23 +16,20 @@ import re import sys from argparse import Namespace +from collections import UserDict from contextlib import contextmanager from dataclasses import dataclass, field from functools import partial -from types import MethodType +from types import MethodType, ModuleType from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union from unittest import mock +import torch from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, warnings -from pytorch_lightning.utilities.cli_registries import ( - CALLBACK_REGISTRIES, - OPTIMIZER_REGISTRIES, - Registry, - SCHEDULER_REGISTRIES, -) from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -46,6 +43,60 @@ ArgumentParser = object +class Registry(UserDict): + def __call__( + self, + cls: Optional[Type] = None, + key: Optional[str] = None, + override: bool = False, + ) -> Callable: + """ + Registers a class mapped to a name. + + Args: + cls: the class to be mapped. + key : the name that identifies the provided class. + """ + if key is None: + key = cls.__name__ + elif not isinstance(key, str): + raise TypeError(f"`key` must be a str, found {key}") + + if key in self and not override: + raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") + + def do_register(key, cls) -> Callable: + self[key] = cls + return cls + + do_register(key, cls) + + return do_register + + def register_package(self, module: ModuleType, base_cls: Type) -> None: + """This function is an utility to register all classes from a module.""" + for _, cls in inspect.getmembers(module, predicate=inspect.isclass): + if issubclass(cls, base_cls) and cls != base_cls: + self(cls=cls) + + def available_objects(self) -> List: + """Returns a list of registered objects""" + return list(self.keys()) + + def __str__(self) -> str: + return "Registered objects: {}".format(", ".join(self.keys())) + + +CALLBACK_REGISTRIES = Registry() +CALLBACK_REGISTRIES.register_package(pl.callbacks, pl.callbacks.Callback) + +OPTIMIZER_REGISTRIES = Registry() +OPTIMIZER_REGISTRIES.register_package(torch.optim, torch.optim.Optimizer) + +SCHEDULER_REGISTRIES = Registry() +SCHEDULER_REGISTRIES.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) + + @dataclass class ClassInfo: """This class is an helper to easily build the mocked command line""" @@ -520,6 +571,7 @@ def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) def add_configure_optimizers_method_to_model(self) -> None: """ Adds to the model an automatically generated ``configure_optimizers`` method. + If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a `configure_optimizers` method is automatically implemented in the model class. """ diff --git a/pytorch_lightning/utilities/cli_registries.py b/pytorch_lightning/utilities/cli_registries.py deleted file mode 100644 index 94b513f75957e..0000000000000 --- a/pytorch_lightning/utilities/cli_registries.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -from collections import UserDict -from typing import Callable, List, Optional, Type - -import torch - -import pytorch_lightning as pl -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class Registry(UserDict): - def __call__( - self, - cls: Optional[Type] = None, - key: Optional[str] = None, - override: bool = False, - ) -> Callable: - """ - Registers a plugin mapped to a name and with required metadata. - - Args: - key : the name that identifies a plugin, e.g. "deepspeed_stage_3" - value : plugin class - """ - if key is None: - key = cls.__name__ - elif not isinstance(key, str): - raise TypeError(f"`key` must be a str, found {key}") - - if key in self and not override: - raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") - - def do_register(key, cls) -> Callable: - self[key] = cls - return cls - - do_register(key, cls) - - return do_register - - def register_package(self, module, base_cls: Type) -> None: - for obj_name in dir(module): - obj_cls = getattr(module, obj_name) - if inspect.isclass(obj_cls) and issubclass(obj_cls, base_cls): - self(cls=obj_cls) - - def remove(self, name: str) -> None: - """Removes the registered plugin by name""" - self.pop(name) - - def available_objects(self) -> List: - """Returns a list of registered plugins""" - return list(self.keys()) - - def __str__(self) -> str: - return "Registered Plugins: {}".format(", ".join(self.keys())) - - -CALLBACK_REGISTRIES = Registry() -CALLBACK_REGISTRIES.register_package(pl.callbacks, pl.callbacks.Callback) - -OPTIMIZER_REGISTRIES = Registry() -OPTIMIZER_REGISTRIES.register_package(torch.optim, torch.optim.Optimizer) - -SCHEDULER_REGISTRIES = Registry() -SCHEDULER_REGISTRIES.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 11b870cd66d28..f396486f35196 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -32,8 +32,15 @@ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback -from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES +from pytorch_lightning.utilities.cli import ( + CALLBACK_REGISTRIES, + instantiate_class, + LightningArgumentParser, + LightningCLI, + OPTIMIZER_REGISTRIES, + SaveConfigCallback, + SCHEDULER_REGISTRIES, +) from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -663,9 +670,9 @@ def add_arguments_to_parser(self, parser): def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(self.optimizer_registered, nested_key="optim1", link_to="model.optim1") - parser.add_optimizer_args(torch.optim.SGD, nested_key="optim2", link_to="model.optim2") - parser.add_lr_scheduler_args(self.lr_scheduler_registered, link_to="model.scheduler") + parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1") + parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler") class TestModel(BoringModel): def __init__(self, optim1: dict, optim2: dict, scheduler: dict): @@ -677,11 +684,9 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", - "--optim1 Adam", - "--optim1.weight_decay 0.001", - "--optim2.lr=0.005", - "--lr_scheduler=ExponentialLR", - "--lr_scheduler.gamma=0.1", + "--optim2.class_path=torch.optim.SGD", + "--optim2.init_args.lr=0.01", + "--lr_scheduler.gamma=0.2", ] with mock.patch("sys.argv", ["any.py"] + cli_args): @@ -692,13 +697,62 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) +@pytest.mark.parametrize("use_registries", [False, True]) +def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to_2(use_registries, tmpdir): + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args( + self.optimizer_registered if use_registries else torch.optim.Adam, + nested_key="optim1", + link_to="model.optim1", + ) + parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") + parser.add_lr_scheduler_args( + self.lr_scheduler_registered if use_registries else torch.optim.lr_scheduler.ExponentialLR, + link_to="model.scheduler", + ) + + class TestModel(BoringModel): + def __init__(self, optim1: dict, optim2: dict, scheduler: dict): + super().__init__() + self.optim1 = instantiate_class(self.parameters(), optim1) + self.optim2 = instantiate_class(self.parameters(), optim2) + self.scheduler = instantiate_class(self.optim1, scheduler) + + if use_registries: + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--optim1 Adam", + "--optim1.weight_decay 0.001", + "--optim2=SGD", + "--optim2.lr=0.005", + "--lr_scheduler=ExponentialLR", + "--lr_scheduler.gamma=0.1", + ] + else: + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--optim2.class_path=torch.optim.SGD", + "--optim2.init_args.lr=0.01", + "--lr_scheduler.gamma=0.2", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(TestModel) + + assert isinstance(cli.model.optim1, torch.optim.Adam) + assert isinstance(cli.model.optim2, torch.optim.SGD) + assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) + + def test_registries(tmpdir): assert CALLBACK_REGISTRIES.available_objects() == [ "BackboneFinetuning", "BaseFinetuning", "BasePredictionWriter", - "Callback", "EarlyStopping", "GPUStatsMonitor", "GradientAccumulationScheduler", @@ -714,6 +768,34 @@ def test_registries(tmpdir): "XLAStatsMonitor", ] + assert OPTIMIZER_REGISTRIES.available_objects() == [ + "ASGD", + "Adadelta", + "Adagrad", + "Adam", + "AdamW", + "Adamax", + "LBFGS", + "RMSprop", + "Rprop", + "SGD", + "SparseAdam", + ] + + assert SCHEDULER_REGISTRIES.available_objects() == [ + "CosineAnnealingLR", + "CosineAnnealingWarmRestarts", + "CyclicLR", + "ExponentialLR", + "LambdaLR", + "MultiStepLR", + "MultiplicativeLR", + "OneCycleLR", + "StepLR", + ] + + +def test_registries_resolution(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() @@ -721,7 +803,6 @@ def __init__(self): callbacks = [ dict( class_path="pytorch_lightning.callbacks.Callback", - init_args=dict(), ), dict( class_path="pytorch_lightning.callbacks.Callback", From 705c0bd1f62a9a939e518bed13868227702aa01b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Aug 2021 13:50:08 +0000 Subject: [PATCH 11/77] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 8aade105c79ff..74ec8b8eebb55 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -832,7 +832,7 @@ def __init__(self): assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) assert len(cli.trainer.callbacks) == 4 - + @pytest.mark.parametrize("run", (False, True)) def test_lightning_cli_disabled_run(run): with mock.patch("sys.argv", ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock: From 78a439827ba24abc313ba1078b4690a8a9cde38a Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 16:26:09 +0200 Subject: [PATCH 12/77] add docs --- docs/source/common/lightning_cli.rst | 61 ++++++++++++++++++++++++++++ pytorch_lightning/utilities/cli.py | 6 +-- tests/utilities/test_cli.py | 42 ++++--------------- 3 files changed, 73 insertions(+), 36 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index a8d65ce0ec853..3eb0aff163f4d 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -312,6 +312,45 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr :class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured the same way using :code:`class_path` and :code:`init_args`. +Alternatively, the user can provide the list of callbacks directly from the command line. + +.. code-block:: bash + + $ python ... --trainer.callbacks=[{"class_path": pytorch_lightning.callbacks.EarlyStopping, "init_args": "patience": "5"}, ...]. + +Lightning optionally simplifies the user command line where only the :class:`~pytorch_lightning.callbacks.Callback` name is required. +The argument's order matters and the user needs to pass the arguments in the following way. +This is supported only for PyTorch Lightning built-in :class:`~pytorch_lightning.callbacks.Callback`. + +.. code-block:: bash + + $ python ... --trainer.callbacks={CALLBACK_NAME_1} --trainer.{CALLBACK_1_ARGS_1}=... --trainer.{CALLBACK_1_ARGS_2}=... --trainer.callbacks={CALLBACK_N} --trainer.{CALLBACK_N_ARGS_1}=... + +Here is an example: + +.. code-block:: bash + + $ python ... --trainer.callbacks=EarlyStopping --trainer.patience=5 --trainer.callbacks=LearningRateMonitor + +However, a user can register its own callbacks as follow. + +.. code-block:: python + + from pytorch_lightning.utilities.cli import CALLBACK_REGISTRIES + from pytorch_lightning.callbacks import Callback + + + @CALLBACK_REGISTRIES + class CustomCallback(Callback): + pass + + + cli = LightningCLI(...) + +.. code-block:: bash + + $ python ... --trainer.callbacks=CustomCallback ... + Multiple models and/or datasets ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -683,6 +722,12 @@ And the same through command line: $ python train.py --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01 +Optionally, the command line can be simplified for PyTorch build-in `optimizers` and `schedulers`: + +.. code-block:: bash + + $ python train.py --optimizer=Adam --optimizer.lr=0.01 + The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be: @@ -717,6 +762,22 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th cli = MyLightningCLI(MyModel) +For code simplification, the LightningCLI provides properties with already registered PyTorch built-in `optimizers` and `schedulers`. + +.. code-block:: + + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args( + self.registered_optimizers, + link_to="model.optimizer_init", + ) + parser.add_lr_scheduler_args( + self.registered_lr_schedulers, + link_to="model.lr_scheduler_init", + ) + + For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and :code:`init_args` entries. The function diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index fa6542d99c714..4ba88b8b17074 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -382,11 +382,11 @@ def __init__( self.after_fit() @property - def optimizer_registered(self) -> Tuple[Type[Optimizer]]: + def registered_optimizers(self) -> Tuple[Type[Optimizer]]: return tuple(OPTIMIZER_REGISTRIES.values()) @property - def lr_scheduler_registered(self) -> Tuple[LRSchedulerType]: + def registered_lr_schedulers(self) -> Tuple[LRSchedulerType]: return tuple(SCHEDULER_REGISTRIES.values()) def init_parser(self, **kwargs: Any) -> LightningArgumentParser: @@ -439,7 +439,7 @@ def link_optimizers_and_lr_schedulers(self) -> None: if re.match(fr"^--optimizer[^\S+=]*?{optim_name}?", v) ): if "optimizer" not in self.parser.groups: - self.parser.add_optimizer_args(self.optimizer_registered) + self.parser.add_optimizer_args(self.registered_optimizers) if any( True diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 74ec8b8eebb55..67a5c21a912b9 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -667,48 +667,18 @@ def add_arguments_to_parser(self, parser): assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 -def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1") - parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") - parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler") - - class TestModel(BoringModel): - def __init__(self, optim1: dict, optim2: dict, scheduler: dict): - super().__init__() - self.optim1 = instantiate_class(self.parameters(), optim1) - self.optim2 = instantiate_class(self.parameters(), optim2) - self.scheduler = instantiate_class(self.optim1, scheduler) - - cli_args = [ - f"--trainer.default_root_dir={tmpdir}", - "--trainer.max_epochs=1", - "--optim2.class_path=torch.optim.SGD", - "--optim2.init_args.lr=0.01", - "--lr_scheduler.gamma=0.2", - ] - - with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = MyLightningCLI(TestModel) - - assert isinstance(cli.model.optim1, torch.optim.Adam) - assert isinstance(cli.model.optim2, torch.optim.SGD) - assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) - - @pytest.mark.parametrize("use_registries", [False, True]) -def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to_2(use_registries, tmpdir): +def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args( - self.optimizer_registered if use_registries else torch.optim.Adam, + self.registered_optimizers if use_registries else torch.optim.Adam, nested_key="optim1", link_to="model.optim1", ) parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") parser.add_lr_scheduler_args( - self.lr_scheduler_registered if use_registries else torch.optim.lr_scheduler.ExponentialLR, + self.registered_lr_schedulers if use_registries else torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler", ) @@ -747,6 +717,11 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) +@CALLBACK_REGISTRIES +class CustomCallback(Callback): + pass + + def test_registries(tmpdir): assert CALLBACK_REGISTRIES.available_objects() == [ @@ -766,6 +741,7 @@ def test_registries(tmpdir): "StochasticWeightAveraging", "Timer", "XLAStatsMonitor", + "CustomCallback", ] assert OPTIMIZER_REGISTRIES.available_objects() == [ From e96dc28b24317e59763ac9937cec168b19a26860 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 16:34:44 +0200 Subject: [PATCH 13/77] doc updates --- docs/source/common/lightning_cli.rst | 34 +++++++++++++++++++++++----- pytorch_lightning/utilities/cli.py | 12 +++++----- tests/utilities/test_cli.py | 4 ++-- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 3eb0aff163f4d..d176767095a49 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -762,6 +762,14 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th cli = MyLightningCLI(MyModel) + +For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with +a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including +:code:`class_path` and :code:`init_args` entries. The function +:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in +:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the +:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. + For code simplification, the LightningCLI provides properties with already registered PyTorch built-in `optimizers` and `schedulers`. .. code-block:: @@ -777,13 +785,27 @@ For code simplification, the LightningCLI provides properties with already regis link_to="model.lr_scheduler_init", ) +However, a user can register its own optimizers or schedulers as follow. + +.. code-block:: python + + import torch + from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRIES, LR_SCHEDULER_REGISTRIES + from pytorch_lightning.callbacks import Callback + + + @CALLBACK_REGISTRIES + class CustomAdam(torch.optim.Adam): + pass + + + @LR_SCHEDULER_REGISTRIES + class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): + pass + + + cli = LightningCLI(...) -For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with -a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including -:code:`class_path` and :code:`init_args` entries. The function -:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in -:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the -:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. Notes related to reproducibility diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 4ba88b8b17074..5ccd68ab8b2dc 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -93,8 +93,8 @@ def __str__(self) -> str: OPTIMIZER_REGISTRIES = Registry() OPTIMIZER_REGISTRIES.register_package(torch.optim, torch.optim.Optimizer) -SCHEDULER_REGISTRIES = Registry() -SCHEDULER_REGISTRIES.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) +LR_SCHEDULER_REGISTRIES = Registry() +LR_SCHEDULER_REGISTRIES.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) @dataclass @@ -387,7 +387,7 @@ def registered_optimizers(self) -> Tuple[Type[Optimizer]]: @property def registered_lr_schedulers(self) -> Tuple[LRSchedulerType]: - return tuple(SCHEDULER_REGISTRIES.values()) + return tuple(LR_SCHEDULER_REGISTRIES.values()) def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" @@ -444,10 +444,10 @@ def link_optimizers_and_lr_schedulers(self) -> None: if any( True for v in sys.argv - for sch_name in SCHEDULER_REGISTRIES + for sch_name in LR_SCHEDULER_REGISTRIES if re.match(fr"^--lr_scheduler[^\S+=]*{sch_name}?", v) ): - lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values()) + lr_schdulers = tuple(v for v in LR_SCHEDULER_REGISTRIES.values()) if "lr_scheduler" not in self.parser.groups: self.parser.add_lr_scheduler_args(lr_schdulers) @@ -540,7 +540,7 @@ def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" # fmt: off with self.prepare_from_registry(OPTIMIZER_REGISTRIES), \ - self.prepare_from_registry(SCHEDULER_REGISTRIES), \ + self.prepare_from_registry(LR_SCHEDULER_REGISTRIES), \ self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRIES): self.config = parser.parse_args() # fmt: on diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 67a5c21a912b9..ac54033ec8adf 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -37,9 +37,9 @@ instantiate_class, LightningArgumentParser, LightningCLI, + LR_SCHEDULER_REGISTRIES, OPTIMIZER_REGISTRIES, SaveConfigCallback, - SCHEDULER_REGISTRIES, ) from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel @@ -758,7 +758,7 @@ def test_registries(tmpdir): "SparseAdam", ] - assert SCHEDULER_REGISTRIES.available_objects() == [ + assert LR_SCHEDULER_REGISTRIES.available_objects() == [ "CosineAnnealingLR", "CosineAnnealingWarmRestarts", "CyclicLR", From 631aa72a4ca803a97d990eef1500fb707381bfa8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 16:50:22 +0200 Subject: [PATCH 14/77] update --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/cli.py | 2 +- tests/utilities/test_cli.py | 17 +++++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3055b15011a2f..8601b8fc9aa80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) +- Added simplification for handling optimizers, schedulers and callbacks with `LightningCLI` ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 5ccd68ab8b2dc..237a31d495874 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -71,7 +71,7 @@ def do_register(key, cls) -> Callable: do_register(key, cls) - return do_register + return cls def register_package(self, module: ModuleType, base_cls: Type) -> None: """This function is an utility to register all classes from a module.""" diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index ac54033ec8adf..ce20172ed5809 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -816,3 +816,20 @@ def test_lightning_cli_disabled_run(run): fit_mock.call_count == run assert isinstance(cli.trainer, Trainer) assert isinstance(cli.model, LightningModule) + + +@pytest.mark.skipif(True, reason="typing from json-argparse is failing.") +def test_custom_callbacks(tmpdir): + class TestModel(BoringModel): + def on_fit_start(self): + callbacks = [c for c in self.trainer.callbacks if isinstance(c, CustomCallback)] + assert len(callbacks) == 1 + + with mock.patch("sys.argv", ["any.py"]): + LightningCLI( + TestModel, + trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True, callbacks=CustomCallback()), + ) + + with mock.patch("sys.argv", ["any.py", "--trainer.callbacks=[{'class_path': tests.utilities.CustomCallback}]"]): + LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) From 2fc3c0a30425b9aaca19a3fd620ee5979cf0bc02 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 17:36:39 +0200 Subject: [PATCH 15/77] update --- docs/source/common/lightning_cli.rst | 10 +++--- pytorch_lightning/utilities/cli.py | 28 ++++++++--------- tests/utilities/test_cli.py | 46 ++++++++++++++++++++++------ 3 files changed, 56 insertions(+), 28 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index d176767095a49..c2a95b9e59bab 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -336,11 +336,11 @@ However, a user can register its own callbacks as follow. .. code-block:: python - from pytorch_lightning.utilities.cli import CALLBACK_REGISTRIES + from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY from pytorch_lightning.callbacks import Callback - @CALLBACK_REGISTRIES + @CALLBACK_REGISTRY class CustomCallback(Callback): pass @@ -790,16 +790,16 @@ However, a user can register its own optimizers or schedulers as follow. .. code-block:: python import torch - from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRIES, LR_SCHEDULER_REGISTRIES + from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY from pytorch_lightning.callbacks import Callback - @CALLBACK_REGISTRIES + @OPTIMIZER_REGISTRY class CustomAdam(torch.optim.Adam): pass - @LR_SCHEDULER_REGISTRIES + @LR_SCHEDULER_REGISTRY class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): pass diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 237a31d495874..d306813c71409 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -87,14 +87,14 @@ def __str__(self) -> str: return "Registered objects: {}".format(", ".join(self.keys())) -CALLBACK_REGISTRIES = Registry() -CALLBACK_REGISTRIES.register_package(pl.callbacks, pl.callbacks.Callback) +CALLBACK_REGISTRY = Registry() +CALLBACK_REGISTRY.register_package(pl.callbacks, pl.callbacks.Callback) -OPTIMIZER_REGISTRIES = Registry() -OPTIMIZER_REGISTRIES.register_package(torch.optim, torch.optim.Optimizer) +OPTIMIZER_REGISTRY = Registry() +OPTIMIZER_REGISTRY.register_package(torch.optim, torch.optim.Optimizer) -LR_SCHEDULER_REGISTRIES = Registry() -LR_SCHEDULER_REGISTRIES.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) +LR_SCHEDULER_REGISTRY = Registry() +LR_SCHEDULER_REGISTRY.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) @dataclass @@ -383,11 +383,11 @@ def __init__( @property def registered_optimizers(self) -> Tuple[Type[Optimizer]]: - return tuple(OPTIMIZER_REGISTRIES.values()) + return tuple(OPTIMIZER_REGISTRY.values()) @property def registered_lr_schedulers(self) -> Tuple[LRSchedulerType]: - return tuple(LR_SCHEDULER_REGISTRIES.values()) + return tuple(LR_SCHEDULER_REGISTRY.values()) def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" @@ -435,7 +435,7 @@ def link_optimizers_and_lr_schedulers(self) -> None: if any( True for v in sys.argv - for optim_name in OPTIMIZER_REGISTRIES + for optim_name in OPTIMIZER_REGISTRY if re.match(fr"^--optimizer[^\S+=]*?{optim_name}?", v) ): if "optimizer" not in self.parser.groups: @@ -444,10 +444,10 @@ def link_optimizers_and_lr_schedulers(self) -> None: if any( True for v in sys.argv - for sch_name in LR_SCHEDULER_REGISTRIES + for sch_name in LR_SCHEDULER_REGISTRY if re.match(fr"^--lr_scheduler[^\S+=]*{sch_name}?", v) ): - lr_schdulers = tuple(v for v in LR_SCHEDULER_REGISTRIES.values()) + lr_schdulers = tuple(v for v in LR_SCHEDULER_REGISTRY.values()) if "lr_scheduler" not in self.parser.groups: self.parser.add_lr_scheduler_args(lr_schdulers) @@ -539,9 +539,9 @@ def prepare_class_list_from_registry(self, pattern: str, registry: Registry): def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" # fmt: off - with self.prepare_from_registry(OPTIMIZER_REGISTRIES), \ - self.prepare_from_registry(LR_SCHEDULER_REGISTRIES), \ - self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRIES): + with self.prepare_from_registry(OPTIMIZER_REGISTRY), \ + self.prepare_from_registry(LR_SCHEDULER_REGISTRY), \ + self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): self.config = parser.parse_args() # fmt: on diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index ce20172ed5809..8e89700ff7c3b 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -19,6 +19,7 @@ import sys from argparse import Namespace from contextlib import redirect_stdout +from importlib.util import module_from_spec, spec_from_file_location from io import StringIO from typing import List, Optional, Union from unittest import mock @@ -33,12 +34,12 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import ( - CALLBACK_REGISTRIES, + CALLBACK_REGISTRY, instantiate_class, LightningArgumentParser, LightningCLI, - LR_SCHEDULER_REGISTRIES, - OPTIMIZER_REGISTRIES, + LR_SCHEDULER_REGISTRY, + OPTIMIZER_REGISTRY, SaveConfigCallback, ) from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE @@ -717,14 +718,24 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) -@CALLBACK_REGISTRIES +@CALLBACK_REGISTRY class CustomCallback(Callback): pass +@OPTIMIZER_REGISTRY +class CustomAdam(torch.optim.Adam): + pass + + +@LR_SCHEDULER_REGISTRY +class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): + pass + + def test_registries(tmpdir): - assert CALLBACK_REGISTRIES.available_objects() == [ + assert CALLBACK_REGISTRY.available_objects() == [ "BackboneFinetuning", "BaseFinetuning", "BasePredictionWriter", @@ -744,7 +755,7 @@ def test_registries(tmpdir): "CustomCallback", ] - assert OPTIMIZER_REGISTRIES.available_objects() == [ + assert OPTIMIZER_REGISTRY.available_objects() == [ "ASGD", "Adadelta", "Adagrad", @@ -756,9 +767,10 @@ def test_registries(tmpdir): "Rprop", "SGD", "SparseAdam", + "CustomAdam", ] - assert LR_SCHEDULER_REGISTRIES.available_objects() == [ + assert LR_SCHEDULER_REGISTRY.available_objects() == [ "CosineAnnealingLR", "CosineAnnealingWarmRestarts", "CyclicLR", @@ -768,6 +780,7 @@ def test_registries(tmpdir): "MultiplicativeLR", "OneCycleLR", "StepLR", + "CustomCosineAnnealingLR", ] @@ -788,7 +801,7 @@ def __init__(self): cli_args = [ f"--trainer.default_root_dir={tmpdir}", - "--trainer.fast_dev_run=True", + "--trainer.fast_dev_run=1", "--trainer.progress_bar_refresh_rate=0", "--optimizer Adam", "--optimizer.lr 0.0001", @@ -831,5 +844,20 @@ def on_fit_start(self): trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True, callbacks=CustomCallback()), ) - with mock.patch("sys.argv", ["any.py", "--trainer.callbacks=[{'class_path': tests.utilities.CustomCallback}]"]): + code = """from pytorch_lightning.callbacks import Callback\nfrom pytorch_lightning.utilities.cli import CALLBACK_REGISTRY\n\nclass TestCallback(Callback):\n\tpass\n\nCALLBACK_REGISTRY(cls=TestCallback)""" # noqa E501 + + f = open(tmpdir / "test.py", "w") + f.write(code) + f.close() + + spec = spec_from_file_location("test", f.name) + mod = module_from_spec(spec) + sys.modules["test"] = mod + spec.loader.exec_module(mod) + callback_cls = getattr(mod, "TestCallback") + assert issubclass(callback_cls, Callback) + + callback = {"class_path": f"{tmpdir}.test.CustomCallback"} + + with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks=[{callback}]"]): LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) From 43dd8b49502983cb894cfc25b8d3eb715bd3d7d9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 17:42:39 +0200 Subject: [PATCH 16/77] resolve comments --- docs/source/common/lightning_cli.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index c2a95b9e59bab..f729a93e4cf09 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -722,7 +722,7 @@ And the same through command line: $ python train.py --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01 -Optionally, the command line can be simplified for PyTorch build-in `optimizers` and `schedulers`: +Optionally, the command line can be simplified for PyTorch built-in `optimizers` and `schedulers`: .. code-block:: bash @@ -785,7 +785,7 @@ For code simplification, the LightningCLI provides properties with already regis link_to="model.lr_scheduler_init", ) -However, a user can register its own optimizers or schedulers as follow. +However, a user can register its own optimizers or schedulers as follows. .. code-block:: python From c6ae6691bb9c2a27342080bbcf37ac0e1ecfb3b2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 17:49:24 +0200 Subject: [PATCH 17/77] comment --- tests/utilities/test_cli.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 8e89700ff7c3b..1f39c5e86c6b3 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -833,6 +833,10 @@ def test_lightning_cli_disabled_run(run): @pytest.mark.skipif(True, reason="typing from json-argparse is failing.") def test_custom_callbacks(tmpdir): + """ + Test that registered callbacks can be used with LightningCLI. + """ + class TestModel(BoringModel): def on_fit_start(self): callbacks = [c for c in self.trainer.callbacks if isinstance(c, CustomCallback)] From 5c21b1cf317db225266bd182d093dbe53ab38bab Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 18:06:52 +0200 Subject: [PATCH 18/77] add comment --- tests/utilities/test_cli.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 1f39c5e86c6b3..8caf5677ac708 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -785,6 +785,10 @@ def test_registries(tmpdir): def test_registries_resolution(tmpdir): + """ + This test validates registries are used when simplified command line are being used. + """ + class TestModel(BoringModel): def __init__(self): super().__init__() From b370deb6d94dcbd81f35ad5f266c2543dff5754f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 10 Aug 2021 20:06:43 +0200 Subject: [PATCH 19/77] typo --- tests/utilities/test_cli.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 8caf5677ac708..5283fca018a73 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -785,9 +785,7 @@ def test_registries(tmpdir): def test_registries_resolution(tmpdir): - """ - This test validates registries are used when simplified command line are being used. - """ + """This test validates registries are used when simplified command line are being used.""" class TestModel(BoringModel): def __init__(self): From f8e7ca7d2635b582f3f6e01851a66719008c8833 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Aug 2021 10:10:59 +0200 Subject: [PATCH 20/77] update on comments --- docs/source/common/lightning_cli.rst | 8 +++---- pytorch_lightning/utilities/cli.py | 36 ++++++++++++++-------------- tests/utilities/test_cli.py | 15 ++++++++---- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index f729a93e4cf09..d1e3bbf9ab681 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -318,7 +318,7 @@ Alternatively, the user can provide the list of callbacks directly from the comm $ python ... --trainer.callbacks=[{"class_path": pytorch_lightning.callbacks.EarlyStopping, "init_args": "patience": "5"}, ...]. -Lightning optionally simplifies the user command line where only the :class:`~pytorch_lightning.callbacks.Callback` name is required. +Lightning optionally simplifies the user command line so that only the :class:`~pytorch_lightning.callbacks.Callback` name is required. The argument's order matters and the user needs to pass the arguments in the following way. This is supported only for PyTorch Lightning built-in :class:`~pytorch_lightning.callbacks.Callback`. @@ -332,7 +332,7 @@ Here is an example: $ python ... --trainer.callbacks=EarlyStopping --trainer.patience=5 --trainer.callbacks=LearningRateMonitor -However, a user can register its own callbacks as follow. +However, a user can register their own callbacks as follows. .. code-block:: python @@ -770,7 +770,7 @@ a single class or a tuple of classes, the value given to :code:`optimizer_init` :code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the :code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. -For code simplification, the LightningCLI provides properties with already registered PyTorch built-in `optimizers` and `schedulers`. +For code simplification, the LightningCLI provides properties with PyTorch's built-in `optimizers` and `schedulers` already registered. .. code-block:: @@ -785,7 +785,7 @@ For code simplification, the LightningCLI provides properties with already regis link_to="model.lr_scheduler_init", ) -However, a user can register its own optimizers or schedulers as follows. +However, a user can register their own optimizers or schedulers as follows. .. code-block:: python diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d306813c71409..0c235f4fde2a7 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect import os -import re import sys from argparse import Namespace from collections import UserDict @@ -115,8 +114,7 @@ def class_init(self) -> Dict[str, str]: class_init["class_path"] = self.cls.__module__ + "." + self.cls.__name__ init_args = {} for init_arg in self.class_init_args: - separator = "=" if "=" in init_arg else " " - arg_path, value = init_arg.split(separator) + arg_path, value = init_arg.split("=") init_args[arg_path.split(".")[-1]] = value class_init["init_args"] = init_args return class_init @@ -363,6 +361,7 @@ def __init__( parser_kwargs = parser_kwargs or {} parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse}) + self.sanetize_argv() self.setup_parser(**parser_kwargs) self.link_optimizers_and_lr_schedulers() self.parse_arguments(self.parser) @@ -393,6 +392,18 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" return LightningArgumentParser(**kwargs) + def sanetize_argv(self) -> None: + args = [idx for idx, v in enumerate(sys.argv) if v.startswith("--")] + if len(args) > 0: + start_index = args[0] + argv = [] + for v in sys.argv[start_index:]: + if v.startswith("--"): + argv.append(v) + else: + argv[-1] += "=" + v + sys.argv = sys.argv[:start_index] + argv + def setup_parser(self, **kwargs: Any) -> None: """Initialize and setup the parser, and arguments.""" self.parser = self.init_parser(**kwargs) @@ -432,21 +443,11 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def link_optimizers_and_lr_schedulers(self) -> None: """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" - if any( - True - for v in sys.argv - for optim_name in OPTIMIZER_REGISTRY - if re.match(fr"^--optimizer[^\S+=]*?{optim_name}?", v) - ): + if any(True for v in sys.argv for optim_name in OPTIMIZER_REGISTRY if f"--optimizer={optim_name}" in v): if "optimizer" not in self.parser.groups: self.parser.add_optimizer_args(self.registered_optimizers) - if any( - True - for v in sys.argv - for sch_name in LR_SCHEDULER_REGISTRY - if re.match(fr"^--lr_scheduler[^\S+=]*{sch_name}?", v) - ): + if any(True for v in sys.argv for sch_name in LR_SCHEDULER_REGISTRY if f"--lr_scheduler={sch_name}" in v): lr_schdulers = tuple(v for v in LR_SCHEDULER_REGISTRY.values()) if "lr_scheduler" not in self.parser.groups: self.parser.add_lr_scheduler_args(lr_schdulers) @@ -470,9 +471,8 @@ def prepare_from_registry(self, registry: Registry): map_user_key_to_info = {} for registered_name, registered_cls in registry.items(): for v in sys.argv: - separator = "=" if "=" in v else " " - if f"{separator}{registered_name}" in v: - key = v.split(separator)[0] + if f"={registered_name}" in v: + key = v.split("=")[0] map_user_key_to_info[key] = ClassInfo(class_arg=v, cls=registered_cls) if len(map_user_key_to_info) > 0: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 5283fca018a73..a8cb8c887dff9 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -694,8 +694,10 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", - "--optim1 Adam", - "--optim1.weight_decay 0.001", + "--optim1", + "Adam", + "--optim1.weight_decay", + "0.001", "--optim2=SGD", "--optim2.lr=0.005", "--lr_scheduler=ExponentialLR", @@ -805,15 +807,18 @@ def __init__(self): f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", "--trainer.progress_bar_refresh_rate=0", - "--optimizer Adam", - "--optimizer.lr 0.0001", + "--optimizer", + "Adam", + "--optimizer.lr", + "0.0001", "--trainer.callbacks=LearningRateMonitor", "--trainer.callbacks.logging_interval=epoch", "--trainer.callbacks.log_momentum=True", "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=loss", f"--trainer.callbacks={json.dumps(callbacks)}", - "--lr_scheduler StepLR", + "--lr_scheduler", + "StepLR", "--lr_scheduler.step_size=50", ] From e428d2f86526d5ad4a059c0782f004d34872ac9c Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Aug 2021 11:23:09 +0200 Subject: [PATCH 21/77] resolve bug --- pytorch_lightning/utilities/cli.py | 3 ++- tests/utilities/test_cli.py | 36 ++++++++++++++++++------------ 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 0c235f4fde2a7..1c9c9d18732b1 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -528,7 +528,8 @@ def prepare_class_list_from_registry(self, pattern: str, registry: Registry): class_args = [info.class_init for info in infos] # add other callback arguments. - class_args.extend(eval(all_non_simplified_args[0].split("=")[-1])) + if len(all_non_simplified_args) > 0: + class_args.extend(eval(all_non_simplified_args[0].split("=")[-1])) argv += [f"{pattern}={class_args}"] with mock.patch("sys.argv", argv): diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index a8cb8c887dff9..21be22e927001 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -786,23 +786,14 @@ def test_registries(tmpdir): ] -def test_registries_resolution(tmpdir): +@pytest.mark.parametrize("use_class_path_callbacks", [False, True]) +def test_registries_resolution(use_class_path_callbacks, tmpdir): """This test validates registries are used when simplified command line are being used.""" class TestModel(BoringModel): def __init__(self): super().__init__() - callbacks = [ - dict( - class_path="pytorch_lightning.callbacks.Callback", - ), - dict( - class_path="pytorch_lightning.callbacks.Callback", - init_args=dict(), - ), - ] - cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", @@ -816,17 +807,34 @@ def __init__(self): "--trainer.callbacks.log_momentum=True", "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=loss", - f"--trainer.callbacks={json.dumps(callbacks)}", "--lr_scheduler", "StepLR", "--lr_scheduler.step_size=50", ] + expected_callbacks = 2 + + if use_class_path_callbacks: + + callbacks = [ + dict( + class_path="pytorch_lightning.callbacks.Callback", + ), + dict( + class_path="pytorch_lightning.callbacks.Callback", + init_args=dict(), + ), + ] + + cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"] + + expected_callbacks = 4 + with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = LightningCLI(TestModel) + cli = LightningCLI(BoringModel) assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.callbacks) == 4 + assert len(cli.trainer.callbacks) == expected_callbacks @pytest.mark.parametrize("run", (False, True)) From 3e979059fc7b6dc25e1224dee6ca0dbbe2ddc899 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Aug 2021 11:43:38 +0200 Subject: [PATCH 22/77] typo --- docs/source/common/lightning_cli.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index d1e3bbf9ab681..8b2491da524b3 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -324,13 +324,13 @@ This is supported only for PyTorch Lightning built-in :class:`~pytorch_lightning .. code-block:: bash - $ python ... --trainer.callbacks={CALLBACK_NAME_1} --trainer.{CALLBACK_1_ARGS_1}=... --trainer.{CALLBACK_1_ARGS_2}=... --trainer.callbacks={CALLBACK_N} --trainer.{CALLBACK_N_ARGS_1}=... + $ python ... --trainer.callbacks={CALLBACK_NAME_1} --trainer.callbacks.{CALLBACK_1_ARGS_1}=... --trainer.{CALLBACK_1_ARGS_2}=... --trainer.callbacks={CALLBACK_N} --trainer.callbacks.{CALLBACK_N_ARGS_1}=... Here is an example: .. code-block:: bash - $ python ... --trainer.callbacks=EarlyStopping --trainer.patience=5 --trainer.callbacks=LearningRateMonitor + $ python ... --trainer.callbacks=EarlyStopping --trainer.callbacks.patience=5 --trainer.callbacks=LearningRateMonitor However, a user can register their own callbacks as follows. From 3b1bdb6b0cc3ad280c21f170d39a0f4c03a6590d Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Aug 2021 11:46:21 +0200 Subject: [PATCH 23/77] update --- docs/source/common/lightning_cli.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 8b2491da524b3..11996557b8ba8 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -324,7 +324,7 @@ This is supported only for PyTorch Lightning built-in :class:`~pytorch_lightning .. code-block:: bash - $ python ... --trainer.callbacks={CALLBACK_NAME_1} --trainer.callbacks.{CALLBACK_1_ARGS_1}=... --trainer.{CALLBACK_1_ARGS_2}=... --trainer.callbacks={CALLBACK_N} --trainer.callbacks.{CALLBACK_N_ARGS_1}=... + $ python ... --trainer.callbacks={CALLBACK_NAME_1} --trainer.callbacks.{CALLBACK_1_ARGS_1}=... --trainer.callbacks.{CALLBACK_1_ARGS_2}=... --trainer.callbacks={CALLBACK_N} --trainer.callbacks.{CALLBACK_N_ARGS_1}=... Here is an example: From 0d1db29e68f07b715967e2fe2737a8a85c084c32 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Aug 2021 18:31:21 +0200 Subject: [PATCH 24/77] resolve comments --- docs/source/common/lightning_cli.rst | 8 +-- pytorch_lightning/plugins/plugins_registry.py | 3 +- pytorch_lightning/utilities/cli.py | 65 +++++++++---------- 3 files changed, 32 insertions(+), 44 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 11996557b8ba8..17992944072a5 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -312,12 +312,6 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr :class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured the same way using :code:`class_path` and :code:`init_args`. -Alternatively, the user can provide the list of callbacks directly from the command line. - -.. code-block:: bash - - $ python ... --trainer.callbacks=[{"class_path": pytorch_lightning.callbacks.EarlyStopping, "init_args": "patience": "5"}, ...]. - Lightning optionally simplifies the user command line so that only the :class:`~pytorch_lightning.callbacks.Callback` name is required. The argument's order matters and the user needs to pass the arguments in the following way. This is supported only for PyTorch Lightning built-in :class:`~pytorch_lightning.callbacks.Callback`. @@ -342,7 +336,7 @@ However, a user can register their own callbacks as follows. @CALLBACK_REGISTRY class CustomCallback(Callback): - pass + ... cli = LightningCLI(...) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 81307157f527e..c8896f4beedf6 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -13,7 +13,6 @@ # limitations under the License. import importlib import inspect -from collections import UserDict from inspect import getmembers, isclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -22,7 +21,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException -class _TrainingTypePluginsRegistry(UserDict): +class _TrainingTypePluginsRegistry(dict): """ This class is a Registry that stores information about the Training Type Plugins. diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 1c9c9d18732b1..0a8fe91cad3fc 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -15,7 +15,6 @@ import os import sys from argparse import Namespace -from collections import UserDict from contextlib import contextmanager from dataclasses import dataclass, field from functools import partial @@ -42,19 +41,19 @@ ArgumentParser = object -class Registry(UserDict): +class _Registry(dict): def __call__( self, cls: Optional[Type] = None, key: Optional[str] = None, override: bool = False, - ) -> Callable: + ) -> "Optional[Type]": """ Registers a class mapped to a name. Args: cls: the class to be mapped. - key : the name that identifies the provided class. + key: the name that identifies the provided class. """ if key is None: key = cls.__name__ @@ -64,12 +63,7 @@ def __call__( if key in self and not override: raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") - def do_register(key, cls) -> Callable: - self[key] = cls - return cls - - do_register(key, cls) - + self[key] = cls return cls def register_package(self, module: ModuleType, base_cls: Type) -> None: @@ -78,40 +72,40 @@ def register_package(self, module: ModuleType, base_cls: Type) -> None: if issubclass(cls, base_cls) and cls != base_cls: self(cls=cls) - def available_objects(self) -> List: + def available_objects(self) -> List[str]: """Returns a list of registered objects""" return list(self.keys()) def __str__(self) -> str: - return "Registered objects: {}".format(", ".join(self.keys())) + objects = ", ".join(self.keys()) + return f"Registered objects: {objects}" -CALLBACK_REGISTRY = Registry() +CALLBACK_REGISTRY = _Registry() CALLBACK_REGISTRY.register_package(pl.callbacks, pl.callbacks.Callback) -OPTIMIZER_REGISTRY = Registry() +OPTIMIZER_REGISTRY = _Registry() OPTIMIZER_REGISTRY.register_package(torch.optim, torch.optim.Optimizer) -LR_SCHEDULER_REGISTRY = Registry() +LR_SCHEDULER_REGISTRY = _Registry() LR_SCHEDULER_REGISTRY.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) @dataclass -class ClassInfo: +class _ClassInfo: """This class is an helper to easily build the mocked command line""" class_arg: str - cls: str + cls: Type class_init_args: List[str] = field(default_factory=lambda: []) - def add_class_init_args(self, args: Dict[str, str]) -> None: + def add_class_init_args(self, args: str) -> None: if args != self.class_arg: self.class_init_args.append(args) @property def class_init(self) -> Dict[str, str]: - class_init = {} - class_init["class_path"] = self.cls.__module__ + "." + self.cls.__name__ + class_init = {"class_path": self.cls.__module__ + "." + self.cls.__name__} init_args = {} for init_arg in self.class_init_args: arg_path, value = init_arg.split("=") @@ -361,7 +355,7 @@ def __init__( parser_kwargs = parser_kwargs or {} parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse}) - self.sanetize_argv() + self.sanitize_argv() self.setup_parser(**parser_kwargs) self.link_optimizers_and_lr_schedulers() self.parse_arguments(self.parser) @@ -392,17 +386,18 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" return LightningArgumentParser(**kwargs) - def sanetize_argv(self) -> None: + def sanitize_argv(self) -> None: args = [idx for idx, v in enumerate(sys.argv) if v.startswith("--")] - if len(args) > 0: - start_index = args[0] - argv = [] - for v in sys.argv[start_index:]: - if v.startswith("--"): - argv.append(v) - else: - argv[-1] += "=" + v - sys.argv = sys.argv[:start_index] + argv + if not args: + return + start_index = args[0] + argv = [] + for v in sys.argv[start_index:]: + if v.startswith("-"): + argv.append(v) + else: + argv[-1] += "=" + v + sys.argv = sys.argv[:start_index] + argv def setup_parser(self, **kwargs: Any) -> None: """Initialize and setup the parser, and arguments.""" @@ -462,7 +457,7 @@ def link_optimizers_and_lr_schedulers(self) -> None: self.parser.link_arguments(key, link_to, compute_fn=add_class_path) @contextmanager - def prepare_from_registry(self, registry: Registry): + def prepare_from_registry(self, registry: _Registry): """ This context manager is used to simplify unique class instantiation. """ @@ -473,7 +468,7 @@ def prepare_from_registry(self, registry: Registry): for v in sys.argv: if f"={registered_name}" in v: key = v.split("=")[0] - map_user_key_to_info[key] = ClassInfo(class_arg=v, cls=registered_cls) + map_user_key_to_info[key] = _ClassInfo(class_arg=v, cls=registered_cls) if len(map_user_key_to_info) > 0: # for each shortcut command line, add its init arguments and skip them from `sys.argv`. @@ -495,7 +490,7 @@ def prepare_from_registry(self, registry: Registry): yield @contextmanager - def prepare_class_list_from_registry(self, pattern: str, registry: Registry): + def prepare_class_list_from_registry(self, pattern: str, registry: _Registry): """ This context manager is used to simplify instantiation of a list of class. """ @@ -519,7 +514,7 @@ def prepare_class_list_from_registry(self, pattern: str, registry: Registry): for class_arg in all_cls_simplified_args: class_name = class_arg.split("=")[1] registered_cls = registry[class_name] - infos.append(ClassInfo(class_arg=class_arg, cls=registered_cls)) + infos.append(_ClassInfo(class_arg=class_arg, cls=registered_cls)) for v in all_simplified_args: if v in all_cls_simplified_args: From 3d35c82b4dd47aa230b18c67e2de8bc448d6220b Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Aug 2021 19:14:01 +0200 Subject: [PATCH 25/77] add unittesting --- pytorch_lightning/utilities/cli.py | 14 ++-- tests/utilities/test_cli.py | 105 +++++++++++++++++++++++++++-- 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 0a8fe91cad3fc..791781ddba146 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -98,6 +98,7 @@ class _ClassInfo: class_arg: str cls: Type class_init_args: List[str] = field(default_factory=lambda: []) + class_arg_idx: Optional[int] = None def add_class_init_args(self, args: str) -> None: if args != self.class_arg: @@ -511,14 +512,15 @@ def prepare_class_list_from_registry(self, pattern: str, registry: _Registry): # group arguments per callbacks infos = [] - for class_arg in all_cls_simplified_args: - class_name = class_arg.split("=")[1] - registered_cls = registry[class_name] - infos.append(_ClassInfo(class_arg=class_arg, cls=registered_cls)) + for class_arg_idx, class_arg in enumerate(all_simplified_args): + if class_arg in all_cls_simplified_args: + class_name = class_arg.split("=")[1] + registered_cls = registry[class_name] + infos.append(_ClassInfo(class_arg=class_arg, cls=registered_cls, class_arg_idx=class_arg_idx)) - for v in all_simplified_args: + for idx, v in enumerate(all_simplified_args): if v in all_cls_simplified_args: - current_info = infos[all_cls_simplified_args.index(v)] + current_info = [info for info in infos if idx == info.class_arg_idx][0] current_info.add_class_init_args(v) class_args = [info.class_init for info in infos] diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 21be22e927001..684679c8f37b2 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -790,10 +790,6 @@ def test_registries(tmpdir): def test_registries_resolution(use_class_path_callbacks, tmpdir): """This test validates registries are used when simplified command line are being used.""" - class TestModel(BoringModel): - def __init__(self): - super().__init__() - cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", @@ -880,3 +876,104 @@ def on_fit_start(self): with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks=[{callback}]"]): LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + + +def test_argv_modifiers(): + """ + This test validates ``sys.argv`` from `LightningCLI` are properly transforming the command line. + """ + + class TestLightningCLI(LightningCLI): + def __init__(self, *args, expected=None, **kwargs): + self.expected = expected + super().__init__(*args, **kwargs) + + def parse_arguments(self, parser: LightningArgumentParser) -> None: + with self.prepare_from_registry(OPTIMIZER_REGISTRY), self.prepare_from_registry( + LR_SCHEDULER_REGISTRY + ), self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): + self.config = parser.parse_args() + + base = ["any.py", "--trainer.max_epochs=1"] + + with mock.patch("sys.argv", base): + expected = base + TestLightningCLI(BoringModel, run=False, expected=expected) + + with mock.patch("sys.argv", base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]): + callbacks = [ + dict( + class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + init_args=dict(monitor="val_loss"), + ), + ] + expected = base + [ + f"--trainer.callbacks={str(callbacks)}", + ] + TestLightningCLI(BoringModel, run=False, expected=expected) + + cli_args = [ + "--trainer.callbacks=ModelCheckpoint", + "--trainer.callbacks.monitor=val_loss", + "--trainer.callbacks=ModelCheckpoint", + "--trainer.callbacks.monitor=val_acc", + ] + + with mock.patch("sys.argv", base + cli_args): + callbacks = [ + dict( + class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + init_args=dict(monitor="val_loss"), + ), + dict( + class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + init_args=dict(monitor="val_acc"), + ), + ] + expected = base + [f"--trainer.callbacks={str(callbacks)}"] + TestLightningCLI(BoringModel, run=False, expected=expected) + + cli_args = [ + "--trainer.callbacks=ModelCheckpoint", + "--trainer.callbacks.monitor=val_loss", + "--trainer.callbacks=ModelCheckpoint", + "--trainer.callbacks.monitor=val_acc", + "--trainer.callbacks=[{'class_path': 'pytorch_lightning.callbacks.Callback'}]", + ] + + with mock.patch("sys.argv", base + cli_args): + callbacks = [ + dict( + class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + init_args=dict(monitor="val_loss"), + ), + dict( + class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + init_args=dict(monitor="val_acc"), + ), + dict( + class_path="pytorch_lightning.callbacks.Callback", + ), + ] + expected = base + [f"--trainer.callbacks={str(callbacks)}"] + TestLightningCLI(BoringModel, run=False, expected=expected) + + with mock.patch("sys.argv", base + ["--optimizer", "Adadelta"]): + expected = base + ['--optimizer={"class_path":"torch.optim.Adadelta"}'] + TestLightningCLI(BoringModel, run=False, expected=expected) + + with mock.patch("sys.argv", base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"]): + expected = base + ['--optimizer={"class_path": "torch.optim.Adadelta", "init_args": {"lr": "10"}}'] + TestLightningCLI(BoringModel, run=False, expected=expected) + + with mock.patch("sys.argv", base + ["--lr_scheduler", "OneCycleLR"]): + expected = base + ['--lr_scheduler={"class_path": "torch.optim.lr_scheduler.OneCycleLR"}'] + TestLightningCLI(BoringModel, run=False, expected=expected) + + with mock.patch("sys.argv", base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"]): + lr_scheduler = dict( + class_path="torch.optim.lr_scheduler.OneCycleLR", + init_args=dict(anneal_strategy="linear"), + ) + expected = base + [f"--lr_scheduler={lr_scheduler}"] + TestLightningCLI(BoringModel, run=False, expected=expected) From 4c0f9602c634ccbc9475e066d2eee1370193c89e Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Aug 2021 19:23:15 +0200 Subject: [PATCH 26/77] resolve tests --- tests/utilities/test_cli.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 684679c8f37b2..e8f342dd85319 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -889,10 +889,13 @@ def __init__(self, *args, expected=None, **kwargs): super().__init__(*args, **kwargs) def parse_arguments(self, parser: LightningArgumentParser) -> None: - with self.prepare_from_registry(OPTIMIZER_REGISTRY), self.prepare_from_registry( - LR_SCHEDULER_REGISTRY - ), self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): + # fmt: off + with self.prepare_from_registry(OPTIMIZER_REGISTRY), \ + self.prepare_from_registry(LR_SCHEDULER_REGISTRY), \ + self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): + assert sys.argv == self.expected self.config = parser.parse_args() + # fmt: on base = ["any.py", "--trainer.max_epochs=1"] @@ -959,15 +962,27 @@ def parse_arguments(self, parser: LightningArgumentParser) -> None: TestLightningCLI(BoringModel, run=False, expected=expected) with mock.patch("sys.argv", base + ["--optimizer", "Adadelta"]): - expected = base + ['--optimizer={"class_path":"torch.optim.Adadelta"}'] + optimizer = dict( + class_path="torch.optim.adadelta.Adadelta", + init_args=dict(), + ) + expected = base + [f"--optimizer={optimizer}"] TestLightningCLI(BoringModel, run=False, expected=expected) with mock.patch("sys.argv", base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"]): - expected = base + ['--optimizer={"class_path": "torch.optim.Adadelta", "init_args": {"lr": "10"}}'] + optimizer = dict( + class_path="torch.optim.adadelta.Adadelta", + init_args=dict(lr="10"), + ) + expected = base + [f"--optimizer={optimizer}"] TestLightningCLI(BoringModel, run=False, expected=expected) with mock.patch("sys.argv", base + ["--lr_scheduler", "OneCycleLR"]): - expected = base + ['--lr_scheduler={"class_path": "torch.optim.lr_scheduler.OneCycleLR"}'] + lr_scheduler = dict( + class_path="torch.optim.lr_scheduler.OneCycleLR", + init_args=dict(), + ) + expected = base + [f"--lr_scheduler={lr_scheduler}"] TestLightningCLI(BoringModel, run=False, expected=expected) with mock.patch("sys.argv", base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"]): From d3a62ca0f15f41a60038a9deb834bb0f473c7d3d Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 12 Aug 2021 09:20:59 +0200 Subject: [PATCH 27/77] resolve comments --- pytorch_lightning/utilities/cli.py | 17 +++++++++-------- tests/utilities/test_cli.py | 6 +++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 791781ddba146..dc64b46bf5838 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -356,7 +356,7 @@ def __init__( parser_kwargs = parser_kwargs or {} parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse}) - self.sanitize_argv() + self._sanitize_argv() self.setup_parser(**parser_kwargs) self.link_optimizers_and_lr_schedulers() self.parse_arguments(self.parser) @@ -387,8 +387,9 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" return LightningArgumentParser(**kwargs) - def sanitize_argv(self) -> None: - args = [idx for idx, v in enumerate(sys.argv) if v.startswith("--")] + def _sanitize_argv(self) -> None: + """This function is used to replace space within `sys.argv` with its equal sign counter-part.""" + args = [idx for idx, v in enumerate(sys.argv) if v.startswith("-")] if not args: return start_index = args[0] @@ -458,7 +459,7 @@ def link_optimizers_and_lr_schedulers(self) -> None: self.parser.link_arguments(key, link_to, compute_fn=add_class_path) @contextmanager - def prepare_from_registry(self, registry: _Registry): + def _prepare_from_registry(self, registry: _Registry): """ This context manager is used to simplify unique class instantiation. """ @@ -491,7 +492,7 @@ def prepare_from_registry(self, registry: _Registry): yield @contextmanager - def prepare_class_list_from_registry(self, pattern: str, registry: _Registry): + def _prepare_class_list_from_registry(self, pattern: str, registry: _Registry): """ This context manager is used to simplify instantiation of a list of class. """ @@ -537,9 +538,9 @@ def prepare_class_list_from_registry(self, pattern: str, registry: _Registry): def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" # fmt: off - with self.prepare_from_registry(OPTIMIZER_REGISTRY), \ - self.prepare_from_registry(LR_SCHEDULER_REGISTRY), \ - self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): + with self._prepare_from_registry(OPTIMIZER_REGISTRY), \ + self._prepare_from_registry(LR_SCHEDULER_REGISTRY), \ + self._prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): self.config = parser.parse_args() # fmt: on diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e8f342dd85319..f552e3e353965 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -890,9 +890,9 @@ def __init__(self, *args, expected=None, **kwargs): def parse_arguments(self, parser: LightningArgumentParser) -> None: # fmt: off - with self.prepare_from_registry(OPTIMIZER_REGISTRY), \ - self.prepare_from_registry(LR_SCHEDULER_REGISTRY), \ - self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): + with self._prepare_from_registry(OPTIMIZER_REGISTRY), \ + self._prepare_from_registry(LR_SCHEDULER_REGISTRY), \ + self._prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): assert sys.argv == self.expected self.config = parser.parse_args() # fmt: on From 39781a15090efbfc04f9ba95308ca8f181770e26 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 15:36:58 +0200 Subject: [PATCH 28/77] update on comments --- docs/source/common/lightning_cli.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 17992944072a5..32d2b83d89ffc 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -326,7 +326,10 @@ Here is an example: $ python ... --trainer.callbacks=EarlyStopping --trainer.callbacks.patience=5 --trainer.callbacks=LearningRateMonitor -However, a user can register their own callbacks as follows. +Register your callbacks +^^^^^^^^^^^^^^^^^^^^^^^ + +Lightning provides registries for you to add your own callbacks and benefit from the command line simplification as described above: .. code-block:: python From 68c03dedb35ccc39978d872f34039e11f5e30c6c Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 15:42:51 +0200 Subject: [PATCH 29/77] doc updates --- docs/source/common/lightning_cli.rst | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 32d2b83d89ffc..12288762dae2c 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -767,22 +767,39 @@ a single class or a tuple of classes, the value given to :code:`optimizer_init` :code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the :code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. +Built in schedulers & optimizers and registering your own +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + For code simplification, the LightningCLI provides properties with PyTorch's built-in `optimizers` and `schedulers` already registered. +Only the optimizer or scheduler name needs to be passed along its arguments. + +.. code-block:: bash + + $ python train.py --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=CosineAnnealingLR + +If you model requires multiple optimizers, the LightningCLI provides already registered optimizers and schedulers under the properties `registered_optimizers` and `registered_lr_schedulers` + .. code-block:: class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args( self.registered_optimizers, + nested="gen_optimizer", link_to="model.optimizer_init", ) - parser.add_lr_scheduler_args( - self.registered_lr_schedulers, - link_to="model.lr_scheduler_init", + parser.add_optimizer_args( + self.registered_optimizers, + nested="gen_discriminator", + link_to="model.optimizer_init", ) -However, a user can register their own optimizers or schedulers as follows. +.. code-block:: bash + + $ python train.py --gen_optimizer=Adam --optimizer.lr=0.01 -gen_discriminator=Adam --optimizer.lr=0.0001 + +Furthermore, a user can register their own optimizers or schedulers as follows. .. code-block:: python @@ -804,7 +821,6 @@ However, a user can register their own optimizers or schedulers as follows. cli = LightningCLI(...) - Notes related to reproducibility ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From b01828b0431b2793c93e13b28f347379f1626e80 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 15:43:21 +0200 Subject: [PATCH 30/77] update --- docs/source/common/lightning_cli.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 12288762dae2c..e6bcd13bf9450 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -820,6 +820,10 @@ Furthermore, a user can register their own optimizers or schedulers as follows. cli = LightningCLI(...) +.. code-block:: bash + + $ python train.py --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR + Notes related to reproducibility ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 5935ec4944ff600f0ed25d3b664d8be39156af03 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 17 Aug 2021 19:14:47 +0100 Subject: [PATCH 31/77] update on comments --- docs/source/common/lightning_cli.rst | 2 +- pytorch_lightning/utilities/cli.py | 42 +++++++++++++++----- tests/utilities/test_cli.py | 57 ++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 11 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index e6bcd13bf9450..6e5a7f7df18c0 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -778,7 +778,7 @@ Only the optimizer or scheduler name needs to be passed along its arguments. $ python train.py --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=CosineAnnealingLR -If you model requires multiple optimizers, the LightningCLI provides already registered optimizers and schedulers under the properties `registered_optimizers` and `registered_lr_schedulers` +If your model requires multiple optimizers, the LightningCLI provides already registered optimizers and schedulers under the properties `registered_optimizers` and `registered_lr_schedulers` .. code-block:: diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index dc64b46bf5838..d35de5b988260 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -44,7 +44,7 @@ class _Registry(dict): def __call__( self, - cls: Optional[Type] = None, + cls: Type, key: Optional[str] = None, override: bool = False, ) -> "Optional[Type]": @@ -356,7 +356,6 @@ def __init__( parser_kwargs = parser_kwargs or {} parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse}) - self._sanitize_argv() self.setup_parser(**parser_kwargs) self.link_optimizers_and_lr_schedulers() self.parse_arguments(self.parser) @@ -387,18 +386,31 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" return LightningArgumentParser(**kwargs) - def _sanitize_argv(self) -> None: + def _sanitize_registry_argv(self) -> None: """This function is used to replace space within `sys.argv` with its equal sign counter-part.""" - args = [idx for idx, v in enumerate(sys.argv) if v.startswith("-")] + + def validate_arg(v: str) -> bool: + keys = {"--optimizer", "--lr_scheduler", "--trainer.callbacks"} + keys.update({f"--{key}" for key in self.parser.optimizers_and_lr_schedulers.keys()}) + return any(v.startswith(k) for k in keys) + + args = [idx for idx, v in enumerate(sys.argv) if validate_arg(v)] if not args: return start_index = args[0] argv = [] + should_add = False for v in sys.argv[start_index:]: - if v.startswith("-"): + if validate_arg(v): argv.append(v) + should_add = True else: - argv[-1] += "=" + v + if should_add and not v.startswith("--"): + argv[-1] += "=" + v + else: + argv.append(v) + should_add = False + sys.argv = sys.argv[:start_index] + argv def setup_parser(self, **kwargs: Any) -> None: @@ -438,13 +450,21 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser: The parser object to which arguments can be added """ + @staticmethod + def _contains_from_registry(pattern: str, registry: _Registry) -> bool: + return any(True for v in sys.argv for registered_name in registry if f"--{pattern}={registered_name}" in v) + def link_optimizers_and_lr_schedulers(self) -> None: """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" - if any(True for v in sys.argv for optim_name in OPTIMIZER_REGISTRY if f"--optimizer={optim_name}" in v): + + # sanetize registry arguments + self._sanitize_registry_argv() + + if self._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): if "optimizer" not in self.parser.groups: self.parser.add_optimizer_args(self.registered_optimizers) - if any(True for v in sys.argv for sch_name in LR_SCHEDULER_REGISTRY if f"--lr_scheduler={sch_name}" in v): + if self._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): lr_schdulers = tuple(v for v in LR_SCHEDULER_REGISTRY.values()) if "lr_scheduler" not in self.parser.groups: self.parser.add_lr_scheduler_args(lr_schdulers) @@ -468,8 +488,10 @@ def _prepare_from_registry(self, registry: _Registry): map_user_key_to_info = {} for registered_name, registered_cls in registry.items(): for v in sys.argv: - if f"={registered_name}" in v: - key = v.split("=")[0] + if "=" not in v: + continue + key, name = v.split("=") + if registered_name == name: map_user_key_to_info[key] = _ClassInfo(class_arg=v, cls=registered_cls) if len(map_user_key_to_info) > 0: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index f552e3e353965..c4f81ab3806d9 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -992,3 +992,60 @@ def parse_arguments(self, parser: LightningArgumentParser) -> None: ) expected = base + [f"--lr_scheduler={lr_scheduler}"] TestLightningCLI(BoringModel, run=False, expected=expected) + + class MyLightningCLI(TestLightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args( + self.registered_optimizers, + nested_key="optim1", + link_to="model.optim1", + ) + parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") + parser.add_lr_scheduler_args( + self.registered_lr_schedulers, + link_to="model.scheduler", + ) + + def parse_arguments(self, parser: LightningArgumentParser) -> None: + # fmt: off + with self._prepare_from_registry(OPTIMIZER_REGISTRY), \ + self._prepare_from_registry(LR_SCHEDULER_REGISTRY), \ + self._prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): + assert sys.argv == self.expected + raise Exception("Should raise") + + class TestModel(BoringModel): + def __init__(self, optim1: dict, optim2: dict, scheduler: dict): + super().__init__() + self.optim1 = instantiate_class(self.parameters(), optim1) + self.optim2 = instantiate_class(self.parameters(), optim2) + self.scheduler = instantiate_class(self.optim1, scheduler) + + cli_args = [ + "--lr_scheduler", "OneCycleLR", + "--optim1", "Adam", + "--optim1.lr=0.1", + "--optim2", "ASGD", + "--lr_scheduler.anneal_strategy=linear", + "--something", "a", "b", "c" + ] + + with pytest.raises(Exception, match='Should raise'), mock.patch("sys.argv", base + cli_args): + optim_2 = dict( + class_path="torch.optim.asgd.ASGD", + init_args=dict(), + ) + optim_1 = dict( + class_path="torch.optim.adam.Adam", + init_args=dict(lr="0.1"), + ) + lr_scheduler = dict( + class_path="torch.optim.lr_scheduler.OneCycleLR", + init_args=dict(anneal_strategy="linear"), + ) + expected = base + expected += ["--something", "a", "b", "c"] + expected += [f"--optim2={optim_2}"] + expected += [f"--optim1={optim_1}"] + expected += [f"--lr_scheduler={lr_scheduler}"] + MyLightningCLI(TestModel, run=False, expected=expected) From 37fd6793d18dcbc5577129483af90ffda700feae Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 13:22:39 +0200 Subject: [PATCH 32/77] Fix mypy --- pytorch_lightning/utilities/cli.py | 33 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d35de5b988260..05b6a2e472547 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -19,7 +19,7 @@ from dataclasses import dataclass, field from functools import partial from types import MethodType, ModuleType -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Type, TypedDict, Union from unittest import mock import torch @@ -100,19 +100,22 @@ class _ClassInfo: class_init_args: List[str] = field(default_factory=lambda: []) class_arg_idx: Optional[int] = None - def add_class_init_args(self, args: str) -> None: - if args != self.class_arg: - self.class_init_args.append(args) + class _ClassConfig(TypedDict): + class_path: str + init_args: Dict[str, str] + + def add_class_init_arg(self, arg: str) -> None: + if arg != self.class_arg: + self.class_init_args.append(arg) @property - def class_init(self) -> Dict[str, str]: - class_init = {"class_path": self.cls.__module__ + "." + self.cls.__name__} + def class_init(self) -> _ClassConfig: init_args = {} for init_arg in self.class_init_args: arg_path, value = init_arg.split("=") - init_args[arg_path.split(".")[-1]] = value - class_init["init_args"] = init_args - return class_init + key = arg_path.split(".")[-1] + init_args[key] = value + return self._ClassConfig(class_path=self.cls.__module__ + "." + self.cls.__name__, init_args=init_args) class LightningArgumentParser(ArgumentParser): @@ -375,11 +378,11 @@ def __init__( self.after_fit() @property - def registered_optimizers(self) -> Tuple[Type[Optimizer]]: + def registered_optimizers(self) -> Tuple[Type[Optimizer], ...]: return tuple(OPTIMIZER_REGISTRY.values()) @property - def registered_lr_schedulers(self) -> Tuple[LRSchedulerType]: + def registered_lr_schedulers(self) -> Tuple[LRSchedulerType, ...]: return tuple(LR_SCHEDULER_REGISTRY.values()) def init_parser(self, **kwargs: Any) -> LightningArgumentParser: @@ -479,7 +482,7 @@ def link_optimizers_and_lr_schedulers(self) -> None: self.parser.link_arguments(key, link_to, compute_fn=add_class_path) @contextmanager - def _prepare_from_registry(self, registry: _Registry): + def _prepare_from_registry(self, registry: _Registry) -> Generator[None, None, None]: """ This context manager is used to simplify unique class instantiation. """ @@ -502,7 +505,7 @@ def _prepare_from_registry(self, registry: _Registry): for key in map_user_key_to_info: if key in v: skip = True - map_user_key_to_info[key].add_class_init_args(v) + map_user_key_to_info[key].add_class_init_arg(v) if not skip: argv.append(v) @@ -514,7 +517,7 @@ def _prepare_from_registry(self, registry: _Registry): yield @contextmanager - def _prepare_class_list_from_registry(self, pattern: str, registry: _Registry): + def _prepare_class_list_from_registry(self, pattern: str, registry: _Registry) -> Generator[None, None, None]: """ This context manager is used to simplify instantiation of a list of class. """ @@ -544,7 +547,7 @@ def _prepare_class_list_from_registry(self, pattern: str, registry: _Registry): for idx, v in enumerate(all_simplified_args): if v in all_cls_simplified_args: current_info = [info for info in infos if idx == info.class_arg_idx][0] - current_info.add_class_init_args(v) + current_info.add_class_init_arg(v) class_args = [info.class_init for info in infos] # add other callback arguments. From f16db3d5a96fd3e888ddeb9a866296f2e20e2ca2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 13:27:45 +0200 Subject: [PATCH 33/77] Revert unrelated change which had broken mypy --- pytorch_lightning/utilities/cli.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 05b6a2e472547..f98ae67ed7b14 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -17,7 +17,6 @@ from argparse import Namespace from contextlib import contextmanager from dataclasses import dataclass, field -from functools import partial from types import MethodType, ModuleType from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Type, TypedDict, Union from unittest import mock @@ -645,25 +644,17 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: if not isinstance(lr_scheduler_class, tuple): lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) - configure_optimizers = partial( - self.configure_optimizers, optimizer_init=optimizer_init, lr_scheduler_init=lr_scheduler_init - ) - configure_optimizers.__code__ = self.model.configure_optimizers.__code__ + def configure_optimizers( + self: LightningModule, + ) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]: + optimizer = instantiate_class(self.parameters(), optimizer_init) + if not lr_scheduler_init: + return optimizer + lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) + return [optimizer], [lr_scheduler] self.model.configure_optimizers = MethodType(configure_optimizers, self.model) - @staticmethod - def configure_optimizers( - pl_module: LightningModule, - optimizer_init: Union[str, List[str]], - lr_scheduler_init: Optional[Union[str, List[str]]] = None, - ) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]: - optimizer = instantiate_class(pl_module.parameters(), optimizer_init) - if not lr_scheduler_init: - return optimizer - lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) - return [optimizer], [lr_scheduler] - def prepare_fit_kwargs(self) -> None: """Prepares fit_kwargs including datamodule using self.config_init['data'] if given""" self.fit_kwargs = {"model": self.model} From 572488c37a455e02f73acf1582c2e62373b1f907 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 14:31:50 +0200 Subject: [PATCH 34/77] Convert to staticmethod --- pytorch_lightning/utilities/cli.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index f98ae67ed7b14..fc7b38311c438 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -100,6 +100,8 @@ class _ClassInfo: class_arg_idx: Optional[int] = None class _ClassConfig(TypedDict): + """Defines the config structure that ``jsonargparse`` uses for instantiation""" + class_path: str init_args: Dict[str, str] @@ -388,12 +390,13 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" return LightningArgumentParser(**kwargs) - def _sanitize_registry_argv(self) -> None: - """This function is used to replace space within `sys.argv` with its equal sign counter-part.""" + @staticmethod + def _sanitize_argv(optimizers_and_lr_schedulers: List[str]) -> None: + """This function is used to replace ```` in ``sys.argv`` with ``=``.""" def validate_arg(v: str) -> bool: keys = {"--optimizer", "--lr_scheduler", "--trainer.callbacks"} - keys.update({f"--{key}" for key in self.parser.optimizers_and_lr_schedulers.keys()}) + keys.update({f"--{key}" for key in optimizers_and_lr_schedulers}) return any(v.startswith(k) for k in keys) args = [idx for idx, v in enumerate(sys.argv) if validate_arg(v)] @@ -459,8 +462,7 @@ def _contains_from_registry(pattern: str, registry: _Registry) -> bool: def link_optimizers_and_lr_schedulers(self) -> None: """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" - # sanetize registry arguments - self._sanitize_registry_argv() + self._sanitize_argv(list(self.parser.optimizers_and_lr_schedulers)) if self._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): if "optimizer" not in self.parser.groups: From 2fc46084d27416f1fea53b2ebfc7c8b73d29e856 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 14:49:24 +0200 Subject: [PATCH 35/77] Replace context managers for functional static transformations --- pytorch_lightning/utilities/cli.py | 54 ++++++++++++------------------ 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index fc7b38311c438..81e7c73d3901d 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -15,14 +15,14 @@ import os import sys from argparse import Namespace -from contextlib import contextmanager from dataclasses import dataclass, field from types import MethodType, ModuleType -from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Type, TypedDict, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union from unittest import mock import torch from torch.optim import Optimizer +from typing_extensions import TypedDict import pytorch_lightning as pl from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer @@ -482,12 +482,8 @@ def link_optimizers_and_lr_schedulers(self) -> None: add_class_path = _add_class_path_generator(class_type) self.parser.link_arguments(key, link_to, compute_fn=add_class_path) - @contextmanager - def _prepare_from_registry(self, registry: _Registry) -> Generator[None, None, None]: - """ - This context manager is used to simplify unique class instantiation. - """ - + @staticmethod + def _prepare_from_registry(argv: List[str], registry: _Registry) -> List[str]: # find if the users is using shortcut command line. map_user_key_to_info = {} for registered_name, registered_cls in registry.items(): @@ -500,30 +496,25 @@ def _prepare_from_registry(self, registry: _Registry) -> Generator[None, None, N if len(map_user_key_to_info) > 0: # for each shortcut command line, add its init arguments and skip them from `sys.argv`. - argv = [] - for v in sys.argv: + out = [] + for v in argv: skip = False for key in map_user_key_to_info: if key in v: skip = True map_user_key_to_info[key].add_class_init_arg(v) if not skip: - argv.append(v) + out.append(v) # re-create the global command line and mock `sys.argv`. - argv += [f"{user_key}={info.class_init}" for user_key, info in map_user_key_to_info.items()] - with mock.patch("sys.argv", argv): - yield - else: - yield + out += [f"{user_key}={info.class_init}" for user_key, info in map_user_key_to_info.items()] + return out + return argv - @contextmanager - def _prepare_class_list_from_registry(self, pattern: str, registry: _Registry) -> Generator[None, None, None]: - """ - This context manager is used to simplify instantiation of a list of class. - """ - argv = [v for v in sys.argv if pattern not in v] - all_matched_args = [v for v in sys.argv if pattern in v] + @staticmethod + def _prepare_class_list_from_registry(argv: List[str], pattern: str, registry: _Registry) -> List[str]: + out = [v for v in argv if pattern not in v] + all_matched_args = [v for v in argv if pattern in v] all_simplified_args = [v for v in all_matched_args if f"{pattern}" in v and f"{pattern}=[" not in v] all_cls_simplified_args = [v for v in all_simplified_args if f"{pattern}=" in v] all_non_simplified_args = [v for v in all_matched_args if f"{pattern}=" in v and f"{pattern}=[" in v] @@ -555,20 +546,17 @@ def _prepare_class_list_from_registry(self, pattern: str, registry: _Registry) - if len(all_non_simplified_args) > 0: class_args.extend(eval(all_non_simplified_args[0].split("=")[-1])) - argv += [f"{pattern}={class_args}"] - with mock.patch("sys.argv", argv): - yield - else: - yield + out += [f"{pattern}={class_args}"] + return out + return argv def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" - # fmt: off - with self._prepare_from_registry(OPTIMIZER_REGISTRY), \ - self._prepare_from_registry(LR_SCHEDULER_REGISTRY), \ - self._prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): + argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) + argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + with mock.patch("sys.argv", argv): self.config = parser.parse_args() - # fmt: on def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes.""" From 9f383dc93e5caa7816fd5bc1f8d5c2123762bdea Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 15:32:05 +0200 Subject: [PATCH 36/77] Split tests --- tests/utilities/test_cli.py | 228 +++++++++++++++++------------------- 1 file changed, 110 insertions(+), 118 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 530ade4e2d575..e9f35ac95caef 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -812,13 +812,8 @@ def test_registries_resolution(use_class_path_callbacks, tmpdir): if use_class_path_callbacks: callbacks = [ - dict( - class_path="pytorch_lightning.callbacks.Callback", - ), - dict( - class_path="pytorch_lightning.callbacks.Callback", - init_args=dict(), - ), + dict(class_path="pytorch_lightning.callbacks.Callback"), + dict(class_path="pytorch_lightning.callbacks.Callback", init_args=dict()), ] cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"] @@ -877,141 +872,141 @@ def on_fit_start(self): LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) -def test_argv_modifiers(): - """ - This test validates ``sys.argv`` from `LightningCLI` are properly transforming the command line. - """ - - class TestLightningCLI(LightningCLI): - def __init__(self, *args, expected=None, **kwargs): - self.expected = expected - super().__init__(*args, **kwargs) +def test_argv_transformation_noop(): + base = ["any.py", "--trainer.max_epochs=1"] + argv = LightningCLI._prepare_from_registry(base, OPTIMIZER_REGISTRY) + assert argv == base + argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + assert argv == base + argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + assert argv == base - def parse_arguments(self, parser: LightningArgumentParser) -> None: - # fmt: off - with self._prepare_from_registry(OPTIMIZER_REGISTRY), \ - self._prepare_from_registry(LR_SCHEDULER_REGISTRY), \ - self._prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): - assert sys.argv == self.expected - self.config = parser.parse_args() - # fmt: on +def test_argv_transformation_single_callback(): base = ["any.py", "--trainer.max_epochs=1"] + input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"] + callbacks = [ + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_loss"}, + } + ] + expected = base + [f"--trainer.callbacks={str(callbacks)}"] + argv = LightningCLI._prepare_from_registry(input, OPTIMIZER_REGISTRY) + assert argv == input + argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + assert argv == input + argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + assert argv == expected - with mock.patch("sys.argv", base): - expected = base - TestLightningCLI(BoringModel, run=False, expected=expected) - with mock.patch("sys.argv", base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]): - callbacks = [ - dict( - class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", - init_args=dict(monitor="val_loss"), - ), - ] - expected = base + [ - f"--trainer.callbacks={str(callbacks)}", - ] - TestLightningCLI(BoringModel, run=False, expected=expected) - - cli_args = [ +def test_argv_transformation_multiple_callbacks(): + base = ["any.py", "--trainer.max_epochs=1"] + input = base + [ "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss", "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_acc", ] + callbacks = [ + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_loss"}, + }, + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_acc"}, + }, + ] + expected = base + [f"--trainer.callbacks={str(callbacks)}"] + argv = LightningCLI._prepare_from_registry(input, OPTIMIZER_REGISTRY) + assert argv == input + argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + assert argv == input + argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + assert argv == expected - with mock.patch("sys.argv", base + cli_args): - callbacks = [ - dict( - class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", - init_args=dict(monitor="val_loss"), - ), - dict( - class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", - init_args=dict(monitor="val_acc"), - ), - ] - expected = base + [f"--trainer.callbacks={str(callbacks)}"] - TestLightningCLI(BoringModel, run=False, expected=expected) - cli_args = [ +def test_argv_transformation_multiple_callbacks_with_config(): + base = ["any.py", "--trainer.max_epochs=1"] + input = base + [ "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss", "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_acc", "--trainer.callbacks=[{'class_path': 'pytorch_lightning.callbacks.Callback'}]", ] + callbacks = [ + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_loss"}, + }, + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_acc"}, + }, + {"class_path": "pytorch_lightning.callbacks.Callback"}, + ] + expected = base + [f"--trainer.callbacks={str(callbacks)}"] + argv = LightningCLI._prepare_from_registry(input, OPTIMIZER_REGISTRY) + assert argv == input + argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + assert argv == input + argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + assert argv == expected - with mock.patch("sys.argv", base + cli_args): - callbacks = [ - dict( - class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", - init_args=dict(monitor="val_loss"), - ), - dict( - class_path="pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", - init_args=dict(monitor="val_acc"), - ), - dict( - class_path="pytorch_lightning.callbacks.Callback", - ), - ] - expected = base + [f"--trainer.callbacks={str(callbacks)}"] - TestLightningCLI(BoringModel, run=False, expected=expected) + +def test_argv_modifiers(): + """ + This test validates ``sys.argv`` from `LightningCLI` are properly transforming the command line. + """ + + class TestLightningCLI(LightningCLI): + def __init__(self, *args, expected=None, **kwargs): + self.expected = expected + super().__init__(*args, **kwargs) + + def parse_arguments(self, parser: LightningArgumentParser) -> None: + argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) + argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + assert argv == self.expected + super().parse_arguments(parser) + + base = ["any.py", "--trainer.max_epochs=1"] with mock.patch("sys.argv", base + ["--optimizer", "Adadelta"]): - optimizer = dict( - class_path="torch.optim.adadelta.Adadelta", - init_args=dict(), - ) + optimizer = dict(class_path="torch.optim.adadelta.Adadelta", init_args=dict()) expected = base + [f"--optimizer={optimizer}"] TestLightningCLI(BoringModel, run=False, expected=expected) with mock.patch("sys.argv", base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"]): - optimizer = dict( - class_path="torch.optim.adadelta.Adadelta", - init_args=dict(lr="10"), - ) + optimizer = dict(class_path="torch.optim.adadelta.Adadelta", init_args=dict(lr="10")) expected = base + [f"--optimizer={optimizer}"] TestLightningCLI(BoringModel, run=False, expected=expected) with mock.patch("sys.argv", base + ["--lr_scheduler", "OneCycleLR"]): - lr_scheduler = dict( - class_path="torch.optim.lr_scheduler.OneCycleLR", - init_args=dict(), - ) + lr_scheduler = dict(class_path="torch.optim.lr_scheduler.OneCycleLR", init_args=dict()) expected = base + [f"--lr_scheduler={lr_scheduler}"] TestLightningCLI(BoringModel, run=False, expected=expected) with mock.patch("sys.argv", base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"]): - lr_scheduler = dict( - class_path="torch.optim.lr_scheduler.OneCycleLR", - init_args=dict(anneal_strategy="linear"), - ) + lr_scheduler = dict(class_path="torch.optim.lr_scheduler.OneCycleLR", init_args=dict(anneal_strategy="linear")) expected = base + [f"--lr_scheduler={lr_scheduler}"] TestLightningCLI(BoringModel, run=False, expected=expected) class MyLightningCLI(TestLightningCLI): def add_arguments_to_parser(self, parser): - parser.add_optimizer_args( - self.registered_optimizers, - nested_key="optim1", - link_to="model.optim1", - ) + parser.add_optimizer_args(self.registered_optimizers, nested_key="optim1", link_to="model.optim1") parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") - parser.add_lr_scheduler_args( - self.registered_lr_schedulers, - link_to="model.scheduler", - ) + parser.add_lr_scheduler_args(self.registered_lr_schedulers, link_to="model.scheduler") def parse_arguments(self, parser: LightningArgumentParser) -> None: - # fmt: off - with self._prepare_from_registry(OPTIMIZER_REGISTRY), \ - self._prepare_from_registry(LR_SCHEDULER_REGISTRY), \ - self._prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRY): - assert sys.argv == self.expected - raise Exception("Should raise") + argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) + argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + assert argv == self.expected + raise Exception("Should raise") class TestModel(BoringModel): def __init__(self, optim1: dict, optim2: dict, scheduler: dict): @@ -1021,27 +1016,24 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): self.scheduler = instantiate_class(self.optim1, scheduler) cli_args = [ - "--lr_scheduler", "OneCycleLR", - "--optim1", "Adam", + "--lr_scheduler", + "OneCycleLR", + "--optim1", + "Adam", "--optim1.lr=0.1", - "--optim2", "ASGD", + "--optim2", + "ASGD", "--lr_scheduler.anneal_strategy=linear", - "--something", "a", "b", "c" + "--something", + "a", + "b", + "c", ] - with pytest.raises(Exception, match='Should raise'), mock.patch("sys.argv", base + cli_args): - optim_2 = dict( - class_path="torch.optim.asgd.ASGD", - init_args=dict(), - ) - optim_1 = dict( - class_path="torch.optim.adam.Adam", - init_args=dict(lr="0.1"), - ) - lr_scheduler = dict( - class_path="torch.optim.lr_scheduler.OneCycleLR", - init_args=dict(anneal_strategy="linear"), - ) + with pytest.raises(Exception, match="Should raise"), mock.patch("sys.argv", base + cli_args): + optim_2 = dict(class_path="torch.optim.asgd.ASGD", init_args=dict()) + optim_1 = dict(class_path="torch.optim.adam.Adam", init_args=dict(lr="0.1")) + lr_scheduler = dict(class_path="torch.optim.lr_scheduler.OneCycleLR", init_args=dict(anneal_strategy="linear")) expected = base expected += ["--something", "a", "b", "c"] expected += [f"--optim2={optim_2}"] From 2a7dfa88db3b1a69f3b9d1b3c6a872209675b6e0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 15:59:55 +0200 Subject: [PATCH 37/77] Refactor optimizer tests --- tests/utilities/test_cli.py | 101 +++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e9f35ac95caef..2b0ecbae07886 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -956,57 +956,53 @@ def test_argv_transformation_multiple_callbacks_with_config(): assert argv == expected -def test_argv_modifiers(): - """ - This test validates ``sys.argv`` from `LightningCLI` are properly transforming the command line. - """ - +def test_argv_transformations_with_optimizers_and_lr_schedulers(): class TestLightningCLI(LightningCLI): - def __init__(self, *args, expected=None, **kwargs): + def __init__(self, expected, *args): self.expected = expected - super().__init__(*args, **kwargs) + super().__init__(*args, run=False) - def parse_arguments(self, parser: LightningArgumentParser) -> None: + def before_instantiate_classes(self): argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == self.expected - super().parse_arguments(parser) base = ["any.py", "--trainer.max_epochs=1"] - with mock.patch("sys.argv", base + ["--optimizer", "Adadelta"]): - optimizer = dict(class_path="torch.optim.adadelta.Adadelta", init_args=dict()) - expected = base + [f"--optimizer={optimizer}"] - TestLightningCLI(BoringModel, run=False, expected=expected) - - with mock.patch("sys.argv", base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"]): - optimizer = dict(class_path="torch.optim.adadelta.Adadelta", init_args=dict(lr="10")) - expected = base + [f"--optimizer={optimizer}"] - TestLightningCLI(BoringModel, run=False, expected=expected) - - with mock.patch("sys.argv", base + ["--lr_scheduler", "OneCycleLR"]): - lr_scheduler = dict(class_path="torch.optim.lr_scheduler.OneCycleLR", init_args=dict()) - expected = base + [f"--lr_scheduler={lr_scheduler}"] - TestLightningCLI(BoringModel, run=False, expected=expected) - - with mock.patch("sys.argv", base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"]): - lr_scheduler = dict(class_path="torch.optim.lr_scheduler.OneCycleLR", init_args=dict(anneal_strategy="linear")) - expected = base + [f"--lr_scheduler={lr_scheduler}"] - TestLightningCLI(BoringModel, run=False, expected=expected) - - class MyLightningCLI(TestLightningCLI): + input = base + ["--optimizer", "Adadelta"] + optimizer = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}} + expected = base + [f"--optimizer={optimizer}"] + with mock.patch("sys.argv", input): + TestLightningCLI(expected, BoringModel) + + input = base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"] + optimizer = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}} + expected = base + [f"--optimizer={optimizer}"] + with mock.patch("sys.argv", input): + TestLightningCLI(expected, BoringModel) + + input = base + ["--lr_scheduler", "OneCycleLR"] + lr_scheduler = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}} + expected = base + [f"--lr_scheduler={lr_scheduler}"] + with mock.patch("sys.argv", input): + TestLightningCLI(expected, BoringModel) + + input = base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"] + lr_scheduler = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}} + expected = base + [f"--lr_scheduler={lr_scheduler}"] + with mock.patch("sys.argv", input): + TestLightningCLI(expected, BoringModel) + + class TestLightningCLI2(TestLightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(self.registered_optimizers, nested_key="optim1", link_to="model.optim1") parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") parser.add_lr_scheduler_args(self.registered_lr_schedulers, link_to="model.scheduler") + parser.add_argument("--something", type=str, nargs="+") - def parse_arguments(self, parser: LightningArgumentParser) -> None: - argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) - argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) - assert argv == self.expected - raise Exception("Should raise") + def instantiate_classes(self): + pass class TestModel(BoringModel): def __init__(self, optim1: dict, optim2: dict, scheduler: dict): @@ -1015,12 +1011,13 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): self.optim2 = instantiate_class(self.parameters(), optim2) self.scheduler = instantiate_class(self.optim1, scheduler) - cli_args = [ + input = base + [ "--lr_scheduler", "OneCycleLR", + "--lr_scheduler.total_steps=10", "--optim1", "Adam", - "--optim1.lr=0.1", + "--optim2.lr=0.1", "--optim2", "ASGD", "--lr_scheduler.anneal_strategy=linear", @@ -1029,14 +1026,20 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): "b", "c", ] - - with pytest.raises(Exception, match="Should raise"), mock.patch("sys.argv", base + cli_args): - optim_2 = dict(class_path="torch.optim.asgd.ASGD", init_args=dict()) - optim_1 = dict(class_path="torch.optim.adam.Adam", init_args=dict(lr="0.1")) - lr_scheduler = dict(class_path="torch.optim.lr_scheduler.OneCycleLR", init_args=dict(anneal_strategy="linear")) - expected = base - expected += ["--something", "a", "b", "c"] - expected += [f"--optim2={optim_2}"] - expected += [f"--optim1={optim_1}"] - expected += [f"--lr_scheduler={lr_scheduler}"] - MyLightningCLI(TestModel, run=False, expected=expected) + optim_1 = {"class_path": "torch.optim.adam.Adam", "init_args": {}} + optim_2 = {"class_path": "torch.optim.asgd.ASGD", "init_args": {"lr": "0.1"}} + lr_scheduler = { + "class_path": "torch.optim.lr_scheduler.OneCycleLR", + "init_args": {"total_steps": "10", "anneal_strategy": "linear"}, + } + expected = base + [ + "--something", + "a", + "b", + "c", + f"--optim2={optim_2}", + f"--optim1={optim_1}", + f"--lr_scheduler={lr_scheduler}", + ] + with mock.patch("sys.argv", input): + TestLightningCLI2(expected, TestModel) From 423ab7b6e15ce125959777fc5f3b114a7044d204 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 16:22:48 +0200 Subject: [PATCH 38/77] Cleaning tests --- tests/utilities/test_cli.py | 68 ++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 2b0ecbae07886..7400b315d6ba8 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -30,7 +30,7 @@ from packaging import version from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ProgressBar from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import ( @@ -689,36 +689,38 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): self.optim2 = instantiate_class(self.parameters(), optim2) self.scheduler = instantiate_class(self.optim1, scheduler) + cli_args = [f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--lr_scheduler.gamma=0.2"] if use_registries: - cli_args = [ - f"--trainer.default_root_dir={tmpdir}", - "--trainer.max_epochs=1", + cli_args += [ "--optim1", "Adam", "--optim1.weight_decay", "0.001", "--optim2=SGD", - "--optim2.lr=0.005", + "--optim2.lr=0.01", "--lr_scheduler=ExponentialLR", - "--lr_scheduler.gamma=0.1", ] else: - cli_args = [ - f"--trainer.default_root_dir={tmpdir}", - "--trainer.max_epochs=1", - "--optim2.class_path=torch.optim.SGD", - "--optim2.init_args.lr=0.01", - "--lr_scheduler.gamma=0.2", - ] + cli_args += ["--optim2.class_path=torch.optim.SGD", "--optim2.init_args.lr=0.01"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(TestModel) assert isinstance(cli.model.optim1, torch.optim.Adam) assert isinstance(cli.model.optim2, torch.optim.SGD) + assert cli.model.optim2.param_groups[0]["lr"] == 0.01 assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) +@pytest.mark.parametrize("run", (False, True)) +def test_lightning_cli_disabled_run(run): + with mock.patch("sys.argv", ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock: + cli = LightningCLI(BoringModel, run=run) + fit_mock.call_count == run + assert isinstance(cli.trainer, Trainer) + assert isinstance(cli.model, LightningModule) + + @CALLBACK_REGISTRY class CustomCallback(Callback): pass @@ -735,7 +737,6 @@ class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): def test_registries(tmpdir): - assert CALLBACK_REGISTRY.available_objects() == [ "BackboneFinetuning", "BaseFinetuning", @@ -786,13 +787,9 @@ def test_registries(tmpdir): @pytest.mark.parametrize("use_class_path_callbacks", [False, True]) -def test_registries_resolution(use_class_path_callbacks, tmpdir): +def test_registries_resolution(use_class_path_callbacks): """This test validates registries are used when simplified command line are being used.""" - cli_args = [ - f"--trainer.default_root_dir={tmpdir}", - "--trainer.fast_dev_run=1", - "--trainer.progress_bar_refresh_rate=0", "--optimizer", "Adam", "--optimizer.lr", @@ -807,33 +804,28 @@ def test_registries_resolution(use_class_path_callbacks, tmpdir): "--lr_scheduler.step_size=50", ] - expected_callbacks = 2 - + extras = [] if use_class_path_callbacks: - callbacks = [ - dict(class_path="pytorch_lightning.callbacks.Callback"), - dict(class_path="pytorch_lightning.callbacks.Callback", init_args=dict()), + {"class_path": "pytorch_lightning.callbacks.Callback"}, + {"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}}, ] - cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"] - - expected_callbacks = 4 + extras = [Callback, Callback] with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = LightningCLI(BoringModel) + cli = LightningCLI(BoringModel, run=False) - assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.callbacks) == expected_callbacks + optimizers, lr_scheduler = cli.model.configure_optimizers() + assert isinstance(optimizers[0], torch.optim.Adam) + assert optimizers[0].param_groups[0]["lr"] == 0.0001 + assert lr_scheduler[0].step_size == 50 - -@pytest.mark.parametrize("run", (False, True)) -def test_lightning_cli_disabled_run(run): - with mock.patch("sys.argv", ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock: - cli = LightningCLI(BoringModel, run=run) - fit_mock.call_count == run - assert isinstance(cli.trainer, Trainer) - assert isinstance(cli.model, LightningModule) + assert [type(c) for c in cli.trainer.callbacks] == [LearningRateMonitor] + extras + [ + SaveConfigCallback, + ProgressBar, + ModelCheckpoint, + ] @pytest.mark.skipif(True, reason="typing from json-argparse is failing.") From 7c2e39e9802702458e5d52f2c407047cfc34742f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 16:24:09 +0200 Subject: [PATCH 39/77] Delete broken test --- tests/utilities/test_cli.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 7400b315d6ba8..10aa841254973 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -19,7 +19,6 @@ import sys from argparse import Namespace from contextlib import redirect_stdout -from importlib.util import module_from_spec, spec_from_file_location from io import StringIO from typing import List, Optional, Union from unittest import mock @@ -828,42 +827,6 @@ def test_registries_resolution(use_class_path_callbacks): ] -@pytest.mark.skipif(True, reason="typing from json-argparse is failing.") -def test_custom_callbacks(tmpdir): - """ - Test that registered callbacks can be used with LightningCLI. - """ - - class TestModel(BoringModel): - def on_fit_start(self): - callbacks = [c for c in self.trainer.callbacks if isinstance(c, CustomCallback)] - assert len(callbacks) == 1 - - with mock.patch("sys.argv", ["any.py"]): - LightningCLI( - TestModel, - trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True, callbacks=CustomCallback()), - ) - - code = """from pytorch_lightning.callbacks import Callback\nfrom pytorch_lightning.utilities.cli import CALLBACK_REGISTRY\n\nclass TestCallback(Callback):\n\tpass\n\nCALLBACK_REGISTRY(cls=TestCallback)""" # noqa E501 - - f = open(tmpdir / "test.py", "w") - f.write(code) - f.close() - - spec = spec_from_file_location("test", f.name) - mod = module_from_spec(spec) - sys.modules["test"] = mod - spec.loader.exec_module(mod) - callback_cls = getattr(mod, "TestCallback") - assert issubclass(callback_cls, Callback) - - callback = {"class_path": f"{tmpdir}.test.CustomCallback"} - - with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks=[{callback}]"]): - LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) - - def test_argv_transformation_noop(): base = ["any.py", "--trainer.max_epochs=1"] argv = LightningCLI._prepare_from_registry(base, OPTIMIZER_REGISTRY) From 048e1597480eaf3542bc6ccc1e9a513e5dec40a1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 16:32:37 +0200 Subject: [PATCH 40/77] Docs improvements --- docs/source/common/lightning_cli.rst | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 6e5a7f7df18c0..e858af1c57868 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -759,7 +759,6 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th cli = MyLightningCLI(MyModel) - For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and :code:`init_args` entries. The function @@ -770,15 +769,16 @@ a single class or a tuple of classes, the value given to :code:`optimizer_init` Built in schedulers & optimizers and registering your own ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -For code simplification, the LightningCLI provides properties with PyTorch's built-in `optimizers` and `schedulers` already registered. - +For code simplification, the CLI provides properties with PyTorch's built-in optimizers and learning rate schedulers +already registered. Only the optimizer or scheduler name needs to be passed along its arguments. .. code-block:: bash $ python train.py --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=CosineAnnealingLR -If your model requires multiple optimizers, the LightningCLI provides already registered optimizers and schedulers under the properties `registered_optimizers` and `registered_lr_schedulers` +If your model requires multiple optimizers, you can choose from all available optimizers and learning rate schedulers +by accessing `self.registered_optimizers` and `self.registered_lr_schedulers` respectively. .. code-block:: @@ -786,20 +786,20 @@ If your model requires multiple optimizers, the LightningCLI provides already re def add_arguments_to_parser(self, parser): parser.add_optimizer_args( self.registered_optimizers, - nested="gen_optimizer", + nested_key="gen_optimizer", link_to="model.optimizer_init", ) parser.add_optimizer_args( self.registered_optimizers, - nested="gen_discriminator", + nested_key="gen_discriminator", link_to="model.optimizer_init", ) .. code-block:: bash - $ python train.py --gen_optimizer=Adam --optimizer.lr=0.01 -gen_discriminator=Adam --optimizer.lr=0.0001 + $ python train.py --gen_optimizer=Adam --optimizer.lr=0.01 --gen_discriminator=Adam --optimizer.lr=0.0001 -Furthermore, a user can register their own optimizers or schedulers as follows. +Furthermore, you can register your own optimizers and/or learning rate schedulers as follows: .. code-block:: python @@ -810,12 +810,12 @@ Furthermore, a user can register their own optimizers or schedulers as follows. @OPTIMIZER_REGISTRY class CustomAdam(torch.optim.Adam): - pass + ... @LR_SCHEDULER_REGISTRY class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): - pass + ... cli = LightningCLI(...) From 86fce55762d0cc9e9de4a895040cf1388a27eccf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 16:37:40 +0200 Subject: [PATCH 41/77] Docs improvements --- docs/source/common/lightning_cli.rst | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index e858af1c57868..ab4a6b8f4119c 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -312,19 +312,30 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr :class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured the same way using :code:`class_path` and :code:`init_args`. -Lightning optionally simplifies the user command line so that only the :class:`~pytorch_lightning.callbacks.Callback` name is required. -The argument's order matters and the user needs to pass the arguments in the following way. +Lightning optionally simplifies the user command line so that only the :class:`~pytorch_lightning.callbacks.Callback` +name is required. The argument's order matters and the user needs to pass the arguments in the following way. This is supported only for PyTorch Lightning built-in :class:`~pytorch_lightning.callbacks.Callback`. .. code-block:: bash - $ python ... --trainer.callbacks={CALLBACK_NAME_1} --trainer.callbacks.{CALLBACK_1_ARGS_1}=... --trainer.callbacks.{CALLBACK_1_ARGS_2}=... --trainer.callbacks={CALLBACK_N} --trainer.callbacks.{CALLBACK_N_ARGS_1}=... + $ python ... \ + --trainer.callbacks={CALLBACK_1_NAME} \ + --trainer.callbacks.{CALLBACK_1_ARGS_1}=... \ + --trainer.callbacks.{CALLBACK_1_ARGS_2}=... \ + ... + --trainer.callbacks={CALLBACK_N_NAME} \ + --trainer.callbacks.{CALLBACK_N_ARGS_1}=... \ + ... Here is an example: .. code-block:: bash - $ python ... --trainer.callbacks=EarlyStopping --trainer.callbacks.patience=5 --trainer.callbacks=LearningRateMonitor + $ python ... \ + --trainer.callbacks=EarlyStopping \ + --trainer.callbacks.patience=5 \ + --trainer.callbacks=LearningRateMonitor \ + --trainer.callbacks.logging_interval=epoch Register your callbacks ^^^^^^^^^^^^^^^^^^^^^^^ From 624b0d80723abeb4956023cdf877fa212b384695 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 17:59:56 +0200 Subject: [PATCH 42/77] Restructure docs --- docs/source/common/lightning_cli.rst | 185 ++++++++++++--------------- 1 file changed, 82 insertions(+), 103 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index ab4a6b8f4119c..77d20394fcafc 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -4,7 +4,7 @@ import torch from unittest import mock from typing import List - from pytorch_lightning import LightningModule, LightningDataModule, Trainer + from pytorch_lightning import LightningModule, LightningDataModule, Trainer, Callback from pytorch_lightning.utilities.cli import LightningCLI cli_fit = LightningCLI.fit @@ -665,10 +665,79 @@ Optimizers and learning rate schedulers ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Optimizers and learning rate schedulers can also be made configurable. The most common case is when a model only has a -single optimizer and optionally a single learning rate scheduler. In this case the model's -:class:`~pytorch_lightning.core.lightning.LightningModule` could be left without implementing the -:code:`configure_optimizers` method since it is normally always the same and just adds boilerplate. The following code -snippet shows how to implement it: +single optimizer and optionally a single learning rate scheduler. In this case, the model's +:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` could be left unimplemented since it is +normally always the same and just adds boilerplate. + +The CLI works out-of-the-box with PyTorch's built-in optimizers and learning rate schedulers when +at most one of each is used. +Only the optimizer or scheduler name needs to be passed, optionally with its ``__init__`` arguments: + +.. code-block:: bash + + $ python train.py --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1 + +A corresponding example of the config file would be: + +.. code-block:: yaml + + optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.01 + lr_scheduler: + class_path: torch.optim.lr_scheduler.ExponentialLR + init_args: + gamma: 0.1 + +Furthermore, you can register your own optimizers and/or learning rate schedulers as follows: + +.. code-block:: python + + from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY + + + @OPTIMIZER_REGISTRY + class CustomAdam(torch.optim.Adam): + ... + + + @LR_SCHEDULER_REGISTRY + class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): + ... + + + cli = LightningCLI(...) + +.. code-block:: bash + + $ python train.py --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR + +If you need to customize the key names or link arguments together, you can choose from all available optimizers and +learning rate schedulers by accessing `self.registered_optimizers` and `self.registered_lr_schedulers` respectively. + +.. code-block:: + + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args( + self.registered_optimizers, + nested_key="gen_optimizer", + link_to="model.optimizer_init", + ) + parser.add_optimizer_args( + self.registered_optimizers, + nested_key="gen_discriminator", + link_to="model.optimizer_init", + ) + +.. code-block:: bash + + $ python train.py --gen_optimizer=Adam --optimizer.lr=0.01 --gen_discriminator=Adam --optimizer.lr=0.0001 + +If you will not be changing the class, you can manually add the arguments for specific optimizers and/or +learning rate schedulers by subclassing the CLI. This has the advantage of providing the proper help message for those +classes. The following code snippet shows how to implement it: .. testcode:: @@ -684,9 +753,9 @@ snippet shows how to implement it: cli = MyLightningCLI(MyModel) -With this the :code:`configure_optimizers` method is automatically implemented and in the config the :code:`optimizer` -and :code:`lr_scheduler` groups would accept all of the options for the given classes, in this example :code:`Adam` and -:code:`ExponentialLR`. Therefore, the config file would be structured like: +With this, in the config the :code:`optimizer` and :code:`lr_scheduler` groups would accept all of the options for the +given classes, in this example :code:`Adam` and :code:`ExponentialLR`. +Therefore, the config file would be structured like: .. code-block:: yaml @@ -705,37 +774,6 @@ And any of these arguments could be passed directly through command line. For ex $ python train.py --optimizer.lr=0.01 --lr_scheduler.gamma=0.2 -There is also the possibility of selecting among multiple classes by giving them as a tuple. For example: - -.. testcode:: - - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) - -In this case in the config the :code:`optimizer` group instead of having directly init settings, it should specify -:code:`class_path` and optionally :code:`init_args`. Sub-classes of the classes in the tuple would also be accepted. -A corresponding example of the config file would be: - -.. code-block:: yaml - - optimizer: - class_path: torch.optim.Adam - init_args: - lr: 0.01 - -And the same through command line: - -.. code-block:: bash - - $ python train.py --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01 - -Optionally, the command line can be simplified for PyTorch built-in `optimizers` and `schedulers`: - -.. code-block:: bash - - $ python train.py --optimizer=Adam --optimizer.lr=0.01 - The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be: @@ -770,70 +808,11 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th cli = MyLightningCLI(MyModel) -For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with -a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including -:code:`class_path` and :code:`init_args` entries. The function -:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in -:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the -:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. - -Built in schedulers & optimizers and registering your own -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For code simplification, the CLI provides properties with PyTorch's built-in optimizers and learning rate schedulers -already registered. -Only the optimizer or scheduler name needs to be passed along its arguments. - -.. code-block:: bash - - $ python train.py --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=CosineAnnealingLR - -If your model requires multiple optimizers, you can choose from all available optimizers and learning rate schedulers -by accessing `self.registered_optimizers` and `self.registered_lr_schedulers` respectively. - -.. code-block:: - - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_optimizer_args( - self.registered_optimizers, - nested_key="gen_optimizer", - link_to="model.optimizer_init", - ) - parser.add_optimizer_args( - self.registered_optimizers, - nested_key="gen_discriminator", - link_to="model.optimizer_init", - ) - -.. code-block:: bash - - $ python train.py --gen_optimizer=Adam --optimizer.lr=0.01 --gen_discriminator=Adam --optimizer.lr=0.0001 - -Furthermore, you can register your own optimizers and/or learning rate schedulers as follows: - -.. code-block:: python - - import torch - from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY - from pytorch_lightning.callbacks import Callback - - - @OPTIMIZER_REGISTRY - class CustomAdam(torch.optim.Adam): - ... - - - @LR_SCHEDULER_REGISTRY - class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): - ... - - - cli = LightningCLI(...) - -.. code-block:: bash - - $ python train.py --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR +The value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and +:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class` +takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments, +in this case :code:`self.parameters()`, and the :code:`init_args`. +Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. Notes related to reproducibility From 2cc0dc59bcc8285de966770ae9b422b6bdbaddb1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 18:11:30 +0200 Subject: [PATCH 43/77] Docs for callbacks --- docs/source/common/lightning_cli.rst | 29 +++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 77d20394fcafc..b0c2dc680c984 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -312,9 +312,9 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr :class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured the same way using :code:`class_path` and :code:`init_args`. -Lightning optionally simplifies the user command line so that only the :class:`~pytorch_lightning.callbacks.Callback` -name is required. The argument's order matters and the user needs to pass the arguments in the following way. -This is supported only for PyTorch Lightning built-in :class:`~pytorch_lightning.callbacks.Callback`. +For callbacks in particular, Lightning simplifies the command line so that only +the :class:`~pytorch_lightning.callbacks.Callback` name is required. +The argument's order matters and the user needs to pass the arguments in the following way. .. code-block:: bash @@ -337,15 +337,12 @@ Here is an example: --trainer.callbacks=LearningRateMonitor \ --trainer.callbacks.logging_interval=epoch -Register your callbacks -^^^^^^^^^^^^^^^^^^^^^^^ - -Lightning provides registries for you to add your own callbacks and benefit from the command line simplification as described above: +Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification +as described above: .. code-block:: python from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY - from pytorch_lightning.callbacks import Callback @CALLBACK_REGISTRY @@ -359,6 +356,15 @@ Lightning provides registries for you to add your own callbacks and benefit from $ python ... --trainer.callbacks=CustomCallback ... +This callback will be included in the generated config: + +.. code-block:: yaml + + trainer: + callbacks: + - class_path: your_class_path.CustomCallback + init_args: + ... Multiple models and/or datasets ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -510,9 +516,10 @@ instantiating the trainer class can be found in :code:`self.config['trainer']`. Configurable callbacks ^^^^^^^^^^^^^^^^^^^^^^ -As explained previously, any callback can be added by including it in the config via :code:`class_path` and -:code:`init_args` entries. However, there are other cases in which a callback should always be present and be -configurable. This can be implemented as follows: +As explained previously, any Lightning callback can be added by passing it through command line or +including it in the config via :code:`class_path` and :code:`init_args` entries. +However, there are other cases in which a callback should always be present and be configurable. +This can be implemented as follows: .. testcode:: From f9b49febeb220af6d9c7f6ad2df2089317513af0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 18:53:20 +0200 Subject: [PATCH 44/77] Add reload test when add_optimizer_args is added by the user --- tests/utilities/test_cli.py | 84 +++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 10aa841254973..1ab6ede6d7bb1 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -949,31 +949,43 @@ def before_instantiate_classes(self): with mock.patch("sys.argv", input): TestLightningCLI(expected, BoringModel) - class TestLightningCLI2(TestLightningCLI): + +def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir): + class TestLightningCLI(LightningCLI): + def __init__(self, *args): + super().__init__(*args, run=False) + def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(self.registered_optimizers, nested_key="optim1", link_to="model.optim1") - parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") - parser.add_lr_scheduler_args(self.registered_lr_schedulers, link_to="model.scheduler") + parser.add_optimizer_args(self.registered_optimizers, nested_key="opt1", link_to="model.opt1_config") + parser.add_optimizer_args( + (torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config" + ) + parser.add_lr_scheduler_args(self.registered_lr_schedulers, link_to="model.sch_config") parser.add_argument("--something", type=str, nargs="+") - def instantiate_classes(self): - pass - class TestModel(BoringModel): - def __init__(self, optim1: dict, optim2: dict, scheduler: dict): + def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict): super().__init__() - self.optim1 = instantiate_class(self.parameters(), optim1) - self.optim2 = instantiate_class(self.parameters(), optim2) - self.scheduler = instantiate_class(self.optim1, scheduler) + self.opt1_config = opt1_config + self.opt2_config = opt2_config + self.sch_config = sch_config + opt1 = instantiate_class(self.parameters(), opt1_config) + assert isinstance(opt1, torch.optim.Adam) + opt2 = instantiate_class(self.parameters(), opt2_config) + assert isinstance(opt2, torch.optim.ASGD) + sch = instantiate_class(opt1, sch_config) + assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR) + base = ["any.py", "--trainer.max_epochs=1"] input = base + [ "--lr_scheduler", "OneCycleLR", "--lr_scheduler.total_steps=10", - "--optim1", + "--lr_scheduler.max_lr=1", + "--opt1", "Adam", - "--optim2.lr=0.1", - "--optim2", + "--opt2.lr=0.1", + "--opt2", "ASGD", "--lr_scheduler.anneal_strategy=linear", "--something", @@ -981,20 +993,30 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): "b", "c", ] - optim_1 = {"class_path": "torch.optim.adam.Adam", "init_args": {}} - optim_2 = {"class_path": "torch.optim.asgd.ASGD", "init_args": {"lr": "0.1"}} - lr_scheduler = { - "class_path": "torch.optim.lr_scheduler.OneCycleLR", - "init_args": {"total_steps": "10", "anneal_strategy": "linear"}, - } - expected = base + [ - "--something", - "a", - "b", - "c", - f"--optim2={optim_2}", - f"--optim1={optim_1}", - f"--lr_scheduler={lr_scheduler}", - ] - with mock.patch("sys.argv", input): - TestLightningCLI2(expected, TestModel) + + # save config + out = StringIO() + with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): + TestLightningCLI(TestModel) + + # validate yaml + yaml_config = out.getvalue() + dict_config = yaml.safe_load(yaml_config) + assert dict_config["opt1"]["class_path"] == "torch.optim.adam.Adam" + assert dict_config["opt2"]["class_path"] == "torch.optim.asgd.ASGD" + assert dict_config["opt2"]["init_args"]["lr"] == 0.1 + assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear" + assert dict_config["something"] == ["a", "b", "c"] + + # reload config + yaml_config_file = tmpdir / "config.yaml" + yaml_config_file.write_text(yaml_config, "utf-8") + with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): + cli = TestLightningCLI(TestModel) + + assert cli.model.opt1_config["class_path"] == "torch.optim.adam.Adam" + assert cli.model.opt2_config["class_path"] == "torch.optim.asgd.ASGD" + assert cli.model.opt2_config["init_args"]["lr"] == 0.1 + assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear" From afcc4ba902f9a8ad16c986f87f1674eec681566f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 18:59:08 +0200 Subject: [PATCH 45/77] Add failing config test - needs to be fixed --- tests/utilities/test_cli.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 1ab6ede6d7bb1..c3b9a8f1c7f3a 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -950,6 +950,37 @@ def before_instantiate_classes(self): TestLightningCLI(expected, BoringModel) +def test_optimizers_and_lr_schedulers_reload(tmpdir): + base = ["any.py", "--trainer.max_epochs=1"] + input = base + [ + "--lr_scheduler", + "OneCycleLR", + "--lr_scheduler.total_steps=10", + "--lr_scheduler.max_lr=1", + "--optimizer", + "Adam", + "--optimizer.lr=0.1", + ] + + # save config + out = StringIO() + with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): + LightningCLI(BoringModel) + + # validate yaml + yaml_config = out.getvalue() + dict_config = yaml.safe_load(yaml_config) + assert dict_config["optimizer"]["class_path"] == "torch.optim.adam.Adam" + assert dict_config["optimizer"]["init_args"]["lr"] == 0.1 + assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + + # reload config + yaml_config_file = tmpdir / "config.yaml" + yaml_config_file.write_text(yaml_config, "utf-8") + with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): + LightningCLI(BoringModel) + + def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir): class TestLightningCLI(LightningCLI): def __init__(self, *args): From 0ed4ae8624ef2576a4c838647219307b289d97b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 28 Aug 2021 05:11:26 +0000 Subject: [PATCH 46/77] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index c6dda74f34ac8..9487fb778d0be 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -497,7 +497,7 @@ def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any # need to save which arguments were added to pass them to the method later self._subcommand_method_arguments[subcommand] = added return parser - + @staticmethod def _contains_from_registry(pattern: str, registry: _Registry) -> bool: # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 From 4dd073277f1c5821eabf575328fc3ec8a48eb804 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 19 Aug 2021 21:20:35 +0200 Subject: [PATCH 47/77] Use property --- pytorch_lightning/utilities/cli.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 9487fb778d0be..3f7ca69791044 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -513,9 +513,8 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: parser.add_optimizer_args(self.registered_optimizers) if self._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): - lr_schdulers = tuple(v for v in LR_SCHEDULER_REGISTRY.values()) - if "lr_scheduler" not in parser.groups: - parser.add_lr_scheduler_args(lr_schdulers) + if "lr_scheduler" not in self.parser.groups: + self.parser.add_lr_scheduler_args(self.registered_lr_schedulers) for key, (class_type, link_to) in parser.optimizers_and_lr_schedulers.items(): if link_to == "AUTOMATIC": From e0fae4f747a31a3bd385133920ba12f12745e960 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 28 Aug 2021 07:22:07 +0200 Subject: [PATCH 48/77] Fixes after merge --- pytorch_lightning/utilities/cli.py | 33 +++++++++++++----------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 3f7ca69791044..aa18ac8bb6ff9 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import sys from argparse import Namespace from dataclasses import dataclass, field -from types import MethodType +from types import MethodType, ModuleType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from unittest import mock @@ -47,7 +48,7 @@ def __call__( cls: Type, key: Optional[str] = None, override: bool = False, - ) -> "Optional[Type]": + ) -> None: """ Registers a class mapped to a name. @@ -62,9 +63,7 @@ def __call__( if key in self and not override: raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") - self[key] = cls - return cls def register_package(self, module: ModuleType, base_cls: Type) -> None: """This function is an utility to register all classes from a module.""" @@ -73,9 +72,13 @@ def register_package(self, module: ModuleType, base_cls: Type) -> None: self(cls=cls) def available_objects(self) -> List[str]: - """Returns a list of registered objects""" + """Returns the keys of the registered objects""" return list(self.keys()) + def registered_values(self) -> Tuple[Type, ...]: + """Returns the values of the registered objects""" + return tuple(self.values()) + def __str__(self) -> str: objects = ", ".join(self.keys()) return f"Registered objects: {objects}" @@ -378,14 +381,6 @@ def _setup_parser_kwargs( main_kwargs.update(kwargs) return main_kwargs, {} - @property - def registered_optimizers(self) -> Tuple[Type[Optimizer], ...]: - return tuple(OPTIMIZER_REGISTRY.values()) - - @property - def registered_lr_schedulers(self) -> Tuple[LRSchedulerType, ...]: - return tuple(LR_SCHEDULER_REGISTRY.values()) - def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" return LightningArgumentParser(**kwargs) @@ -506,15 +501,15 @@ def _contains_from_registry(pattern: str, registry: _Registry) -> bool: @staticmethod def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" - self._sanitize_argv(list(parser.optimizers_and_lr_schedulers)) + LightningCLI._sanitize_argv(list(parser.optimizers_and_lr_schedulers)) - if self._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): + if LightningCLI._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): if "optimizer" not in parser.groups: - parser.add_optimizer_args(self.registered_optimizers) + parser.add_optimizer_args(OPTIMIZER_REGISTRY.registered_values()) - if self._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): - if "lr_scheduler" not in self.parser.groups: - self.parser.add_lr_scheduler_args(self.registered_lr_schedulers) + if LightningCLI._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): + if "lr_scheduler" not in parser.groups: + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.registered_values()) for key, (class_type, link_to) in parser.optimizers_and_lr_schedulers.items(): if link_to == "AUTOMATIC": From a22fdb34eca5f0dcdc2671a9079be8daf89e3c86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Sep 2021 01:15:18 +0000 Subject: [PATCH 49/77] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/cli.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index de5b4b77d7e4a..d9fd5f22ac49e 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -49,8 +49,7 @@ def __call__( key: Optional[str] = None, override: bool = False, ) -> None: - """ - Registers a class mapped to a name. + """Registers a class mapped to a name. Args: cls: the class to be mapped. @@ -72,11 +71,11 @@ def register_package(self, module: ModuleType, base_cls: Type) -> None: self(cls=cls) def available_objects(self) -> List[str]: - """Returns the keys of the registered objects""" + """Returns the keys of the registered objects.""" return list(self.keys()) def registered_values(self) -> Tuple[Type, ...]: - """Returns the values of the registered objects""" + """Returns the values of the registered objects.""" return tuple(self.values()) def __str__(self) -> str: @@ -96,7 +95,7 @@ def __str__(self) -> str: @dataclass class _ClassInfo: - """This class is an helper to easily build the mocked command line""" + """This class is an helper to easily build the mocked command line.""" class_arg: str cls: Type @@ -104,7 +103,7 @@ class _ClassInfo: class_arg_idx: Optional[int] = None class _ClassConfig(TypedDict): - """Defines the config structure that ``jsonargparse`` uses for instantiation""" + """Defines the config structure that ``jsonargparse`` uses for instantiation.""" class_path: str init_args: Dict[str, str] From 160b3f621c36a9b67e687619113bf6eb6831ea18 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Sep 2021 03:16:18 +0200 Subject: [PATCH 50/77] Update jsonargparse version --- requirements/extra.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index dfffc6fce8428..b93e49eceba5f 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,6 +7,6 @@ torchtext>=0.7 onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.19.0 +jsonargparse[signatures]>=3.19.2 gcsfs>=2021.5.0 rich>=10.2.2 From f185c2db4e100567ca1e09d248c33ef3a51ad313 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Sep 2021 03:51:19 +0200 Subject: [PATCH 51/77] Use properties in registry --- pytorch_lightning/utilities/cli.py | 22 +++---- tests/utilities/test_cli.py | 95 +++++++----------------------- 2 files changed, 32 insertions(+), 85 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d9fd5f22ac49e..f7ea1c5a5d0fd 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -70,17 +70,18 @@ def register_package(self, module: ModuleType, base_cls: Type) -> None: if issubclass(cls, base_cls) and cls != base_cls: self(cls=cls) - def available_objects(self) -> List[str]: - """Returns the keys of the registered objects.""" + @property + def names(self) -> List[str]: + """Returns the registered names.""" return list(self.keys()) - def registered_values(self) -> Tuple[Type, ...]: - """Returns the values of the registered objects.""" + @property + def classes(self) -> Tuple[Type, ...]: + """Returns the registered classes.""" return tuple(self.values()) def __str__(self) -> str: - objects = ", ".join(self.keys()) - return f"Registered objects: {objects}" + return f"Registered objects: {self.names}" CALLBACK_REGISTRY = _Registry() @@ -202,7 +203,7 @@ def add_optimizer_args( assert issubclass(optimizer_class, Optimizer) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): - self.add_subclass_arguments(optimizer_class, nested_key, required=True, **kwargs) + self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) else: self.add_class_arguments(optimizer_class, nested_key, **kwargs) self.optimizers_and_lr_schedulers[nested_key] = (optimizer_class, link_to) @@ -226,7 +227,7 @@ def add_lr_scheduler_args( assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): - self.add_subclass_arguments(lr_scheduler_class, nested_key, required=True, **kwargs) + self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) else: self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self.optimizers_and_lr_schedulers[nested_key] = (lr_scheduler_class, link_to) @@ -498,11 +499,10 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: if LightningCLI._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): if "optimizer" not in parser.groups: - parser.add_optimizer_args(OPTIMIZER_REGISTRY.registered_values()) - + parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) if LightningCLI._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): if "lr_scheduler" not in parser.groups: - parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.registered_values()) + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes) for key, (class_type, link_to) in parser.optimizers_and_lr_schedulers.items(): if link_to == "AUTOMATIC": diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 3b3f248fa61f0..8b2864ef28157 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -29,7 +29,7 @@ from packaging import version from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ProgressBar +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE @@ -696,13 +696,13 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args( - self.registered_optimizers if use_registries else torch.optim.Adam, + OPTIMIZER_REGISTRY.classes if use_registries else torch.optim.Adam, nested_key="optim1", link_to="model.optim1", ) parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") parser.add_lr_scheduler_args( - self.registered_lr_schedulers if use_registries else torch.optim.lr_scheduler.ExponentialLR, + LR_SCHEDULER_REGISTRY.classes if use_registries else torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler", ) @@ -872,53 +872,16 @@ class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): def test_registries(tmpdir): - assert CALLBACK_REGISTRY.available_objects() == [ - "BackboneFinetuning", - "BaseFinetuning", - "BasePredictionWriter", - "EarlyStopping", - "GPUStatsMonitor", - "GradientAccumulationScheduler", - "LambdaCallback", - "LearningRateMonitor", - "ModelCheckpoint", - "ModelPruning", - "ProgressBar", - "ProgressBarBase", - "QuantizationAwareTraining", - "StochasticWeightAveraging", - "Timer", - "XLAStatsMonitor", - "CustomCallback", - ] + assert "EarlyStopping" in CALLBACK_REGISTRY.names + assert "CustomCallback" in CALLBACK_REGISTRY.names - assert OPTIMIZER_REGISTRY.available_objects() == [ - "ASGD", - "Adadelta", - "Adagrad", - "Adam", - "AdamW", - "Adamax", - "LBFGS", - "RMSprop", - "Rprop", - "SGD", - "SparseAdam", - "CustomAdam", - ] + assert "SGD" in OPTIMIZER_REGISTRY.names + assert "RMSprop" in OPTIMIZER_REGISTRY.names + assert "CustomAdam" in OPTIMIZER_REGISTRY.names - assert LR_SCHEDULER_REGISTRY.available_objects() == [ - "CosineAnnealingLR", - "CosineAnnealingWarmRestarts", - "CyclicLR", - "ExponentialLR", - "LambdaLR", - "MultiStepLR", - "MultiplicativeLR", - "OneCycleLR", - "StepLR", - "CustomCosineAnnealingLR", - ] + assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names + assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names + assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names @pytest.mark.parametrize("use_class_path_callbacks", [False, True]) @@ -956,19 +919,15 @@ def test_registries_resolution(use_class_path_callbacks): assert optimizers[0].param_groups[0]["lr"] == 0.0001 assert lr_scheduler[0].step_size == 50 - assert [type(c) for c in cli.trainer.callbacks] == [LearningRateMonitor] + extras + [ - SaveConfigCallback, - ProgressBar, - ModelCheckpoint, - ] + callback_types = [type(c) for c in cli.trainer.callbacks] + expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras + assert all(t in callback_types for t in expected) def test_argv_transformation_noop(): base = ["any.py", "--trainer.max_epochs=1"] argv = LightningCLI._prepare_from_registry(base, OPTIMIZER_REGISTRY) - assert argv == base argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - assert argv == base argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == base @@ -983,11 +942,7 @@ def test_argv_transformation_single_callback(): } ] expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningCLI._prepare_from_registry(input, OPTIMIZER_REGISTRY) - assert argv == input - argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - assert argv == input - argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningCLI._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == expected @@ -1010,11 +965,7 @@ def test_argv_transformation_multiple_callbacks(): }, ] expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningCLI._prepare_from_registry(input, OPTIMIZER_REGISTRY) - assert argv == input - argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - assert argv == input - argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningCLI._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == expected @@ -1039,11 +990,7 @@ def test_argv_transformation_multiple_callbacks_with_config(): {"class_path": "pytorch_lightning.callbacks.Callback"}, ] expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningCLI._prepare_from_registry(input, OPTIMIZER_REGISTRY) - assert argv == input - argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - assert argv == input - argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningCLI._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == expected @@ -1101,7 +1048,7 @@ def test_optimizers_and_lr_schedulers_reload(tmpdir): # save config out = StringIO() with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): - LightningCLI(BoringModel) + LightningCLI(BoringModel, run=False) # validate yaml yaml_config = out.getvalue() @@ -1114,7 +1061,7 @@ def test_optimizers_and_lr_schedulers_reload(tmpdir): yaml_config_file = tmpdir / "config.yaml" yaml_config_file.write_text(yaml_config, "utf-8") with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): - LightningCLI(BoringModel) + LightningCLI(BoringModel, run=False) def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir): @@ -1123,11 +1070,11 @@ def __init__(self, *args): super().__init__(*args, run=False) def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(self.registered_optimizers, nested_key="opt1", link_to="model.opt1_config") + parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config") parser.add_optimizer_args( (torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config" ) - parser.add_lr_scheduler_args(self.registered_lr_schedulers, link_to="model.sch_config") + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config") parser.add_argument("--something", type=str, nargs="+") class TestModel(BoringModel): From 803385c346f0e401c7e49302b424d3376cfbacbb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Sep 2021 04:24:31 +0200 Subject: [PATCH 52/77] Keep hacks together --- pytorch_lightning/utilities/cli.py | 244 +++++++++++++++-------------- tests/utilities/test_cli.py | 20 +-- 2 files changed, 140 insertions(+), 124 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index f7ea1c5a5d0fd..07874bf7a7841 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -197,6 +197,11 @@ def add_optimizer_args( nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ + # FIXME: this should be in the __init__ + # self._sanitize_argv(list(self.optimizers_and_lr_schedulers)) + # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 + # if not self._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): + # return if isinstance(optimizer_class, tuple): assert all(issubclass(o, Optimizer) for o in optimizer_class) else: @@ -221,6 +226,11 @@ def add_lr_scheduler_args( nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ + # FIXME: this should be in the __init__ + # self._sanitize_argv(list(self.optimizers_and_lr_schedulers)) + # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 + # if not self._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): + # return if isinstance(lr_scheduler_class, tuple): assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) else: @@ -232,6 +242,116 @@ def add_lr_scheduler_args( self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self.optimizers_and_lr_schedulers[nested_key] = (lr_scheduler_class, link_to) + def parse_args(self, *args, **kwargs) -> Union[Namespace, Dict[str, Any]]: + # hack before https://github.com/omni-us/jsonargparse/issues/84 + argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) + argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + # hack before https://github.com/omni-us/jsonargparse/issues/85 + argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + with mock.patch("sys.argv", argv): + return super().parse_args(*args, **kwargs) + + @staticmethod + def _prepare_from_registry(argv: List[str], registry: _Registry) -> List[str]: + # find if the users is using shortcut command line. + map_user_key_to_info = {} + for registered_name, registered_cls in registry.items(): + for v in sys.argv: + if "=" not in v: + continue + key, name = v.split("=") + if registered_name == name: + map_user_key_to_info[key] = _ClassInfo(class_arg=v, cls=registered_cls) + + if len(map_user_key_to_info) > 0: + # for each shortcut command line, add its init arguments and skip them from `sys.argv`. + out = [] + for v in argv: + skip = False + for key in map_user_key_to_info: + if key in v: + skip = True + map_user_key_to_info[key].add_class_init_arg(v) + if not skip: + out.append(v) + + # re-create the global command line and mock `sys.argv`. + out += [f"{user_key}={info.class_init}" for user_key, info in map_user_key_to_info.items()] + return out + return argv + + @staticmethod + def _prepare_class_list_from_registry(argv: List[str], pattern: str, registry: _Registry) -> List[str]: + out = [v for v in argv if pattern not in v] + all_matched_args = [v for v in argv if pattern in v] + all_simplified_args = [v for v in all_matched_args if f"{pattern}" in v and f"{pattern}=[" not in v] + all_cls_simplified_args = [v for v in all_simplified_args if f"{pattern}=" in v] + all_non_simplified_args = [v for v in all_matched_args if f"{pattern}=" in v and f"{pattern}=[" in v] + + num_simplified_cls = len(all_simplified_args) + should_replace = num_simplified_cls > 0 and not all("class_path" in v for v in all_matched_args) + + if should_replace: + # verify the user is properly ordering arguments. + assert all_cls_simplified_args[0] == all_simplified_args[0] + if len(all_non_simplified_args) > 1: + raise MisconfigurationException(f"When provided {pattern} as list, please group them under 1 argument.") + + # group arguments per callbacks + infos = [] + for class_arg_idx, class_arg in enumerate(all_simplified_args): + if class_arg in all_cls_simplified_args: + class_name = class_arg.split("=")[1] + registered_cls = registry[class_name] + infos.append(_ClassInfo(class_arg=class_arg, cls=registered_cls, class_arg_idx=class_arg_idx)) + + for idx, v in enumerate(all_simplified_args): + if v in all_cls_simplified_args: + current_info = [info for info in infos if idx == info.class_arg_idx][0] + current_info.add_class_init_arg(v) + + class_args = [info.class_init for info in infos] + # add other callback arguments. + if len(all_non_simplified_args) > 0: + class_args.extend(eval(all_non_simplified_args[0].split("=")[-1])) + + out += [f"{pattern}={class_args}"] + return out + return argv + + @staticmethod + def _sanitize_argv(optimizers_and_lr_schedulers: List[str]) -> None: + """This function is used to replace ```` in ``sys.argv`` with ``=``.""" + + def validate_arg(v: str) -> bool: + keys = {"--optimizer", "--lr_scheduler", "--trainer.callbacks"} + keys.update({f"--{key}" for key in optimizers_and_lr_schedulers}) + return any(v.startswith(k) for k in keys) + + args = [idx for idx, v in enumerate(sys.argv) if validate_arg(v)] + if not args: + return + start_index = args[0] + argv = [] + should_add = False + for v in sys.argv[start_index:]: + if validate_arg(v): + argv.append(v) + should_add = True + else: + if should_add and not v.startswith("--"): + argv[-1] += "=" + v + else: + argv.append(v) + should_add = False + + sys.argv = sys.argv[:start_index] + argv + + @staticmethod + def _contains_from_registry(pattern: str, registry: _Registry) -> bool: + # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 + return any(True for v in sys.argv for registered_name in registry if f"--{pattern}={registered_name}" in v) + class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -380,34 +500,6 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" return LightningArgumentParser(**kwargs) - @staticmethod - def _sanitize_argv(optimizers_and_lr_schedulers: List[str]) -> None: - """This function is used to replace ```` in ``sys.argv`` with ``=``.""" - - def validate_arg(v: str) -> bool: - keys = {"--optimizer", "--lr_scheduler", "--trainer.callbacks"} - keys.update({f"--{key}" for key in optimizers_and_lr_schedulers}) - return any(v.startswith(k) for k in keys) - - args = [idx for idx, v in enumerate(sys.argv) if validate_arg(v)] - if not args: - return - start_index = args[0] - argv = [] - should_add = False - for v in sys.argv[start_index:]: - if validate_arg(v): - argv.append(v) - should_add = True - else: - if should_add and not v.startswith("--"): - argv[-1] += "=" + v - else: - argv.append(v) - should_add = False - - sys.argv = sys.argv[:start_index] + argv - def setup_parser( self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any] ) -> None: @@ -442,6 +534,14 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None: self.add_default_arguments_to_parser(parser) self.add_core_arguments_to_parser(parser) self.add_arguments_to_parser(parser) + # add default optimizer args + LightningArgumentParser._sanitize_argv(list(parser.optimizers_and_lr_schedulers)) + if LightningArgumentParser._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): + if "optimizer" not in parser.groups: # already added by the user in `add_arguments_to_parser` + parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) + if LightningArgumentParser._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): + if "lr_scheduler" not in parser.groups: # already added by the user in `add_arguments_to_parser` + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes) self.link_optimizers_and_lr_schedulers(parser) def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: @@ -487,23 +587,9 @@ def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any self._subcommand_method_arguments[subcommand] = added return parser - @staticmethod - def _contains_from_registry(pattern: str, registry: _Registry) -> bool: - # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 - return any(True for v in sys.argv for registered_name in registry if f"--{pattern}={registered_name}" in v) - @staticmethod def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" - LightningCLI._sanitize_argv(list(parser.optimizers_and_lr_schedulers)) - - if LightningCLI._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): - if "optimizer" not in parser.groups: - parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) - if LightningCLI._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): - if "lr_scheduler" not in parser.groups: - parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes) - for key, (class_type, link_to) in parser.optimizers_and_lr_schedulers.items(): if link_to == "AUTOMATIC": continue @@ -513,81 +599,9 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: add_class_path = _add_class_path_generator(class_type) parser.link_arguments(key, link_to, compute_fn=add_class_path) - @staticmethod - def _prepare_from_registry(argv: List[str], registry: _Registry) -> List[str]: - # find if the users is using shortcut command line. - map_user_key_to_info = {} - for registered_name, registered_cls in registry.items(): - for v in sys.argv: - if "=" not in v: - continue - key, name = v.split("=") - if registered_name == name: - map_user_key_to_info[key] = _ClassInfo(class_arg=v, cls=registered_cls) - - if len(map_user_key_to_info) > 0: - # for each shortcut command line, add its init arguments and skip them from `sys.argv`. - out = [] - for v in argv: - skip = False - for key in map_user_key_to_info: - if key in v: - skip = True - map_user_key_to_info[key].add_class_init_arg(v) - if not skip: - out.append(v) - - # re-create the global command line and mock `sys.argv`. - out += [f"{user_key}={info.class_init}" for user_key, info in map_user_key_to_info.items()] - return out - return argv - - @staticmethod - def _prepare_class_list_from_registry(argv: List[str], pattern: str, registry: _Registry) -> List[str]: - out = [v for v in argv if pattern not in v] - all_matched_args = [v for v in argv if pattern in v] - all_simplified_args = [v for v in all_matched_args if f"{pattern}" in v and f"{pattern}=[" not in v] - all_cls_simplified_args = [v for v in all_simplified_args if f"{pattern}=" in v] - all_non_simplified_args = [v for v in all_matched_args if f"{pattern}=" in v and f"{pattern}=[" in v] - - num_simplified_cls = len(all_simplified_args) - should_replace = num_simplified_cls > 0 and not all("class_path" in v for v in all_matched_args) - - if should_replace: - # verify the user is properly ordering arguments. - assert all_cls_simplified_args[0] == all_simplified_args[0] - if len(all_non_simplified_args) > 1: - raise MisconfigurationException(f"When provided {pattern} as list, please group them under 1 argument.") - - # group arguments per callbacks - infos = [] - for class_arg_idx, class_arg in enumerate(all_simplified_args): - if class_arg in all_cls_simplified_args: - class_name = class_arg.split("=")[1] - registered_cls = registry[class_name] - infos.append(_ClassInfo(class_arg=class_arg, cls=registered_cls, class_arg_idx=class_arg_idx)) - - for idx, v in enumerate(all_simplified_args): - if v in all_cls_simplified_args: - current_info = [info for info in infos if idx == info.class_arg_idx][0] - current_info.add_class_init_arg(v) - - class_args = [info.class_init for info in infos] - # add other callback arguments. - if len(all_non_simplified_args) > 0: - class_args.extend(eval(all_non_simplified_args[0].split("=")[-1])) - - out += [f"{pattern}={class_args}"] - return out - return argv - def parse_arguments(self, parser: LightningArgumentParser) -> None: """Parses command line arguments and stores it in ``self.config``.""" - argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) - argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) - with mock.patch("sys.argv", argv): - self.config = parser.parse_args() + self.config = parser.parse_args() def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes.""" diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 8b2864ef28157..e1d210b0f1762 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -926,9 +926,9 @@ def test_registries_resolution(use_class_path_callbacks): def test_argv_transformation_noop(): base = ["any.py", "--trainer.max_epochs=1"] - argv = LightningCLI._prepare_from_registry(base, OPTIMIZER_REGISTRY) - argv = LightningCLI._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - argv = LightningCLI._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningArgumentParser._prepare_from_registry(base, OPTIMIZER_REGISTRY) + argv = LightningArgumentParser._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + argv = LightningArgumentParser._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == base @@ -942,7 +942,7 @@ def test_argv_transformation_single_callback(): } ] expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningCLI._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningArgumentParser._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == expected @@ -965,7 +965,7 @@ def test_argv_transformation_multiple_callbacks(): }, ] expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningCLI._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningArgumentParser._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == expected @@ -990,7 +990,7 @@ def test_argv_transformation_multiple_callbacks_with_config(): {"class_path": "pytorch_lightning.callbacks.Callback"}, ] expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningCLI._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningArgumentParser._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == expected @@ -1001,9 +1001,11 @@ def __init__(self, expected, *args): super().__init__(*args, run=False) def before_instantiate_classes(self): - argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) - argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningArgumentParser._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) + argv = LightningArgumentParser._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) + argv = LightningArgumentParser._prepare_class_list_from_registry( + argv, "--trainer.callbacks", CALLBACK_REGISTRY + ) assert argv == self.expected base = ["any.py", "--trainer.max_epochs=1"] From 8eb8b05bb71b888978f6638fc27390910b7beda4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Sep 2021 04:38:59 +0200 Subject: [PATCH 53/77] Add FIXMEs --- pytorch_lightning/utilities/cli.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 07874bf7a7841..142a9aa006ca0 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -197,11 +197,6 @@ def add_optimizer_args( nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ - # FIXME: this should be in the __init__ - # self._sanitize_argv(list(self.optimizers_and_lr_schedulers)) - # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 - # if not self._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): - # return if isinstance(optimizer_class, tuple): assert all(issubclass(o, Optimizer) for o in optimizer_class) else: @@ -226,11 +221,6 @@ def add_lr_scheduler_args( nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ - # FIXME: this should be in the __init__ - # self._sanitize_argv(list(self.optimizers_and_lr_schedulers)) - # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 - # if not self._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): - # return if isinstance(lr_scheduler_class, tuple): assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) else: @@ -350,7 +340,7 @@ def validate_arg(v: str) -> bool: @staticmethod def _contains_from_registry(pattern: str, registry: _Registry) -> bool: # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 - return any(True for v in sys.argv for registered_name in registry if f"--{pattern}={registered_name}" in v) + return any(f"--{pattern}={name}" == v for v in sys.argv for name in registry) class SaveConfigCallback(Callback): @@ -535,12 +525,16 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None: self.add_core_arguments_to_parser(parser) self.add_arguments_to_parser(parser) # add default optimizer args + # FIXME: this should be done before or after + # FIXME: this shouldn't take `optimizers_and_lr_schedulers` LightningArgumentParser._sanitize_argv(list(parser.optimizers_and_lr_schedulers)) - if LightningArgumentParser._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): - if "optimizer" not in parser.groups: # already added by the user in `add_arguments_to_parser` + if "optimizer" not in parser.groups: # already added by the user in `add_arguments_to_parser` + # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 + if LightningArgumentParser._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) - if LightningArgumentParser._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): - if "lr_scheduler" not in parser.groups: # already added by the user in `add_arguments_to_parser` + if "lr_scheduler" not in parser.groups: # already added by the user in `add_arguments_to_parser` + # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 + if LightningArgumentParser._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes) self.link_optimizers_and_lr_schedulers(parser) From 9d84127e445f1f267e4f9e6638879b10fbf1fa2b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Sep 2021 06:43:36 +0200 Subject: [PATCH 54/77] add_class_choices --- pytorch_lightning/utilities/cli.py | 98 ++++++++++++++++++++---------- tests/utilities/test_cli.py | 79 +++++++++++------------- 2 files changed, 99 insertions(+), 78 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 142a9aa006ca0..a56c620f053cf 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import ast import inspect import os import sys @@ -143,6 +144,7 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non ) self.callback_keys: List[str] = [] self.optimizers_and_lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + self._argv = sys.argv.copy() def add_lightning_class_args( self, @@ -203,7 +205,7 @@ def add_optimizer_args( assert issubclass(optimizer_class, Optimizer) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): - self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) + self.add_class_choices(optimizer_class, nested_key, **kwargs) else: self.add_class_arguments(optimizer_class, nested_key, **kwargs) self.optimizers_and_lr_schedulers[nested_key] = (optimizer_class, link_to) @@ -227,48 +229,78 @@ def add_lr_scheduler_args( assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): - self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) + self.add_class_choices(lr_scheduler_class, nested_key, **kwargs) else: self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self.optimizers_and_lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args, **kwargs) -> Union[Namespace, Dict[str, Any]]: - # hack before https://github.com/omni-us/jsonargparse/issues/84 - argv = self._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) - argv = self._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) # hack before https://github.com/omni-us/jsonargparse/issues/85 - argv = self._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = self._prepare_class_list_from_registry(self._argv, "--trainer.callbacks", CALLBACK_REGISTRY) with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) + def add_class_choices(self, classes: Tuple[Type, ...], nested_key: str, *args: Any, **kwargs: Any) -> None: + """Replacement for https://github.com/omni-us/jsonargparse/issues/84. + + This should be removed once implemented. + """ + if self._probably_defined_in_config(nested_key, self._argv): + # parsing config files would be too difficult, fall back to what's available + self.add_subclass_arguments(classes, nested_key, *args, **kwargs) + else: + clean_argv, config = self._convert_argv_to_config(classes, nested_key, self._argv) + self.add_subclass_arguments(classes, nested_key, *args, **kwargs) + self.set_defaults({nested_key: config}) + self._argv = clean_argv + @staticmethod - def _prepare_from_registry(argv: List[str], registry: _Registry) -> List[str]: - # find if the users is using shortcut command line. - map_user_key_to_info = {} - for registered_name, registered_cls in registry.items(): - for v in sys.argv: - if "=" not in v: - continue - key, name = v.split("=") - if registered_name == name: - map_user_key_to_info[key] = _ClassInfo(class_arg=v, cls=registered_cls) - - if len(map_user_key_to_info) > 0: - # for each shortcut command line, add its init arguments and skip them from `sys.argv`. - out = [] - for v in argv: - skip = False - for key in map_user_key_to_info: - if key in v: - skip = True - map_user_key_to_info[key].add_class_init_arg(v) - if not skip: - out.append(v) - - # re-create the global command line and mock `sys.argv`. - out += [f"{user_key}={info.class_init}" for user_key, info in map_user_key_to_info.items()] - return out - return argv + def _probably_defined_in_config(nested_key: str, argv: List[str]) -> bool: + key_in_argv = any(arg.startswith(f"--{nested_key}") for arg in argv) + has_config = any(arg.startswith("--config") for arg in argv) + return not key_in_argv and has_config + + @staticmethod + def _convert_argv_to_config(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> Tuple[List[str], Dict]: + passed_args = {} + clean_argv = [] + argv_key = f"--{nested_key}" + # get the argv args for this nested key + i = 0 + while i < len(argv): + arg = argv[i] + if arg.startswith(argv_key): + if "=" in arg: + key, value = arg.split("=") + else: + key = arg + i += 1 + value = argv[i] + passed_args[key] = value + else: + clean_argv.append(arg) + i += 1 + # generate the associated config file + argv_class = passed_args.pop(argv_key, None) + if argv_class is None: + # the user passed a config as a str + class_path = passed_args[f"{argv_key}.class_path"] + init_args = passed_args.get(f"{argv_key}.init_args", {}) + config = {"class_path": class_path, "init_args": init_args} + elif argv_class.startswith("{"): + # the user passed a config as a dict + config = ast.literal_eval(argv_class) + assert isinstance(config, dict) + else: + # the user passed the short format + init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period + for cls in classes: + if cls.__name__ == argv_class: + config = _global_add_class_path(cls, init_args) + break + else: + raise ValueError(f"Could not generate a config for {repr(argv_class)}") + return clean_argv, config @staticmethod def _prepare_class_list_from_registry(argv: List[str], pattern: str, registry: _Registry) -> List[str]: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e1d210b0f1762..9f9dd20899185 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -142,9 +142,9 @@ def _raise(): def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" cli_args = cli_args.split(" ") if cli_args else [] - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) - parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) args = parser.parse_args() for k, v in expected.items(): @@ -163,9 +163,9 @@ def test_parse_args_parsing(cli_args, expected): ) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) - parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) args = parser.parse_args() for k, v in expected.items(): @@ -179,9 +179,9 @@ def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" monkeypatch.setattr("torch.cuda.device_count", lambda: 2) cli_args = cli_args.split(" ") if cli_args else [] - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) - parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) args = parser.parse_args() trainer = Trainer.from_argparse_args(args) @@ -926,9 +926,7 @@ def test_registries_resolution(use_class_path_callbacks): def test_argv_transformation_noop(): base = ["any.py", "--trainer.max_epochs=1"] - argv = LightningArgumentParser._prepare_from_registry(base, OPTIMIZER_REGISTRY) - argv = LightningArgumentParser._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - argv = LightningArgumentParser._prepare_class_list_from_registry(argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningArgumentParser._prepare_class_list_from_registry(base, "--trainer.callbacks", CALLBACK_REGISTRY) assert argv == base @@ -995,44 +993,35 @@ def test_argv_transformation_multiple_callbacks_with_config(): def test_argv_transformations_with_optimizers_and_lr_schedulers(): - class TestLightningCLI(LightningCLI): - def __init__(self, expected, *args): - self.expected = expected - super().__init__(*args, run=False) - - def before_instantiate_classes(self): - argv = LightningArgumentParser._prepare_from_registry(sys.argv, OPTIMIZER_REGISTRY) - argv = LightningArgumentParser._prepare_from_registry(argv, LR_SCHEDULER_REGISTRY) - argv = LightningArgumentParser._prepare_class_list_from_registry( - argv, "--trainer.callbacks", CALLBACK_REGISTRY - ) - assert argv == self.expected - base = ["any.py", "--trainer.max_epochs=1"] - input = base + ["--optimizer", "Adadelta"] - optimizer = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}} - expected = base + [f"--optimizer={optimizer}"] - with mock.patch("sys.argv", input): - TestLightningCLI(expected, BoringModel) - - input = base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"] - optimizer = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}} - expected = base + [f"--optimizer={optimizer}"] - with mock.patch("sys.argv", input): - TestLightningCLI(expected, BoringModel) - - input = base + ["--lr_scheduler", "OneCycleLR"] - lr_scheduler = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}} - expected = base + [f"--lr_scheduler={lr_scheduler}"] - with mock.patch("sys.argv", input): - TestLightningCLI(expected, BoringModel) - - input = base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"] - lr_scheduler = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}} - expected = base + [f"--lr_scheduler={lr_scheduler}"] - with mock.patch("sys.argv", input): - TestLightningCLI(expected, BoringModel) + argv = base + ["--optimizer", "Adadelta"] + expected = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}} + new_argv, actual = LightningArgumentParser._convert_argv_to_config(OPTIMIZER_REGISTRY.classes, "optimizer", argv) + assert new_argv == base + assert actual == expected + + argv = base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"] + expected = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}} + base, actual = LightningArgumentParser._convert_argv_to_config(OPTIMIZER_REGISTRY.classes, "optimizer", argv) + assert new_argv == base + assert actual == expected + + argv = base + ["--lr_scheduler", "OneCycleLR"] + expected = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}} + new_argv, actual = LightningArgumentParser._convert_argv_to_config( + LR_SCHEDULER_REGISTRY.classes, "lr_scheduler", argv + ) + assert new_argv == base + assert actual == expected + + argv = base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"] + expected = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}} + new_argv, actual = LightningArgumentParser._convert_argv_to_config( + LR_SCHEDULER_REGISTRY.classes, "lr_scheduler", argv + ) + assert new_argv == base + assert actual == expected def test_optimizers_and_lr_schedulers_reload(tmpdir): From cf82e1a6c4bdcd15e5802446b68d09cd3ff8ce01 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 01:45:06 +0200 Subject: [PATCH 55/77] Remove contains registry. Avoid nested_key clash for optimizers and lr schedulers --- pytorch_lightning/utilities/cli.py | 92 ++++++++++++++++-------------- tests/utilities/test_cli.py | 2 +- 2 files changed, 51 insertions(+), 43 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index a56c620f053cf..fcd2a0e6b0571 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -143,8 +143,10 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." ) self.callback_keys: List[str] = [] - self.optimizers_and_lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - self._argv = sys.argv.copy() + self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + # we need a mutable global argv copy in order to support `add_class_choices` + sys._pl_argv = sys.argv.copy() def add_lightning_class_args( self, @@ -208,7 +210,7 @@ def add_optimizer_args( self.add_class_choices(optimizer_class, nested_key, **kwargs) else: self.add_class_arguments(optimizer_class, nested_key, **kwargs) - self.optimizers_and_lr_schedulers[nested_key] = (optimizer_class, link_to) + self._optimizers[nested_key] = (optimizer_class, link_to) def add_lr_scheduler_args( self, @@ -232,33 +234,40 @@ def add_lr_scheduler_args( self.add_class_choices(lr_scheduler_class, nested_key, **kwargs) else: self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) - self.optimizers_and_lr_schedulers[nested_key] = (lr_scheduler_class, link_to) + self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args, **kwargs) -> Union[Namespace, Dict[str, Any]]: # hack before https://github.com/omni-us/jsonargparse/issues/85 - argv = self._prepare_class_list_from_registry(self._argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = self._prepare_class_list_from_registry(sys._pl_argv, "--trainer.callbacks", CALLBACK_REGISTRY) with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) - def add_class_choices(self, classes: Tuple[Type, ...], nested_key: str, *args: Any, **kwargs: Any) -> None: - """Replacement for https://github.com/omni-us/jsonargparse/issues/84. + def add_class_choices( + self, classes: Tuple[Type, ...], nested_key: str, *args: Any, required: bool = False, **kwargs: Any + ) -> None: + """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. This should be removed once implemented. """ - if self._probably_defined_in_config(nested_key, self._argv): - # parsing config files would be too difficult, fall back to what's available - self.add_subclass_arguments(classes, nested_key, *args, **kwargs) + if not any(arg.startswith(f"--{nested_key}") for arg in sys._pl_argv): # the key was passed + if any(arg.startswith("--config") for arg in sys._pl_argv): # a config was passed + # parsing config files would be too difficult, fall back to what's available + self.add_subclass_arguments(classes, nested_key, *args, **kwargs) + elif required: + raise MisconfigurationException(f"The {nested_key} is required but wasn't passed") else: - clean_argv, config = self._convert_argv_to_config(classes, nested_key, self._argv) + clean_argv, config = self._convert_argv_to_config(classes, nested_key, sys._pl_argv) self.add_subclass_arguments(classes, nested_key, *args, **kwargs) self.set_defaults({nested_key: config}) - self._argv = clean_argv + sys._pl_argv = clean_argv @staticmethod - def _probably_defined_in_config(nested_key: str, argv: List[str]) -> bool: - key_in_argv = any(arg.startswith(f"--{nested_key}") for arg in argv) - has_config = any(arg.startswith("--config") for arg in argv) - return not key_in_argv and has_config + def _try_eval(val: str) -> Any: + try: + val = ast.literal_eval(val) + except ValueError: + pass + return val @staticmethod def _convert_argv_to_config(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> Tuple[List[str], Dict]: @@ -285,7 +294,8 @@ def _convert_argv_to_config(classes: Tuple[Type, ...], nested_key: str, argv: Li if argv_class is None: # the user passed a config as a str class_path = passed_args[f"{argv_key}.class_path"] - init_args = passed_args.get(f"{argv_key}.init_args", {}) + init_args_key = f"{argv_key}.init_args" + init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)} config = {"class_path": class_path, "init_args": init_args} elif argv_class.startswith("{"): # the user passed a config as a dict @@ -300,6 +310,8 @@ def _convert_argv_to_config(classes: Tuple[Type, ...], nested_key: str, argv: Li break else: raise ValueError(f"Could not generate a config for {repr(argv_class)}") + # need to convert from str to the appropriate type + config["init_args"] = {k: LightningArgumentParser._try_eval(v) for k, v in config["init_args"].items()} return clean_argv, config @staticmethod @@ -369,11 +381,6 @@ def validate_arg(v: str) -> bool: sys.argv = sys.argv[:start_index] + argv - @staticmethod - def _contains_from_registry(pattern: str, registry: _Registry) -> bool: - # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 - return any(f"--{pattern}={name}" == v for v in sys.argv for name in registry) - class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -559,15 +566,11 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None: # add default optimizer args # FIXME: this should be done before or after # FIXME: this shouldn't take `optimizers_and_lr_schedulers` - LightningArgumentParser._sanitize_argv(list(parser.optimizers_and_lr_schedulers)) - if "optimizer" not in parser.groups: # already added by the user in `add_arguments_to_parser` - # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 - if LightningArgumentParser._contains_from_registry("optimizer", OPTIMIZER_REGISTRY): - parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) - if "lr_scheduler" not in parser.groups: # already added by the user in `add_arguments_to_parser` - # FIXME: remove me after https://github.com/omni-us/jsonargparse/issues/83 - if LightningArgumentParser._contains_from_registry("lr_scheduler", LR_SCHEDULER_REGISTRY): - parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes) + LightningArgumentParser._sanitize_argv(list(parser._optimizers) + list(parser._lr_schedulers)) + if not parser._optimizers: # already added by the user in `add_arguments_to_parser` + parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) + if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes) self.link_optimizers_and_lr_schedulers(parser) def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: @@ -616,7 +619,8 @@ def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any @staticmethod def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" - for key, (class_type, link_to) in parser.optimizers_and_lr_schedulers.items(): + optimizers_and_lr_schedulers = {**parser._optimizers, **parser._lr_schedulers} + for key, (class_type, link_to) in optimizers_and_lr_schedulers.items(): if link_to == "AUTOMATIC": continue if isinstance(class_type, tuple): @@ -665,7 +669,7 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback] config["callbacks"].append(config_callback) return self.trainer_class(**config) - def _parser(self, subcommand: Optional[str]) -> ArgumentParser: + def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: if subcommand is None: return self.parser # return the subcommand parser for the subcommand passed @@ -680,19 +684,20 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) - `configure_optimizers` method is automatically implemented in the model class. """ parser = self._parser(subcommand) - optimizers_and_lr_schedulers = parser.optimizers_and_lr_schedulers - def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: + def get_automatic( + class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] + ) -> List[str]: automatic = [] - for key, (base_class, link_to) in optimizers_and_lr_schedulers.items(): + for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): base_class = (base_class,) if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class): automatic.append(key) return automatic - optimizers = get_automatic(Optimizer) - lr_schedulers = get_automatic(LRSchedulerTypeTuple) + optimizers = get_automatic(Optimizer, parser._optimizers) + lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) if len(optimizers) == 0: return @@ -712,14 +717,17 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." ) - optimizer_class = optimizers_and_lr_schedulers[optimizers[0]][0] - optimizer_init = self._get(self.config_init, optimizers[0], default={}) + optimizer_class = parser._optimizers[optimizers[0]][0] + optimizer_init = self._get(self.config_init, optimizers[0]) if not isinstance(optimizer_class, tuple): optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) + if not optimizer_init: + # optimizers were registered automatically but not passed by the user + return lr_scheduler_init = None if lr_schedulers: - lr_scheduler_class = optimizers_and_lr_schedulers[lr_schedulers[0]][0] - lr_scheduler_init = self._get(self.config_init, lr_schedulers[0], default={}) + lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0] + lr_scheduler_init = self._get(self.config_init, lr_schedulers[0]) if not isinstance(lr_scheduler_class, tuple): lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 9f9dd20899185..7d5632626f79d 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -1002,7 +1002,7 @@ def test_argv_transformations_with_optimizers_and_lr_schedulers(): assert actual == expected argv = base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"] - expected = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}} + expected = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": 10}} base, actual = LightningArgumentParser._convert_argv_to_config(OPTIMIZER_REGISTRY.classes, "optimizer", argv) assert new_argv == base assert actual == expected From b1cd083efb4fb2eadfebf736921fc5c67bac59df Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 01:48:40 +0200 Subject: [PATCH 56/77] Remove sanitize argv --- pytorch_lightning/utilities/cli.py | 33 +----------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index fcd2a0e6b0571..18cad957e6d3b 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -353,34 +353,6 @@ def _prepare_class_list_from_registry(argv: List[str], pattern: str, registry: _ return out return argv - @staticmethod - def _sanitize_argv(optimizers_and_lr_schedulers: List[str]) -> None: - """This function is used to replace ```` in ``sys.argv`` with ``=``.""" - - def validate_arg(v: str) -> bool: - keys = {"--optimizer", "--lr_scheduler", "--trainer.callbacks"} - keys.update({f"--{key}" for key in optimizers_and_lr_schedulers}) - return any(v.startswith(k) for k in keys) - - args = [idx for idx, v in enumerate(sys.argv) if validate_arg(v)] - if not args: - return - start_index = args[0] - argv = [] - should_add = False - for v in sys.argv[start_index:]: - if validate_arg(v): - argv.append(v) - should_add = True - else: - if should_add and not v.startswith("--"): - argv[-1] += "=" + v - else: - argv.append(v) - should_add = False - - sys.argv = sys.argv[:start_index] + argv - class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -563,10 +535,7 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None: self.add_default_arguments_to_parser(parser) self.add_core_arguments_to_parser(parser) self.add_arguments_to_parser(parser) - # add default optimizer args - # FIXME: this should be done before or after - # FIXME: this shouldn't take `optimizers_and_lr_schedulers` - LightningArgumentParser._sanitize_argv(list(parser._optimizers) + list(parser._lr_schedulers)) + # add default optimizer args if necessary if not parser._optimizers: # already added by the user in `add_arguments_to_parser` parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` From 95d31a75815b3cea957df1aabc4c34e07c84e32f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 04:38:01 +0200 Subject: [PATCH 57/77] Better support for new callback format --- pytorch_lightning/utilities/cli.py | 145 ++++++++++++++--------------- tests/utilities/test_cli.py | 34 +++---- 2 files changed, 86 insertions(+), 93 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 18cad957e6d3b..6e67ae3c43a3c 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -16,14 +16,13 @@ import os import sys from argparse import Namespace -from dataclasses import dataclass, field from types import MethodType, ModuleType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from unittest import mock import torch +import yaml from torch.optim import Optimizer -from typing_extensions import TypedDict import pytorch_lightning as pl from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer @@ -95,35 +94,6 @@ def __str__(self) -> str: LR_SCHEDULER_REGISTRY.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) -@dataclass -class _ClassInfo: - """This class is an helper to easily build the mocked command line.""" - - class_arg: str - cls: Type - class_init_args: List[str] = field(default_factory=lambda: []) - class_arg_idx: Optional[int] = None - - class _ClassConfig(TypedDict): - """Defines the config structure that ``jsonargparse`` uses for instantiation.""" - - class_path: str - init_args: Dict[str, str] - - def add_class_init_arg(self, arg: str) -> None: - if arg != self.class_arg: - self.class_init_args.append(arg) - - @property - def class_init(self) -> _ClassConfig: - init_args = {} - for init_arg in self.class_init_args: - arg_path, value = init_arg.split("=") - key = arg_path.split(".")[-1] - init_args[key] = value - return self._ClassConfig(class_path=self.cls.__module__ + "." + self.cls.__name__, init_args=init_args) - - class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" @@ -237,8 +207,7 @@ def add_lr_scheduler_args( self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args, **kwargs) -> Union[Namespace, Dict[str, Any]]: - # hack before https://github.com/omni-us/jsonargparse/issues/85 - argv = self._prepare_class_list_from_registry(sys._pl_argv, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = self._convert_argv_issue_85("trainer.callbacks", sys._pl_argv, CALLBACK_REGISTRY) with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) @@ -254,9 +223,9 @@ def add_class_choices( # parsing config files would be too difficult, fall back to what's available self.add_subclass_arguments(classes, nested_key, *args, **kwargs) elif required: - raise MisconfigurationException(f"The {nested_key} is required but wasn't passed") + raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed") else: - clean_argv, config = self._convert_argv_to_config(classes, nested_key, sys._pl_argv) + clean_argv, config = self._convert_argv_issue_84(classes, nested_key, sys._pl_argv) self.add_subclass_arguments(classes, nested_key, *args, **kwargs) self.set_defaults({nested_key: config}) sys._pl_argv = clean_argv @@ -270,9 +239,8 @@ def _try_eval(val: str) -> Any: return val @staticmethod - def _convert_argv_to_config(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> Tuple[List[str], Dict]: - passed_args = {} - clean_argv = [] + def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> Tuple[List[str], Dict]: + passed_args, clean_argv = {}, [] argv_key = f"--{nested_key}" # get the argv args for this nested key i = 0 @@ -315,43 +283,66 @@ def _convert_argv_to_config(classes: Tuple[Type, ...], nested_key: str, argv: Li return clean_argv, config @staticmethod - def _prepare_class_list_from_registry(argv: List[str], pattern: str, registry: _Registry) -> List[str]: - out = [v for v in argv if pattern not in v] - all_matched_args = [v for v in argv if pattern in v] - all_simplified_args = [v for v in all_matched_args if f"{pattern}" in v and f"{pattern}=[" not in v] - all_cls_simplified_args = [v for v in all_simplified_args if f"{pattern}=" in v] - all_non_simplified_args = [v for v in all_matched_args if f"{pattern}=" in v and f"{pattern}=[" in v] - - num_simplified_cls = len(all_simplified_args) - should_replace = num_simplified_cls > 0 and not all("class_path" in v for v in all_matched_args) - - if should_replace: - # verify the user is properly ordering arguments. - assert all_cls_simplified_args[0] == all_simplified_args[0] - if len(all_non_simplified_args) > 1: - raise MisconfigurationException(f"When provided {pattern} as list, please group them under 1 argument.") - - # group arguments per callbacks - infos = [] - for class_arg_idx, class_arg in enumerate(all_simplified_args): - if class_arg in all_cls_simplified_args: - class_name = class_arg.split("=")[1] - registered_cls = registry[class_name] - infos.append(_ClassInfo(class_arg=class_arg, cls=registered_cls, class_arg_idx=class_arg_idx)) - - for idx, v in enumerate(all_simplified_args): - if v in all_cls_simplified_args: - current_info = [info for info in infos if idx == info.class_arg_idx][0] - current_info.add_class_init_arg(v) - - class_args = [info.class_init for info in infos] - # add other callback arguments. - if len(all_non_simplified_args) > 0: - class_args.extend(eval(all_non_simplified_args[0].split("=")[-1])) - - out += [f"{pattern}={class_args}"] - return out - return argv + def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry) -> List[str]: + """Placeholder for https://github.com/omni-us/jsonargparse/issues/85. + + This should be removed once implemented. + """ + passed_args, clean_argv = [], [] + passed_configs = {} + argv_key = f"--{nested_key}" + # get the argv args for this nested key + i = 0 + while i < len(argv): + arg = argv[i] + if arg.startswith(argv_key): + if "=" in arg: + key, value = arg.split("=") + else: + key = arg + i += 1 + value = argv[i] + key = key[2:] # remove dashes + if "class_path" in value: + # the user passed a config as a dict + config = yaml.safe_load(value) + assert all(isinstance(cfg, dict) for cfg in config) + passed_configs[key] = config + else: + passed_args.append((key, value)) + else: + clean_argv.append(arg) + i += 1 + # generate the associated config file + out = [] + i, n = 0, len(passed_args) + while i < n - 1: + ki, vi = passed_args[i] + # convert class name to class path + cls_type = registry.get(vi) + if cls_type is None: + raise ValueError( + f"Passed the class `--{nested_key}={ki}` but it's not registered in the registry." + f"The available classes are: {registry.names}" + ) + config = _global_add_class_path(cls_type) + out.append(config) + # get any init args + j = i + 1 # in case the j-loop doesn't run + for j in range(i + 1, n): + kj, vj = passed_args[j] + if ki == kj: + break + if kj.startswith(ki): + init_arg_name = kj.split(".")[-1] + out[-1]["init_args"][init_arg_name] = vj + i = j + # update at the end to preserve the order + for k, v in passed_configs.items(): + out.extend(v) + if not out: + return clean_argv + return clean_argv + [argv_key, str(out)] class SaveConfigCallback(Callback): @@ -743,8 +734,8 @@ def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: return fn_kwargs -def _global_add_class_path(class_type: Type, init_args: Dict[str, Any]) -> Dict[str, Any]: - return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args} +def _global_add_class_path(class_type: Type, init_args: Dict[str, Any] = None) -> Dict[str, Any]: + return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}} def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 7d5632626f79d..b0b09534533f3 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -926,7 +926,7 @@ def test_registries_resolution(use_class_path_callbacks): def test_argv_transformation_noop(): base = ["any.py", "--trainer.max_epochs=1"] - argv = LightningArgumentParser._prepare_class_list_from_registry(base, "--trainer.callbacks", CALLBACK_REGISTRY) + argv = LightningArgumentParser._convert_argv_issue_85("trainer.callbacks", base, CALLBACK_REGISTRY) assert argv == base @@ -939,8 +939,8 @@ def test_argv_transformation_single_callback(): "init_args": {"monitor": "val_loss"}, } ] - expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningArgumentParser._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) + expected = base + ["--trainer.callbacks", str(callbacks)] + argv = LightningArgumentParser._convert_argv_issue_85("trainer.callbacks", input, CALLBACK_REGISTRY) assert argv == expected @@ -962,19 +962,20 @@ def test_argv_transformation_multiple_callbacks(): "init_args": {"monitor": "val_acc"}, }, ] - expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningArgumentParser._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) + expected = base + ["--trainer.callbacks", str(callbacks)] + argv = LightningArgumentParser._convert_argv_issue_85("trainer.callbacks", input, CALLBACK_REGISTRY) assert argv == expected def test_argv_transformation_multiple_callbacks_with_config(): base = ["any.py", "--trainer.max_epochs=1"] + nested_key = "trainer.callbacks" input = base + [ - "--trainer.callbacks=ModelCheckpoint", - "--trainer.callbacks.monitor=val_loss", - "--trainer.callbacks=ModelCheckpoint", - "--trainer.callbacks.monitor=val_acc", - "--trainer.callbacks=[{'class_path': 'pytorch_lightning.callbacks.Callback'}]", + f"--{nested_key}=ModelCheckpoint", + f"--{nested_key}.monitor=val_loss", + f"--{nested_key}=ModelCheckpoint", + f"--{nested_key}.monitor=val_acc", + f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]", ] callbacks = [ { @@ -987,8 +988,9 @@ def test_argv_transformation_multiple_callbacks_with_config(): }, {"class_path": "pytorch_lightning.callbacks.Callback"}, ] - expected = base + [f"--trainer.callbacks={str(callbacks)}"] - argv = LightningArgumentParser._prepare_class_list_from_registry(input, "--trainer.callbacks", CALLBACK_REGISTRY) + expected = base + ["--trainer.callbacks", str(callbacks)] + nested_key = "trainer.callbacks" + argv = LightningArgumentParser._convert_argv_issue_85(nested_key, input, CALLBACK_REGISTRY) assert argv == expected @@ -997,19 +999,19 @@ def test_argv_transformations_with_optimizers_and_lr_schedulers(): argv = base + ["--optimizer", "Adadelta"] expected = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}} - new_argv, actual = LightningArgumentParser._convert_argv_to_config(OPTIMIZER_REGISTRY.classes, "optimizer", argv) + new_argv, actual = LightningArgumentParser._convert_argv_issue_84(OPTIMIZER_REGISTRY.classes, "optimizer", argv) assert new_argv == base assert actual == expected argv = base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"] expected = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": 10}} - base, actual = LightningArgumentParser._convert_argv_to_config(OPTIMIZER_REGISTRY.classes, "optimizer", argv) + base, actual = LightningArgumentParser._convert_argv_issue_84(OPTIMIZER_REGISTRY.classes, "optimizer", argv) assert new_argv == base assert actual == expected argv = base + ["--lr_scheduler", "OneCycleLR"] expected = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}} - new_argv, actual = LightningArgumentParser._convert_argv_to_config( + new_argv, actual = LightningArgumentParser._convert_argv_issue_84( LR_SCHEDULER_REGISTRY.classes, "lr_scheduler", argv ) assert new_argv == base @@ -1017,7 +1019,7 @@ def test_argv_transformations_with_optimizers_and_lr_schedulers(): argv = base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"] expected = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}} - new_argv, actual = LightningArgumentParser._convert_argv_to_config( + new_argv, actual = LightningArgumentParser._convert_argv_issue_84( LR_SCHEDULER_REGISTRY.classes, "lr_scheduler", argv ) assert new_argv == base From 231e0ed9e9f07de1703376d33be2023f7080f62a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 04:50:16 +0200 Subject: [PATCH 58/77] Avoid evaluating --- pytorch_lightning/utilities/cli.py | 32 +++----------- tests/utilities/test_cli.py | 69 +++++++++++++++--------------- 2 files changed, 40 insertions(+), 61 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6e67ae3c43a3c..7c40180a61231 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import ast import inspect import os import sys @@ -43,12 +42,7 @@ class _Registry(dict): - def __call__( - self, - cls: Type, - key: Optional[str] = None, - override: bool = False, - ) -> None: + def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> None: """Registers a class mapped to a name. Args: @@ -225,21 +219,12 @@ def add_class_choices( elif required: raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed") else: - clean_argv, config = self._convert_argv_issue_84(classes, nested_key, sys._pl_argv) + clean_argv = self._convert_argv_issue_84(classes, nested_key, sys._pl_argv) self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - self.set_defaults({nested_key: config}) sys._pl_argv = clean_argv @staticmethod - def _try_eval(val: str) -> Any: - try: - val = ast.literal_eval(val) - except ValueError: - pass - return val - - @staticmethod - def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> Tuple[List[str], Dict]: + def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: passed_args, clean_argv = {}, [] argv_key = f"--{nested_key}" # get the argv args for this nested key @@ -267,8 +252,7 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis config = {"class_path": class_path, "init_args": init_args} elif argv_class.startswith("{"): # the user passed a config as a dict - config = ast.literal_eval(argv_class) - assert isinstance(config, dict) + config = argv_class else: # the user passed the short format init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period @@ -278,9 +262,7 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis break else: raise ValueError(f"Could not generate a config for {repr(argv_class)}") - # need to convert from str to the appropriate type - config["init_args"] = {k: LightningArgumentParser._try_eval(v) for k, v in config["init_args"].items()} - return clean_argv, config + return clean_argv + [argv_key, str(config)] @staticmethod def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry) -> List[str]: @@ -305,9 +287,7 @@ def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry key = key[2:] # remove dashes if "class_path" in value: # the user passed a config as a dict - config = yaml.safe_load(value) - assert all(isinstance(cfg, dict) for cfg in config) - passed_configs[key] = config + passed_configs[key] = yaml.safe_load(value) else: passed_args.append((key, value)) else: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index b0b09534533f3..d551367a04883 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -647,12 +647,7 @@ def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) - cli_args = [ - "fit", - f"--trainer.default_root_dir={tmpdir}", - "--trainer.fast_dev_run=1", - "--lr_scheduler.gamma=0.8", - ] + cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", "--lr_scheduler.gamma=0.8"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) @@ -994,36 +989,40 @@ def test_argv_transformation_multiple_callbacks_with_config(): assert argv == expected -def test_argv_transformations_with_optimizers_and_lr_schedulers(): +@pytest.mark.parametrize( + ["args", "expected", "nested_key", "registry"], + [ + ( + ["--optimizer", "Adadelta"], + {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}}, + "optimizer", + OPTIMIZER_REGISTRY, + ), + ( + ["--optimizer", "Adadelta", "--optimizer.lr", "10"], + {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}}, + "optimizer", + OPTIMIZER_REGISTRY, + ), + ( + ["--lr_scheduler", "OneCycleLR"], + {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}}, + "lr_scheduler", + LR_SCHEDULER_REGISTRY, + ), + ( + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"], + {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}}, + "lr_scheduler", + LR_SCHEDULER_REGISTRY, + ), + ], +) +def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry): base = ["any.py", "--trainer.max_epochs=1"] - - argv = base + ["--optimizer", "Adadelta"] - expected = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}} - new_argv, actual = LightningArgumentParser._convert_argv_issue_84(OPTIMIZER_REGISTRY.classes, "optimizer", argv) - assert new_argv == base - assert actual == expected - - argv = base + ["--optimizer", "Adadelta", "--optimizer.lr", "10"] - expected = {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": 10}} - base, actual = LightningArgumentParser._convert_argv_issue_84(OPTIMIZER_REGISTRY.classes, "optimizer", argv) - assert new_argv == base - assert actual == expected - - argv = base + ["--lr_scheduler", "OneCycleLR"] - expected = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}} - new_argv, actual = LightningArgumentParser._convert_argv_issue_84( - LR_SCHEDULER_REGISTRY.classes, "lr_scheduler", argv - ) - assert new_argv == base - assert actual == expected - - argv = base + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"] - expected = {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}} - new_argv, actual = LightningArgumentParser._convert_argv_issue_84( - LR_SCHEDULER_REGISTRY.classes, "lr_scheduler", argv - ) - assert new_argv == base - assert actual == expected + argv = base + args + new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv) + assert new_argv == base + [f"--{nested_key}", str(expected)] def test_optimizers_and_lr_schedulers_reload(tmpdir): From 2af596f615f4b8a9b4104284d431f0c494fb86e5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 04:52:18 +0200 Subject: [PATCH 59/77] Minor cleaning --- pytorch_lightning/utilities/cli.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 7c40180a61231..c64ac1c6f674d 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -284,7 +284,6 @@ def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry key = arg i += 1 value = argv[i] - key = key[2:] # remove dashes if "class_path" in value: # the user passed a config as a dict passed_configs[key] = yaml.safe_load(value) @@ -294,7 +293,7 @@ def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry clean_argv.append(arg) i += 1 # generate the associated config file - out = [] + config = [] i, n = 0, len(passed_args) while i < n - 1: ki, vi = passed_args[i] @@ -305,8 +304,7 @@ def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry f"Passed the class `--{nested_key}={ki}` but it's not registered in the registry." f"The available classes are: {registry.names}" ) - config = _global_add_class_path(cls_type) - out.append(config) + config.append(_global_add_class_path(cls_type)) # get any init args j = i + 1 # in case the j-loop doesn't run for j in range(i + 1, n): @@ -315,14 +313,14 @@ def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry break if kj.startswith(ki): init_arg_name = kj.split(".")[-1] - out[-1]["init_args"][init_arg_name] = vj + config[-1]["init_args"][init_arg_name] = vj i = j # update at the end to preserve the order for k, v in passed_configs.items(): - out.extend(v) - if not out: + config.extend(v) + if not config: return clean_argv - return clean_argv + [argv_key, str(out)] + return clean_argv + [argv_key, str(config)] class SaveConfigCallback(Callback): From 6add61927f65dac978708d8661d7548c3fbbc524 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 04:55:53 +0200 Subject: [PATCH 60/77] Mark argv as private --- pytorch_lightning/utilities/cli.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index c64ac1c6f674d..72dc43ef6fcb6 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -110,7 +110,7 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} # we need a mutable global argv copy in order to support `add_class_choices` - sys._pl_argv = sys.argv.copy() + sys.__argv = sys.argv.copy() def add_lightning_class_args( self, @@ -201,7 +201,7 @@ def add_lr_scheduler_args( self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args, **kwargs) -> Union[Namespace, Dict[str, Any]]: - argv = self._convert_argv_issue_85("trainer.callbacks", sys._pl_argv, CALLBACK_REGISTRY) + argv = self._convert_argv_issue_85("trainer.callbacks", sys.__argv, CALLBACK_REGISTRY) with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) @@ -212,16 +212,16 @@ def add_class_choices( This should be removed once implemented. """ - if not any(arg.startswith(f"--{nested_key}") for arg in sys._pl_argv): # the key was passed - if any(arg.startswith("--config") for arg in sys._pl_argv): # a config was passed + if not any(arg.startswith(f"--{nested_key}") for arg in sys.__argv): # the key was passed + if any(arg.startswith("--config") for arg in sys.__argv): # a config was passed # parsing config files would be too difficult, fall back to what's available self.add_subclass_arguments(classes, nested_key, *args, **kwargs) elif required: raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed") else: - clean_argv = self._convert_argv_issue_84(classes, nested_key, sys._pl_argv) + clean_argv = self._convert_argv_issue_84(classes, nested_key, sys.__argv) self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - sys._pl_argv = clean_argv + sys.__argv = clean_argv @staticmethod def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: From 525358a0fb1058bd7f554832254f1bb9fe0eaa14 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 05:04:24 +0200 Subject: [PATCH 61/77] Fix mypy --- .github/workflows/code-checks.yml | 1 + pytorch_lightning/utilities/cli.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index d666f60597786..ec8e53e26443c 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -18,5 +18,6 @@ jobs: - name: Install mypy run: | grep mypy requirements/test.txt | xargs -0 pip install + mypy --install-types pip list - run: mypy diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 72dc43ef6fcb6..e0f54c2c06453 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -200,14 +200,12 @@ def add_lr_scheduler_args( self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) - def parse_args(self, *args, **kwargs) -> Union[Namespace, Dict[str, Any]]: + def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: argv = self._convert_argv_issue_85("trainer.callbacks", sys.__argv, CALLBACK_REGISTRY) with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) - def add_class_choices( - self, classes: Tuple[Type, ...], nested_key: str, *args: Any, required: bool = False, **kwargs: Any - ) -> None: + def add_class_choices(self, classes: Tuple[Type, ...], nested_key: str, *args: Any, **kwargs: Any) -> None: """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. This should be removed once implemented. @@ -216,7 +214,7 @@ def add_class_choices( if any(arg.startswith("--config") for arg in sys.__argv): # a config was passed # parsing config files would be too difficult, fall back to what's available self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - elif required: + elif kwargs.get("required", False): raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed") else: clean_argv = self._convert_argv_issue_84(classes, nested_key, sys.__argv) @@ -249,7 +247,7 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis class_path = passed_args[f"{argv_key}.class_path"] init_args_key = f"{argv_key}.init_args" init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)} - config = {"class_path": class_path, "init_args": init_args} + config = str({"class_path": class_path, "init_args": init_args}) elif argv_class.startswith("{"): # the user passed a config as a dict config = argv_class @@ -258,11 +256,11 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period for cls in classes: if cls.__name__ == argv_class: - config = _global_add_class_path(cls, init_args) + config = str(_global_add_class_path(cls, init_args)) break else: raise ValueError(f"Could not generate a config for {repr(argv_class)}") - return clean_argv + [argv_key, str(config)] + return clean_argv + [argv_key, config] @staticmethod def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry) -> List[str]: From 84b8120e756a58a94d1fe6993522956102dd14e3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 05:07:58 +0200 Subject: [PATCH 62/77] Fix mypy --- .github/workflows/code-checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index ec8e53e26443c..0411c54f23648 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -18,6 +18,6 @@ jobs: - name: Install mypy run: | grep mypy requirements/test.txt | xargs -0 pip install - mypy --install-types + mypy --install-types --non-interactive pip list - run: mypy From 7e48c0e2c517bbc12d497e24fd90a905bb8bb914 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 05:09:16 +0200 Subject: [PATCH 63/77] Fix mypy --- .github/workflows/code-checks.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index 0411c54f23648..8bede9ea9ddda 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -18,6 +18,5 @@ jobs: - name: Install mypy run: | grep mypy requirements/test.txt | xargs -0 pip install - mypy --install-types --non-interactive pip list - - run: mypy + - run: mypy --install-types --non-interactive From 3e77e8e6115b38ca5b46408fa5f45939da38e627 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 15:57:59 +0200 Subject: [PATCH 64/77] Support shorthand notation to instantiate optimizers and learning rate schedulers --- CHANGELOG.md | 3 + docs/source/common/lightning_cli.rst | 135 +++++++++++----- pytorch_lightning/utilities/cli.py | 122 ++++++++++++++- requirements/extra.txt | 2 +- tests/utilities/test_cli.py | 225 +++++++++++++++++++++++++-- 5 files changed, 434 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 42a720e5c0d1a..4c793fe5558cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751)) * Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508)) * Allow easy trainer re-instantiation ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/9241)) + * Automatically register all optimizers and learning rate schedulers ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) + * Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) + * Support shorthand notation to instantiate optimizers and learning rate schedulers ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) - Fault-tolerant training: diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index b873c36168d1c..cd794112305ad 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -665,69 +665,131 @@ Optimizers and learning rate schedulers ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Optimizers and learning rate schedulers can also be made configurable. The most common case is when a model only has a -single optimizer and optionally a single learning rate scheduler. In this case the model's -:class:`~pytorch_lightning.core.lightning.LightningModule` could be left without implementing the -:code:`configure_optimizers` method since it is normally always the same and just adds boilerplate. The following code -snippet shows how to implement it: +single optimizer and optionally a single learning rate scheduler. In this case, the model's +:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` could be left unimplemented since it is +normally always the same and just adds boilerplate. -.. testcode:: - - import torch - - - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(torch.optim.Adam) - parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) +The CLI works out-of-the-box with PyTorch's built-in optimizers and learning rate schedulers when +at most one of each is used. +Only the optimizer or scheduler name needs to be passed, optionally with its ``__init__`` arguments: +.. code-block:: bash - cli = MyLightningCLI(MyModel) + $ python trainer.py fit --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1 -With this the :code:`configure_optimizers` method is automatically implemented and in the config the :code:`optimizer` -and :code:`lr_scheduler` groups would accept all of the options for the given classes, in this example :code:`Adam` and -:code:`ExponentialLR`. Therefore, the config file would be structured like: +A corresponding example of the config file would be: .. code-block:: yaml optimizer: - lr: 0.01 + class_path: torch.optim.Adam + init_args: + lr: 0.01 lr_scheduler: - gamma: 0.2 + class_path: torch.optim.lr_scheduler.ExponentialLR + init_args: + gamma: 0.1 model: ... trainer: ... -And any of these arguments could be passed directly through command line. For example: +.. note:: + + This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file + generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation. + +Furthermore, you can register your own optimizers and/or learning rate schedulers as follows: + +.. code-block:: python + + from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY + + + @OPTIMIZER_REGISTRY + class CustomAdam(torch.optim.Adam): + ... + + + @LR_SCHEDULER_REGISTRY + class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): + ... + + + cli = LightningCLI(...) .. code-block:: bash - $ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2 + $ python trainer.py fit --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR + +If you need to customize the key names or link arguments together, you can choose from all available optimizers and +learning rate schedulers by accessing the registries. + +.. code-block:: + + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args( + OPTIMIZER_REGISTRY.classes, + nested_key="gen_optimizer", + link_to="model.optimizer1_init" + ) + parser.add_optimizer_args( + OPTIMIZER_REGISTRY.classes, + nested_key="gen_discriminator", + link_to="model.optimizer2_init" + ) + +.. code-block:: bash + + $ python trainer.py fit \ + --gen_optimizer=Adam \ + --gen_optimizer.lr=0.01 \ + --gen_discriminator=AdamW \ + --gen_discriminator.lr=0.0001 + +You can also use pass the class path directly, for example, if the optimizer hasn't been registered to the +``OPTIMIZER_REGISTRY``: + +.. code-block:: bash -There is also the possibility of selecting among multiple classes by giving them as a tuple. For example: + $ python trainer.py fit \ + --gen_optimizer.class_path=torch.optim.Adam \ + --gen_optimizer.init_args.lr=0.01 \ + --gen_discriminator.class_path=torch.optim.AdamW \ + --gen_discriminator.init_args.lr=0.0001 + +If you will not be changing the class, you can manually add the arguments for specific optimizers and/or +learning rate schedulers by subclassing the CLI. This has the advantage of providing the proper help message for those +classes. The following code snippet shows how to implement it: .. testcode:: class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): - parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) + parser.add_optimizer_args(torch.optim.Adam) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) -In this case in the config the :code:`optimizer` group instead of having directly init settings, it should specify -:code:`class_path` and optionally :code:`init_args`. Sub-classes of the classes in the tuple would also be accepted. -A corresponding example of the config file would be: +With this, in the config the :code:`optimizer` and :code:`lr_scheduler` groups would accept all of the options for the +given classes, in this example :code:`Adam` and :code:`ExponentialLR`. +Therefore, the config file would be structured like: .. code-block:: yaml optimizer: - class_path: torch.optim.Adam - init_args: - lr: 0.01 + lr: 0.01 + lr_scheduler: + gamma: 0.2 + model: + ... + trainer: + ... -And the same through command line: +Where the arguments can be passed directly through command line without specifying the class. For example: .. code-block:: bash - $ python trainer.py fit --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01 + $ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2 The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be: @@ -763,12 +825,11 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th cli = MyLightningCLI(MyModel) -For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with -a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including -:code:`class_path` and :code:`init_args` entries. The function -:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in -:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the -:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. +The value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and +:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class` +takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments, +in this case :code:`self.parameters()`, and the :code:`init_args`. +Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. Notes related to reproducibility diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 1d437e69a9ef0..576bd4ae67554 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os +import sys from argparse import Namespace -from types import MethodType +from types import MethodType, ModuleType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from unittest import mock +import torch from torch.optim import Optimizer from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer @@ -35,6 +39,50 @@ ArgumentParser = object +class _Registry(dict): + def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> None: + """Registers a class mapped to a name. + + Args: + cls: the class to be mapped. + key: the name that identifies the provided class. + """ + if key is None: + key = cls.__name__ + elif not isinstance(key, str): + raise TypeError(f"`key` must be a str, found {key}") + + if key in self and not override: + raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") + self[key] = cls + + def register_package(self, module: ModuleType, base_cls: Type) -> None: + """This function is an utility to register all classes from a module.""" + for _, cls in inspect.getmembers(module, predicate=inspect.isclass): + if issubclass(cls, base_cls) and cls != base_cls: + self(cls=cls) + + @property + def names(self) -> List[str]: + """Returns the registered names.""" + return list(self.keys()) + + @property + def classes(self) -> Tuple[Type, ...]: + """Returns the registered classes.""" + return tuple(self.values()) + + def __str__(self) -> str: + return f"Registered objects: {self.names}" + + +OPTIMIZER_REGISTRY = _Registry() +OPTIMIZER_REGISTRY.register_package(torch.optim, torch.optim.Optimizer) + +LR_SCHEDULER_REGISTRY = _Registry() +LR_SCHEDULER_REGISTRY.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) + + class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" @@ -57,6 +105,8 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non # separate optimizers and lr schedulers to know which were added self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + # we need a mutable global argv copy in order to support `add_class_choices` + sys.__argv = sys.argv.copy() def add_lightning_class_args( self, @@ -117,7 +167,7 @@ def add_optimizer_args( assert issubclass(optimizer_class, Optimizer) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): - self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) + self.add_class_choices(optimizer_class, nested_key, **kwargs) else: self.add_class_arguments(optimizer_class, nested_key, **kwargs) self._optimizers[nested_key] = (optimizer_class, link_to) @@ -141,11 +191,72 @@ def add_lr_scheduler_args( assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): - self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) + self.add_class_choices(lr_scheduler_class, nested_key, **kwargs) else: self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) + def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + with mock.patch("sys.argv", sys.__argv): + return super().parse_args(*args, **kwargs) + + def add_class_choices(self, classes: Tuple[Type, ...], nested_key: str, *args: Any, **kwargs: Any) -> None: + """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. + + This should be removed once implemented. + """ + if not any(arg.startswith(f"--{nested_key}") for arg in sys.__argv): # the key was passed + if any(arg.startswith("--config") for arg in sys.__argv): # a config was passed + # parsing config files would be too difficult, fall back to what's available + self.add_subclass_arguments(classes, nested_key, *args, **kwargs) + elif kwargs.get("required", False): + raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed") + else: + clean_argv = self._convert_argv_issue_84(classes, nested_key, sys.__argv) + self.add_subclass_arguments(classes, nested_key, *args, **kwargs) + sys.__argv = clean_argv + + @staticmethod + def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: + passed_args, clean_argv = {}, [] + argv_key = f"--{nested_key}" + # get the argv args for this nested key + i = 0 + while i < len(argv): + arg = argv[i] + if arg.startswith(argv_key): + if "=" in arg: + key, value = arg.split("=") + else: + key = arg + i += 1 + value = argv[i] + passed_args[key] = value + else: + clean_argv.append(arg) + i += 1 + # generate the associated config file + argv_class = passed_args.pop(argv_key, None) + if argv_class is None: + # the user passed a config as a str + class_path = passed_args[f"{argv_key}.class_path"] + init_args_key = f"{argv_key}.init_args" + init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)} + config = str({"class_path": class_path, "init_args": init_args}) + elif argv_class.startswith("{"): + # the user passed a config as a dict + config = argv_class + else: + # the user passed the shorthand format + init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period + for cls in classes: + if cls.__name__ == argv_class: + config = str(_global_add_class_path(cls, init_args)) + break + else: + raise ValueError(f"Could not generate a config for {repr(argv_class)}") + return clean_argv + [argv_key, config] + class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -328,6 +439,11 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None: self.add_default_arguments_to_parser(parser) self.add_core_arguments_to_parser(parser) self.add_arguments_to_parser(parser) + # add default optimizer args if necessary + if not parser._optimizers: # already added by the user in `add_arguments_to_parser` + parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) + if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes) self.link_optimizers_and_lr_schedulers(parser) def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: diff --git a/requirements/extra.txt b/requirements/extra.txt index dfffc6fce8428..03fb29b29825a 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,6 +7,6 @@ torchtext>=0.7 onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.19.0 +git+git://github.com/omni-us/jsonargparse@c3e5c21#egg=jsonargparse[signatures] gcsfs>=2021.5.0 rich>=10.2.2 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e526027708418..6a224be45258c 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -33,7 +33,14 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback +from pytorch_lightning.utilities.cli import ( + instantiate_class, + LightningArgumentParser, + LightningCLI, + LR_SCHEDULER_REGISTRY, + OPTIMIZER_REGISTRY, + SaveConfigCallback, +) from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -678,12 +685,20 @@ def add_arguments_to_parser(self, parser): assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 -def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): +@pytest.mark.parametrize("use_registries", [False, True]) +def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1") + parser.add_optimizer_args( + OPTIMIZER_REGISTRY.classes if use_registries else torch.optim.Adam, + nested_key="optim1", + link_to="model.optim1", + ) parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") - parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler") + parser.add_lr_scheduler_args( + LR_SCHEDULER_REGISTRY.classes if use_registries else torch.optim.lr_scheduler.ExponentialLR, + link_to="model.scheduler", + ) class TestModel(BoringModel): def __init__(self, optim1: dict, optim2: dict, scheduler: dict): @@ -692,20 +707,26 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): self.optim2 = instantiate_class(self.parameters(), optim2) self.scheduler = instantiate_class(self.optim1, scheduler) - cli_args = [ - "fit", - f"--trainer.default_root_dir={tmpdir}", - "--trainer.max_epochs=1", - "--optim2.class_path=torch.optim.SGD", - "--optim2.init_args.lr=0.01", - "--lr_scheduler.gamma=0.2", - ] + cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--lr_scheduler.gamma=0.2"] + if use_registries: + cli_args += [ + "--optim1", + "Adam", + "--optim1.weight_decay", + "0.001", + "--optim2=SGD", + "--optim2.lr=0.01", + "--lr_scheduler=ExponentialLR", + ] + else: + cli_args += ["--optim2.class_path=torch.optim.SGD", "--optim2.init_args.lr=0.01"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(TestModel) assert isinstance(cli.model.optim1, torch.optim.Adam) assert isinstance(cli.model.optim2, torch.optim.SGD) + assert cli.model.optim2.param_groups[0]["lr"] == 0.01 assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) @@ -829,6 +850,186 @@ def test_lightning_cli_run(): assert isinstance(cli.model, LightningModule) +@OPTIMIZER_REGISTRY +class CustomAdam(torch.optim.Adam): + pass + + +@LR_SCHEDULER_REGISTRY +class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): + pass + + +def test_registries(tmpdir): + assert "SGD" in OPTIMIZER_REGISTRY.names + assert "RMSprop" in OPTIMIZER_REGISTRY.names + assert "CustomAdam" in OPTIMIZER_REGISTRY.names + + assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names + assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names + assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names + + +def test_registries_resolution(): + """This test validates registries are used when simplified command line are being used.""" + cli_args = [ + "--optimizer", + "Adam", + "--optimizer.lr", + "0.0001", + "--lr_scheduler", + "StepLR", + "--lr_scheduler.step_size=50", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(BoringModel, run=False) + + optimizers, lr_scheduler = cli.model.configure_optimizers() + assert isinstance(optimizers[0], torch.optim.Adam) + assert optimizers[0].param_groups[0]["lr"] == 0.0001 + assert lr_scheduler[0].step_size == 50 + + +@pytest.mark.parametrize( + ["args", "expected", "nested_key", "registry"], + [ + ( + ["--optimizer", "Adadelta"], + {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}}, + "optimizer", + OPTIMIZER_REGISTRY, + ), + ( + ["--optimizer", "Adadelta", "--optimizer.lr", "10"], + {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}}, + "optimizer", + OPTIMIZER_REGISTRY, + ), + ( + ["--lr_scheduler", "OneCycleLR"], + {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}}, + "lr_scheduler", + LR_SCHEDULER_REGISTRY, + ), + ( + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"], + {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}}, + "lr_scheduler", + LR_SCHEDULER_REGISTRY, + ), + ], +) +def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry): + base = ["any.py", "--trainer.max_epochs=1"] + argv = base + args + new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv) + assert new_argv == base + [f"--{nested_key}", str(expected)] + + +def test_optimizers_and_lr_schedulers_reload(tmpdir): + base = ["any.py", "--trainer.max_epochs=1"] + input = base + [ + "--lr_scheduler", + "OneCycleLR", + "--lr_scheduler.total_steps=10", + "--lr_scheduler.max_lr=1", + "--optimizer", + "Adam", + "--optimizer.lr=0.1", + ] + + # save config + out = StringIO() + with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): + LightningCLI(BoringModel, run=False) + + # validate yaml + yaml_config = out.getvalue() + dict_config = yaml.safe_load(yaml_config) + assert dict_config["optimizer"]["class_path"] == "torch.optim.adam.Adam" + assert dict_config["optimizer"]["init_args"]["lr"] == 0.1 + assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + + # reload config + yaml_config_file = tmpdir / "config.yaml" + yaml_config_file.write_text(yaml_config, "utf-8") + with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): + LightningCLI(BoringModel, run=False) + + +def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir): + class TestLightningCLI(LightningCLI): + def __init__(self, *args): + super().__init__(*args, run=False) + + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config") + parser.add_optimizer_args( + (torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config" + ) + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config") + parser.add_argument("--something", type=str, nargs="+") + + class TestModel(BoringModel): + def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict): + super().__init__() + self.opt1_config = opt1_config + self.opt2_config = opt2_config + self.sch_config = sch_config + opt1 = instantiate_class(self.parameters(), opt1_config) + assert isinstance(opt1, torch.optim.Adam) + opt2 = instantiate_class(self.parameters(), opt2_config) + assert isinstance(opt2, torch.optim.ASGD) + sch = instantiate_class(opt1, sch_config) + assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR) + + base = ["any.py", "--trainer.max_epochs=1"] + input = base + [ + "--lr_scheduler", + "OneCycleLR", + "--lr_scheduler.total_steps=10", + "--lr_scheduler.max_lr=1", + "--opt1", + "Adam", + "--opt2.lr=0.1", + "--opt2", + "ASGD", + "--lr_scheduler.anneal_strategy=linear", + "--something", + "a", + "b", + "c", + ] + + # save config + out = StringIO() + with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): + TestLightningCLI(TestModel) + + # validate yaml + yaml_config = out.getvalue() + dict_config = yaml.safe_load(yaml_config) + assert dict_config["opt1"]["class_path"] == "torch.optim.adam.Adam" + assert dict_config["opt2"]["class_path"] == "torch.optim.asgd.ASGD" + assert dict_config["opt2"]["init_args"]["lr"] == 0.1 + assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear" + assert dict_config["something"] == ["a", "b", "c"] + + # reload config + yaml_config_file = tmpdir / "config.yaml" + yaml_config_file.write_text(yaml_config, "utf-8") + with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): + cli = TestLightningCLI(TestModel) + + assert cli.model.opt1_config["class_path"] == "torch.optim.adam.Adam" + assert cli.model.opt2_config["class_path"] == "torch.optim.asgd.ASGD" + assert cli.model.opt2_config["init_args"]["lr"] == 0.1 + assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear" + + @RunIf(min_python="3.7.3") # bpo-17185: `autospec=True` and `inspect.signature` do not play well def test_lightning_cli_config_with_subcommand(): config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} From 1512a8056b4a5db881d624d31f6cde2a36cad05a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 16:06:36 +0200 Subject: [PATCH 65/77] Update CHANGELOG --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c793fe5558cb..ea557e4a1d2f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,9 +53,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751)) * Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508)) * Allow easy trainer re-instantiation ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/9241)) - * Automatically register all optimizers and learning rate schedulers ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) - * Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) - * Support shorthand notation to instantiate optimizers and learning rate schedulers ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) + * Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) + * Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) + * Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) - Fault-tolerant training: From c6b86b147e329d80218a92f6d5eb791eb6620021 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 16:20:09 +0200 Subject: [PATCH 66/77] Fix install --- requirements/extra.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/extra.txt b/requirements/extra.txt index 03fb29b29825a..571e12b3ef65f 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,6 +7,7 @@ torchtext>=0.7 onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 +jsonargparse[signatures] @https://github.com/omni-us/jsonargparse/archive/c3e5c210ed9a93bab11fff9da5e1281def893467.zip git+git://github.com/omni-us/jsonargparse@c3e5c21#egg=jsonargparse[signatures] gcsfs>=2021.5.0 rich>=10.2.2 From 6f1600c8f5721b388ac52fa329f79a658d3b2ce5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 16:20:45 +0200 Subject: [PATCH 67/77] Fix install --- requirements/extra.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index 571e12b3ef65f..ee4110c87877e 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,6 +8,5 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 jsonargparse[signatures] @https://github.com/omni-us/jsonargparse/archive/c3e5c210ed9a93bab11fff9da5e1281def893467.zip -git+git://github.com/omni-us/jsonargparse@c3e5c21#egg=jsonargparse[signatures] gcsfs>=2021.5.0 rich>=10.2.2 From a3a791f3e80cc33732f4efa3743350257f0294eb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 18:02:39 +0200 Subject: [PATCH 68/77] Use release --- requirements/extra.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index ee4110c87877e..e7a62d3071158 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,6 +7,6 @@ torchtext>=0.7 onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures] @https://github.com/omni-us/jsonargparse/archive/c3e5c210ed9a93bab11fff9da5e1281def893467.zip +jsonargparse[signatures]>=3.19.3 gcsfs>=2021.5.0 rich>=10.2.2 From fedae467bbebad7a8af6eea0d83c84e0be8a9783 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Sep 2021 17:32:05 +0000 Subject: [PATCH 69/77] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/cli.py | 2 +- tests/utilities/test_cli.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index dfdf742af110f..54f4a5e02cf70 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -20,8 +20,8 @@ from unittest import mock import torch -from torch.optim import Optimizer import yaml +from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index cdb63ef710af1..5d9b7b392a673 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -874,7 +874,7 @@ def test_registries(tmpdir): assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names - + assert "EarlyStopping" in CALLBACK_REGISTRY.names assert "CustomCallback" in CALLBACK_REGISTRY.names From ee7a068a89453c64b865c26de0768993b0f935ff Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 21:08:24 +0200 Subject: [PATCH 70/77] Introduce set_choices --- pytorch_lightning/utilities/cli.py | 27 ++++++++++++++++++--------- tests/utilities/test_cli.py | 8 ++++---- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 54f4a5e02cf70..148da53ca33f3 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -112,6 +112,7 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} # we need a mutable global argv copy in order to support `add_class_choices` sys.__argv = sys.argv.copy() + self._choices: Dict[str, Tuple[Type, ...]] = {} def add_lightning_class_args( self, @@ -202,7 +203,9 @@ def add_lr_scheduler_args( self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: - argv = self._convert_argv_issue_85("trainer.callbacks", sys.__argv, CALLBACK_REGISTRY) + argv = sys.__argv + for k, v in self._choices.items(): + argv = self._convert_argv_issue_85(v, k, argv) with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) @@ -263,12 +266,17 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis raise ValueError(f"Could not generate a config for {repr(argv_class)}") return clean_argv + [argv_key, config] - @staticmethod - def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry) -> List[str]: + def set_choices(self, *args: Dict[str, Tuple[Type, ...]], **kwargs: Tuple[Type, ...]) -> None: """Placeholder for https://github.com/omni-us/jsonargparse/issues/85. This should be removed once implemented. """ + for arg in args: + self._choices.update(arg) + self._choices.update(kwargs) + + @staticmethod + def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: passed_args, clean_argv = [], [] passed_configs = {} argv_key = f"--{nested_key}" @@ -297,12 +305,12 @@ def _convert_argv_issue_85(nested_key: str, argv: List[str], registry: _Registry while i < n - 1: ki, vi = passed_args[i] # convert class name to class path - cls_type = registry.get(vi) - if cls_type is None: - raise ValueError( - f"Passed the class `--{nested_key}={ki}` but it's not registered in the registry." - f"The available classes are: {registry.names}" - ) + for cls in classes: + if cls.__name__ == vi: + cls_type = cls + break + else: + raise ValueError(f"Could not generate a config for {repr(vi)}") config.append(_global_add_class_path(cls_type)) # get any init args j = i + 1 # in case the j-loop doesn't run @@ -492,6 +500,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Adds arguments from the core classes to the parser.""" parser.add_lightning_class_args(self.trainer_class, "trainer") + parser.set_choices({"trainer.callbacks": CALLBACK_REGISTRY.classes}) trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} parser.set_defaults(trainer_defaults) parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 5d9b7b392a673..5acc57137d56d 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -921,7 +921,7 @@ def test_registries_resolution(use_class_path_callbacks): def test_argv_transformation_noop(): base = ["any.py", "--trainer.max_epochs=1"] - argv = LightningArgumentParser._convert_argv_issue_85("trainer.callbacks", base, CALLBACK_REGISTRY) + argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base) assert argv == base @@ -935,7 +935,7 @@ def test_argv_transformation_single_callback(): } ] expected = base + ["--trainer.callbacks", str(callbacks)] - argv = LightningArgumentParser._convert_argv_issue_85("trainer.callbacks", input, CALLBACK_REGISTRY) + argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) assert argv == expected @@ -958,7 +958,7 @@ def test_argv_transformation_multiple_callbacks(): }, ] expected = base + ["--trainer.callbacks", str(callbacks)] - argv = LightningArgumentParser._convert_argv_issue_85("trainer.callbacks", input, CALLBACK_REGISTRY) + argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) assert argv == expected @@ -985,7 +985,7 @@ def test_argv_transformation_multiple_callbacks_with_config(): ] expected = base + ["--trainer.callbacks", str(callbacks)] nested_key = "trainer.callbacks" - argv = LightningArgumentParser._convert_argv_issue_85(nested_key, input, CALLBACK_REGISTRY) + argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input) assert argv == expected From 6e67617a0102ed625ddfbe9beeae0c1362bac842 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 21:51:36 +0200 Subject: [PATCH 71/77] Undo change --- pytorch_lightning/utilities/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 148da53ca33f3..2f7f7ad958d11 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -642,7 +642,7 @@ def get_automatic( automatic.append(key) return automatic - optimizers = get_automatic(torch.optim.Optimizer, parser._optimizers) + optimizers = get_automatic(Optimizer, parser._optimizers) lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) if len(optimizers) == 0: From e7f6d6171a9297eed226a65024f4fe2413288611 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 22:46:06 +0200 Subject: [PATCH 72/77] Replace add_class_choices with set_choices --- pytorch_lightning/utilities/cli.py | 55 ++++++++++++++---------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 2f7f7ad958d11..504b8a500caaf 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -91,6 +91,9 @@ def __str__(self) -> str: class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" + # use class attribute because `parse_args` is only called on the main parser + _choices: Dict[str, Tuple[Type, ...]] = {} + def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. @@ -110,9 +113,6 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non # separate optimizers and lr schedulers to know which were added self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - # we need a mutable global argv copy in order to support `add_class_choices` - sys.__argv = sys.argv.copy() - self._choices: Dict[str, Tuple[Type, ...]] = {} def add_lightning_class_args( self, @@ -173,7 +173,8 @@ def add_optimizer_args( assert issubclass(optimizer_class, Optimizer) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): - self.add_class_choices(optimizer_class, nested_key, **kwargs) + self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) + self.set_choices(nested_key, optimizer_class) else: self.add_class_arguments(optimizer_class, nested_key, **kwargs) self._optimizers[nested_key] = (optimizer_class, link_to) @@ -197,36 +198,37 @@ def add_lr_scheduler_args( assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): - self.add_class_choices(lr_scheduler_class, nested_key, **kwargs) + self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) + self.set_choices(nested_key, lr_scheduler_class) else: self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: - argv = sys.__argv + argv = sys.argv for k, v in self._choices.items(): - argv = self._convert_argv_issue_85(v, k, argv) + if not any(arg.startswith(f"--{k}") for arg in argv): + # the key wasn't passed - maybe defined in a config, maybe it's optional + continue + classes, is_list = v + if is_list: + argv = self._convert_argv_issue_85(classes, k, argv) + else: + argv = self._convert_argv_issue_84(classes, k, argv) + self._choices.clear() # reset with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) - def add_class_choices(self, classes: Tuple[Type, ...], nested_key: str, *args: Any, **kwargs: Any) -> None: + def set_choices(self, nested_key: str, classes: Tuple[Type, ...], is_list: bool = False) -> None: + # knowing whether the argument is a list type automatically would be too complex + self._choices[nested_key] = (classes, is_list) + + @staticmethod + def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. This should be removed once implemented. """ - if not any(arg.startswith(f"--{nested_key}") for arg in sys.__argv): # the key was passed - if any(arg.startswith("--config") for arg in sys.__argv): # a config was passed - # parsing config files would be too difficult, fall back to what's available - self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - elif kwargs.get("required", False): - raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed") - else: - clean_argv = self._convert_argv_issue_84(classes, nested_key, sys.__argv) - self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - sys.__argv = clean_argv - - @staticmethod - def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: passed_args, clean_argv = {}, [] argv_key = f"--{nested_key}" # get the argv args for this nested key @@ -266,17 +268,12 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis raise ValueError(f"Could not generate a config for {repr(argv_class)}") return clean_argv + [argv_key, config] - def set_choices(self, *args: Dict[str, Tuple[Type, ...]], **kwargs: Tuple[Type, ...]) -> None: + @staticmethod + def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: """Placeholder for https://github.com/omni-us/jsonargparse/issues/85. This should be removed once implemented. """ - for arg in args: - self._choices.update(arg) - self._choices.update(kwargs) - - @staticmethod - def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: passed_args, clean_argv = [], [] passed_configs = {} argv_key = f"--{nested_key}" @@ -500,7 +497,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Adds arguments from the core classes to the parser.""" parser.add_lightning_class_args(self.trainer_class, "trainer") - parser.set_choices({"trainer.callbacks": CALLBACK_REGISTRY.classes}) + parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True) trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} parser.set_defaults(trainer_defaults) parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model) From 8e873591d5cac65d973784b1291cd14636d14cb6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 22:50:03 +0200 Subject: [PATCH 73/77] Replace add_class_choices with set_choices --- pytorch_lightning/utilities/cli.py | 41 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 576bd4ae67554..7b73b97baf1cd 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -77,7 +77,7 @@ def __str__(self) -> str: OPTIMIZER_REGISTRY = _Registry() -OPTIMIZER_REGISTRY.register_package(torch.optim, torch.optim.Optimizer) +OPTIMIZER_REGISTRY.register_package(torch.optim, Optimizer) LR_SCHEDULER_REGISTRY = _Registry() LR_SCHEDULER_REGISTRY.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) @@ -86,6 +86,9 @@ def __str__(self) -> str: class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" + # use class attribute because `parse_args` is only called on the main parser + _choices: Dict[str, Tuple[Type, ...]] = {} + def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. @@ -105,8 +108,6 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non # separate optimizers and lr schedulers to know which were added self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - # we need a mutable global argv copy in order to support `add_class_choices` - sys.__argv = sys.argv.copy() def add_lightning_class_args( self, @@ -167,7 +168,8 @@ def add_optimizer_args( assert issubclass(optimizer_class, Optimizer) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): - self.add_class_choices(optimizer_class, nested_key, **kwargs) + self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) + self.set_choices(nested_key, optimizer_class) else: self.add_class_arguments(optimizer_class, nested_key, **kwargs) self._optimizers[nested_key] = (optimizer_class, link_to) @@ -191,33 +193,32 @@ def add_lr_scheduler_args( assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): - self.add_class_choices(lr_scheduler_class, nested_key, **kwargs) + self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) + self.set_choices(nested_key, lr_scheduler_class) else: self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: - with mock.patch("sys.argv", sys.__argv): + argv = sys.argv + for k, classes in self._choices.items(): + if not any(arg.startswith(f"--{k}") for arg in argv): + # the key wasn't passed - maybe defined in a config, maybe it's optional + continue + argv = self._convert_argv_issue_84(classes, k, argv) + self._choices.clear() + with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) - def add_class_choices(self, classes: Tuple[Type, ...], nested_key: str, *args: Any, **kwargs: Any) -> None: + def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None: + self._choices[nested_key] = classes + + @staticmethod + def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. This should be removed once implemented. """ - if not any(arg.startswith(f"--{nested_key}") for arg in sys.__argv): # the key was passed - if any(arg.startswith("--config") for arg in sys.__argv): # a config was passed - # parsing config files would be too difficult, fall back to what's available - self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - elif kwargs.get("required", False): - raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed") - else: - clean_argv = self._convert_argv_issue_84(classes, nested_key, sys.__argv) - self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - sys.__argv = clean_argv - - @staticmethod - def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: passed_args, clean_argv = {}, [] argv_key = f"--{nested_key}" # get the argv args for this nested key From 66cdb5280b373126cab21e618b0d226f7741a6f4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 23:02:10 +0200 Subject: [PATCH 74/77] Docstrings --- pytorch_lightning/utilities/cli.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index c5d784a944aac..cc15e9ea1d56f 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -211,6 +211,7 @@ def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # the key wasn't passed - maybe defined in a config, maybe it's optional continue classes, is_list = v + # knowing whether the argument is a list type automatically would be too complex if is_list: argv = self._convert_argv_issue_85(classes, k, argv) else: @@ -220,14 +221,20 @@ def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: return super().parse_args(*args, **kwargs) def set_choices(self, nested_key: str, classes: Tuple[Type, ...], is_list: bool = False) -> None: - # knowing whether the argument is a list type automatically would be too complex + """Adds support for shorthand notation for a particular nested key. + + Args: + nested_key: The key whose choices will be set. + classes: A tuple of classes to choose from. + is_list: Whether the argument is a ``List[object]`` type. + """ self._choices[nested_key] = (classes, is_list) @staticmethod def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. - This should be removed once implemented. + Adds support for shorthand notation for ``object`` arguments. """ passed_args, clean_argv = {}, [] argv_key = f"--{nested_key}" @@ -272,7 +279,7 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: """Placeholder for https://github.com/omni-us/jsonargparse/issues/85. - This should be removed once implemented. + Adds support for shorthand notation for ``List[object]`` arguments. """ passed_args, clean_argv = [], [] passed_configs = {} From 7b50401646559ee9f47a97486ef5d86f470a81d6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Sep 2021 17:08:18 +0000 Subject: [PATCH 75/77] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index dab90ee8c4f84..bff5d7e9111e4 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -878,7 +878,7 @@ def test_registries(tmpdir): assert "EarlyStopping" in CALLBACK_REGISTRY.names assert "CustomCallback" in CALLBACK_REGISTRY.names - + with pytest.raises(MisconfigurationException, match="is already present in the registry"): OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer) OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True) From 1406be991f8856678485f79bb1e54b2dc682354d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Sep 2021 19:11:42 +0200 Subject: [PATCH 76/77] Fix mypy --- pytorch_lightning/trainer/progress.py | 4 ++-- pytorch_lightning/utilities/cli.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 0f07c61999e1c..03d6a93b8ff07 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -234,5 +234,5 @@ def load_state_dict(self, state_dict: dict) -> None: self.optimizer_position = state_dict["optimizer_position"] def reset_on_restart(self) -> None: - self.optimizer.step.current.reset_on_restart() - self.optimizer.zero_grad.current.reset_on_restart() + self.optimizer.step.reset_on_restart() + self.optimizer.zero_grad.reset_on_restart() diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index ca8810ce6f78b..d97ef9ccddebb 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -93,7 +93,7 @@ class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" # use class attribute because `parse_args` is only called on the main parser - _choices: Dict[str, Tuple[Type, ...]] = {} + _choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {} def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. From a00044637fc9bcc9a88ea4fa8cb7edc73fceddd0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Sep 2021 19:21:27 +0200 Subject: [PATCH 77/77] Undo change --- pytorch_lightning/trainer/progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 03d6a93b8ff07..0f07c61999e1c 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -234,5 +234,5 @@ def load_state_dict(self, state_dict: dict) -> None: self.optimizer_position = state_dict["optimizer_position"] def reset_on_restart(self) -> None: - self.optimizer.step.reset_on_restart() - self.optimizer.zero_grad.reset_on_restart() + self.optimizer.step.current.reset_on_restart() + self.optimizer.zero_grad.current.reset_on_restart()