diff --git a/CHANGELOG.md b/CHANGELOG.md index b9c9316bba449..7e49a59c79f94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - +- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) ### Deprecated diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index b468209171393..67a4704b628fc 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -69,6 +69,10 @@ def scale_batch_size(trainer, **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning) + return + if not lightning_hasattr(model, batch_arg_name): raise MisconfigurationException( f'Field {batch_arg_name} not found in both `model` and `model.hparams`') diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 54e7e082bc0a8..b6d8c8178093b 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -29,6 +29,7 @@ from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem # check if ipywidgets is installed before importing tqdm.auto @@ -42,6 +43,10 @@ def _run_lr_finder_internally(trainer, model: LightningModule): """ Call lr finder internally during Trainer.fit() """ lr_finder = lr_find(trainer, model) + + if lr_finder is None: + return + lr = lr_finder.suggestion() # TODO: log lr.results to self.logger @@ -131,6 +136,10 @@ def lr_find( trainer.fit(model) """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) + return + save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') __lr_finder_dump_params(trainer, model) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py new file mode 100644 index 0000000000000..cbe4d4012227a --- /dev/null +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -0,0 +1,21 @@ +import pytest +from pytorch_lightning import Trainer +from tests.base import EvalModelTemplate + + +@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder']) +def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg): + """ Test that tuner algorithms are skipped if fast dev run is enabled """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False, + auto_lr_find=True if tuner_alg == 'learning rate finder' else False, + fast_dev_run=True + ) + expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`' + with pytest.warns(UserWarning, match=expected_message): + trainer.tune(model)