Skip to content

Commit

Permalink
Update optimizer_step methods in accelerator and plugins (#10023)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
awaelchli and carmocca authored Oct 20, 2021
1 parent 5b30a65 commit d419028
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))


### Changed

Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,25 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:

return closure_loss

def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
def optimizer_step(
self,
optimizer: Optimizer,
opt_idx: int,
lambda_closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any
) -> None:
"""performs the actual optimizer step.
Args:
optimizer: the optimizer performing the step
opt_idx: index of the current optimizer
lambda_closure: closure calculating the loss value
model: reference to the model, optionally defining optimizer step related hooks
"""
model = model or self.lightning_module
make_optimizer_step = self.precision_plugin.pre_optimizer_step(
self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs
model, optimizer, opt_idx, lambda_closure, **kwargs
)
if make_optimizer_step:
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -97,7 +98,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand All @@ -112,7 +113,7 @@ def pre_optimizer_step(
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
optimizer.step(**kwargs)
return False
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand All @@ -63,12 +63,12 @@ def pre_optimizer_step(
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine = model.trainer.model
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
deepspeed_engine.step()
return False

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/ipu_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> No

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
Expand All @@ -55,7 +55,7 @@ def pre_optimizer_step(
closure_result = lambda_closure()
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
# we lack coverage here and IPUs are (currently) limited - something to explore if there's demand
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not implemented for IPUs."
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand All @@ -84,7 +84,7 @@ def pre_optimizer_step(
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer)
self.scaler.update()
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,15 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
return True

def clip_gradients(
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/precision/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable
from typing import Any, Callable, Union

from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand All @@ -27,7 +28,7 @@
class TPUPrecisionPlugin(PrecisionPlugin):
def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
Expand All @@ -37,7 +38,7 @@ def pre_optimizer_step(
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs})
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
# we lack coverage here so disable this - something to explore if there's demand
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not implemented for TPUs."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return trainer.init_optimizers(model)

def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs: Any) -> None:
optimizer.step(closure=lambda_closure, **kwargs)

@property
Expand Down

0 comments on commit d419028

Please sign in to comment.