Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] [docs] Clearer optimizer_step override instructions #4455

Merged
merged 4 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):

# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
batch_idx,
optimizer,
opt_idx,
lambda_closure,
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=native_amp,
using_lbfgs=is_lbfgs
)
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,13 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):

# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
batch_idx, optimizer,
opt_idx,
lambda_closure,
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=True,
using_native_amp=False,
using_lbfgs=is_lbfgs
)

Expand Down
30 changes: 18 additions & 12 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,17 +1143,23 @@ def optimizer_step(
batch_idx: int,
optimizer: Optimizer,
optimizer_idx: int,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = False,
using_native_amp: bool = False,
using_lbfgs: bool = False,
optimizer_closure: Optional[Callable],
on_tpu: bool,
using_native_amp: bool,
using_lbfgs: bool,
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
once per optimizer.

Warning:
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
to ``optimizer.step()`` function as shown in the examples. This ensures that
``train_step_and_backward_closure`` is called within
:meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`.

Args:
epoch: Current epoch
batch_idx: Index of current batch
Expand All @@ -1168,23 +1174,23 @@ def optimizer_step(
.. code-block:: python

# DEFAULT
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
optimizer.step()
optimizer.step(closure=optimizer_closure)

# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_idx % 2 == 0 :
optimizer.step()
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()

# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_idx % 4 == 0 :
optimizer.step()
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()

# ...
Expand All @@ -1197,16 +1203,16 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
.. code-block:: python

# learning rate warm-up
def optimizer_step(self, current_epoch, batch_idx, optimizer,
optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.learning_rate

# update params
optimizer.step()
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()

Note:
Expand Down