Skip to content

Commit

Permalink
Add LightningCLI(run=False|True) (#8751)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
carmocca and awaelchli authored Aug 10, 2021
1 parent 3096ab8 commit cb2a8ed
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662))


- Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))


- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))

Expand Down
35 changes: 25 additions & 10 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import torch
from unittest import mock
from typing import List
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from pytorch_lightning.utilities.cli import LightningCLI

original_fit = LightningCLI.fit
LightningCLI.fit = lambda self: None
cli_fit = LightningCLI.fit
LightningCLI.fit = lambda *_, **__: None
trainer_fit = Trainer.fit
Trainer.fit = lambda *_, **__: None

class MyModel(LightningModule):
Expand Down Expand Up @@ -47,7 +48,8 @@

.. testcleanup:: *

LightningCLI.fit = original_fit
LightningCLI.fit = cli_fit
Trainer.fit = trainer_fit
mock_argv.stop()


Expand Down Expand Up @@ -260,17 +262,30 @@ file. Loading a defaults file :code:`my_cli_defaults.yaml` in the current workin

.. testcode::

cli = LightningCLI(
MyModel,
MyDataModule,
parser_kwargs={"default_config_files": ["my_cli_defaults.yaml"]},
)
cli = LightningCLI(MyModel, MyDataModule, parser_kwargs={"default_config_files": ["my_cli_defaults.yaml"]})

To load a file in the user's home directory would be just changing to :code:`~/.my_cli_defaults.yaml`. Note that this
setting is given through :code:`parser_kwargs`. More parameters are supported. For details see the `ArgumentParser API
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.__init__>`_ documentation.


Instantiation only mode
^^^^^^^^^^^^^^^^^^^^^^^

The CLI is designed to start fitting with minimal code changes. On class instantiation, the CLI will automatically
call ``trainer.fit(...)`` internally so you don't have to do it. To avoid this, you can set the following argument:

.. testcode::

cli = LightningCLI(MyModel, run=False) # True by default
# you'll have to call fit yourself:
cli.trainer.fit(cli.model)


This can be useful to implement custom logic without having to subclass the CLI, but still using the CLI's instantiation
and argument parsing capabilities.


Trainer Callbacks and arguments with class type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(
parser_kwargs: Optional[Dict[str, Any]] = None,
subclass_mode_model: bool = False,
subclass_mode_data: bool = False,
run: bool = True,
) -> None:
"""
Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
Expand Down Expand Up @@ -258,6 +259,8 @@ def __init__(
subclass_mode_data: Whether datamodule can be any `subclass
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
of the given class.
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
"""
self.model_class = model_class
self.datamodule_class = datamodule_class
Expand All @@ -284,10 +287,11 @@ def __init__(
self.instantiate_classes()
self.add_configure_optimizers_method_to_model()

self.prepare_fit_kwargs()
self.before_fit()
self.fit()
self.after_fit()
if run:
self.prepare_fit_kwargs()
self.before_fit()
self.fit()
self.after_fit()

def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
"""Method that instantiates the argument parser."""
Expand Down
9 changes: 9 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,12 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
assert isinstance(cli.model.optim1, torch.optim.Adam)
assert isinstance(cli.model.optim2, torch.optim.SGD)
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)


@pytest.mark.parametrize("run", (False, True))
def test_lightning_cli_disabled_run(run):
with mock.patch("sys.argv", ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock:
cli = LightningCLI(BoringModel, run=run)
fit_mock.call_count == run
assert isinstance(cli.trainer, Trainer)
assert isinstance(cli.model, LightningModule)

0 comments on commit cb2a8ed

Please sign in to comment.