Skip to content

Commit

Permalink
Deprecate lr_sch_names from LearningRateMonitor (#10066)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored Oct 28, 2021
1 parent b8ac176 commit 9af1dd7
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 20 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `ClusterEnvironment.creates_children()` in favor of `ClusterEnvironment.creates_processes_externally` (property) ([#10106](https://github.com/PyTorchLightning/pytorch-lightning/pull/10106))



- Deprecated `PrecisionPlugin.master_params()` in favor of `PrecisionPlugin.main_params()` ([#10105](https://github.com/PyTorchLightning/pytorch-lightning/pull/10105))


- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool =
self.logging_interval = logging_interval
self.log_momentum = log_momentum
self.lrs: Dict[str, List[float]] = {}
self.lr_sch_names: List[str] = []
self._lr_sch_names: List[str] = []

def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Called before training, determines unique names for all lr schedulers in the case of multiple of the
Expand Down Expand Up @@ -334,6 +335,16 @@ def _check_duplicates_and_update_name(
name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]

if add_lr_sch_names:
self.lr_sch_names.append(name)
self._lr_sch_names.append(name)

return name_list

@property
def lr_sch_names(self) -> List[str]:
# TODO remove `lr_sch_names` and `add_lr_sch_names` argument in v1.7.0
rank_zero_deprecation(
"`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5 and will be removed in 1.7."
" Consider accessing them using `LearningRateMonitor.lrs.keys()` which will return"
" the names of all the optimizers, even those without a scheduler."
)
return self._lr_sch_names
34 changes: 17 additions & 17 deletions tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_lr_monitor_single_lr(tmpdir):
assert lr_monitor.lrs, "No learning rates logged"
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["lr-SGD"]
assert list(lr_monitor.lrs) == ["lr-SGD"]


@pytest.mark.parametrize("opt", ["SGD", "Adam"])
Expand Down Expand Up @@ -77,7 +77,7 @@ def configure_optimizers(self):

assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers)
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys())
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values)


def test_log_momentum_no_momentum_optimizer(tmpdir):
Expand All @@ -104,7 +104,7 @@ def configure_optimizers(self):

assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers)
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys())
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values)


def test_lr_monitor_no_lr_scheduler_single_lr(tmpdir):
Expand All @@ -127,7 +127,7 @@ def configure_optimizers(self):

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(trainer.optimizers)
assert lr_monitor.lr_sch_names == ["lr-SGD"]
assert list(lr_monitor.lrs) == ["lr-SGD"]


@pytest.mark.parametrize("opt", ["SGD", "Adam"])
Expand Down Expand Up @@ -162,7 +162,7 @@ def configure_optimizers(self):

assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers)
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys())
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values)


def test_log_momentum_no_momentum_optimizer_no_lr_scheduler(tmpdir):
Expand All @@ -188,7 +188,7 @@ def configure_optimizers(self):

assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers)
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys())
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values)


def test_lr_monitor_no_logger(tmpdir):
Expand Down Expand Up @@ -238,7 +238,7 @@ def configure_optimizers(self):

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"

if logging_interval == "step":
expected_number_logged = trainer.global_step // log_every_n_steps
Expand Down Expand Up @@ -281,7 +281,7 @@ def configure_optimizers(self):

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(trainer.optimizers)
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"

if logging_interval == "step":
expected_number_logged = trainer.global_step // log_every_n_steps
Expand Down Expand Up @@ -317,8 +317,7 @@ def configure_optimizers(self):

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == ["lr-Adam"]
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly"
assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly"


def test_lr_monitor_custom_name(tmpdir):
Expand All @@ -339,7 +338,7 @@ def configure_optimizers(self):
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["my_logging_name"]
assert list(lr_monitor.lrs) == ["my_logging_name"]


def test_lr_monitor_custom_pg_name(tmpdir):
Expand All @@ -360,7 +359,6 @@ def configure_optimizers(self):
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == ["lr-SGD"]
assert list(lr_monitor.lrs) == ["lr-SGD/linear"]


Expand Down Expand Up @@ -434,7 +432,7 @@ def configure_optimizers(self):
class Check(Callback):
def on_train_epoch_start(self, trainer, pl_module) -> None:
num_param_groups = sum(len(opt.param_groups) for opt in trainer.optimizers)
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1", "lr-Adam-2"]

if trainer.current_epoch == 0:
assert num_param_groups == 3
elif trainer.current_epoch == 1:
Expand Down Expand Up @@ -512,7 +510,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
assert lr_monitor.lrs["lr-Adam-1/pg3"] == expected


def test_lr_monitor_multiple_param_groups_no_scheduler(tmpdir):
def test_lr_monitor_multiple_param_groups_no_lr_scheduler(tmpdir):
"""Test that the `LearningRateMonitor` is able to log correct keys with multiple param groups and no
lr_scheduler."""

class TestModel(BoringModel):
def __init__(self, lr, momentum):
super().__init__()
Expand Down Expand Up @@ -550,8 +551,7 @@ def configure_optimizers(self):
trainer.fit(model)

assert len(lr_monitor.lrs) == len(trainer.optimizers[0].param_groups)
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"]
assert lr_monitor.lr_sch_names == ["lr-Adam"]
assert list(lr_monitor.last_momentum_values.keys()) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"]
assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
assert all(val == momentum for val in lr_monitor.last_momentum_values.values())
assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs)
11 changes: 11 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
from tests.callbacks.test_callbacks import OldStatefulCallback
Expand Down Expand Up @@ -438,3 +439,13 @@ def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")


def test_v1_7_0_deprecate_lr_sch_names(tmpdir):
model = BoringModel()
lr_monitor = LearningRateMonitor()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[lr_monitor])
trainer.fit(model)

with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"):
assert lr_monitor.lr_sch_names == ["lr-SGD"]

0 comments on commit 9af1dd7

Please sign in to comment.