Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Shorthand notation to instantiate datamodules #10011

Merged
merged 5 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Support passing lists of callbacks via command line ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815))
* Support shorthand notation to instantiate models ([#9588](https://github.com/PyTorchLightning/pytorch-lightning/pull/9588))
* Support shorthand notation to instantiate datamodules ([#10011](https://github.com/PyTorchLightning/pytorch-lightning/pull/10011))


- Fault-tolerant training:
Expand Down
13 changes: 9 additions & 4 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -436,20 +436,25 @@ In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI`
datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for
multiple models and datasets.

The model argument can be left unset if a model has been registered first, this is particularly interesting for library
authors who want to provide their users a range of models to choose from:
The model and datamodule arguments can be left unset if a class has been registered first.
This is particularly interesting for library authors who want to provide their users a range of models to choose from:

.. code-block:: python

import flash.image
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from pytorch_lightning.utilities.cli import MODEL_REGISTRY, DATAMODULE_REGISTRY


@MODEL_REGISTRY
class MyModel(LightningModule):
...


@DATAMODULE_REGISTRY
class MyData(LightningDataModule):
...


# register all `LightningModule` subclasses from a package
MODEL_REGISTRY.register_classes(flash.image, LightningModule)
# print(MODEL_REGISTRY)
Expand All @@ -459,7 +464,7 @@ authors who want to provide their users a range of models to choose from:

.. code-block:: bash

$ python trainer.py fit --model=MyModel --model.feat_dim=64
$ python trainer.py fit --model=MyModel --model.feat_dim=64 --data=MyData
carmocca marked this conversation as resolved.
Show resolved Hide resolved

.. note::

Expand Down
25 changes: 19 additions & 6 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __str__(self) -> str:

MODEL_REGISTRY = _Registry()

DATAMODULE_REGISTRY = _Registry()


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
Expand Down Expand Up @@ -129,13 +131,15 @@ def add_lightning_class_args(
],
nested_key: str,
subclass_mode: bool = False,
required: bool = True,
) -> List[str]:
"""Adds arguments from a lightning class to a nested key of the parser.

Args:
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
nested_key: Name of the nested namespace to store arguments.
subclass_mode: Whether allow any subclass of the given class.
required: Whether the argument group is required.

Returns:
A list with the names of the class arguments added.
Expand All @@ -149,7 +153,7 @@ def add_lightning_class_args(
if issubclass(lightning_class, Callback):
self.callback_keys.append(nested_key)
if subclass_mode:
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=True)
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required)
return self.add_class_arguments(
lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer)
)
Expand Down Expand Up @@ -432,7 +436,7 @@ def __init__(
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
called.
called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``.
save_config_callback: A callback class to save the training config.
save_config_filename: Filename for the config file.
save_config_overwrite: Whether to overwrite an existing config file.
Expand All @@ -455,21 +459,24 @@ def __init__(
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
"""
self.datamodule_class = datamodule_class
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
self.save_config_overwrite = save_config_overwrite
self.save_config_multifile = save_config_multifile
self.trainer_class = trainer_class
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
self.subclass_mode_data = subclass_mode_data

self.model_class = model_class
# used to differentiate between the original value and the processed value
self._model_class = model_class or LightningModule
self.subclass_mode_model = (model_class is None) or subclass_mode_model

self.datamodule_class = datamodule_class
# used to differentiate between the original value and the processed value
self._datamodule_class = datamodule_class or LightningDataModule
self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data

main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
Expand Down Expand Up @@ -531,12 +538,18 @@ def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.set_defaults(trainer_defaults)

parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model)
if self.model_class is None and MODEL_REGISTRY:
if self.model_class is None and len(MODEL_REGISTRY):
# did not pass a model and there are models registered
parser.set_choices("model", MODEL_REGISTRY.classes)

if self.datamodule_class is not None:
parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data)
parser.add_lightning_class_args(self._datamodule_class, "data", subclass_mode=self.subclass_mode_data)
elif len(DATAMODULE_REGISTRY):
# this should not be required because the user might want to use the `LightningModule` dataloaders
parser.add_lightning_class_args(
self._datamodule_class, "data", subclass_mode=self.subclass_mode_data, required=False
)
parser.set_choices("data", DATAMODULE_REGISTRY.classes)

def _add_arguments(self, parser: LightningArgumentParser) -> None:
# default + core + custom arguments
Expand Down
60 changes: 60 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import (
CALLBACK_REGISTRY,
DATAMODULE_REGISTRY,
instantiate_class,
LightningArgumentParser,
LightningCLI,
Expand Down Expand Up @@ -915,6 +916,65 @@ def test_lightning_cli_model_choices():
assert cli.model.bar == 5


@DATAMODULE_REGISTRY
class MyDataModule(BoringDataModule):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar


DATAMODULE_REGISTRY(cls=BoringDataModule)


def test_lightning_cli_datamodule_choices():
# with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule)

with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]):
cli = LightningCLI(BoringModel, run=False)
assert isinstance(cli.datamodule, MyDataModule)
assert cli.datamodule.foo == 123
assert cli.datamodule.bar == 5

# with configurable model
with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch(
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"pytorch_lightning.Trainer._fit_impl"
) as run:
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule)

with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]):
cli = LightningCLI(run=False)
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, MyDataModule)

assert len(DATAMODULE_REGISTRY) # needs a value initially added
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, run=False)
# data was not passed but we are adding it automatically because there are datamodules registered
assert "data" in cli.parser.groups
assert not hasattr(cli.parser.groups["data"], "group_class")

with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True):
cli = LightningCLI(BoringModel, run=False)
# no registered classes so not added automatically
assert "data" not in cli.parser.groups
assert len(DATAMODULE_REGISTRY) # check state was not modified

with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, BoringDataModule, run=False)
# since we are passing the DataModule, that's whats added to the parser
assert cli.parser.groups["data"].group_class is BoringDataModule


@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
"""This test validates registries are used when simplified command line are being used."""
Expand Down