Skip to content

Commit

Permalink
Remove Strategy.init_optimizers (#11236)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellepintz authored Dec 23, 2021
1 parent ba6a8dd commit ca9b25d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed support for Python 3.6 ([#11117](https://github.com/PyTorchLightning/pytorch-lightning/pull/11117))


- Removed `Strategy.init_optimizers` in favor of `Strategy.setup_optimizers` ([#11236](https://github.com/PyTorchLightning/pytorch-lightning/pull/11236))

### Fixed

- Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the `PyYAML` dependency ([#11099](https://github.com/PyTorchLightning/pytorch-lightning/pull/11099))
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,21 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Tuple[List, List, List]:
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""Creates optimizers and schedulers.
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
return
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
# via `_initialize_deepspeed_train`
return [], [], [] # empty optimizers, schedulers and frequencies
# empty optimizers, schedulers and frequencies
self.optimizers = []
self.lr_schedulers = []
self.optimizer_frequencies = []

@property
def handles_gradient_accumulation(self) -> bool:
Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,9 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
return
optimizers, lr_schedulers, optimizer_frequencies = self.init_optimizers(
trainer=trainer, model=self.lightning_module
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
self.lightning_module
)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def setup(self, trainer: "pl.Trainer") -> None:
"""Setup plugins for the trainer fit and creates optimizers.
Expand Down Expand Up @@ -377,9 +374,6 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
"""
return dataloader

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return _init_optimizers_and_lr_schedulers(model)

@property
def restore_checkpoint_after_setup(self) -> bool:
"""Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin
Expand Down
22 changes: 12 additions & 10 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -99,14 +99,14 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self._total_batch_idx = 0 # for debug purpose

def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
"""Decorate `trainer.strategy.init_optimizers` method such that it returns the user's originally specified
optimizer together with a new scheduler that that takes care of the learning rate search."""
init_optimizers = trainer.strategy.init_optimizers
"""Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
optimizer together with a new scheduler that takes care of the learning rate search."""
setup_optimizers = trainer.strategy.setup_optimizers

@wraps(init_optimizers)
def func(trainer, model):
# Decide the structure of the output from trainer.strategy.init_optimizers
optimizers, _, _ = init_optimizers(trainer, model)
@wraps(setup_optimizers)
def func(trainer):
# Decide the structure of the output from _init_optimizers_and_lr_schedulers
optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module)

if len(optimizers) != 1:
raise MisconfigurationException(
Expand All @@ -126,7 +126,9 @@ def func(trainer, model):
sched_config = _get_default_scheduler_config()
sched_config.update({"scheduler": scheduler, "interval": "step"})

return [optimizer], [sched_config], []
trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_schedulers = [sched_config]
trainer.strategy.optimizer_frequencies = []

return func

Expand Down Expand Up @@ -232,7 +234,7 @@ def lr_find(
trainer.save_checkpoint(str(save_path))

# Configure optimizer and scheduler
trainer.strategy.init_optimizers = lr_finder._exchange_scheduler(trainer, model)
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model)

# Fit, lr & loss logged in callback
trainer.tuner._run(model)
Expand Down

0 comments on commit ca9b25d

Please sign in to comment.