Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate lr_sch_names from LearningRateMonitor #10066

Merged
merged 6 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `ClusterEnvironment.creates_children()` in favor of `ClusterEnvironment.creates_processes_externally` (property) ([#10106](https://github.com/PyTorchLightning/pytorch-lightning/pull/10106))



- Deprecated `PrecisionPlugin.master_params()` in favor of `PrecisionPlugin.main_params()` ([#10105](https://github.com/PyTorchLightning/pytorch-lightning/pull/10105))


- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool =
self.logging_interval = logging_interval
self.log_momentum = log_momentum
self.lrs: Dict[str, List[float]] = {}
self.lr_sch_names: List[str] = []
self._lr_sch_names: List[str] = []

def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Called before training, determines unique names for all lr schedulers in the case of multiple of the
Expand Down Expand Up @@ -334,6 +335,16 @@ def _check_duplicates_and_update_name(
name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]

if add_lr_sch_names:
self.lr_sch_names.append(name)
self._lr_sch_names.append(name)

return name_list

@property
def lr_sch_names(self) -> List[str]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# TODO remove `lr_sch_names` and `add_lr_sch_names` argument in v1.7.0
rank_zero_deprecation(
"`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5 and will be removed in 1.7."
" Consider accessing them using `LearningRateMonitor.lrs.keys()` which will return"
" the names of all the optimizers, even those without a scheduler."
)
return self._lr_sch_names
34 changes: 17 additions & 17 deletions tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_lr_monitor_single_lr(tmpdir):
assert lr_monitor.lrs, "No learning rates logged"
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["lr-SGD"]
assert list(lr_monitor.lrs) == ["lr-SGD"]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("opt", ["SGD", "Adam"])
Expand Down Expand Up @@ -77,7 +77,7 @@ def configure_optimizers(self):

assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers)
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys())
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values)


def test_log_momentum_no_momentum_optimizer(tmpdir):
Expand All @@ -104,7 +104,7 @@ def configure_optimizers(self):

assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers)
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys())
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values)


def test_lr_monitor_no_lr_scheduler_single_lr(tmpdir):
Expand All @@ -127,7 +127,7 @@ def configure_optimizers(self):

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(trainer.optimizers)
assert lr_monitor.lr_sch_names == ["lr-SGD"]
assert list(lr_monitor.lrs) == ["lr-SGD"]


@pytest.mark.parametrize("opt", ["SGD", "Adam"])
Expand Down Expand Up @@ -162,7 +162,7 @@ def configure_optimizers(self):

assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers)
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys())
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values)


def test_log_momentum_no_momentum_optimizer_no_lr_scheduler(tmpdir):
Expand All @@ -188,7 +188,7 @@ def configure_optimizers(self):

assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers)
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys())
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values)


def test_lr_monitor_no_logger(tmpdir):
Expand Down Expand Up @@ -238,7 +238,7 @@ def configure_optimizers(self):

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"

if logging_interval == "step":
expected_number_logged = trainer.global_step // log_every_n_steps
Expand Down Expand Up @@ -281,7 +281,7 @@ def configure_optimizers(self):

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(trainer.optimizers)
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"

if logging_interval == "step":
expected_number_logged = trainer.global_step // log_every_n_steps
Expand Down Expand Up @@ -317,8 +317,7 @@ def configure_optimizers(self):

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == ["lr-Adam"]
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly"
assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly"


def test_lr_monitor_custom_name(tmpdir):
Expand All @@ -339,7 +338,7 @@ def configure_optimizers(self):
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["my_logging_name"]
assert list(lr_monitor.lrs) == ["my_logging_name"]


def test_lr_monitor_custom_pg_name(tmpdir):
Expand All @@ -360,7 +359,6 @@ def configure_optimizers(self):
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == ["lr-SGD"]
assert list(lr_monitor.lrs) == ["lr-SGD/linear"]


Expand Down Expand Up @@ -434,7 +432,7 @@ def configure_optimizers(self):
class Check(Callback):
def on_train_epoch_start(self, trainer, pl_module) -> None:
num_param_groups = sum(len(opt.param_groups) for opt in trainer.optimizers)
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1", "lr-Adam-2"]

if trainer.current_epoch == 0:
assert num_param_groups == 3
elif trainer.current_epoch == 1:
Expand Down Expand Up @@ -512,7 +510,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
assert lr_monitor.lrs["lr-Adam-1/pg3"] == expected


def test_lr_monitor_multiple_param_groups_no_scheduler(tmpdir):
def test_lr_monitor_multiple_param_groups_no_lr_scheduler(tmpdir):
"""Test that the `LearningRateMonitor` is able to log correct keys with multiple param groups and no
lr_scheduler."""

class TestModel(BoringModel):
def __init__(self, lr, momentum):
super().__init__()
Expand Down Expand Up @@ -550,8 +551,7 @@ def configure_optimizers(self):
trainer.fit(model)

assert len(lr_monitor.lrs) == len(trainer.optimizers[0].param_groups)
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"]
assert lr_monitor.lr_sch_names == ["lr-Adam"]
assert list(lr_monitor.last_momentum_values.keys()) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"]
assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
assert all(val == momentum for val in lr_monitor.last_momentum_values.values())
assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs)
11 changes: 11 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
from tests.callbacks.test_callbacks import OldStatefulCallback
Expand Down Expand Up @@ -438,3 +439,13 @@ def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")


def test_v1_7_0_deprecate_lr_sch_names(tmpdir):
model = BoringModel()
lr_monitor = LearningRateMonitor()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[lr_monitor])
trainer.fit(model)

with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"):
assert lr_monitor.lr_sch_names == ["lr-SGD"]