Skip to content

Commit

Permalink
[bug] [docs] Clearer optimizer_step override instructions (#4455)
Browse files Browse the repository at this point in the history
* fix

* flags

* remove defaults
  • Loading branch information
ananyahjha93 authored and rohitgr7 committed Nov 21, 2020
1 parent 047ded0 commit 7b019b8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
11 changes: 6 additions & 5 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):

# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
batch_idx,
optimizer,
opt_idx,
lambda_closure,
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=native_amp,
using_lbfgs=is_lbfgs
)
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,13 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):

# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
batch_idx, optimizer,
opt_idx,
lambda_closure,
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=True,
using_native_amp=False,
using_lbfgs=is_lbfgs
)

Expand Down
30 changes: 18 additions & 12 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,17 +1143,23 @@ def optimizer_step(
batch_idx: int,
optimizer: Optimizer,
optimizer_idx: int,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = False,
using_native_amp: bool = False,
using_lbfgs: bool = False,
optimizer_closure: Optional[Callable],
on_tpu: bool,
using_native_amp: bool,
using_lbfgs: bool,
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
once per optimizer.
Warning:
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
to ``optimizer.step()`` function as shown in the examples. This ensures that
``train_step_and_backward_closure`` is called within
:meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`.
Args:
epoch: Current epoch
batch_idx: Index of current batch
Expand All @@ -1168,23 +1174,23 @@ def optimizer_step(
.. code-block:: python
# DEFAULT
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
optimizer.step()
optimizer.step(closure=optimizer_closure)
# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_idx % 2 == 0 :
optimizer.step()
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_idx % 4 == 0 :
optimizer.step()
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
# ...
Expand All @@ -1197,16 +1203,16 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
.. code-block:: python
# learning rate warm-up
def optimizer_step(self, current_epoch, batch_idx, optimizer,
optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.learning_rate
# update params
optimizer.step()
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
Note:
Expand Down

0 comments on commit 7b019b8

Please sign in to comment.