From ca9b25db80f08a3b9a3c448048949ec1adb845ba Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Thu, 23 Dec 2021 10:48:21 -0800 Subject: [PATCH] Remove `Strategy.init_optimizers` (#11236) --- CHANGELOG.md | 3 +++ pytorch_lightning/strategies/deepspeed.py | 14 ++++++++++++-- pytorch_lightning/strategies/strategy.py | 10 ++-------- pytorch_lightning/tuner/lr_finder.py | 22 ++++++++++++---------- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47214bff258eb..a6c7bbe4c6b40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -343,6 +343,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed support for Python 3.6 ([#11117](https://github.com/PyTorchLightning/pytorch-lightning/pull/11117)) + +- Removed `Strategy.init_optimizers` in favor of `Strategy.setup_optimizers` ([#11236](https://github.com/PyTorchLightning/pytorch-lightning/pull/11236)) + ### Fixed - Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the `PyYAML` dependency ([#11099](https://github.com/PyTorchLightning/pytorch-lightning/pull/11099)) diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 8b34061d187d9..452f3c8e1a8b4 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -562,11 +562,21 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs - def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Tuple[List, List, List]: + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + """Creates optimizers and schedulers. + + Args: + trainer: the Trainer, these optimizers should be connected to + """ + if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + return # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled # via `_initialize_deepspeed_train` - return [], [], [] # empty optimizers, schedulers and frequencies + # empty optimizers, schedulers and frequencies + self.optimizers = [] + self.lr_schedulers = [] + self.optimizer_frequencies = [] @property def handles_gradient_accumulation(self) -> bool: diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 2ee586a724ef3..fe9093838c157 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -104,12 +104,9 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: """ if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): return - optimizers, lr_schedulers, optimizer_frequencies = self.init_optimizers( - trainer=trainer, model=self.lightning_module + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers( + self.lightning_module ) - self.optimizers = optimizers - self.lr_schedulers = lr_schedulers - self.optimizer_frequencies = optimizer_frequencies def setup(self, trainer: "pl.Trainer") -> None: """Setup plugins for the trainer fit and creates optimizers. @@ -377,9 +374,6 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: """ return dataloader - def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"): - return _init_optimizers_and_lr_schedulers(model) - @property def restore_checkpoint_after_setup(self) -> bool: """Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 5b5c6adf32182..bea07e4c36ffa 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -24,7 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.optimizer import _get_default_scheduler_config +from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -99,14 +99,14 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): self._total_batch_idx = 0 # for debug purpose def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"): - """Decorate `trainer.strategy.init_optimizers` method such that it returns the user's originally specified - optimizer together with a new scheduler that that takes care of the learning rate search.""" - init_optimizers = trainer.strategy.init_optimizers + """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified + optimizer together with a new scheduler that takes care of the learning rate search.""" + setup_optimizers = trainer.strategy.setup_optimizers - @wraps(init_optimizers) - def func(trainer, model): - # Decide the structure of the output from trainer.strategy.init_optimizers - optimizers, _, _ = init_optimizers(trainer, model) + @wraps(setup_optimizers) + def func(trainer): + # Decide the structure of the output from _init_optimizers_and_lr_schedulers + optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module) if len(optimizers) != 1: raise MisconfigurationException( @@ -126,7 +126,9 @@ def func(trainer, model): sched_config = _get_default_scheduler_config() sched_config.update({"scheduler": scheduler, "interval": "step"}) - return [optimizer], [sched_config], [] + trainer.strategy.optimizers = [optimizer] + trainer.strategy.lr_schedulers = [sched_config] + trainer.strategy.optimizer_frequencies = [] return func @@ -232,7 +234,7 @@ def lr_find( trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler - trainer.strategy.init_optimizers = lr_finder._exchange_scheduler(trainer, model) + trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # Fit, lr & loss logged in callback trainer.tuner._run(model)