Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement double optimizer closure for hook structure consistency #10167

Merged
merged 8 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed gradients not being unscaled when clipping or logging the gradient norm ([#9287](https://github.com/PyTorchLightning/pytorch-lightning/pull/9287))


- Fixed `on_before_optimizer_step` getting called before the optimizer closure (including backward) has run ([#10167](https://github.com/PyTorchLightning/pytorch-lightning/pull/10167))


- Fixed monitor value in `ModelCheckpoint` getting moved to the wrong device in a special case where it becomes NaN ([#10118](https://github.com/PyTorchLightning/pytorch-lightning/pull/10118))


Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def optimizer_step(
self,
optimizer: Optimizer,
opt_idx: int,
lambda_closure: Callable[[], Any],
closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any
) -> None:
Expand All @@ -327,12 +327,17 @@ def optimizer_step(
Args:
optimizer: the optimizer performing the step
opt_idx: index of the current optimizer
lambda_closure: closure calculating the loss value
closure: closure calculating the loss value
model: reference to the model, optionally defining optimizer step related hooks
**kwargs: Any extra arguments to ``optimizer.step``
"""
model = model or self.lightning_module
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs)
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)

if not isinstance(model, pl.LightningModule):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# gradient clipping and norm tracking only available with a LightingModule/Trainer
return

trainer = model.trainer
assert isinstance(trainer, pl.Trainer)
# TODO: this is done for the entire model but should be changed to per-optimizer
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def optimizer_step(
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = lambda_closure()
closure_result = closure()
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
skipped_backward = closure_result is None
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def optimizer_step(
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = lambda_closure()
closure_result = closure()
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
skipped_backward = closure_result is None
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/ipu_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def optimizer_step(
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
"""IPUs handle the optimizer step internally."""
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = lambda_closure()
closure_result = closure()
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
skipped_backward = closure_result is None
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ def optimizer_step(
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = lambda_closure()
closure_result = closure()
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer)
if isinstance(model, pl.LightningModule):
Expand Down
26 changes: 22 additions & 4 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from functools import partial
from typing import Any, Callable, Generator, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -110,21 +111,38 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k
"""
tensor.backward(*args, **kwargs)

def _wrap_closure(
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self,
model: "pl.LightningModule",
optimizer: Optimizer,
optimizer_idx: int,
closure: Callable[[], Any],
) -> Any:
"""This double-closure allows makes sure the ``closure`` is executed before the
``on_before_optimizer_step`` hook is called.

The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
"""
closure_result = closure()
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
return closure_result

def optimizer_step(
self,
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
"""Hook to run the optimizer step."""
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
optimizer.step(closure=lambda_closure, **kwargs)
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
optimizer.step(closure=closure, **kwargs)

def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if float(trainer.track_grad_norm) == -1:
if trainer.track_grad_norm == -1:
return
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator)
if grad_norm_dict:
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/precision/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Callable, Union

from torch.nn import Module
Expand All @@ -31,12 +32,12 @@ def optimizer_step(
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
closure: Callable[[], Any],
**kwargs: Any
) -> None:
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs})
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
Expand Down
21 changes: 9 additions & 12 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,23 +349,20 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va

for pg in optimizer.param_groups:
for p in pg["params"]:
p.grad[p.grad > self.custom_gradient_clip_val] = self.custom_gradient_clip_val
p.grad[p.grad <= 0] = 0

def on_before_optimizer_step(self, optimizer, optimizer_idx):
for pg in optimizer.param_groups:
for p in pg["params"]:
if p.grad is not None and p.grad.abs().sum() > 0:
self.has_validated_gradients = True
assert p.grad.min() >= 0
assert p.grad.max() <= self.custom_gradient_clip_val
p.grad.clamp_(min=0, max=self.custom_gradient_clip_val)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, gradient_clip_val=1e-4
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=0, gradient_clip_val=1e-4
)
trainer.fit(model)
assert model.has_validated_gradients

optimizer = model.optimizers()
for pg in optimizer.param_groups:
for p in pg["params"]:
if p.grad is not None:
assert p.grad.min() >= 0
assert p.grad.max() <= model.custom_gradient_clip_val


def test_lightning_module_configure_gradient_clipping_different_argument_values(tmpdir):
Expand Down
16 changes: 4 additions & 12 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,7 @@ def _train_batch(self, *args, **kwargs):
def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs):
using_native_amp = kwargs.get("amp_backend") == "native"
using_deepspeed = kwargs.get("strategy") == "deepspeed"
using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy")
out = []
on_before_optimizer_step = [
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
dict(name="on_before_optimizer_step", args=(ANY, 0)),
]
for i in range(batches):
out.extend(
[
Expand All @@ -291,8 +286,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
dict(name="Callback.on_batch_start", args=(trainer, model)),
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)),
dict(name="on_train_batch_start", args=(ANY, i)),
# without a precision plugin, we execute the closure inside the `optimizer.step`
*([] if using_plugin else on_before_optimizer_step),
carmocca marked this conversation as resolved.
Show resolved Hide resolved
dict(name="forward", args=(ANY,)),
dict(name="training_step", args=(ANY, i)),
dict(name="training_step_end", args=(dict(loss=ANY),)),
Expand All @@ -305,7 +298,9 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name="Callback.on_after_backward", args=(trainer, model)),
dict(name="on_after_backward"),
*(on_before_optimizer_step if using_plugin else []),
# note: unscaling happens here in the case of AMP
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
dict(name="on_before_optimizer_step", args=(ANY, 0)),
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
dict(
name="clip_gradients",
Expand Down Expand Up @@ -334,7 +329,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
@staticmethod
def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs):
using_deepspeed = kwargs.get("strategy") == "deepspeed"
using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy")
out = []
for i in range(batches):
out.extend(
Expand All @@ -355,11 +349,9 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k
dict(name="on_after_backward"),
# `manual_backward` calls the previous 3
dict(name="manual_backward", args=(ANY,)),
*([dict(name="closure")] if using_plugin else []),
dict(name="closure"),
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
dict(name="on_before_optimizer_step", args=(ANY, 0)),
# without a precision plugin, we execute the closure inside the `optimizer.step`
*([] if using_plugin else [dict(name="closure")]),
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
dict(name="training_step", args=(ANY, i)),
dict(name="training_step_end", args=(dict(loss=ANY),)),
Expand Down