Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log LR using LearningRateMonitor even when LR Scheduler is not defined. #9786

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Add support for monitoring the learning rate monitor without schedulers in `LearningRateMonitor` ([#9786](https://github.com/PyTorchLightning/pytorch-lightning/issues/9786))


- Register `ShardedTensor` state dict hooks in `LightningModule.__init__` if the pytorch version supports `ShardedTensor` ([#8944](https://github.com/PyTorchLightning/pytorch-lightning/pull/8944))


Expand Down
154 changes: 111 additions & 43 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,13 @@ def on_train_start(self, trainer, *args, **kwargs):
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
)

if not trainer.lr_schedulers:
rank_zero_warn(
"You are using `LearningRateMonitor` callback with models that"
" have no learning rate schedulers. Please see documentation"
" for `configure_optimizers` method.",
RuntimeWarning,
)

if self.log_momentum:

def _check_no_key(key):
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)
if trainer.lr_schedulers:
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)

return any(key not in optimizer.defaults for optimizer in trainer.optimizers)

if _check_no_key("momentum") and _check_no_key("betas"):
rank_zero_warn(
Expand All @@ -127,7 +122,21 @@ def _check_no_key(key):
)

# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)
names = []
(
sched_hparam_keys,
optimizers_with_scheduler,
optimizers_with_scheduler_types,
) = self._find_names_from_schedulers(trainer.lr_schedulers)
names.extend(sched_hparam_keys)

# Find names for leftover optimizers
optimizer_hparam_keys, _ = self._find_names_from_optimizers(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer.optimizers,
seen_optimizers=optimizers_with_scheduler,
seen_optimizer_types=optimizers_with_scheduler_types,
)
names.extend(optimizer_hparam_keys)

# Initialize for storing values
self.lrs = {name: [] for name in names}
Expand Down Expand Up @@ -155,26 +164,49 @@ def on_train_epoch_start(self, trainer, *args, **kwargs):
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
latest_stat = {}

names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False)
self._remap_keys(names)
(
scheduler_hparam_keys,
optimizers_with_scheduler,
optimizers_with_scheduler_types,
) = self._find_names_from_schedulers(trainer.lr_schedulers, add_lr_sch_names=False)
self._remap_keys(scheduler_hparam_keys)

for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler["interval"] == interval or interval == "any":
if interval in [scheduler["interval"], "any"]:
opt = scheduler["scheduler"].optimizer
param_groups = opt.param_groups
use_betas = "betas" in opt.defaults

for i, pg in enumerate(param_groups):
name_and_suffix = self._add_suffix(name, param_groups, i)
lr = self._extract_lr(pg, name_and_suffix)
latest_stat.update(lr)
momentum = self._extract_momentum(
param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas
)
latest_stat.update(momentum)
current_stat = self._get_lr_momentum_stat(opt, name)
latest_stat.update(current_stat)

optimizer_hparam_keys, optimizers_without_scheduler = self._find_names_from_optimizers(
trainer.optimizers,
seen_optimizers=optimizers_with_scheduler,
seen_optimizer_types=optimizers_with_scheduler_types,
add_lr_sch_names=False,
)
self._remap_keys(optimizer_hparam_keys)

for opt, name in zip(optimizers_without_scheduler, optimizer_hparam_keys):
current_stat = self._get_lr_momentum_stat(opt, name)
latest_stat.update(current_stat)

return latest_stat

def _get_lr_momentum_stat(self, optimizer: Optimizer, name: str) -> None:
lr_momentum_stat = {}
param_groups = optimizer.param_groups
use_betas = "betas" in optimizer.defaults

for i, pg in enumerate(param_groups):
name_and_suffix = self._add_suffix(name, param_groups, i)
lr = self._extract_lr(pg, name_and_suffix)
lr_momentum_stat.update(lr)
momentum = self._extract_momentum(
param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas
)
lr_momentum_stat.update(momentum)

return lr_momentum_stat

def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
lr = param_group.get("lr")
self.lrs[name].append(lr)
Expand Down Expand Up @@ -223,7 +255,7 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
return set()
return {n for n in names if names.count(n) > 1}

def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
def _find_names_from_schedulers(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
# Create unique names in the case we have multiple of the same learning
# rate scheduler + multiple parameter groups
names = []
Expand All @@ -236,28 +268,64 @@ def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> Lis
else:
name = "lr-" + sch.optimizer.__class__.__name__

seen_optimizers.append(sch.optimizer)
optimizer_cls = type(sch.optimizer)
if scheduler["name"] is None:
seen_optimizer_types[optimizer_cls] += 1

# Multiple param groups for the same scheduler
param_groups = sch.optimizer.param_groups
duplicates = self._duplicate_param_group_names(param_groups)
if duplicates:
raise MisconfigurationException(
"A single `Optimizer` cannot have multiple parameter groups with identical "
f"`name` values. {name} has duplicated parameter group names {duplicates}"
)
updated_name = self._check_duplicates_and_update_name(
sch.optimizer, name, seen_optimizers, seen_optimizer_types, scheduler, add_lr_sch_names
)
names.extend(updated_name)
return names, seen_optimizers, seen_optimizer_types

def _find_names_from_optimizers(
self, optimizers, seen_optimizers, seen_optimizer_types, add_lr_sch_names: bool = True
) -> List[str]:
names = []
optimizers_without_scheduler = []

name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
for optimizer in optimizers:
# Deepspeed optimizer wraps the native optimizer
optimizer = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
if optimizer in seen_optimizers:
continue

name = "lr-" + optimizer.__class__.__name__
updated_name = self._check_duplicates_and_update_name(
optimizer, name, seen_optimizers, seen_optimizer_types, None, add_lr_sch_names
)
names.extend(updated_name)
optimizers_without_scheduler.append(optimizer)
return names, optimizers_without_scheduler

def _check_duplicates_and_update_name(
self,
optimizer: Optimizer,
name: str,
seen_optimizers: List,
seen_optimizer_types: List,
scheduler: Dict[str, Any] = None,
add_lr_sch_names: bool = True,
) -> List:
seen_optimizers.append(optimizer)
optimizer_cls = type(optimizer)
if scheduler is not None and scheduler["name"] is None:
seen_optimizer_types[optimizer_cls] += 1
elif scheduler is None:
seen_optimizer_types[optimizer_cls] += 1

# Multiple param groups for the same optimizer
param_groups = optimizer.param_groups
duplicates = self._duplicate_param_group_names(param_groups)
if duplicates:
raise MisconfigurationException(
"A single `Optimizer` cannot have multiple parameter groups with identical "
f"`name` values. {name} has duplicated parameter group names {duplicates}"
)

names.extend(self._add_suffix(name, param_groups, i) for i in range(len(param_groups)))
name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]

if add_lr_sch_names:
self.lr_sch_names.append(name)
if add_lr_sch_names:
self.lr_sch_names.append(name)

return names
return name_list

@staticmethod
def _should_log(trainer) -> bool:
Expand Down
Loading