diff --git a/CHANGELOG.md b/CHANGELOG.md index 89e0738f43275..dfdc701385b13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -542,6 +542,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed automatic patching of `{train,val,test,predict}_dataloader()` on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764)) +- Removed `pytorch_lightning.trainer.connectors.OptimizerConnector` ([#10120](https://github.com/PyTorchLightning/pytorch-lightning/pull/10120)) + + ### Fixed diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 1fe70d9d4e77c..d0bef40ae7872 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -23,6 +23,7 @@ from pytorch_lightning.loops.utilities import _get_active_optimizers, _update_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher @@ -436,12 +437,78 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) - active_optimizers = _get_active_optimizers( self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx ) - self.trainer.optimizer_connector.update_learning_rates( + self._update_learning_rates( interval=interval, update_plateau_schedulers=update_plateau_schedulers, opt_indices=[opt_idx for opt_idx, _ in active_optimizers], ) + def _update_learning_rates( + self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None + ) -> None: + """Update learning rates. + + Args: + interval: either 'epoch' or 'step'. + update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated. + This is used so non-plateau schedulers can be updated before running validation. Checkpoints are + commonly saved during validation, however, on-plateau schedulers might monitor a validation metric + so they have to be updated separately. + opt_indices: indices of the optimizers to update. + """ + if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: + return + + if opt_indices is None: + opt_indices = [] + + for lr_scheduler in self.trainer.lr_schedulers: + if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices: + continue + + if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]: + continue + + current_idx = self.trainer.fit_loop.batch_idx if interval == "step" else self.trainer.current_epoch + current_idx += 1 # account for both batch and epoch starts from 0 + # Take step if call to update_learning_rates matches the interval key and + # the current step modulo the schedulers frequency is zero + if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0: + monitor_val = None + if lr_scheduler["reduce_on_plateau"]: + # If instance of ReduceLROnPlateau, we need a monitor + monitor_key = lr_scheduler["monitor"] + monitor_val = self._get_monitor_value(monitor_key) + if monitor_val is None: + if lr_scheduler.get("strict", True): + avail_metrics = list(self.trainer.callback_metrics) + raise MisconfigurationException( + f"ReduceLROnPlateau conditioned on metric {monitor_key}" + f" which is not available. Available metrics are: {avail_metrics}." + " Condition can be set using `monitor` key in lr scheduler dict" + ) + rank_zero_warn( + f"ReduceLROnPlateau conditioned on metric {monitor_key}" + " which is not available but strict is set to `False`." + " Skipping learning rate update.", + RuntimeWarning, + ) + continue + + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() + + # update LR + if lr_scheduler["reduce_on_plateau"]: + lr_scheduler["scheduler"].step(monitor_val) + else: + lr_scheduler["scheduler"].step() + + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed() + + def _get_monitor_value(self, key: str) -> Any: + # this is a separate method to aid in testing + return self.trainer.callback_metrics.get(key) + def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: """Decide if we should run validation.""" if not self.trainer.enable_validation: diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py deleted file mode 100644 index e894d9df535d2..0000000000000 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Optional -from weakref import proxy - -import pytorch_lightning as pl -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class OptimizerConnector: - def __init__(self, trainer: "pl.Trainer") -> None: - self.trainer = proxy(trainer) - - def on_trainer_init(self) -> None: - self.trainer.lr_schedulers = [] - self.trainer.optimizers = [] - self.trainer.optimizer_frequencies = [] - - def update_learning_rates( - self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None - ) -> None: - """Update learning rates. - - Args: - interval: either 'epoch' or 'step'. - update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated. - This is used so non-plateau schedulers can be updated before running validation. Checkpoints are - commonly saved during validation, however, on-plateau schedulers might monitor a validation metric - so they have to be updated separately. - opt_indices: indices of the optimizers to update. - """ - if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: - return - - if opt_indices is None: - opt_indices = [] - - for lr_scheduler in self.trainer.lr_schedulers: - if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices: - continue - - if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]: - continue - - current_idx = self.trainer.fit_loop.batch_idx if interval == "step" else self.trainer.current_epoch - current_idx += 1 # account for both batch and epoch starts from 0 - # Take step if call to update_learning_rates matches the interval key and - # the current step modulo the schedulers frequency is zero - if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0: - monitor_val = None - if lr_scheduler["reduce_on_plateau"]: - # If instance of ReduceLROnPlateau, we need a monitor - monitor_key = lr_scheduler["monitor"] - monitor_val = self._get_monitor_value(monitor_key) - if monitor_val is None: - if lr_scheduler.get("strict", True): - avail_metrics = list(self.trainer.callback_metrics) - raise MisconfigurationException( - f"ReduceLROnPlateau conditioned on metric {monitor_key}" - f" which is not available. Available metrics are: {avail_metrics}." - " Condition can be set using `monitor` key in lr scheduler dict" - ) - rank_zero_warn( - f"ReduceLROnPlateau conditioned on metric {monitor_key}" - " which is not available but strict is set to `False`." - " Skipping learning rate update.", - RuntimeWarning, - ) - continue - - self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() - - # update LR - if lr_scheduler["reduce_on_plateau"]: - lr_scheduler["scheduler"].step(monitor_val) - else: - lr_scheduler["scheduler"].step() - - self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed() - - def _get_monitor_value(self, key: str) -> Any: - # this is a separate method to aid in testing - return self.trainer.callback_metrics.get(key) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c4aaf630a29e3..455233a88eb1b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -58,7 +58,6 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.connectors.model_connector import ModelConnector -from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin @@ -430,7 +429,6 @@ def __init__( # init connectors self._data_connector = DataConnector(self, multiple_trainloader_mode) - self.optimizer_connector = OptimizerConnector(self) self._accelerator_connector = AcceleratorConnector( num_processes, @@ -516,7 +514,9 @@ def __init__( self.on_init_start() # init optimizer + lr scheduler related flags - self.optimizer_connector.on_trainer_init() + self.lr_schedulers = [] + self.optimizers = [] + self.optimizer_frequencies = [] # init data flags self._data_connector.on_trainer_init( diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 97dd32ad4f278..518d67cf251f5 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -63,17 +63,17 @@ def validation_epoch_end(self, outputs): self.log("val_acc", outs) -def mock_optimizer_connector(trainer): +def mock_training_epoch_loop(trainer): # do not use `unittest.Mock` because we need to store the return value calls = {} - old_get_monitor_value = trainer.optimizer_connector._get_monitor_value + old_get_monitor_value = trainer.fit_loop.epoch_loop._get_monitor_value def mock(key): value = old_get_monitor_value(key) calls[trainer.current_epoch] = {key: value} return value - trainer.optimizer_connector._get_monitor_value = mock + trainer.fit_loop.epoch_loop._get_monitor_value = mock return calls @@ -150,7 +150,7 @@ def on_validation_epoch_end(self): max_epochs=max_epochs, enable_progress_bar=False, ) - calls = mock_optimizer_connector(trainer) + calls = mock_training_epoch_loop(trainer) trainer.fit(model) ckpt_files = list(Path(tmpdir).glob("*.ckpt")) @@ -248,7 +248,7 @@ def configure_optimizers(self): enable_progress_bar=False, num_sanity_val_steps=0, ) - calls = mock_optimizer_connector(trainer) + calls = mock_training_epoch_loop(trainer) trainer.fit(model) def _make_assertions(epoch, ix):