Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove optimizer_connector.py #10120

Merged
merged 5 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.loops.utilities import _get_active_optimizers, _update_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
Expand Down Expand Up @@ -436,12 +437,78 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -
active_optimizers = _get_active_optimizers(
self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx
)
self.trainer.optimizer_connector.update_learning_rates(
self._update_learning_rates(
interval=interval,
update_plateau_schedulers=update_plateau_schedulers,
opt_indices=[opt_idx for opt_idx, _ in active_optimizers],
)

def _update_learning_rates(
self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None
) -> None:
"""Update learning rates.

Args:
interval: either 'epoch' or 'step'.
update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated.
This is used so non-plateau schedulers can be updated before running validation. Checkpoints are
commonly saved during validation, however, on-plateau schedulers might monitor a validation metric
so they have to be updated separately.
opt_indices: indices of the optimizers to update.
"""
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
return

if opt_indices is None:
opt_indices = []

for lr_scheduler in self.trainer.lr_schedulers:
if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices:
continue

if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
continue

current_idx = self.trainer.fit_loop.batch_idx if interval == "step" else self.trainer.current_epoch
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0:
monitor_val = None
if lr_scheduler["reduce_on_plateau"]:
# If instance of ReduceLROnPlateau, we need a monitor
monitor_key = lr_scheduler["monitor"]
monitor_val = self._get_monitor_value(monitor_key)
if monitor_val is None:
if lr_scheduler.get("strict", True):
avail_metrics = list(self.trainer.callback_metrics)
raise MisconfigurationException(
f"ReduceLROnPlateau conditioned on metric {monitor_key}"
f" which is not available. Available metrics are: {avail_metrics}."
" Condition can be set using `monitor` key in lr scheduler dict"
)
rank_zero_warn(
f"ReduceLROnPlateau conditioned on metric {monitor_key}"
" which is not available but strict is set to `False`."
" Skipping learning rate update.",
RuntimeWarning,
)
continue

self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready()
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

# update LR
if lr_scheduler["reduce_on_plateau"]:
lr_scheduler["scheduler"].step(monitor_val)
else:
lr_scheduler["scheduler"].step()

self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed()

def _get_monitor_value(self, key: str) -> Any:
# this is a separate method to aid in testing
return self.trainer.callback_metrics.get(key)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
"""Decide if we should run validation."""
if not self.trainer.enable_validation:
Expand Down
95 changes: 0 additions & 95 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py

This file was deleted.

6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
Expand Down Expand Up @@ -430,7 +429,6 @@ def __init__(

# init connectors
self._data_connector = DataConnector(self, multiple_trainloader_mode)
self.optimizer_connector = OptimizerConnector(self)

self._accelerator_connector = AcceleratorConnector(
num_processes,
Expand Down Expand Up @@ -516,7 +514,9 @@ def __init__(
self.on_init_start()

# init optimizer + lr scheduler related flags
self.optimizer_connector.on_trainer_init()
self.lr_schedulers = []
self.optimizers = []
self.optimizer_frequencies = []

# init data flags
self._data_connector.on_trainer_init(
Expand Down
10 changes: 5 additions & 5 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ def validation_epoch_end(self, outputs):
self.log("val_acc", outs)


def mock_optimizer_connector(trainer):
def mock_training_epoch_loop(trainer):
# do not use `unittest.Mock` because we need to store the return value
calls = {}
old_get_monitor_value = trainer.optimizer_connector._get_monitor_value
old_get_monitor_value = trainer.fit_loop.epoch_loop._get_monitor_value

def mock(key):
value = old_get_monitor_value(key)
calls[trainer.current_epoch] = {key: value}
return value

trainer.optimizer_connector._get_monitor_value = mock
trainer.fit_loop.epoch_loop._get_monitor_value = mock
return calls


Expand Down Expand Up @@ -150,7 +150,7 @@ def on_validation_epoch_end(self):
max_epochs=max_epochs,
enable_progress_bar=False,
)
calls = mock_optimizer_connector(trainer)
calls = mock_training_epoch_loop(trainer)
trainer.fit(model)

ckpt_files = list(Path(tmpdir).glob("*.ckpt"))
Expand Down Expand Up @@ -248,7 +248,7 @@ def configure_optimizers(self):
enable_progress_bar=False,
num_sanity_val_steps=0,
)
calls = mock_optimizer_connector(trainer)
calls = mock_training_epoch_loop(trainer)
trainer.fit(model)

def _make_assertions(epoch, ix):
Expand Down