Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extract manual optimization loop #9266

Merged
merged 83 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
fea744f
wip
awaelchli Jul 24, 2021
81d3797
extract optimizer loop
awaelchli Aug 27, 2021
b39ddbf
handle restart
awaelchli Aug 29, 2021
97eabcf
update running loss
awaelchli Aug 29, 2021
69a29d9
add changelog
awaelchli Aug 29, 2021
bec5341
update tests
awaelchli Aug 29, 2021
58404aa
refactor block parallel sync behavior
awaelchli Aug 29, 2021
9226271
remove automatic opt specific logic
awaelchli Aug 29, 2021
d376c91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2021
38acf9a
fix circular import
awaelchli Aug 30, 2021
36c8692
fix swa tests
awaelchli Aug 30, 2021
5682a6b
fix state dict test
awaelchli Aug 30, 2021
0b59c37
add connect
awaelchli Aug 30, 2021
574c11e
fix reset
awaelchli Aug 30, 2021
ce4f08a
fix imports
awaelchli Aug 30, 2021
39fb458
add license
awaelchli Aug 30, 2021
9f1880e
fix test_loops.py
awaelchli Aug 30, 2021
d4dfd54
remove commented code
awaelchli Aug 30, 2021
0b68a11
add docstrings
awaelchli Aug 30, 2021
624394f
fix typing in constructor
awaelchli Aug 30, 2021
07fbafe
update hidden state management
awaelchli Aug 30, 2021
7572deb
extract build_kwargs method
awaelchli Aug 30, 2021
9d4f459
remove todo
awaelchli Aug 30, 2021
b4962ef
isort
awaelchli Aug 30, 2021
42cdc14
update init files
awaelchli Aug 30, 2021
fde9e9e
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 30, 2021
bd128f0
fix loop state dict test
awaelchli Aug 30, 2021
a304ac5
fix tbtt tests
awaelchli Aug 30, 2021
eb00a4c
fix imports
awaelchli Aug 30, 2021
d346be2
no longer duplicated
awaelchli Aug 30, 2021
b0c997e
remove unused optimiizer arguments for the manual opt path
awaelchli Aug 30, 2021
86743aa
update typehint
awaelchli Aug 30, 2021
e130da5
update docs
awaelchli Aug 30, 2021
586ead7
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 30, 2021
6b123be
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 30, 2021
fa4c788
remove unused argument
awaelchli Aug 30, 2021
9701c89
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 30, 2021
ef05def
update typing
awaelchli Aug 31, 2021
87f3002
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 31, 2021
73c60d3
remove redundant process closure result
awaelchli Aug 31, 2021
fc03c90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2021
6ca1933
add todo
awaelchli Aug 31, 2021
e5454ff
remove list copy for optimizers
awaelchli Aug 31, 2021
2fadaac
undo skip_backward changes in swa
awaelchli Aug 31, 2021
8475526
Merge branch 'master' into refactor/optimizer-loop
awaelchli Sep 1, 2021
9fdd03b
clean up manual optimization logic after optimizer loop
awaelchli Sep 2, 2021
848ee53
manual loop
awaelchli Sep 2, 2021
303411a
update output handling of None
awaelchli Sep 2, 2021
cb4c178
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 2, 2021
9fb7c3d
Merge branch 'master' into refactor/clean-up-manual-opt
carmocca Sep 2, 2021
0467f6b
Merge branch 'master' into refactor/manual-loop
awaelchli Sep 3, 2021
dc80242
init imports
awaelchli Sep 3, 2021
d48d183
add license
awaelchli Sep 3, 2021
9275fc0
add docs
awaelchli Sep 3, 2021
b2e50ee
rename model reference
awaelchli Sep 3, 2021
8c297f0
mypy non sense
awaelchli Sep 3, 2021
65491a7
fix loop structure test
awaelchli Sep 3, 2021
c3febd7
update changelog
awaelchli Sep 3, 2021
cc25477
fix argument passing
awaelchli Sep 3, 2021
18efdea
update loop structure
awaelchli Sep 3, 2021
de53c4b
Merge branch 'master' into refactor/manual-loop
awaelchli Sep 3, 2021
bca26cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2021
8404d07
fix state dict structure
awaelchli Sep 3, 2021
df7f240
update connect()
awaelchli Sep 3, 2021
273dd8d
Fix mypy
carmocca Sep 3, 2021
eddd77c
update type ignore
awaelchli Sep 3, 2021
5e3358d
Merge branch 'master' into refactor/clean-up-manual-opt
tchaton Sep 6, 2021
4bb11db
Merge branch 'master' into refactor/manual-loop
awaelchli Sep 6, 2021
04c68cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2021
1eb01f1
notebooks
awaelchli Sep 6, 2021
2f6e5ea
notebook
awaelchli Sep 6, 2021
b4b4f04
_notebooks
awaelchli Sep 6, 2021
3a474ec
make loops optional in connect()
awaelchli Sep 6, 2021
c9b72f4
return tuple and update hidden state regardless
awaelchli Sep 7, 2021
99a9206
parametrize test with manual opt
awaelchli Sep 7, 2021
3649511
update docstring
awaelchli Sep 7, 2021
4cc2a45
update model attributes
awaelchli Sep 7, 2021
5924231
update return type hint
awaelchli Sep 7, 2021
bc93fc4
extract model class
awaelchli Sep 7, 2021
4cdfa9a
create manual model class
awaelchli Sep 7, 2021
ce1f9a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2021
998ab0f
add assertion for grad fn
awaelchli Sep 7, 2021
f5f8339
rename classes
awaelchli Sep 7, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))

- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ ignore_errors = "True"
module = [
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.loops.closure",
"pytorch_lightning.loops.batch.manual",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector",
"pytorch_lightning.utilities.apply_func",
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from pytorch_lightning.loops.base import Loop # noqa: F401
from pytorch_lightning.loops.batch import ManualOptimization # noqa: F401
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loops/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
93 changes: 93 additions & 0 deletions pytorch_lightning/loops/batch/manual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Tuple

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.utilities import (
_build_training_step_kwargs,
_check_training_step_output,
_process_training_step_output,
)
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection


class ManualOptimization(Loop):
"""
A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user
is responsible for back-propagating gradients and making calls to the optimizers.

This loop is a trivial case because it performs only a single iteration (calling directly into the module's
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`) and passing through the output(s).
"""

def __init__(self) -> None:
super().__init__()
self._done: bool = False
self._hiddens: Optional[Any] = None
self._output: Optional[ResultCollection] = None

@property
def done(self) -> bool:
return self._done

def reset(self) -> None:
self._done = False

def advance(self, batch: Any, batch_idx: int, hiddens: Optional[Any] = None) -> None: # type: ignore[override]
"""Performs the training step for manual optimization.

Args:
batch: the current tbptt split of the current batch
batch_idx: the index of the current batch
hiddens: the model's hidden state of the previous iteration

Returns:
post-processed outputs from the training step, or ``None`` if training step returned nothing
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""
assert self.trainer is not None
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
ligtning_module = self.trainer.lightning_module

with self.trainer.profiler.profile("model_forward"):

step_kwargs = _build_training_step_kwargs(
ligtning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=hiddens
)

# manually capture logged metrics
ligtning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()

del step_kwargs

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

_check_training_step_output(ligtning_module, training_step_output)

result_collection, hiddens = _process_training_step_output(self.trainer, training_step_output)

self._done = True
self._hiddens = hiddens
self._output = result_collection

def on_run_end(self) -> Optional[Tuple[ResultCollection, Optional[Any]]]:
hiddens = self._hiddens
output = self._output
self._hiddens, self._output = None, None # free memory
if output is None:
return None
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return output, hiddens
108 changes: 10 additions & 98 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from functools import partial
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import numpy as np
from deprecate import void
from torch import Tensor
from torch.optim import Optimizer

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.batch.manual import ManualOptimization
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import (
_build_training_step_kwargs,
_check_training_step_output,
_process_training_step_output,
)
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -45,6 +39,7 @@ def __init__(self) -> None:
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: Optional[int] = None
self.optimizer_loop = OptimizerLoop()
self.manual_loop = ManualOptimization()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

self._warning_cache: WarningCache = WarningCache()
self._hiddens: Optional[Tensor] = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -63,8 +58,9 @@ def optimizer_freq_cumsum(self) -> int:
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def connect(self, optimizer_loop: "Loop") -> None:
def connect(self, optimizer_loop: "Loop", manual_loop: ManualOptimization) -> None:
self.optimizer_loop = optimizer_loop
self.manual_loop = manual_loop

def run(self, batch: Any, batch_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks
Expand Down Expand Up @@ -132,10 +128,11 @@ def advance(self, batch, batch_idx):
for k in range(len(batch_outputs)):
self.batch_outputs[k].extend(batch_outputs[k])
else:
# in manual optimization, there is no looping over optimizers
result = self._run_optimization(batch_idx, split_batch)
if result:
self.batch_outputs[0].append(deepcopy(result.result_collection))
# in manual optimization, hand over execution to the ManualOptimization loop
output = self.manual_loop.run(split_batch, batch_idx, self._hiddens)
if output:
result, self._hiddens = output
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.batch_outputs[0].append(deepcopy(result))

def teardown(self) -> None:
# release memory
Expand All @@ -145,91 +142,6 @@ def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
"""Gets the number of active optimizers based on their frequency"""
return len(self.get_active_optimizers(batch_idx))

def _run_optimization(
self,
batch_idx: int,
split_batch: Any,
) -> Optional[ClosureResult]:
"""Runs closure (train step + backward) together with optimization if necessary.

Args:
batch_idx: the index of the current batch
split_batch: the current tbptt split of the whole batch
"""
# TODO: replace call through closure by direct call (manual optimization)
closure = self._make_closure(split_batch, batch_idx, self._hiddens)
closure()
result = closure.get_result()

if result:
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
self._update_running_loss(result.loss)
Copy link
Contributor

@carmocca carmocca Sep 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed this while resolving conflicts: this is not called anymore, is it? Is that intentional?

(personally lean towards removing the running loss, so this would be okay): see #9372

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically this should not have been removed, ... sigh
but yeah, we can see what happens with #9372


return result

def _make_closure(
self,
split_batch: Any,
batch_idx: int,
hiddens: Any,
) -> Closure:
"""
Build a closure object that captures the given arguments and runs the `training_step` function and optionally
other functions such as `backward` and `zero_grad`.
"""
step_fn = self._make_step_fn(split_batch, batch_idx, hiddens)
backward_fn = None
zero_grad_fn = None

return Closure(
step_fn=step_fn,
backward_fn=backward_fn,
zero_grad_fn=zero_grad_fn,
profiler=self.trainer.profiler,
)

def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]:
"""Build the step function that runs the `training_step` and processes its output."""
return partial(self._training_step, split_batch, batch_idx, hiddens)

def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
"""Performs the training step for manual optimization.

Args:
split_batch: the current tbptt split of the current batch
batch_idx: the index of the current batch
hiddens: the model's hidden state of the previous iteration

Returns:
an AttributeDict containing the training step output.
"""
# give the PL module a result for logging
model_ref = self.trainer.lightning_module

with self.trainer.profiler.profile("model_forward"):
step_kwargs = _build_training_step_kwargs(
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens
)

# manually capture logged metrics
model_ref._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()

del step_kwargs

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

_check_training_step_output(self.trainer.lightning_module, training_step_output)

result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
if result_collection is None:
return

return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection)

def _tbptt_split_batch(self, batch: Any) -> List[Any]:
"""Splits a single batch into a list of sequence steps for tbptt.

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def lightning_module(self) -> "pl.LightningModule":
return self.accelerator.lightning_module

@property
def optimizers(self) -> Optional[List[Optimizer]]:
def optimizers(self) -> List[Optimizer]:
return self.accelerator.optimizers

@optimizers.setter
Expand Down
3 changes: 2 additions & 1 deletion tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def test_loops_state_dict_structure():
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {
Expand Down
1 change: 1 addition & 0 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def configure_optimizers_multiple(self):
"current": {"ready": be_sch_steps, "started": None, "processed": None, "completed": be_sch_steps},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer_idx": stop_optimizer,
Expand Down