From 3a0e42d590b0c61c7280128ccbbc0c67cbb79174 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 29 Oct 2021 03:03:44 +0200 Subject: [PATCH 1/8] Implement double-closure --- CHANGELOG.md | 3 ++ pytorch_lightning/accelerators/accelerator.py | 5 ++++ .../plugins/precision/precision_plugin.py | 29 ++++++++++++++++--- pytorch_lightning/plugins/precision/tpu.py | 5 ++-- tests/models/test_hooks.py | 16 +++------- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88c093e04184f..a6f5a15f9349d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -665,6 +665,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed gradients not being unscaled when clipping or logging the gradient norm ([#9287](https://github.com/PyTorchLightning/pytorch-lightning/pull/9287)) +- Fixed `on_before_optimizer_step` getting called before the optimizer closure (including backward) has run ([#10167](https://github.com/PyTorchLightning/pytorch-lightning/pull/10167)) + + - Fixed monitor value in `ModelCheckpoint` getting moved to the wrong device in a special case where it becomes NaN ([#10118](https://github.com/PyTorchLightning/pytorch-lightning/pull/10118)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a8acec23c6ed3..d2b44fc3fca1c 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -333,6 +333,11 @@ def optimizer_step( """ model = model or self.lightning_module self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs) + + if not isinstance(model, pl.LightningModule): + # gradient clipping and norm tracking only available with a LightingModule/Trainer + return + trainer = model.trainer assert isinstance(trainer, pl.Trainer) # TODO: this is done for the entire model but should be changed to per-optimizer diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index dcbaa76b48559..99104c05d16b5 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -110,6 +110,28 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k """ tensor.backward(*args, **kwargs) + def _wrap_closure( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + lambda_closure: Callable[[], Any], + ) -> Callable[[], Any]: + """This double-closure allows makes sure the ``lambda_closure`` is executed before the + ``on_before_optimizer_step`` hook is called. + + The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is + consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(lambda_closure)`` directly. + """ + + def inner() -> Any: + closure_result = lambda_closure() + if isinstance(model, pl.LightningModule): + model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + return closure_result + + return inner + def optimizer_step( self, model: Union["pl.LightningModule", Module], @@ -119,12 +141,11 @@ def optimizer_step( **kwargs: Any, ) -> None: """Hook to run the optimizer step.""" - if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) - optimizer.step(closure=lambda_closure, **kwargs) + closure = self._wrap_closure(model, optimizer, optimizer_idx, lambda_closure) + optimizer.step(closure=closure, **kwargs) def _track_grad_norm(self, trainer: "pl.Trainer") -> None: - if float(trainer.track_grad_norm) == -1: + if trainer.track_grad_norm == -1: return grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator) if grad_norm_dict: diff --git a/pytorch_lightning/plugins/precision/tpu.py b/pytorch_lightning/plugins/precision/tpu.py index efa1e6e398414..2bdc6775b20f2 100644 --- a/pytorch_lightning/plugins/precision/tpu.py +++ b/pytorch_lightning/plugins/precision/tpu.py @@ -34,9 +34,8 @@ def optimizer_step( lambda_closure: Callable[[], Any], **kwargs: Any ) -> None: - if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) - closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs}) + closure = self._wrap_closure(model, optimizer, optimizer_idx, lambda_closure) + closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index ccd31af4dbd3e..6b34553ff313b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -275,12 +275,7 @@ def _train_batch(self, *args, **kwargs): def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs): using_native_amp = kwargs.get("amp_backend") == "native" using_deepspeed = kwargs.get("strategy") == "deepspeed" - using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy") out = [] - on_before_optimizer_step = [ - dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), - dict(name="on_before_optimizer_step", args=(ANY, 0)), - ] for i in range(batches): out.extend( [ @@ -291,8 +286,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict(name="Callback.on_batch_start", args=(trainer, model)), dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)), dict(name="on_train_batch_start", args=(ANY, i)), - # without a precision plugin, we execute the closure inside the `optimizer.step` - *([] if using_plugin else on_before_optimizer_step), dict(name="forward", args=(ANY,)), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), @@ -305,7 +298,9 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre *([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name="Callback.on_after_backward", args=(trainer, model)), dict(name="on_after_backward"), - *(on_before_optimizer_step if using_plugin else []), + # note: unscaling happens here in the case of AMP + dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), + dict(name="on_before_optimizer_step", args=(ANY, 0)), *([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []), dict( name="clip_gradients", @@ -334,7 +329,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre @staticmethod def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs): using_deepspeed = kwargs.get("strategy") == "deepspeed" - using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy") out = [] for i in range(batches): out.extend( @@ -355,11 +349,9 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k dict(name="on_after_backward"), # `manual_backward` calls the previous 3 dict(name="manual_backward", args=(ANY,)), - *([dict(name="closure")] if using_plugin else []), + dict(name="closure"), dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), dict(name="on_before_optimizer_step", args=(ANY, 0)), - # without a precision plugin, we execute the closure inside the `optimizer.step` - *([] if using_plugin else [dict(name="closure")]), *([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), From 15510881257fe63c8a9b366787ee138eba8a9782 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 29 Oct 2021 04:33:17 +0200 Subject: [PATCH 2/8] Fix test --- tests/core/test_lightning_module.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index ff8ffa3c50acd..9bf6bc7e899d5 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -349,23 +349,20 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va for pg in optimizer.param_groups: for p in pg["params"]: - p.grad[p.grad > self.custom_gradient_clip_val] = self.custom_gradient_clip_val - p.grad[p.grad <= 0] = 0 - - def on_before_optimizer_step(self, optimizer, optimizer_idx): - for pg in optimizer.param_groups: - for p in pg["params"]: - if p.grad is not None and p.grad.abs().sum() > 0: - self.has_validated_gradients = True - assert p.grad.min() >= 0 - assert p.grad.max() <= self.custom_gradient_clip_val + p.grad.clamp_(min=0, max=self.custom_gradient_clip_val) model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, gradient_clip_val=1e-4 + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=0, gradient_clip_val=1e-4 ) trainer.fit(model) - assert model.has_validated_gradients + + optimizer = model.optimizers() + for pg in optimizer.param_groups: + for p in pg["params"]: + if p.grad is not None: + assert p.grad.min() >= 0 + assert p.grad.max() <= model.custom_gradient_clip_val def test_lightning_module_configure_gradient_clipping_different_argument_values(tmpdir): From d8874b13e4758fc265cad67de1dd9eeb50623e85 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 29 Oct 2021 13:45:09 +0200 Subject: [PATCH 3/8] Use partial --- .../plugins/precision/precision_plugin.py | 17 +++++++---------- pytorch_lightning/plugins/precision/tpu.py | 4 +++- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 99104c05d16b5..64cbe0c9e76c0 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from functools import partial from typing import Any, Callable, Generator, List, Optional, Tuple, Union import torch @@ -116,21 +117,16 @@ def _wrap_closure( optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable[[], Any], - ) -> Callable[[], Any]: + ) -> Any: """This double-closure allows makes sure the ``lambda_closure`` is executed before the ``on_before_optimizer_step`` hook is called. The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(lambda_closure)`` directly. """ - - def inner() -> Any: - closure_result = lambda_closure() - if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) - return closure_result - - return inner + closure_result = lambda_closure() + model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + return closure_result def optimizer_step( self, @@ -141,7 +137,8 @@ def optimizer_step( **kwargs: Any, ) -> None: """Hook to run the optimizer step.""" - closure = self._wrap_closure(model, optimizer, optimizer_idx, lambda_closure) + if isinstance(model, pl.LightningModule): + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, lambda_closure) optimizer.step(closure=closure, **kwargs) def _track_grad_norm(self, trainer: "pl.Trainer") -> None: diff --git a/pytorch_lightning/plugins/precision/tpu.py b/pytorch_lightning/plugins/precision/tpu.py index 2bdc6775b20f2..0a3522b1a9b7d 100644 --- a/pytorch_lightning/plugins/precision/tpu.py +++ b/pytorch_lightning/plugins/precision/tpu.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Any, Callable, Union from torch.nn import Module @@ -34,7 +35,8 @@ def optimizer_step( lambda_closure: Callable[[], Any], **kwargs: Any ) -> None: - closure = self._wrap_closure(model, optimizer, optimizer_idx, lambda_closure) + if isinstance(model, pl.LightningModule): + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, lambda_closure) closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value From 4f4f387c310bdc3d656ae9ff0c66cbc03c9f52e7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 29 Oct 2021 13:47:03 +0200 Subject: [PATCH 4/8] Rename --- pytorch_lightning/plugins/precision/precision_plugin.py | 8 ++++---- pytorch_lightning/plugins/precision/tpu.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 64cbe0c9e76c0..d40582535397e 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -116,7 +116,7 @@ def _wrap_closure( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], ) -> Any: """This double-closure allows makes sure the ``lambda_closure`` is executed before the ``on_before_optimizer_step`` hook is called. @@ -124,7 +124,7 @@ def _wrap_closure( The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(lambda_closure)`` directly. """ - closure_result = lambda_closure() + closure_result = closure() model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) return closure_result @@ -133,12 +133,12 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: """Hook to run the optimizer step.""" if isinstance(model, pl.LightningModule): - closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, lambda_closure) + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) optimizer.step(closure=closure, **kwargs) def _track_grad_norm(self, trainer: "pl.Trainer") -> None: diff --git a/pytorch_lightning/plugins/precision/tpu.py b/pytorch_lightning/plugins/precision/tpu.py index 0a3522b1a9b7d..53224709bfc08 100644 --- a/pytorch_lightning/plugins/precision/tpu.py +++ b/pytorch_lightning/plugins/precision/tpu.py @@ -32,11 +32,11 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any ) -> None: if isinstance(model, pl.LightningModule): - closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, lambda_closure) + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value From 7c3ddc69d7550a8493c014e91382d6b776409016 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 29 Oct 2021 14:05:09 +0200 Subject: [PATCH 5/8] Rename closure - fix mypy --- pytorch_lightning/plugins/precision/apex_amp.py | 4 ++-- pytorch_lightning/plugins/precision/deepspeed_precision.py | 4 ++-- pytorch_lightning/plugins/precision/ipu_precision.py | 4 ++-- pytorch_lightning/plugins/precision/native_amp.py | 6 +++--- pytorch_lightning/plugins/precision/precision_plugin.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index cc2d0301f5ccc..6447c65ecd131 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -101,14 +101,14 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) - closure_result = lambda_closure() + closure_result = closure() if isinstance(model, pl.LightningModule): model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) skipped_backward = closure_result is None diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 704e8fe3f5c69..01c8661232f93 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -55,14 +55,14 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) - closure_result = lambda_closure() + closure_result = closure() if isinstance(model, pl.LightningModule): model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) skipped_backward = closure_result is None diff --git a/pytorch_lightning/plugins/precision/ipu_precision.py b/pytorch_lightning/plugins/precision/ipu_precision.py index 76db29ecff212..51e80afed3609 100644 --- a/pytorch_lightning/plugins/precision/ipu_precision.py +++ b/pytorch_lightning/plugins/precision/ipu_precision.py @@ -43,7 +43,7 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: """IPUs handle the optimizer step internally.""" @@ -51,7 +51,7 @@ def optimizer_step( raise MisconfigurationException( f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) - closure_result = lambda_closure() + closure_result = closure() if isinstance(model, pl.LightningModule): model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) skipped_backward = closure_result is None diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 603e84a476bee..c80e648eff5c4 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -72,17 +72,17 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler - return super().optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) + return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs) if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) - closure_result = lambda_closure() + closure_result = closure() # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. self.scaler.unscale_(optimizer) if isinstance(model, pl.LightningModule): diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index d40582535397e..8aa0cccdfb86c 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -118,11 +118,11 @@ def _wrap_closure( optimizer_idx: int, closure: Callable[[], Any], ) -> Any: - """This double-closure allows makes sure the ``lambda_closure`` is executed before the + """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step`` hook is called. The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is - consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(lambda_closure)`` directly. + consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly. """ closure_result = closure() model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) From f8d76b4332f31168418df6024fe7250ecacebe81 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 29 Oct 2021 14:07:06 +0200 Subject: [PATCH 6/8] mypy --- pytorch_lightning/plugins/precision/precision_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 8aa0cccdfb86c..507fa47987e18 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -113,7 +113,7 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k def _wrap_closure( self, - model: Union["pl.LightningModule", Module], + model: pl.LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable[[], Any], From 848cda457d26309a0f5c242c2db66769f9ad7808 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 29 Oct 2021 14:10:52 +0200 Subject: [PATCH 7/8] forward reference --- pytorch_lightning/plugins/precision/precision_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 507fa47987e18..8ba8b1a1872ae 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -113,7 +113,7 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k def _wrap_closure( self, - model: pl.LightningModule, + model: "pl.LightningModule", optimizer: Optimizer, optimizer_idx: int, closure: Callable[[], Any], From 2fd93908537c088a30487de9971e36559cc35ed4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 29 Oct 2021 14:16:17 +0200 Subject: [PATCH 8/8] Finish rename --- pytorch_lightning/accelerators/accelerator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d2b44fc3fca1c..058f3fc3fb01f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -318,7 +318,7 @@ def optimizer_step( self, optimizer: Optimizer, opt_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any ) -> None: @@ -327,12 +327,12 @@ def optimizer_step( Args: optimizer: the optimizer performing the step opt_idx: index of the current optimizer - lambda_closure: closure calculating the loss value + closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks **kwargs: Any extra arguments to ``optimizer.step`` """ model = model or self.lightning_module - self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs) + self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) if not isinstance(model, pl.LightningModule): # gradient clipping and norm tracking only available with a LightingModule/Trainer