Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a dataclass as the scheduler config #11443

Merged
merged 21 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Trainer.run_stage` in favor of `Trainer.{fit,validate,test,predict}` ([#11000](https://github.com/PyTorchLightning/pytorch-lightning/pull/11000))


- Deprecated `Trainer.lr_schedulers` in favor of `Trainer.lr_scheduler_configs` which returns a list of dataclasses instead of dictionaries ([#11443](https://github.com/PyTorchLightning/pytorch-lightning/pull/11443))


- Deprecated `Trainer.verbose_evaluate` in favor of `EvaluationLoop(verbose=...)` ([#10931](https://github.com/PyTorchLightning/pytorch-lightning/pull/10931))


Expand Down
35 changes: 19 additions & 16 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import LRSchedulerConfig


class LearningRateMonitor(Callback):
Expand Down Expand Up @@ -111,8 +112,10 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No
if self.log_momentum:

def _check_no_key(key: str) -> bool:
if trainer.lr_schedulers:
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)
if trainer.lr_scheduler_configs:
return any(
key not in config.scheduler.optimizer.defaults for config in trainer.lr_scheduler_configs
)

return any(key not in optimizer.defaults for optimizer in trainer.optimizers)

Expand All @@ -129,7 +132,7 @@ def _check_no_key(key: str) -> bool:
sched_hparam_keys,
optimizers_with_scheduler,
optimizers_with_scheduler_types,
) = self._find_names_from_schedulers(trainer.lr_schedulers)
) = self._find_names_from_schedulers(trainer.lr_scheduler_configs)
names.extend(sched_hparam_keys)

# Find names for leftover optimizers
Expand Down Expand Up @@ -173,12 +176,12 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
scheduler_hparam_keys,
optimizers_with_scheduler,
optimizers_with_scheduler_types,
) = self._find_names_from_schedulers(trainer.lr_schedulers, add_lr_sch_names=False)
) = self._find_names_from_schedulers(trainer.lr_scheduler_configs, add_lr_sch_names=False)
self._remap_keys(scheduler_hparam_keys)

for name, scheduler in zip(scheduler_hparam_keys, trainer.lr_schedulers):
if interval in [scheduler["interval"], "any"]:
opt = scheduler["scheduler"].optimizer
for name, config in zip(scheduler_hparam_keys, trainer.lr_scheduler_configs):
if interval in [config.interval, "any"]:
opt = config.scheduler.optimizer
current_stat = self._get_lr_momentum_stat(opt, name)
latest_stat.update(current_stat)

Expand Down Expand Up @@ -261,22 +264,22 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
return {n for n in names if names.count(n) > 1}

def _find_names_from_schedulers(
self, lr_schedulers: List, add_lr_sch_names: bool = True
self, lr_scheduler_configs: List[LRSchedulerConfig], add_lr_sch_names: bool = True
) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]:
# Create unique names in the case we have multiple of the same learning
# rate scheduler + multiple parameter groups
names = []
seen_optimizers: List[Optimizer] = []
seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int)
for scheduler in lr_schedulers:
sch = scheduler["scheduler"]
if scheduler["name"] is not None:
name = scheduler["name"]
for config in lr_scheduler_configs:
sch = config.scheduler
if config.name is not None:
name = config.name
else:
name = "lr-" + sch.optimizer.__class__.__name__

updated_names = self._check_duplicates_and_update_name(
sch.optimizer, name, seen_optimizers, seen_optimizer_types, scheduler, add_lr_sch_names
sch.optimizer, name, seen_optimizers, seen_optimizer_types, config, add_lr_sch_names
)
names.append(updated_names)

Expand Down Expand Up @@ -313,14 +316,14 @@ def _check_duplicates_and_update_name(
name: str,
seen_optimizers: List[Optimizer],
seen_optimizer_types: DefaultDict[Type[Optimizer], int],
scheduler: Dict[str, Any] = None,
lr_scheduler_config: Optional[LRSchedulerConfig],
add_lr_sch_names: bool = True,
) -> List[str]:
seen_optimizers.append(optimizer)
optimizer_cls = type(optimizer)
if scheduler is not None and scheduler["name"] is None:
if lr_scheduler_config is not None and lr_scheduler_config.name is None:
seen_optimizer_types[optimizer_cls] += 1
elif scheduler is None:
elif lr_scheduler_config is None:
seen_optimizer_types[optimizer_cls] += 1

# Multiple param groups for the same optimizer
Expand Down
26 changes: 11 additions & 15 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import LRSchedulerConfig

_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]

Expand Down Expand Up @@ -142,13 +142,10 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module:
self._average_model = deepcopy(pl_module)

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers

if len(optimizers) != 1:
if len(trainer.optimizers) != 1:
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")

if len(lr_schedulers) > 1:
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")

if isinstance(self._swa_epoch_start, float):
Expand Down Expand Up @@ -182,21 +179,20 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
)
default_scheduler_cfg = _get_default_scheduler_config()
assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1
default_scheduler_cfg["scheduler"] = self._swa_scheduler
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler)
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1

if trainer.lr_schedulers:
scheduler_cfg = trainer.lr_schedulers[0]
if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1:
if trainer.lr_scheduler_configs:
scheduler_cfg = trainer.lr_scheduler_configs[0]
if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1:
rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
rank_zero_info(
f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`"
f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`"
f" for `{self._swa_scheduler.__class__.__name__}`"
)
trainer.lr_schedulers[0] = default_scheduler_cfg
trainer.lr_scheduler_configs[0] = default_scheduler_cfg
else:
trainer.lr_schedulers.append(default_scheduler_cfg)
trainer.lr_scheduler_configs.append(default_scheduler_cfg)

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,19 @@ def optimizers(
# multiple opts
return opts

def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]:
def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRSchedulerTypeUnion]]]:
"""Returns the learning rate scheduler(s) that are being used during training. Useful for manual
optimization.

Returns:
A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no
schedulers were returned in :meth:`configure_optimizers`.
"""
if not self.trainer.lr_schedulers:
if not self.trainer.lr_scheduler_configs:
return None

# ignore other keys "interval", "frequency", etc.
lr_schedulers = [s["scheduler"] for s in self.trainer.lr_schedulers]
lr_schedulers = [config.scheduler for config in self.trainer.lr_scheduler_configs]

# single scheduler
if len(lr_schedulers) == 1:
Expand Down
79 changes: 36 additions & 43 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from dataclasses import fields
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from weakref import proxy

Expand All @@ -23,7 +24,12 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _SupportsStateDict, LRSchedulerTypeTuple
from pytorch_lightning.utilities.types import (
_SupportsStateDict,
LRSchedulerConfig,
LRSchedulerTypeTuple,
ReduceLROnPlateau,
)


def do_nothing_closure() -> None:
Expand Down Expand Up @@ -167,7 +173,7 @@ def closure_dis():

def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[Dict[str, Any]], List[int]]:
) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)

Expand All @@ -178,10 +184,11 @@ def _init_optimizers_and_lr_schedulers(
optim_conf = _MockOptimizer()

optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
_configure_schedulers = (
_configure_schedulers_automatic_opt if model.automatic_optimization else _configure_schedulers_manual_opt
lr_schedulers = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if model.automatic_optimization
else _configure_schedulers_manual_opt(lr_schedulers)
)
lr_schedulers = _configure_schedulers(lr_schedulers, monitor)
_set_scheduler_opt_idx(optimizers, lr_schedulers)
_validate_scheduler_api(lr_schedulers, model)
return optimizers, lr_schedulers, optimizer_frequencies
Expand Down Expand Up @@ -251,18 +258,21 @@ def _configure_optimizers(
return optimizers, lr_schedulers, optimizer_frequencies, monitor


def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]:
def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]:
"""Convert each scheduler into dict structure with relevant information, when using automatic optimization."""
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
supported_keys = {field.name for field in fields(LRSchedulerConfig)}
extra_keys = scheduler.keys() - supported_keys
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
f"Found unsupported keys in the lr scheduler dict: {extra_keys}."
" HINT: remove them from the output of `configure_optimizers`.",
category=RuntimeWarning,
)
scheduler = {k: v for k, v in scheduler.items() if k in supported_keys}
if "scheduler" not in scheduler:
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
Expand All @@ -286,27 +296,24 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]
" Are you sure you didn't mean 'interval': 'step'?",
category=RuntimeWarning,
)
lr_schedulers.append({**default_config, **scheduler})
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
scheduler = LRSchedulerConfig(**scheduler)
elif isinstance(scheduler, ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
" scheduler is used. For example:"
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
scheduler = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor)
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})

scheduler = LRSchedulerConfig(scheduler)
lr_schedulers.append(scheduler)
return lr_schedulers


def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]:
def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]:
"""Convert each scheduler into dict structure with relevant information, when using manual optimization."""
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if isinstance(scheduler, dict):
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
Expand All @@ -319,17 +326,16 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -
category=RuntimeWarning,
)

scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
lr_schedulers.append({**default_config, **scheduler})
scheduler = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})

scheduler = LRSchedulerConfig(scheduler)
lr_schedulers.append(scheduler)
return lr_schedulers


def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.LightningModule") -> None:
for scheduler_config in lr_schedulers:
scheduler = scheduler_config["scheduler"]
def _validate_scheduler_api(lr_schedulers: List[LRSchedulerConfig], model: "pl.LightningModule") -> None:
for config in lr_schedulers:
scheduler = config.scheduler
if not isinstance(scheduler, _SupportsStateDict):
raise TypeError(
f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."
Expand All @@ -344,31 +350,18 @@ def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.Ligh
)


def _get_default_scheduler_config() -> Dict[str, Any]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return {
"scheduler": None,
"name": None, # no custom name
"interval": "epoch", # after epoch is over
"frequency": 1, # every epoch/batch
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
"monitor": None, # value to monitor for ReduceLROnPlateau
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
"opt_idx": None, # opt_idx assigned internally if not assigned by user
}


def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_schedulers: List[Dict[str, Any]]) -> None:
for sch in lr_schedulers:
def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None:
for config in lr_scheduler_configs:

for opt_idx, opt in enumerate(optimizers):
if sch["scheduler"].optimizer is opt:
if sch["opt_idx"] is not None and sch["opt_idx"] != opt_idx:
if config.scheduler.optimizer is opt:
if config.opt_idx is not None and config.opt_idx != opt_idx:
raise MisconfigurationException(
"`opt_idx` set inside scheduler config does not match with the index"
" of the respective optimizer returned from `configure_optimizers`."
)

sch["opt_idx"] = opt_idx
config.opt_idx = opt_idx
break
else:
raise MisconfigurationException(
Expand Down
Loading