Skip to content

Commit

Permalink
Log LR using LearningRateMonitor even when LR Scheduler is not define…
Browse files Browse the repository at this point in the history
…d. (Lightning-AI#9786)

* LR logging works even with no lr scheduler, wrote few extra tests as well

* updated changelog

* modified code as suggested by DeepSource

* added helper functions

* opt with no scheduler

* rename

* chlog

* update test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rohitgr7 <[email protected]>
  • Loading branch information
3 people committed Oct 18, 2021
1 parent f03147b commit 24556e6
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 82 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Add support for monitoring the learning rate monitor without schedulers in `LearningRateMonitor` ([#9786](https://github.com/PyTorchLightning/pytorch-lightning/issues/9786))


- Register `ShardedTensor` state dict hooks in `LightningModule.__init__` if the pytorch version supports `ShardedTensor` ([#8944](https://github.com/PyTorchLightning/pytorch-lightning/pull/8944))


Expand Down
154 changes: 111 additions & 43 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,13 @@ def on_train_start(self, trainer, *args, **kwargs):
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
)

if not trainer.lr_schedulers:
rank_zero_warn(
"You are using `LearningRateMonitor` callback with models that"
" have no learning rate schedulers. Please see documentation"
" for `configure_optimizers` method.",
RuntimeWarning,
)

if self.log_momentum:

def _check_no_key(key):
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)
if trainer.lr_schedulers:
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)

return any(key not in optimizer.defaults for optimizer in trainer.optimizers)

if _check_no_key("momentum") and _check_no_key("betas"):
rank_zero_warn(
Expand All @@ -127,7 +122,21 @@ def _check_no_key(key):
)

# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)
names = []
(
sched_hparam_keys,
optimizers_with_scheduler,
optimizers_with_scheduler_types,
) = self._find_names_from_schedulers(trainer.lr_schedulers)
names.extend(sched_hparam_keys)

# Find names for leftover optimizers
optimizer_hparam_keys, _ = self._find_names_from_optimizers(
trainer.optimizers,
seen_optimizers=optimizers_with_scheduler,
seen_optimizer_types=optimizers_with_scheduler_types,
)
names.extend(optimizer_hparam_keys)

# Initialize for storing values
self.lrs = {name: [] for name in names}
Expand Down Expand Up @@ -155,26 +164,49 @@ def on_train_epoch_start(self, trainer, *args, **kwargs):
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
latest_stat = {}

names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False)
self._remap_keys(names)
(
scheduler_hparam_keys,
optimizers_with_scheduler,
optimizers_with_scheduler_types,
) = self._find_names_from_schedulers(trainer.lr_schedulers, add_lr_sch_names=False)
self._remap_keys(scheduler_hparam_keys)

for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler["interval"] == interval or interval == "any":
if interval in [scheduler["interval"], "any"]:
opt = scheduler["scheduler"].optimizer
param_groups = opt.param_groups
use_betas = "betas" in opt.defaults

for i, pg in enumerate(param_groups):
name_and_suffix = self._add_suffix(name, param_groups, i)
lr = self._extract_lr(pg, name_and_suffix)
latest_stat.update(lr)
momentum = self._extract_momentum(
param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas
)
latest_stat.update(momentum)
current_stat = self._get_lr_momentum_stat(opt, name)
latest_stat.update(current_stat)

optimizer_hparam_keys, optimizers_without_scheduler = self._find_names_from_optimizers(
trainer.optimizers,
seen_optimizers=optimizers_with_scheduler,
seen_optimizer_types=optimizers_with_scheduler_types,
add_lr_sch_names=False,
)
self._remap_keys(optimizer_hparam_keys)

for opt, name in zip(optimizers_without_scheduler, optimizer_hparam_keys):
current_stat = self._get_lr_momentum_stat(opt, name)
latest_stat.update(current_stat)

return latest_stat

def _get_lr_momentum_stat(self, optimizer: Optimizer, name: str) -> None:
lr_momentum_stat = {}
param_groups = optimizer.param_groups
use_betas = "betas" in optimizer.defaults

for i, pg in enumerate(param_groups):
name_and_suffix = self._add_suffix(name, param_groups, i)
lr = self._extract_lr(pg, name_and_suffix)
lr_momentum_stat.update(lr)
momentum = self._extract_momentum(
param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas
)
lr_momentum_stat.update(momentum)

return lr_momentum_stat

def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
lr = param_group.get("lr")
self.lrs[name].append(lr)
Expand Down Expand Up @@ -223,7 +255,7 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
return set()
return {n for n in names if names.count(n) > 1}

def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
def _find_names_from_schedulers(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
# Create unique names in the case we have multiple of the same learning
# rate scheduler + multiple parameter groups
names = []
Expand All @@ -236,28 +268,64 @@ def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> Lis
else:
name = "lr-" + sch.optimizer.__class__.__name__

seen_optimizers.append(sch.optimizer)
optimizer_cls = type(sch.optimizer)
if scheduler["name"] is None:
seen_optimizer_types[optimizer_cls] += 1

# Multiple param groups for the same scheduler
param_groups = sch.optimizer.param_groups
duplicates = self._duplicate_param_group_names(param_groups)
if duplicates:
raise MisconfigurationException(
"A single `Optimizer` cannot have multiple parameter groups with identical "
f"`name` values. {name} has duplicated parameter group names {duplicates}"
)
updated_name = self._check_duplicates_and_update_name(
sch.optimizer, name, seen_optimizers, seen_optimizer_types, scheduler, add_lr_sch_names
)
names.extend(updated_name)
return names, seen_optimizers, seen_optimizer_types

def _find_names_from_optimizers(
self, optimizers, seen_optimizers, seen_optimizer_types, add_lr_sch_names: bool = True
) -> List[str]:
names = []
optimizers_without_scheduler = []

name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
for optimizer in optimizers:
# Deepspeed optimizer wraps the native optimizer
optimizer = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
if optimizer in seen_optimizers:
continue

name = "lr-" + optimizer.__class__.__name__
updated_name = self._check_duplicates_and_update_name(
optimizer, name, seen_optimizers, seen_optimizer_types, None, add_lr_sch_names
)
names.extend(updated_name)
optimizers_without_scheduler.append(optimizer)
return names, optimizers_without_scheduler

def _check_duplicates_and_update_name(
self,
optimizer: Optimizer,
name: str,
seen_optimizers: List,
seen_optimizer_types: List,
scheduler: Dict[str, Any] = None,
add_lr_sch_names: bool = True,
) -> List:
seen_optimizers.append(optimizer)
optimizer_cls = type(optimizer)
if scheduler is not None and scheduler["name"] is None:
seen_optimizer_types[optimizer_cls] += 1
elif scheduler is None:
seen_optimizer_types[optimizer_cls] += 1

# Multiple param groups for the same optimizer
param_groups = optimizer.param_groups
duplicates = self._duplicate_param_group_names(param_groups)
if duplicates:
raise MisconfigurationException(
"A single `Optimizer` cannot have multiple parameter groups with identical "
f"`name` values. {name} has duplicated parameter group names {duplicates}"
)

names.extend(self._add_suffix(name, param_groups, i) for i in range(len(param_groups)))
name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]

if add_lr_sch_names:
self.lr_sch_names.append(name)
if add_lr_sch_names:
self.lr_sch_names.append(name)

return names
return name_list

@staticmethod
def _should_log(trainer) -> bool:
Expand Down
Loading

0 comments on commit 24556e6

Please sign in to comment.