diff --git a/CHANGELOG.md b/CHANGELOG.md index 88c093e04184f..a6f5a15f9349d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -665,6 +665,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed gradients not being unscaled when clipping or logging the gradient norm ([#9287](https://github.com/PyTorchLightning/pytorch-lightning/pull/9287)) +- Fixed `on_before_optimizer_step` getting called before the optimizer closure (including backward) has run ([#10167](https://github.com/PyTorchLightning/pytorch-lightning/pull/10167)) + + - Fixed monitor value in `ModelCheckpoint` getting moved to the wrong device in a special case where it becomes NaN ([#10118](https://github.com/PyTorchLightning/pytorch-lightning/pull/10118)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a8acec23c6ed3..058f3fc3fb01f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -318,7 +318,7 @@ def optimizer_step( self, optimizer: Optimizer, opt_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any ) -> None: @@ -327,12 +327,17 @@ def optimizer_step( Args: optimizer: the optimizer performing the step opt_idx: index of the current optimizer - lambda_closure: closure calculating the loss value + closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks **kwargs: Any extra arguments to ``optimizer.step`` """ model = model or self.lightning_module - self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs) + self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + + if not isinstance(model, pl.LightningModule): + # gradient clipping and norm tracking only available with a LightingModule/Trainer + return + trainer = model.trainer assert isinstance(trainer, pl.Trainer) # TODO: this is done for the entire model but should be changed to per-optimizer diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index cc2d0301f5ccc..6447c65ecd131 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -101,14 +101,14 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) - closure_result = lambda_closure() + closure_result = closure() if isinstance(model, pl.LightningModule): model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) skipped_backward = closure_result is None diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 704e8fe3f5c69..01c8661232f93 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -55,14 +55,14 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) - closure_result = lambda_closure() + closure_result = closure() if isinstance(model, pl.LightningModule): model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) skipped_backward = closure_result is None diff --git a/pytorch_lightning/plugins/precision/ipu_precision.py b/pytorch_lightning/plugins/precision/ipu_precision.py index 76db29ecff212..51e80afed3609 100644 --- a/pytorch_lightning/plugins/precision/ipu_precision.py +++ b/pytorch_lightning/plugins/precision/ipu_precision.py @@ -43,7 +43,7 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: """IPUs handle the optimizer step internally.""" @@ -51,7 +51,7 @@ def optimizer_step( raise MisconfigurationException( f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) - closure_result = lambda_closure() + closure_result = closure() if isinstance(model, pl.LightningModule): model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) skipped_backward = closure_result is None diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 603e84a476bee..c80e648eff5c4 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -72,17 +72,17 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler - return super().optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) + return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs) if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) - closure_result = lambda_closure() + closure_result = closure() # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. self.scaler.unscale_(optimizer) if isinstance(model, pl.LightningModule): diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index dcbaa76b48559..8ba8b1a1872ae 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from functools import partial from typing import Any, Callable, Generator, List, Optional, Tuple, Union import torch @@ -110,21 +111,38 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k """ tensor.backward(*args, **kwargs) + def _wrap_closure( + self, + model: "pl.LightningModule", + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + ) -> Any: + """This double-closure allows makes sure the ``closure`` is executed before the + ``on_before_optimizer_step`` hook is called. + + The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is + consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly. + """ + closure_result = closure() + model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + return closure_result + def optimizer_step( self, model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any, ) -> None: """Hook to run the optimizer step.""" if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) - optimizer.step(closure=lambda_closure, **kwargs) + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + optimizer.step(closure=closure, **kwargs) def _track_grad_norm(self, trainer: "pl.Trainer") -> None: - if float(trainer.track_grad_norm) == -1: + if trainer.track_grad_norm == -1: return grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator) if grad_norm_dict: diff --git a/pytorch_lightning/plugins/precision/tpu.py b/pytorch_lightning/plugins/precision/tpu.py index efa1e6e398414..53224709bfc08 100644 --- a/pytorch_lightning/plugins/precision/tpu.py +++ b/pytorch_lightning/plugins/precision/tpu.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Any, Callable, Union from torch.nn import Module @@ -31,12 +32,12 @@ def optimizer_step( model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, - lambda_closure: Callable[[], Any], + closure: Callable[[], Any], **kwargs: Any ) -> None: if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) - closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs}) + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index ff8ffa3c50acd..9bf6bc7e899d5 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -349,23 +349,20 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va for pg in optimizer.param_groups: for p in pg["params"]: - p.grad[p.grad > self.custom_gradient_clip_val] = self.custom_gradient_clip_val - p.grad[p.grad <= 0] = 0 - - def on_before_optimizer_step(self, optimizer, optimizer_idx): - for pg in optimizer.param_groups: - for p in pg["params"]: - if p.grad is not None and p.grad.abs().sum() > 0: - self.has_validated_gradients = True - assert p.grad.min() >= 0 - assert p.grad.max() <= self.custom_gradient_clip_val + p.grad.clamp_(min=0, max=self.custom_gradient_clip_val) model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, gradient_clip_val=1e-4 + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=0, gradient_clip_val=1e-4 ) trainer.fit(model) - assert model.has_validated_gradients + + optimizer = model.optimizers() + for pg in optimizer.param_groups: + for p in pg["params"]: + if p.grad is not None: + assert p.grad.min() >= 0 + assert p.grad.max() <= model.custom_gradient_clip_val def test_lightning_module_configure_gradient_clipping_different_argument_values(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index ccd31af4dbd3e..6b34553ff313b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -275,12 +275,7 @@ def _train_batch(self, *args, **kwargs): def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs): using_native_amp = kwargs.get("amp_backend") == "native" using_deepspeed = kwargs.get("strategy") == "deepspeed" - using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy") out = [] - on_before_optimizer_step = [ - dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), - dict(name="on_before_optimizer_step", args=(ANY, 0)), - ] for i in range(batches): out.extend( [ @@ -291,8 +286,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict(name="Callback.on_batch_start", args=(trainer, model)), dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)), dict(name="on_train_batch_start", args=(ANY, i)), - # without a precision plugin, we execute the closure inside the `optimizer.step` - *([] if using_plugin else on_before_optimizer_step), dict(name="forward", args=(ANY,)), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), @@ -305,7 +298,9 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre *([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name="Callback.on_after_backward", args=(trainer, model)), dict(name="on_after_backward"), - *(on_before_optimizer_step if using_plugin else []), + # note: unscaling happens here in the case of AMP + dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), + dict(name="on_before_optimizer_step", args=(ANY, 0)), *([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []), dict( name="clip_gradients", @@ -334,7 +329,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre @staticmethod def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs): using_deepspeed = kwargs.get("strategy") == "deepspeed" - using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy") out = [] for i in range(batches): out.extend( @@ -355,11 +349,9 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k dict(name="on_after_backward"), # `manual_backward` calls the previous 3 dict(name="manual_backward", args=(ANY,)), - *([dict(name="closure")] if using_plugin else []), + dict(name="closure"), dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), dict(name="on_before_optimizer_step", args=(ANY, 0)), - # without a precision plugin, we execute the closure inside the `optimizer.step` - *([] if using_plugin else [dict(name="closure")]), *([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)),