From 9af1dd744302c52ce5972a30dd12a15808a13a9f Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 28 Oct 2021 18:27:04 +0530 Subject: [PATCH] Deprecate `lr_sch_names` from `LearningRateMonitor` (#10066) --- CHANGELOG.md | 4 ++- pytorch_lightning/callbacks/lr_monitor.py | 15 ++++++++-- tests/callbacks/test_lr_monitor.py | 34 +++++++++++------------ tests/deprecated_api/test_remove_1-7.py | 11 ++++++++ 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81cd06e43c4c8..8442457b40feb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -445,10 +445,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `ClusterEnvironment.creates_children()` in favor of `ClusterEnvironment.creates_processes_externally` (property) ([#10106](https://github.com/PyTorchLightning/pytorch-lightning/pull/10106)) - - Deprecated `PrecisionPlugin.master_params()` in favor of `PrecisionPlugin.main_params()` ([#10105](https://github.com/PyTorchLightning/pytorch-lightning/pull/10105)) +- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 26bf973bb71ae..c9875cae83e62 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -28,6 +28,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -93,7 +94,7 @@ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = self.logging_interval = logging_interval self.log_momentum = log_momentum self.lrs: Dict[str, List[float]] = {} - self.lr_sch_names: List[str] = [] + self._lr_sch_names: List[str] = [] def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Called before training, determines unique names for all lr schedulers in the case of multiple of the @@ -334,6 +335,16 @@ def _check_duplicates_and_update_name( name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))] if add_lr_sch_names: - self.lr_sch_names.append(name) + self._lr_sch_names.append(name) return name_list + + @property + def lr_sch_names(self) -> List[str]: + # TODO remove `lr_sch_names` and `add_lr_sch_names` argument in v1.7.0 + rank_zero_deprecation( + "`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5 and will be removed in 1.7." + " Consider accessing them using `LearningRateMonitor.lrs.keys()` which will return" + " the names of all the optimizers, even those without a scheduler." + ) + return self._lr_sch_names diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index 8299bec2bdf59..d35b1e8eefc38 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -41,7 +41,7 @@ def test_lr_monitor_single_lr(tmpdir): assert lr_monitor.lrs, "No learning rates logged" assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default" assert len(lr_monitor.lrs) == len(trainer.lr_schedulers) - assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["lr-SGD"] + assert list(lr_monitor.lrs) == ["lr-SGD"] @pytest.mark.parametrize("opt", ["SGD", "Adam"]) @@ -77,7 +77,7 @@ def configure_optimizers(self): assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged" assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers) - assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys()) + assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values) def test_log_momentum_no_momentum_optimizer(tmpdir): @@ -104,7 +104,7 @@ def configure_optimizers(self): assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged" assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers) - assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys()) + assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values) def test_lr_monitor_no_lr_scheduler_single_lr(tmpdir): @@ -127,7 +127,7 @@ def configure_optimizers(self): assert lr_monitor.lrs, "No learning rates logged" assert len(lr_monitor.lrs) == len(trainer.optimizers) - assert lr_monitor.lr_sch_names == ["lr-SGD"] + assert list(lr_monitor.lrs) == ["lr-SGD"] @pytest.mark.parametrize("opt", ["SGD", "Adam"]) @@ -162,7 +162,7 @@ def configure_optimizers(self): assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged" assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers) - assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys()) + assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values) def test_log_momentum_no_momentum_optimizer_no_lr_scheduler(tmpdir): @@ -188,7 +188,7 @@ def configure_optimizers(self): assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged" assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers) - assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys()) + assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values) def test_lr_monitor_no_logger(tmpdir): @@ -238,7 +238,7 @@ def configure_optimizers(self): assert lr_monitor.lrs, "No learning rates logged" assert len(lr_monitor.lrs) == len(trainer.lr_schedulers) - assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" + assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" if logging_interval == "step": expected_number_logged = trainer.global_step // log_every_n_steps @@ -281,7 +281,7 @@ def configure_optimizers(self): assert lr_monitor.lrs, "No learning rates logged" assert len(lr_monitor.lrs) == len(trainer.optimizers) - assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" + assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" if logging_interval == "step": expected_number_logged = trainer.global_step // log_every_n_steps @@ -317,8 +317,7 @@ def configure_optimizers(self): assert lr_monitor.lrs, "No learning rates logged" assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers) - assert lr_monitor.lr_sch_names == ["lr-Adam"] - assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly" + assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly" def test_lr_monitor_custom_name(tmpdir): @@ -339,7 +338,7 @@ def configure_optimizers(self): enable_model_summary=False, ) trainer.fit(TestModel()) - assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["my_logging_name"] + assert list(lr_monitor.lrs) == ["my_logging_name"] def test_lr_monitor_custom_pg_name(tmpdir): @@ -360,7 +359,6 @@ def configure_optimizers(self): enable_model_summary=False, ) trainer.fit(TestModel()) - assert lr_monitor.lr_sch_names == ["lr-SGD"] assert list(lr_monitor.lrs) == ["lr-SGD/linear"] @@ -434,7 +432,7 @@ def configure_optimizers(self): class Check(Callback): def on_train_epoch_start(self, trainer, pl_module) -> None: num_param_groups = sum(len(opt.param_groups) for opt in trainer.optimizers) - assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1", "lr-Adam-2"] + if trainer.current_epoch == 0: assert num_param_groups == 3 elif trainer.current_epoch == 1: @@ -512,7 +510,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): assert lr_monitor.lrs["lr-Adam-1/pg3"] == expected -def test_lr_monitor_multiple_param_groups_no_scheduler(tmpdir): +def test_lr_monitor_multiple_param_groups_no_lr_scheduler(tmpdir): + """Test that the `LearningRateMonitor` is able to log correct keys with multiple param groups and no + lr_scheduler.""" + class TestModel(BoringModel): def __init__(self, lr, momentum): super().__init__() @@ -550,8 +551,7 @@ def configure_optimizers(self): trainer.fit(model) assert len(lr_monitor.lrs) == len(trainer.optimizers[0].param_groups) - assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"] - assert lr_monitor.lr_sch_names == ["lr-Adam"] - assert list(lr_monitor.last_momentum_values.keys()) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"] + assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"] + assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"] assert all(val == momentum for val in lr_monitor.last_momentum_values.values()) assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 1c99b94a765ce..d3328d7a3ad45 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -19,6 +19,7 @@ from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger from tests.callbacks.test_callbacks import OldStatefulCallback @@ -438,3 +439,13 @@ def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): trainer = Trainer(resume_from_checkpoint="trainer_arg_path") with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."): trainer.fit(model, ckpt_path="fit_arg_ckpt_path") + + +def test_v1_7_0_deprecate_lr_sch_names(tmpdir): + model = BoringModel() + lr_monitor = LearningRateMonitor() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[lr_monitor]) + trainer.fit(model) + + with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"): + assert lr_monitor.lr_sch_names == ["lr-SGD"]