Skip to content

Commit

Permalink
Merge branch 'master' into metrics/fix_states
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Nov 10, 2020
2 parents c9e3923 + 4f3160b commit e1e7935
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed


- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))

### Deprecated

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def scale_batch_size(trainer,
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning)
return

if not lightning_hasattr(model, batch_arg_name):
raise MisconfigurationException(
f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem

# check if ipywidgets is installed before importing tqdm.auto
Expand All @@ -42,6 +43,10 @@
def _run_lr_finder_internally(trainer, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_finder = lr_find(trainer, model)

if lr_finder is None:
return

lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
Expand Down Expand Up @@ -131,6 +136,10 @@ def lr_find(
trainer.fit(model)
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning)
return

save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')

__lr_finder_dump_params(trainer, model)
Expand Down
21 changes: 21 additions & 0 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate


@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder'])
def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
""" Test that tuner algorithms are skipped if fast dev run is enabled """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False,
auto_lr_find=True if tuner_alg == 'learning rate finder' else False,
fast_dev_run=True
)
expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`'
with pytest.warns(UserWarning, match=expected_message):
trainer.tune(model)

0 comments on commit e1e7935

Please sign in to comment.