Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom logic to each OutputResult subclass [2/2] #9424

Merged
merged 53 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b6d210c
WIP
carmocca Sep 6, 2021
f97f5f0
WIP
carmocca Sep 6, 2021
573f92a
WIP
carmocca Sep 7, 2021
718c971
WIP
carmocca Sep 7, 2021
096129e
Remove hiddens
carmocca Sep 7, 2021
bd90a94
Add closure result tests
carmocca Sep 7, 2021
bc87081
Fix tests
carmocca Sep 7, 2021
15c5972
Fail if closure is not executed
carmocca Sep 7, 2021
9803cec
Remove deepcopy
carmocca Sep 7, 2021
9f5f13a
Merge branch 'master' into bugfix/closure-result
carmocca Sep 7, 2021
24a70b9
Remove debugging statement
carmocca Sep 7, 2021
8630761
Enforce that the optimizer closure is executed when `optimizer_step` …
carmocca Sep 7, 2021
46d8113
Add comment
carmocca Sep 7, 2021
32492ca
Docs and changelog
carmocca Sep 7, 2021
bd54363
is not None
carmocca Sep 7, 2021
4265449
Improve error message to cover broken optimizers
carmocca Sep 7, 2021
8716d4c
Undo getter
carmocca Sep 7, 2021
3f7ea80
Add license and period
carmocca Sep 8, 2021
ad0cc5a
Merge branch 'master' into refactor/enforce-closure-execution
carmocca Sep 8, 2021
4174e1f
Merge manual loop PR
carmocca Sep 8, 2021
3140d9b
Clone loss
carmocca Sep 8, 2021
6689f3d
Remove apply_accumulation
carmocca Sep 8, 2021
6f9f5af
Move to cpu
carmocca Sep 8, 2021
8108c1f
mypy
carmocca Sep 8, 2021
35006e5
Merge branch 'master' into bugfix/closure-result
carmocca Sep 8, 2021
c4f360a
Handle hiddens with an utility
carmocca Sep 8, 2021
bed7624
Without closure and tests
carmocca Sep 8, 2021
6a81c0c
Remove code to move the ClosureResult
carmocca Sep 8, 2021
a09cfc8
Tests
carmocca Sep 8, 2021
78253e4
Fix TODO
carmocca Sep 8, 2021
26e96a1
Undo docs changes
carmocca Sep 8, 2021
5b50789
Update CHANGELOG
carmocca Sep 8, 2021
bf08c2b
Fix
carmocca Sep 8, 2021
f110bc1
Stricter hiddens returning
carmocca Sep 9, 2021
9d8d587
Drop closure loss
carmocca Sep 9, 2021
5138edf
Logic fix
carmocca Sep 9, 2021
3aaf9dd
Bad rename
carmocca Sep 9, 2021
1f32d52
Fix for raise StopIteration
carmocca Sep 9, 2021
04a54a6
Fix comment
carmocca Sep 9, 2021
6fb8758
Initial implementation
carmocca Sep 9, 2021
fa80ef0
Undo change
carmocca Sep 10, 2021
41a14cf
Merge branch 'master' into refactor/poc-output-result
carmocca Sep 14, 2021
93652ac
Bad merge
carmocca Sep 14, 2021
21697a4
Progress
carmocca Sep 14, 2021
b4a0204
Progress
carmocca Sep 14, 2021
237992e
Progress
carmocca Sep 14, 2021
48a42dc
Add test
carmocca Sep 14, 2021
e758c29
Make conversion in the manual loop
carmocca Sep 14, 2021
e9fdb95
Update CHANGELOG
carmocca Sep 14, 2021
9c3d3fe
Typing
carmocca Sep 14, 2021
9fb2d8c
Merge branch 'master' into refactor/poc-output-result
carmocca Sep 15, 2021
effc3c9
Add test. Raise NotImplementedError
carmocca Sep 15, 2021
fc91d13
Remove invalid test
carmocca Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437))
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424))


- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def advance(self, batch, batch_idx):
else:
# in manual optimization, hand over execution to the ManualOptimization loop
result = self.manual_loop.run(split_batch, batch_idx)
if result is not None and result.loss is not None:
self.batch_outputs[0].append(result.drop_closure_loss())
if result:
self.batch_outputs[0].append(result)

def on_run_end(self) -> None:
self.optimizer_loop._hiddens = None
Expand Down
13 changes: 1 addition & 12 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,7 @@ def _prepare_outputs(
opt_outputs = [opt_outputs]

for batch_outputs in opt_outputs:
processed_tbptt_outputs = []

if isinstance(batch_outputs, OutputResult):
batch_outputs = [batch_outputs]

for tbptt_output in batch_outputs:
out = {}
if tbptt_output.loss is not None:
out["loss"] = tbptt_output.loss
out.update(tbptt_output.extra)
processed_tbptt_outputs.append(out)

processed_tbptt_outputs = batch_outputs if isinstance(batch_outputs, list) else [batch_outputs]
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# if there was only one tbptt step then we can collapse that dimension
if len(processed_tbptt_outputs) == 1:
processed_tbptt_outputs = processed_tbptt_outputs[0]
Expand Down
27 changes: 24 additions & 3 deletions pytorch_lightning/loops/optimization/closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,37 @@
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from typing import Any, Dict, Generic, Optional, TypeVar

from torch import Tensor

from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException

T = TypeVar("T")


@dataclass
class OutputResult:
...
@staticmethod
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> Dict[str, Any]:
# TODO: remove with the deprecation removal in v1.6
# this is only here to avoid duplication
def check_fn(v: Tensor) -> Tensor:
if v.grad_fn is not None:
rank_zero_deprecation(
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
" but this behaviour will change in v1.6. Please detach it manually:"
" `return {'loss': ..., 'something': something.detach()}`"
)
return v.detach()
return v

return apply_to_collection(extra, Tensor, check_fn)

def asdict(self) -> Dict[str, Any]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError


class AbstractClosure(ABC, Generic[T]):
Expand All @@ -33,7 +54,7 @@ class AbstractClosure(ABC, Generic[T]):
object which later can call it like a function but without requiring to pass in any arguments.

This class provides a simple abstraction making the instance of this class callable like a function while capturing
the :class:`OutputResult` and caching it.
the closure result and caching it.
"""

def __init__(self) -> None:
Expand Down
85 changes: 26 additions & 59 deletions pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,9 @@

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.optimization.closure import OutputResult
from pytorch_lightning.loops.utilities import (
_build_training_step_kwargs,
_check_training_step_output,
_extract_hiddens,
check_finite_loss,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_deprecation


@dataclass
Expand All @@ -37,66 +30,45 @@ class ManualResult(OutputResult):
It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.

Attributes:
closure_loss: The loss with a graph attached.
loss: A detached copy of the closure loss.
extra: Any keys other than the loss returned.
extra: Anything returned by the ``training_step``.
"""

closure_loss: Optional[Tensor]
loss: Optional[Tensor] = field(init=False, default=None)
extra: Dict[str, Tensor] = field(default_factory=dict)
extra: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
# TODO: remove with the deprecation removal in v1.6
self._check_extra_detach_deprecation(self.extra)
self.extra = recursive_detach(self.extra)

self._clone_loss()

def _clone_loss(self) -> None:
if self.closure_loss is not None:
# the loss will get scaled for amp. avoid any modifications to it
self.loss = self.closure_loss.detach().clone()
self.extra = self._check_extra_detach_deprecation(self.extra)

@classmethod
def from_training_step_output(
cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
) -> "ManualResult":
closure_loss, extra = None, {}

extra = {}
if isinstance(training_step_output, dict):
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
closure_loss = training_step_output.get("loss")
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
extra = {k: v for k, v in training_step_output.items() if k != "hiddens"}
elif isinstance(training_step_output, Tensor):
closure_loss = training_step_output
extra = {"loss": training_step_output}
elif training_step_output is not None:
raise MisconfigurationException(
"In manual optimization, `training_step` must either return a Tensor, "
"a dict with extras to pass to `training_epoch_end` or have no return."
)

if closure_loss is not None:
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
closure_loss /= normalize
if "loss" in extra:
# accumulate the loss. If `accumulate_grad_batches == 1`, no effect.
# we detach manually as it's expected that it will have a `grad_fn`
extra["loss"] = extra["loss"].detach().div(normalize)

return cls(closure_loss, extra=extra)
return cls(extra=extra)

@staticmethod
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
def check_fn(v: Tensor) -> Tensor:
if v.grad_fn is not None:
rank_zero_deprecation(
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
" but this behaviour will change in v1.6. Please detach it manually:"
" `return {'loss': ..., 'something': something.detach()}`"
)
return v
def asdict(self) -> Dict[str, Any]:
return self.extra

apply_to_collection(extra, Tensor, check_fn)

def drop_closure_loss(self) -> "ManualResult":
"""Return itself without the closure loss which could have a `grad_fn`"""
self.closure_loss = None
return self
_OUTPUTS_TYPE = Dict[str, Any]


class ManualOptimization(Loop):
class ManualOptimization(Loop[_OUTPUTS_TYPE]):
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user
is responsible for back-propagating gradients and making calls to the optimizers.
Expand All @@ -109,7 +81,7 @@ def __init__(self) -> None:
super().__init__()
self._done: bool = False
self._hiddens: Optional[Any] = None
self._output: Optional[ManualResult] = None
self._output: _OUTPUTS_TYPE = {}

@property
def done(self) -> bool:
Expand Down Expand Up @@ -144,27 +116,22 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

_check_training_step_output(lightning_module, training_step_output)

self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

result = ManualResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)

if self.trainer.terminate_on_nan:
check_finite_loss(result.closure_loss)

carmocca marked this conversation as resolved.
Show resolved Hide resolved
if self.trainer.move_metrics_to_cpu:
# hiddens and the training step output are not moved as they are not considered "metrics"
# the user might need them on the correct device for an operation in `training_epoch_end`
assert self.trainer._results is not None
self.trainer._results.cpu()

self._done = True
self._output = result
self._output = result.asdict()

def on_run_end(self) -> Optional[ManualResult]:
def on_run_end(self) -> _OUTPUTS_TYPE:
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
output, self._output = self._output, None # free memory
output, self._output = self._output, {} # free memory
# #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance`
# doesn't finish and `self._output` stays as `None`. If #9415 happens then this would always return a result
return output
46 changes: 18 additions & 28 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,17 @@
from pytorch_lightning.loops.utilities import (
_block_parallel_sync_behavior,
_build_training_step_kwargs,
_check_training_step_output,
_extract_hiddens,
check_finite_loss,
)
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache
from pytorch_lightning.utilities.warnings import WarningCache


@dataclass
Expand All @@ -55,12 +52,11 @@ class ClosureResult(OutputResult):

closure_loss: Optional[Tensor]
loss: Optional[Tensor] = field(init=False, default=None)
extra: Dict[str, Tensor] = field(default_factory=dict)
extra: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
# TODO: remove with the deprecation removal in v1.6
ClosureResult._check_extra_detach_deprecation(self.extra)
self.extra = recursive_detach(self.extra)
self.extra = self._check_extra_detach_deprecation(self.extra)

self._clone_loss()

Expand All @@ -78,33 +74,27 @@ def from_training_step_output(
if isinstance(training_step_output, dict):
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
closure_loss = training_step_output.get("loss")
if closure_loss is None:
raise MisconfigurationException(
"In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present"
)
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
elif isinstance(training_step_output, Tensor):
closure_loss = training_step_output
elif training_step_output is not None:
raise MisconfigurationException(
"In automatic optimization, `training_step` must return a Tensor, "
"a dict, or None (where the step will be skipped)."
)

if closure_loss is not None:
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
closure_loss /= normalize

return cls(closure_loss, extra=extra)

@staticmethod
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
def check_fn(v: Tensor) -> Tensor:
if v.grad_fn is not None:
rank_zero_deprecation(
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
" but this behaviour will change in v1.6. Please detach it manually:"
" `return {'loss': ..., 'something': something.detach()}`"
)
return v

apply_to_collection(extra, Tensor, check_fn)

def drop_closure_loss(self) -> "ClosureResult":
"""Return itself without the closure loss which could have a `grad_fn`"""
self.closure_loss = None
return self
def asdict(self) -> Dict[str, Any]:
return {"loss": self.loss, **self.extra}


class Closure(AbstractClosure[ClosureResult]):
Expand Down Expand Up @@ -170,7 +160,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
return self._result.loss


_OUTPUTS_TYPE = List[List[ClosureResult]]
_OUTPUTS_TYPE = List[List[Dict[str, Any]]]


class OptimizerLoop(Loop):
Expand Down Expand Up @@ -222,7 +212,9 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor
self.optimizer_idx,
)
if result.loss is not None:
self.outputs[self.optimizer_idx].append(result.drop_closure_loss())
# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
# would be skipped otherwise
self.outputs[self.optimizer_idx].append(result.asdict())
self.optim_progress.optimizer_position += 1

def on_run_end(self) -> _OUTPUTS_TYPE:
Expand Down Expand Up @@ -467,8 +459,6 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

_check_training_step_output(lightning_module, training_step_output)

self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
Expand Down
27 changes: 1 addition & 26 deletions pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence
from typing import Any, Dict, Generator, Iterator, Optional, Sequence

import torch
from torch.optim import Optimizer
Expand All @@ -37,31 +37,6 @@ def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
raise ValueError(f"The loss returned in `training_step` is {loss}.")


def _check_training_step_output(model: "pl.LightningModule", training_step_output: STEP_OUTPUT) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Sanity checks that training produced a valid output.

Args:
model: a reference to the trainer
training_step_output: the output of the training step (before wrapping in an AttributeDict)
"""
if (
isinstance(training_step_output, torch.Tensor)
and not model.automatic_optimization
and training_step_output.grad_fn is None
):
# TODO: in manual optimization, anything returned should be considered an `extra`
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
if model.automatic_optimization and not (
isinstance(training_step_output, torch.Tensor)
or (isinstance(training_step_output, Mapping) and "loss" in training_step_output)
or training_step_output is None
):
raise MisconfigurationException(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)


def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]:
"""Get the hidden state if present from the training step output.

Expand Down
Loading