Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix missing call to untoggle_optimizer when accumulating gradients #8284

Merged
merged 6 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208))


- Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284))


## [1.3.8] - 2021-07-01

### Fixed
Expand Down
20 changes: 11 additions & 9 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,17 @@ def _run_optimization(
else:
if self.trainer.lightning_module.automatic_optimization:
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
if len(self.trainer.optimizers) > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes len() == 2, consider 3 optimizers - it should be a cyclic toggle instead of boolean

# revert back to previous state
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
else:
result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)

if not result:
# user decided to skip optimization
return result

# update running loss + reset accumulated loss
if result:
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
self._update_running_loss(result.loss)
self._process_closure_result(result)

self._process_closure_result(result)
# untoggle model params
self._run_optimization_end(opt_idx)
return result

def _training_step_and_backward_closure(
Expand Down Expand Up @@ -490,6 +487,11 @@ def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer
model = self.trainer.lightning_module
model.toggle_optimizer(optimizer, opt_idx)

def _run_optimization_end(self, opt_idx: int) -> None:
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
model = self.trainer.lightning_module
model.untoggle_optimizer(opt_idx)

@contextmanager
def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def optimizer_step(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
accumulate_grad_batches=2,
limit_val_batches=0,
)
trainer.fit(model)
Expand Down Expand Up @@ -331,7 +331,7 @@ def configure_optimizers(self):
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
accumulate_grad_batches=2,
)

trainer.fit(model)
Expand Down