From b6d210c990d28e85fdf63f3d77ad9969712d3fbf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 01:28:46 +0200 Subject: [PATCH 01/47] WIP --- .../loops/batch/training_batch_loop.py | 41 +++++----- pytorch_lightning/loops/closure.py | 72 ++++++++++++++---- .../loops/epoch/training_epoch_loop.py | 13 ++-- .../loops/optimizer/optimizer_loop.py | 38 ++++------ pytorch_lightning/loops/utilities.py | 74 ++++--------------- .../logger_connector/logger_connector.py | 3 - .../connectors/logger_connector/result.py | 65 +--------------- tests/trainer/loops/test_training_loop.py | 13 ---- .../loops/test_training_loop_flow_scalar.py | 13 ++-- 9 files changed, 121 insertions(+), 211 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index c2757e4035b48..058188a0dad6f 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from functools import partial from typing import Any, Callable, List, Optional, Tuple @@ -26,7 +25,7 @@ from pytorch_lightning.loops.utilities import ( _build_training_step_kwargs, _check_training_step_output, - _process_training_step_output, + check_finite_loss, ) from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AttributeDict @@ -134,8 +133,7 @@ def advance(self, batch, batch_idx): else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) - if result: - self.batch_outputs[0].append(deepcopy(result.result_collection)) + self.batch_outputs[0].append(result) def teardown(self) -> None: # release memory @@ -149,7 +147,7 @@ def _run_optimization( self, batch_idx: int, split_batch: Any, - ) -> Optional[ClosureResult]: + ) -> ClosureResult: """Runs closure (train step + backward) together with optimization if necessary. Args: @@ -161,7 +159,7 @@ def _run_optimization( closure() result = closure.get_result() - if result: + if result.loss: # if no result, user decided to skip optimization # otherwise update running loss + reset accumulated loss self._update_running_loss(result.loss) @@ -174,24 +172,15 @@ def _make_closure( batch_idx: int, hiddens: Any, ) -> Closure: - """Build a closure object that captures the given arguments and runs the `training_step` function and - optionally other functions such as `backward` and `zero_grad`.""" + """Build a closure object that captures the given arguments and runs the `training_step` function.""" step_fn = self._make_step_fn(split_batch, batch_idx, hiddens) - backward_fn = None - zero_grad_fn = None + return Closure(step_fn=step_fn, profiler=self.trainer.profiler) - return Closure( - step_fn=step_fn, - backward_fn=backward_fn, - zero_grad_fn=zero_grad_fn, - profiler=self.trainer.profiler, - ) - - def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]: + def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], ClosureResult]: """Build the step function that runs the `training_step` and processes its output.""" return partial(self._training_step, split_batch, batch_idx, hiddens) - def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]: + def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> ClosureResult: """Performs the training step for manual optimization. Args: @@ -200,7 +189,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> O hiddens: the model's hidden state of the previous iteration Returns: - an AttributeDict containing the training step output. + A ``ClosureResult`` containing the training step output. """ # give the PL module a result for logging model_ref = self.trainer.lightning_module @@ -222,11 +211,15 @@ def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> O _check_training_step_output(self.trainer.lightning_module, training_step_output) - result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output) - if result_collection is None: - return + result = ClosureResult.from_training_step_output(training_step_output) + + if self.trainer.terminate_on_nan: + check_finite_loss(result.closure_loss) - return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection) + if self.trainer.move_metrics_to_cpu: + self.trainer._results.cpu() + + return result def _tbptt_split_batch(self, batch: Any) -> List[Any]: """Splits a single batch into a list of sequence steps for tbptt. diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 8097d6e15a5d7..a8dead99a0afa 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -13,13 +13,14 @@ # limitations under the License. from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Callable, Dict, Optional from torch import Tensor from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -30,12 +31,51 @@ class ClosureResult: Attributes: closure_loss: The loss with a graph attached. loss: A detached copy of the closure loss. - result_collection: A collection of results returned by the closure. + FIXME """ closure_loss: Optional[Tensor] - loss: Optional[Tensor] - result_collection: Optional[ResultCollection] + hiddens: Optional[Tensor] + loss: Optional[Tensor] = None + extra: Dict[str, Tensor] = field(default_factory=dict) + + def __post_init__(self): + self._set_loss() + + def _set_loss(self) -> None: + if self.closure_loss is not None: + # the loss will get scaled for amp. avoid any modifications to it + self.loss = self.closure_loss.detach().clone() + + @classmethod + def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ClosureResult": + loss = None + hiddens = None + extra = {} + + if isinstance(training_step_output, dict): + # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` + loss = training_step_output.get("loss") + hiddens = training_step_output.get("hiddens") + # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` + hiddens = recursive_detach(hiddens) + # use the setter instead of `dict.update` because it calls `detach` on the tensor items + # FIXME detach? + extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} + elif isinstance(training_step_output, Tensor): + loss = training_step_output + + # map to results under the hood + return cls(loss, hiddens, extra=extra) + + def apply_accumulation(self, value: int) -> None: + """Accumulate loss. + + If ``accumulate_grad_batches == 1``, no effect. + """ + if self.closure_loss is not None: + self.closure_loss /= value + self._set_loss() class AbstractClosure(ABC): @@ -53,25 +93,26 @@ def __init__(self) -> None: super().__init__() self._result: Optional[ClosureResult] = None - def get_result(self) -> Optional[ClosureResult]: + def get_result(self) -> ClosureResult: """The cached result from the last time the closure was called. Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long as necessary. """ + if self._result is None: + raise ValueError("Called `get_result` but the closure hasn't been executed yet") result = self._result self._result = None # free memory return result @abstractmethod - def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]: + def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: """Implements the behavior of the closure once it is getting called.""" pass def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: self._result = self.closure(*args, **kwargs) - if self._result is not None: - return self._result.loss + return self._result.loss class Closure(AbstractClosure): @@ -102,7 +143,7 @@ class Closure(AbstractClosure): def __init__( self, - step_fn: Callable[[], Optional[Dict]], + step_fn: Callable[[], ClosureResult], backward_fn: Optional[Callable[[Tensor], Tensor]] = None, zero_grad_fn: Optional[Callable[[], None]] = None, profiler: Optional[BaseProfiler] = None, @@ -113,19 +154,20 @@ def __init__( self._zero_grad_fn = zero_grad_fn self._profiler = PassThroughProfiler() if profiler is None else profiler - def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]: + def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: with self._profiler.profile("training_step_and_backward"): step_output = self._step_fn() - step_output = ClosureResult(**step_output) if step_output else None - if step_output is None: - self.warning_cache.warn("training_step returned None. If this was on purpose, ignore this warning...") + if step_output.closure_loss is None: + self.warning_cache.warn( + "`training_step` returned `None`. If this was on purpose, ignore this warning..." + ) if self._zero_grad_fn is not None: with self._profiler.profile("zero_grad"): self._zero_grad_fn() - if self._backward_fn is not None and step_output is not None and step_output.closure_loss is not None: + if self._backward_fn is not None and step_output.closure_loss is not None: with self._profiler.profile("backward"): step_output.closure_loss = self._backward_fn(step_output.closure_loss) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ae63e564511c9..f8040c9686aba 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -17,6 +17,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop +from pytorch_lightning.loops.closure import ClosureResult from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress, SchedulerProgress @@ -281,18 +282,18 @@ def _track_epoch_end_reduce_metrics( @staticmethod def _prepare_outputs( - outputs: List[List[List["ResultCollection"]]], batch_mode: bool + outputs: List[List[List[ClosureResult]]], batch_mode: bool ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: """Extract required information from batch or epoch end results. Args: - outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions: + outputs: A 3-dimensional list of ``ClosureResult`` objects with dimensions: ``[optimizer outs][batch outs][tbptt steps]``. batch_mode: If True, ignore the batch output dimension. Returns: - The cleaned outputs with ``ResultCollection`` objects converted to dictionaries. + The cleaned outputs with ``ClosureResult`` objects converted to dictionaries. All list dimensions of size one will be collapsed. """ processed_outputs = [] @@ -309,13 +310,13 @@ def _prepare_outputs( for batch_outputs in opt_outputs: processed_tbptt_outputs = [] - if isinstance(batch_outputs, ResultCollection): + if isinstance(batch_outputs, ClosureResult): batch_outputs = [batch_outputs] for tbptt_output in batch_outputs: out = {} - if tbptt_output.minimize is not None: - out["loss"] = tbptt_output.minimize.detach() + if tbptt_output.loss is not None: + out["loss"] = tbptt_output.loss out.update(tbptt_output.extra) processed_tbptt_outputs.append(out) diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 5ec476787aed8..76806cf551b8f 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple @@ -27,11 +26,11 @@ _block_parallel_sync_behavior, _build_training_step_kwargs, _check_training_step_output, - _process_training_step_output, + check_finite_loss, ) from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm +from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE @@ -76,6 +75,7 @@ def on_run_start( # type: ignore[override] self._optimizers = optimizers def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None: # type: ignore[override] + # FIXME: remove hiddens? self._hiddens = hiddens result = self._run_optimization( batch, @@ -83,8 +83,7 @@ def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None: # type: i self._optimizers[self.optim_progress.optimizer_idx], self.optim_progress.optimizer_idx, ) - if result: - self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result.result_collection)) + self.outputs[self.optim_progress.optimizer_idx].append(result) self.optim_progress.optimizer_idx += 1 @@ -127,7 +126,7 @@ def _run_optimization( batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int, - ) -> Optional[ClosureResult]: + ) -> ClosureResult: """Runs closure (train step + backward) together with optimization if necessary. Args: @@ -160,14 +159,13 @@ def _run_optimization( result = closure.get_result() - if result: + if result.loss is not None: # if no result, user decided to skip optimization # otherwise update running loss + reset accumulated loss # TODO: find proper way to handle updating running loss assert self.trainer.fit_loop is not None assert self.trainer.fit_loop.epoch_loop is not None assert self.trainer.fit_loop.epoch_loop.batch_loop is not None - assert result.loss is not None self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss) # untoggle model params @@ -197,7 +195,7 @@ def _make_closure( def _make_step_fn( self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any - ) -> Callable[[], Optional[AttributeDict]]: + ) -> Callable[[], ClosureResult]: """Build the step function that runs the `training_step` and processes its output.""" return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens) @@ -326,9 +324,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) self.optim_progress.optimizer.zero_grad.increment_completed() - def _training_step( - self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor - ) -> Optional[AttributeDict]: + def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> ClosureResult: """Performs the actual train step with the tied hooks. Args: @@ -361,18 +357,16 @@ def _training_step( _check_training_step_output(self.trainer.lightning_module, training_step_output) - result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output) - if result_collection is None: - return None + closure_result = ClosureResult.from_training_step_output(training_step_output) - # output validation already done, here loss can't be None - assert result_collection.minimize is not None + if self.trainer.terminate_on_nan: + check_finite_loss(closure_result.closure_loss) + + if self.trainer.move_metrics_to_cpu: + self.trainer._results.cpu() - # accumulate loss. if accumulate_grad_batches==1, no effect - closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches - # the loss will get scaled for amp. avoid any modifications to it - loss = closure_loss.detach().clone() - return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection) + closure_result.apply_accumulation(self.trainer.accumulate_grad_batches) + return closure_result def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]: """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 154680535ef73..bda3c04b3d76e 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence, Tuple +from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence import torch from torch import Tensor @@ -21,8 +21,6 @@ import pytorch_lightning as pl from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -47,17 +45,18 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu model: a reference to the trainer training_step_output: the output of the training step (before wrapping in an AttributeDict) """ - if isinstance(training_step_output, torch.Tensor) and not model.automatic_optimization: - if training_step_output.grad_fn is None: - # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... - raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") - elif model.automatic_optimization: - if not any( - ( - isinstance(training_step_output, torch.Tensor), - (isinstance(training_step_output, Mapping) and "loss" in training_step_output), - training_step_output is None, - ) + if ( + isinstance(training_step_output, torch.Tensor) + and not model.automatic_optimization + and training_step_output.grad_fn is None + ): + # FIXME: do we consider it as an extra? + raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + if model.automatic_optimization: + if not ( + isinstance(training_step_output, torch.Tensor) + or (isinstance(training_step_output, Mapping) and "loss" in training_step_output) + or training_step_output is None ): raise MisconfigurationException( "In automatic optimization, `training_step` must either return a Tensor, " @@ -65,53 +64,6 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu ) -def _process_training_step_output( - trainer: "pl.Trainer", training_step_output: STEP_OUTPUT -) -> Tuple[Optional[ResultCollection], Optional[Any]]: - """Adds the :param:`training_step_output` to the trainer's results. - - Args: - trainer: a reference to the trainer - training_step_output: the output of the training step (before wrapping into an AttributeDict) - - Returns: - the updated results (None if the training_step's output was None) and hiddens exract from the results - """ - if training_step_output is None: - return None, None - - results = trainer._results - - loss = None - hiddens = None - - # handle dict return - if isinstance(training_step_output, dict): - # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` - loss = training_step_output.get("loss") - hiddens = training_step_output.get("hiddens") - # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` - hiddens = apply_to_collection(hiddens, torch.Tensor, lambda t: t.detach()) - # use the setter instead of `dict.update` because it calls `detach` on the tensor items - results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} - - # handle scalar return - elif isinstance(training_step_output, torch.Tensor): - loss = training_step_output - - if trainer.terminate_on_nan: - check_finite_loss(loss) - - # the loss shouldn't be moved to cpu. - if trainer.move_metrics_to_cpu: - results.cpu() - - # map to results under the hood - results.minimize = loss - - return results, hiddens - - def _build_training_step_kwargs( lightning_module: "pl.LightningModule", optimizers: Sequence[Optimizer], diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0ebbaa0e25fc9..e8f71dc80f5b3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -209,9 +209,6 @@ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) self._batch_idx = batch_idx self._split_idx = split_idx - # clear reference to this step's training loss so that it can be garbage collected before the next training step - self.trainer._results.minimize = None - def update_train_step_metrics(self) -> None: if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: return diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 71dd88aee5c7b..6f4b0c9b02a9a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -351,7 +351,6 @@ class ResultCollection(dict): def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None: super().__init__() self.training = training - self._minimize: Optional[torch.Tensor] = None self._batch_size = torch.tensor(1, device=device) self.device: Optional[Union[str, torch.device]] = device @@ -375,43 +374,6 @@ def batch_size(self) -> torch.Tensor: def batch_size(self, value: int) -> None: self._batch_size = torch.tensor(value, device=self.device) - @property - def minimize(self) -> Optional[torch.Tensor]: - """The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss will be saved as the - ``minimize`` attribute.""" - return self._minimize - - @minimize.setter - def minimize(self, loss: Optional[torch.Tensor]) -> None: - if loss is not None and not isinstance(loss, torch.Tensor): - raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") - self._minimize = loss - - @property - def extra(self) -> Dict[str, Any]: - """ - Extras are any keys other than the loss returned by - :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` - """ - self.setdefault("_extra", {}) - return self["_extra"] - - @extra.setter - def extra(self, extra: Dict[str, Any]) -> None: - def check_fn(v: torch.Tensor) -> torch.Tensor: - if v.grad_fn is not None: - warning_cache.deprecation( - f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" - " but this behaviour will change in v1.6. Please detach it manually:" - " `return {'loss': ..., 'something': something.detach()}`" - ) - return v.detach() - return v - - # update instead of replace to keep the extra dict reference. TODO: remove with v1.6 deprecation removal - extra.update(apply_to_collection(extra, torch.Tensor, check_fn)) - self["_extra"] = extra - def log( self, fx: str, @@ -518,9 +480,7 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" - return ( - (k, v) for k, v in self.items() if not k == "_extra" and not (isinstance(v, ResultMetric) and v.has_reset) - ) + return ((k, v) for k, v in self.items() if not (isinstance(v, ResultMetric) and v.has_reset)) def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: name = result_metric.meta.name @@ -600,8 +560,6 @@ def to(self, *args: Any, **kwargs: Any) -> "ResultCollection": """Move all data to the given device.""" self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) - if self.minimize is not None: - self.minimize = self.minimize.to(*args, **kwargs) self._batch_size = self._batch_size.to(*args, **kwargs) if "device" in kwargs: self.device = kwargs["device"] @@ -622,31 +580,16 @@ def unsync(self) -> None: result_metric.unsync() def __str__(self) -> str: - # sample output: `ResultCollection(minimize=1.23, {})` - minimize = f"minimize={self.minimize}, " if self.minimize is not None else "" # remove empty values - self_str = str({k: v for k, v in self.items() if v}) - return f"{self.__class__.__name__}({minimize}{self_str})" + return f"{self.__class__.__name__}({dict(self)})" def __repr__(self) -> str: - # sample output: `{True, cpu, minimize=tensor(1.23 grad_fn=), {'_extra': {}}}` - minimize = f"minimize={repr(self.minimize)}, " if self.minimize is not None else "" - return f"{{{self.training}, {repr(self.device)}, " + minimize + f"{super().__repr__()}}}" + return f"{{{self.training}, {repr(self.device)}, {super().__repr__()}}}" def __getstate__(self, drop_value: bool = True) -> dict: d = self.__dict__.copy() - - # can't deepcopy tensors with grad_fn - minimize = d["_minimize"] - if minimize is not None: - d["_minimize"] = minimize.detach() - - extra = self.get("_extra") - if extra is not None: - d["_extra"] = extra - # all the items should be either `ResultMetric`s or `ResultMetricCollection`s - items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items() if k != "_extra"} + items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items()} return {**d, "items": items} def __setstate__( diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index c37681e4831ca..d4652fe0d2c7f 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -188,16 +188,3 @@ def training_epoch_end(self, outputs) -> None: trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) trainer.fit(model) assert model.on_train_batch_end_called == 2 - - -def test_batch_loop_releases_loss(tmpdir): - """Test that loss/graph is released so that it can be garbage collected before the next training step""" - - class TestModel(BoringModel): - def training_step(self, batch, batch_idx): - assert self.trainer._results.minimize is None - return super().training_step(batch, batch_idx) - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) - trainer.fit(model) diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 56674b0ff8e95..5a76e5e040b2e 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -153,8 +153,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out.minimize, torch.Tensor) - assert train_step_out.minimize.item() == 171 + assert isinstance(train_step_out.loss, torch.Tensor) + assert train_step_out.loss.item() == 171 # make sure the optimizer closure returns the correct things opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( @@ -227,8 +227,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out.minimize, torch.Tensor) - assert train_step_out.minimize.item() == 171 + assert isinstance(train_step_out.loss, torch.Tensor) + assert train_step_out.loss.item() == 171 # make sure the optimizer closure returns the correct things opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( @@ -249,6 +249,7 @@ def training_step(self, batch, batch_idx): self.log("a", loss, on_step=True, on_epoch=True) def training_epoch_end(self, outputs) -> None: + print(outputs) assert len(outputs) == 0 def validation_step(self, batch, batch_idx): @@ -264,7 +265,7 @@ def validation_epoch_end(self, outputs): Closure.warning_cache.clear() - with pytest.warns(UserWarning, match=r"training_step returned None.*"): + with pytest.warns(UserWarning, match=r"training_step` returned `None"): trainer.fit(model) assert model.training_step_called @@ -276,7 +277,7 @@ def validation_epoch_end(self, outputs): Closure.warning_cache.clear() - with no_warning_call(UserWarning, match=r"training_step returned None.*"): + with no_warning_call(UserWarning, match=r"training_step` returned `None"): trainer.fit(model) From f97f5f0e15120416b9de9c8a3102e0dbfcbcd11c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 01:45:54 +0200 Subject: [PATCH 02/47] WIP --- .../loops/batch/training_batch_loop.py | 3 +- pytorch_lightning/loops/closure.py | 35 +++++++++++++++---- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 058188a0dad6f..5a70004273f72 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from functools import partial from typing import Any, Callable, List, Optional, Tuple @@ -133,7 +134,7 @@ def advance(self, batch, batch_idx): else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) - self.batch_outputs[0].append(result) + self.batch_outputs[0].append(deepcopy(result)) def teardown(self) -> None: # release memory diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index a8dead99a0afa..01cb8df5f9405 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -13,25 +13,29 @@ # limitations under the License. from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import Any, Callable, Dict, Optional from torch import Tensor from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.utilities.warnings import WarningCache +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache @dataclass class ClosureResult: """A container to hold the result of a :class:`AbstractClosure` call. + It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`. + Attributes: closure_loss: The loss with a graph attached. + hiddens: The hidden tensors if available. loss: A detached copy of the closure loss. - FIXME + extra: Any keys other than the loss returned. """ closure_loss: Optional[Tensor] @@ -39,7 +43,7 @@ class ClosureResult: loss: Optional[Tensor] = None extra: Dict[str, Tensor] = field(default_factory=dict) - def __post_init__(self): + def __post_init__(self) -> None: self._set_loss() def _set_loss(self) -> None: @@ -59,13 +63,13 @@ def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) hiddens = training_step_output.get("hiddens") # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` hiddens = recursive_detach(hiddens) - # use the setter instead of `dict.update` because it calls `detach` on the tensor items - # FIXME detach? extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} + # TODO: remove with the deprecation removal in v1.6 + ClosureResult._check_extra_detach_deprecation(extra) + extra = recursive_detach(extra) elif isinstance(training_step_output, Tensor): loss = training_step_output - # map to results under the hood return cls(loss, hiddens, extra=extra) def apply_accumulation(self, value: int) -> None: @@ -77,6 +81,23 @@ def apply_accumulation(self, value: int) -> None: self.closure_loss /= value self._set_loss() + @staticmethod + def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None: + def check_fn(v: Tensor) -> Tensor: + if v.grad_fn is not None: + rank_zero_deprecation( + f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" + " but this behaviour will change in v1.6. Please detach it manually:" + " `return {'loss': ..., 'something': something.detach()}`" + ) + return v + + apply_to_collection(extra, Tensor, check_fn) + + def __getstate__(self) -> "ClosureResult": + # return a copy without the closure loss which could have a `grad_fn` + return replace(self, closure_loss=None) + class AbstractClosure(ABC): """Abstract base class for optimizer closures in Lightning. From 573f92a3cd0b493399bfd51599f13e83ea6cbf65 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 02:13:12 +0200 Subject: [PATCH 03/47] WIP --- .../loops/batch/training_batch_loop.py | 3 ++- pytorch_lightning/loops/closure.py | 15 ++++++++---- .../loops/optimizer/optimizer_loop.py | 5 ++-- pytorch_lightning/loops/utilities.py | 24 +++++++++---------- .../loops/test_evaluation_loop_flow.py | 8 +++---- .../loops/test_training_loop_flow_scalar.py | 2 +- 6 files changed, 31 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 5a70004273f72..1e51762e2b2bc 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -134,7 +134,8 @@ def advance(self, batch, batch_idx): else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) - self.batch_outputs[0].append(deepcopy(result)) + if result.loss is not None: + self.batch_outputs[0].append(deepcopy(result)) def teardown(self) -> None: # release memory diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 01cb8df5f9405..75eca084e9e52 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -20,6 +20,7 @@ from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache @@ -44,6 +45,10 @@ class ClosureResult: extra: Dict[str, Tensor] = field(default_factory=dict) def __post_init__(self) -> None: + if self.hiddens is not None and self.closure_loss is None: + raise MisconfigurationException( + "If `hiddens` are returned from `training_step`, the loss cannot be `None`." + ) self._set_loss() def _set_loss(self) -> None: @@ -53,13 +58,13 @@ def _set_loss(self) -> None: @classmethod def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ClosureResult": - loss = None + closure_loss = None hiddens = None extra = {} if isinstance(training_step_output, dict): # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` - loss = training_step_output.get("loss") + closure_loss = training_step_output.get("loss") hiddens = training_step_output.get("hiddens") # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` hiddens = recursive_detach(hiddens) @@ -68,9 +73,9 @@ def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) ClosureResult._check_extra_detach_deprecation(extra) extra = recursive_detach(extra) elif isinstance(training_step_output, Tensor): - loss = training_step_output + closure_loss = training_step_output - return cls(loss, hiddens, extra=extra) + return cls(closure_loss, hiddens, extra=extra) def apply_accumulation(self, value: int) -> None: """Accumulate loss. @@ -94,7 +99,7 @@ def check_fn(v: Tensor) -> Tensor: apply_to_collection(extra, Tensor, check_fn) - def __getstate__(self) -> "ClosureResult": + def __deepcopy__(self, *_: Any) -> "ClosureResult": # return a copy without the closure loss which could have a `grad_fn` return replace(self, closure_loss=None) diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 76806cf551b8f..b2fe5a7a23b8e 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from copy import deepcopy from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple @@ -83,7 +83,8 @@ def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None: # type: i self._optimizers[self.optim_progress.optimizer_idx], self.optim_progress.optimizer_idx, ) - self.outputs[self.optim_progress.optimizer_idx].append(result) + if result.loss is not None: + self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result)) self.optim_progress.optimizer_idx += 1 diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index bda3c04b3d76e..06b2ab8eb5d8a 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -38,8 +38,7 @@ def check_finite_loss(loss: Optional[torch.Tensor]) -> None: def _check_training_step_output(model: "pl.LightningModule", training_step_output: STEP_OUTPUT) -> None: - """Sanity checks that training produced a valid output and optimizer step has already been called in manual - optimization. + """Sanity checks that training produced a valid output. Args: model: a reference to the trainer @@ -50,18 +49,17 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu and not model.automatic_optimization and training_step_output.grad_fn is None ): - # FIXME: do we consider it as an extra? + # TODO: in manual optimization, anything returned should be considered an `extra` raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") - if model.automatic_optimization: - if not ( - isinstance(training_step_output, torch.Tensor) - or (isinstance(training_step_output, Mapping) and "loss" in training_step_output) - or training_step_output is None - ): - raise MisconfigurationException( - "In automatic optimization, `training_step` must either return a Tensor, " - "a dict with key 'loss' or None (where the step will be skipped)." - ) + if model.automatic_optimization and not ( + isinstance(training_step_output, torch.Tensor) + or (isinstance(training_step_output, Mapping) and "loss" in training_step_output) + or training_step_output is None + ): + raise MisconfigurationException( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) def _build_training_step_kwargs( diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 916da177c75d6..ec0479fb710aa 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -70,8 +70,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out.minimize, torch.Tensor) - assert train_step_out.minimize.item() == 171 + assert isinstance(train_step_out.loss, torch.Tensor) + assert train_step_out.loss.item() == 171 # make sure the optimizer closure returns the correct things opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( @@ -135,8 +135,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out.minimize, torch.Tensor) - assert train_step_out.minimize.item() == 171 + assert isinstance(train_step_out.loss, torch.Tensor) + assert train_step_out.loss.item() == 171 # make sure the optimizer closure returns the correct things opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 5a76e5e040b2e..579062591795d 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -304,7 +304,7 @@ def training_step(self, batch, batch_idx): Closure.warning_cache.clear() - with pytest.warns(UserWarning, match=r".*training_step returned None.*"): + with pytest.warns(UserWarning, match=r".*training_step` returned `None.*"): trainer.fit(model) trainer.state.stage = RunningStage.TRAINING From 718c9717cbd694046eb552f874a94090382b3a0a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 02:45:45 +0200 Subject: [PATCH 04/47] WIP --- pytorch_lightning/accelerators/accelerator.py | 37 ++----------------- pytorch_lightning/core/lightning.py | 9 ++--- pytorch_lightning/loops/closure.py | 8 ++-- .../loops/optimizer/optimizer_loop.py | 3 +- pytorch_lightning/loops/utilities.py | 3 +- .../connectors/logger_connector/result.py | 3 +- 6 files changed, 16 insertions(+), 47 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f40dc9e1576cf..93915ac946ae9 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -173,15 +173,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual training step. - Args: - step_kwargs: the arguments for the models training step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): Integer displaying index of this batch - - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. + See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """ with self.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) @@ -192,14 +184,7 @@ def post_training_step(self) -> None: def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual validation step. - Args: - step_kwargs: the arguments for the models validation step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple val dataloaders used) + See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """ with self.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) @@ -207,14 +192,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual test step. - Args: - step_kwargs: the arguments for the models test step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch. - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple test dataloaders used). + See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """ with self.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) @@ -222,14 +200,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual predict step. - Args: - step_kwargs: the arguments for the models predict step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch. - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple predict dataloaders used). + See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """ with self.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e1b4d1f3f7e58..b810a78d1be81 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -620,9 +620,9 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: Args: batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if + batch_idx (``int``): Integer displaying index of this batch + optimizer_idx (``int``): When using multiple optimizers, this argument will also be present. + hiddens (``Any``): Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Return: @@ -669,8 +669,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): # Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step - ... - out, hiddens = self.lstm(data, hiddens) + loss, hiddens = self.lstm(data, hiddens) ... return {"loss": loss, "hiddens": hiddens} diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 75eca084e9e52..e1aba4ffb3be0 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -40,16 +40,16 @@ class ClosureResult: """ closure_loss: Optional[Tensor] - hiddens: Optional[Tensor] + hiddens: Optional[Any] loss: Optional[Tensor] = None extra: Dict[str, Tensor] = field(default_factory=dict) def __post_init__(self) -> None: - if self.hiddens is not None and self.closure_loss is None: + self._set_loss() + if self.hiddens is not None and self.loss is None: raise MisconfigurationException( "If `hiddens` are returned from `training_step`, the loss cannot be `None`." ) - self._set_loss() def _set_loss(self) -> None: if self.closure_loss is not None: @@ -126,7 +126,7 @@ def get_result(self) -> ClosureResult: as necessary. """ if self._result is None: - raise ValueError("Called `get_result` but the closure hasn't been executed yet") + raise ValueError("The closure hasn't been executed yet") result = self._result self._result = None # free memory return result diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index b2fe5a7a23b8e..77f957f9c1147 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -28,14 +28,13 @@ _check_training_step_output, check_finite_loss, ) -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE -_OUTPUTS_TYPE = List[List[Optional[ResultCollection]]] +_OUTPUTS_TYPE = List[List[ClosureResult]] class OptimizerLoop(Loop): diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 06b2ab8eb5d8a..364c821cd6d1e 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -16,7 +16,6 @@ from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence import torch -from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl @@ -68,7 +67,7 @@ def _build_training_step_kwargs( batch: Any, batch_idx: int, opt_idx: Optional[int], - hiddens: Optional[Tensor], + hiddens: Optional[Any], ) -> Dict[str, Any]: """Builds the keyword arguments for training_step. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 6f4b0c9b02a9a..4246f7a7303ff 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -581,7 +581,8 @@ def unsync(self) -> None: def __str__(self) -> str: # remove empty values - return f"{self.__class__.__name__}({dict(self)})" + self_str = str({k: v for k, v in self.items() if v}) + return f"{self.__class__.__name__}({self_str})" def __repr__(self) -> str: return f"{{{self.training}, {repr(self.device)}, {super().__repr__()}}}" From 096129e5fed478e1161aa0302271ab71f92ca19e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 03:45:48 +0200 Subject: [PATCH 05/47] Remove hiddens --- .../loops/batch/training_batch_loop.py | 10 ++++---- pytorch_lightning/loops/closure.py | 7 +++--- .../loops/optimizer/optimizer_loop.py | 25 ++++++++----------- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 1e51762e2b2bc..77c7fbd95fcb8 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -47,7 +47,6 @@ def __init__(self) -> None: self.optimizer_loop = OptimizerLoop() self._warning_cache: WarningCache = WarningCache() - self._hiddens: Optional[Tensor] = None self._optimizer_freq_cumsum: Optional[int] = None self._remaining_splits: Optional[List[Any]] = None @@ -97,7 +96,6 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict: def reset(self) -> None: """Resets the loop state.""" - self._hiddens = None self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] def on_run_start(self, batch: Any, batch_idx: int): @@ -127,7 +125,7 @@ def advance(self, batch, batch_idx): if self.trainer.lightning_module.automatic_optimization: # in automatic optimization, hand over execution to the OptimizerLoop optimizers = [optimizer for _, optimizer in self.get_active_optimizers(batch_idx)] - batch_outputs, self._hiddens = self.optimizer_loop.run(split_batch, self._hiddens, optimizers, batch_idx) + batch_outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) # combine outputs from each optimizer for k in range(len(batch_outputs)): self.batch_outputs[k].extend(batch_outputs[k]) @@ -137,6 +135,9 @@ def advance(self, batch, batch_idx): if result.loss is not None: self.batch_outputs[0].append(deepcopy(result)) + def on_run_end(self) -> Any: + self.optimizer_loop._hiddens = None + def teardown(self) -> None: # release memory self._remaining_splits = None @@ -156,8 +157,7 @@ def _run_optimization( batch_idx: the index of the current batch split_batch: the current tbptt split of the whole batch """ - # TODO: replace call through closure by direct call (manual optimization) - closure = self._make_closure(split_batch, batch_idx, self._hiddens) + closure = self._make_closure(split_batch, batch_idx, None) closure() result = closure.get_result() diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index e1aba4ffb3be0..2773941eb8216 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -47,9 +47,7 @@ class ClosureResult: def __post_init__(self) -> None: self._set_loss() if self.hiddens is not None and self.loss is None: - raise MisconfigurationException( - "If `hiddens` are returned from `training_step`, the loss cannot be `None`." - ) + raise MisconfigurationException("If `hiddens` are returned from `training_step`, the loss cannot be `None`") def _set_loss(self) -> None: if self.closure_loss is not None: @@ -101,7 +99,8 @@ def check_fn(v: Tensor) -> Tensor: def __deepcopy__(self, *_: Any) -> "ClosureResult": # return a copy without the closure loss which could have a `grad_fn` - return replace(self, closure_loss=None) + # and without `hiddens` which are not necessary + return replace(self, closure_loss=None, hiddens=None) class AbstractClosure(ABC): diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 77f957f9c1147..8663eb389ccd2 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -13,7 +13,7 @@ # limitations under the License. from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import torch from torch import Tensor @@ -67,15 +67,11 @@ def reset(self) -> None: self.optim_progress.optimizer_idx = 0 self.outputs = [[] for _ in range(len(self.trainer.optimizers))] - def on_run_start( # type: ignore[override] - self, batch: Any, hiddens: Any, optimizers: List[Optimizer], batch_idx: int - ) -> None: + def on_run_start(self, batch: Any, optimizers: List[Optimizer], batch_idx: int) -> None: # type: ignore[override] self._batch_idx = batch_idx self._optimizers = optimizers - def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None: # type: ignore[override] - # FIXME: remove hiddens? - self._hiddens = hiddens + def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override] result = self._run_optimization( batch, self._batch_idx, @@ -83,17 +79,16 @@ def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None: # type: i self.optim_progress.optimizer_idx, ) if result.loss is not None: + self._hiddens = result.hiddens self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result)) self.optim_progress.optimizer_idx += 1 - def on_run_end(self) -> Tuple[_OUTPUTS_TYPE, Optional[Any]]: + def on_run_end(self) -> _OUTPUTS_TYPE: outputs = self.outputs - hiddens = self._hiddens # free memory self.outputs = [] - self._hiddens = None - return outputs, hiddens + return outputs def backward( self, @@ -178,7 +173,7 @@ def _make_closure( batch_idx: int, opt_idx: int, optimizer: Optimizer, - hiddens: Any, + hiddens: Optional[Any], ) -> Closure: """Build a closure object that captures the given arguments and runs the `training_step` function and optionally other functions such as `backward` and `zero_grad`.""" @@ -194,7 +189,7 @@ def _make_closure( ) def _make_step_fn( - self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any + self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Optional[Any] ) -> Callable[[], ClosureResult]: """Build the step function that runs the `training_step` and processes its output.""" return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens) @@ -324,7 +319,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) self.optim_progress.optimizer.zero_grad.increment_completed() - def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> ClosureResult: + def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Optional[Any]) -> ClosureResult: """Performs the actual train step with the tied hooks. Args: @@ -334,7 +329,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens hiddens: the model's hidden state of the previous iteration Returns: - an AttributeDict containing the loss value and the training step output. + A ``ClosureResult`` containing the training step output. """ # give the PL module a result for logging model_ref = self.trainer.lightning_module From bd90a94726123e72e70d0ca9c82071159fa60763 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 04:15:40 +0200 Subject: [PATCH 06/47] Add closure result tests --- pytorch_lightning/loops/closure.py | 18 +++++++------ tests/loops/test_closure.py | 42 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 8 deletions(-) create mode 100644 tests/loops/test_closure.py diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 2773941eb8216..f3552e264cef1 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from dataclasses import dataclass, field, replace +from dataclasses import asdict, dataclass, field, replace from typing import Any, Callable, Dict, Optional from torch import Tensor @@ -45,6 +45,13 @@ class ClosureResult: extra: Dict[str, Tensor] = field(default_factory=dict) def __post_init__(self) -> None: + # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` + self.hiddens = recursive_detach(self.hiddens) + + # TODO: remove with the deprecation removal in v1.6 + ClosureResult._check_extra_detach_deprecation(self.extra) + self.extra = recursive_detach(self.extra) + self._set_loss() if self.hiddens is not None and self.loss is None: raise MisconfigurationException("If `hiddens` are returned from `training_step`, the loss cannot be `None`") @@ -64,12 +71,7 @@ def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` closure_loss = training_step_output.get("loss") hiddens = training_step_output.get("hiddens") - # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` - hiddens = recursive_detach(hiddens) extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} - # TODO: remove with the deprecation removal in v1.6 - ClosureResult._check_extra_detach_deprecation(extra) - extra = recursive_detach(extra) elif isinstance(training_step_output, Tensor): closure_loss = training_step_output @@ -97,10 +99,10 @@ def check_fn(v: Tensor) -> Tensor: apply_to_collection(extra, Tensor, check_fn) - def __deepcopy__(self, *_: Any) -> "ClosureResult": + def __getstate__(self) -> Dict[str, Any]: # return a copy without the closure loss which could have a `grad_fn` # and without `hiddens` which are not necessary - return replace(self, closure_loss=None, hiddens=None) + return asdict(replace(self, closure_loss=None, hiddens=None)) class AbstractClosure(ABC): diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py new file mode 100644 index 0000000000000..65b46a1463f2b --- /dev/null +++ b/tests/loops/test_closure.py @@ -0,0 +1,42 @@ +import pickle +from copy import deepcopy + +import pytest +import torch + +from pytorch_lightning.loops.closure import ClosureResult +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def test_closure_result_deepcopy(): + closure_loss = torch.tensor(123.45) + hiddens = torch.tensor(321.12, requires_grad=True) + result = ClosureResult(closure_loss, hiddens) + assert not result.hiddens.requires_grad + + assert closure_loss.data_ptr() == result.closure_loss.data_ptr() + # the `loss` is cloned so the storage is different + assert closure_loss.data_ptr() != result.loss.data_ptr() + + copy = deepcopy(result) + assert result.loss == copy.loss + assert copy.closure_loss is None + assert copy.hiddens is None + + assert id(result.loss) != id(copy.loss) + assert result.loss.data_ptr() != copy.loss.data_ptr() + + assert copy == pickle.loads(pickle.dumps(result)) + + +def test_closure_result_raises(): + with pytest.raises(MisconfigurationException, match="If `hiddens` are returned .* the loss cannot be `None`"): + ClosureResult(None, "something") + + +def test_closure_result_apply_accumulation(): + closure_loss = torch.tensor(25.0) + result = ClosureResult(closure_loss, None) + assert result.loss == 25 + result.apply_accumulation(5) + assert result.loss == 5 From bc8708105b04785cd6b3aee38b84bc697a8b9e0a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 04:19:08 +0200 Subject: [PATCH 07/47] Fix tests --- tests/core/test_lightning_optimizer.py | 1 + tests/core/test_metric_result_integration.py | 13 ++----------- tests/loops/test_loop_state_dict.py | 4 ---- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 2348089ed37e5..ea683aaf74e57 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -218,6 +218,7 @@ def training_epoch_end(self, outputs): ... def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_): + # FIXME assert isinstance(optimizer_closure, Closure) # not passing the closure to the optimizer because step is mocked # zero_grad is called inside the closure diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 1f4b72744c03b..a7572f9a77394 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -141,7 +141,6 @@ def test_result_metric_integration(): result.extra = {} assert str(result) == ( "ResultCollection(" - "minimize=1.0, " "{" "'h.a': ResultMetric('a', value=DummyMetric()), " "'h.b': ResultMetric('b', value=DummyMetric()), " @@ -152,12 +151,10 @@ def test_result_metric_integration(): "{" "True, " "device(type='cpu'), " - "minimize=tensor(1.), " "{'h.a': ResultMetric('a', value=DummyMetric()), " "'h.b': ResultMetric('b', value=DummyMetric()), " - "'h.c': ResultMetric('c', value=DummyMetric()), " - "'_extra': {}}" - "}" + "'h.c': ResultMetric('c', value=DummyMetric())" + "}}" ) @@ -352,12 +349,6 @@ def on_save_checkpoint(self, checkpoint) -> None: trainer.fit(model) -def test_result_collection_extra_reference(): - """Unit-test to check that the `extra` dict reference is properly set.""" - rc = ResultCollection(True) - assert rc.extra is rc["_extra"] - - class DummyMeanMetric(Metric): def __init__(self): super().__init__() diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 0010af32b4f99..e00e8aecc98ff 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -80,14 +80,12 @@ def test_loops_state_dict_structure(): }, "epoch_loop.val_loop._results": { "training": False, - "_minimize": None, "_batch_size": torch.tensor(1), "device": None, "items": {}, }, "epoch_loop._results": { "training": True, - "_minimize": None, "_batch_size": torch.tensor(1), "device": None, "items": {}, @@ -107,7 +105,6 @@ def test_loops_state_dict_structure(): }, "_results": { "training": False, - "_minimize": None, "_batch_size": torch.tensor(1), "device": None, "items": {}, @@ -123,7 +120,6 @@ def test_loops_state_dict_structure(): }, "_results": { "training": False, - "_minimize": None, "_batch_size": torch.tensor(1), "device": None, "items": {}, From 15c59723347a383fcc7bab35d28523047349cd41 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 16:42:02 +0200 Subject: [PATCH 08/47] Fail if closure is not executed --- pytorch_lightning/loops/closure.py | 5 +++- tests/core/test_lightning_optimizer.py | 23 +++++++++++-------- .../optimization/test_multiple_optimizers.py | 4 +++- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index f3552e264cef1..8ca748987285b 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -127,7 +127,10 @@ def get_result(self) -> ClosureResult: as necessary. """ if self._result is None: - raise ValueError("The closure hasn't been executed yet") + raise MisconfigurationException( + "The closure hasn't been executed." + " HINT: did you call `optimizer_closure()` in your `optimizer_step` hook?" + ) result = self._result self._result = None # free memory return result diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index ea683aaf74e57..9d3b4d4dc88fe 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -218,16 +218,14 @@ def training_epoch_end(self, outputs): ... def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_): - # FIXME assert isinstance(optimizer_closure, Closure) - # not passing the closure to the optimizer because step is mocked # zero_grad is called inside the closure + optimizer_closure() + # not passing the closure to the optimizer because step is mocked if isinstance(optimizer, SGD) and batch_idx % 2 == 0: - optimizer_closure() optimizer.step() if isinstance(optimizer, Adam) and batch_idx % 4 == 0: - optimizer_closure() - optimizer.step() # not passing the closure here because it's a mock + optimizer.step() def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -237,8 +235,13 @@ def configure_optimizers(self): model = TestModel() + limit_train_batches = 8 trainer = Trainer( - default_root_dir=tmpdir, limit_train_batches=8, limit_val_batches=1, max_epochs=1, weights_summary=None + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, ) with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, patch.multiple( @@ -246,11 +249,11 @@ def configure_optimizers(self): ) as adam: trainer.fit(model) - assert sgd["step"].call_count == 4 - assert adam["step"].call_count == 2 + assert sgd["step"].call_count == limit_train_batches // 2 + assert adam["step"].call_count == limit_train_batches // 4 - assert sgd["zero_grad"].call_count == 4 - assert adam["zero_grad"].call_count == 2 + assert sgd["zero_grad"].call_count == limit_train_batches + assert adam["zero_grad"].call_count == limit_train_batches def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir): diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index 603adb36d6981..352a5e7d937e5 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -177,6 +177,8 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_c if batch_idx % 2 == 0: self.optimizer_step_called[optimizer_idx] += 1 optimizer.step(closure=optimizer_closure) + else: + optimizer_closure() model = TestModel() model.val_dataloader = None @@ -185,5 +187,5 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_c default_root_dir=tmpdir, limit_train_batches=4, max_epochs=1, log_every_n_steps=1, weights_summary=None ) trainer.fit(model) - assert model.training_step_called == [4, 2] + assert model.training_step_called == [4, 4] assert model.optimizer_step_called == [4, 2] From 9803cec7ae0a29d0983cbea1876d09633ea8cdc4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 16:44:21 +0200 Subject: [PATCH 09/47] Remove deepcopy --- pytorch_lightning/loops/batch/training_batch_loop.py | 3 +-- pytorch_lightning/loops/optimizer/optimizer_loop.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 77c7fbd95fcb8..8f1f3169775bd 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from functools import partial from typing import Any, Callable, List, Optional, Tuple @@ -133,7 +132,7 @@ def advance(self, batch, batch_idx): # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) if result.loss is not None: - self.batch_outputs[0].append(deepcopy(result)) + self.batch_outputs[0].append(result) def on_run_end(self) -> Any: self.optimizer_loop._hiddens = None diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 8663eb389ccd2..1089822951fa1 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from functools import partial from typing import Any, Callable, Dict, List, Optional @@ -80,7 +79,7 @@ def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override ) if result.loss is not None: self._hiddens = result.hiddens - self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result)) + self.outputs[self.optim_progress.optimizer_idx].append(result) self.optim_progress.optimizer_idx += 1 From 24a70b95aa059a818552e930b8406116f15eb502 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 16:56:16 +0200 Subject: [PATCH 10/47] Remove debugging statement --- tests/trainer/loops/test_training_loop_flow_scalar.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 579062591795d..f8ae5b1eb7912 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -249,7 +249,6 @@ def training_step(self, batch, batch_idx): self.log("a", loss, on_step=True, on_epoch=True) def training_epoch_end(self, outputs) -> None: - print(outputs) assert len(outputs) == 0 def validation_step(self, batch, batch_idx): From 8630761d9e06716ebba47619898d4c90c561fbd0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 17:47:23 +0200 Subject: [PATCH 11/47] Enforce that the optimizer closure is executed when `optimizer_step` is overridden --- .../loops/batch/training_batch_loop.py | 6 ++--- pytorch_lightning/loops/closure.py | 25 ++++++++++++------- .../loops/optimizer/optimizer_loop.py | 7 +++--- tests/core/test_lightning_optimizer.py | 22 +++++++++------- tests/loops/test_closure.py | 18 +++++++++++++ .../loops/test_training_loop_flow_scalar.py | 6 ++--- .../optimization/test_multiple_optimizers.py | 14 ++++++++--- 7 files changed, 67 insertions(+), 31 deletions(-) create mode 100644 tests/loops/test_closure.py diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index c2757e4035b48..20a1d93cf33af 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -134,7 +134,7 @@ def advance(self, batch, batch_idx): else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) - if result: + if result.result_collection is not None: self.batch_outputs[0].append(deepcopy(result.result_collection)) def teardown(self) -> None: @@ -149,7 +149,7 @@ def _run_optimization( self, batch_idx: int, split_batch: Any, - ) -> Optional[ClosureResult]: + ) -> ClosureResult: """Runs closure (train step + backward) together with optimization if necessary. Args: @@ -161,7 +161,7 @@ def _run_optimization( closure() result = closure.get_result() - if result: + if result.loss: # if no result, user decided to skip optimization # otherwise update running loss + reset accumulated loss self._update_running_loss(result.loss) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 8097d6e15a5d7..508e4741879fd 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -20,6 +20,7 @@ from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -53,25 +54,29 @@ def __init__(self) -> None: super().__init__() self._result: Optional[ClosureResult] = None - def get_result(self) -> Optional[ClosureResult]: + def get_result(self) -> ClosureResult: """The cached result from the last time the closure was called. Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long as necessary. """ result = self._result + if self._result is None: + raise MisconfigurationException( + "The closure hasn't been executed." + " HINT: did you call `optimizer_closure()` in your `optimizer_step` hook?" + ) self._result = None # free memory return result @abstractmethod - def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]: + def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: """Implements the behavior of the closure once it is getting called.""" pass def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: self._result = self.closure(*args, **kwargs) - if self._result is not None: - return self._result.loss + return self._result.loss class Closure(AbstractClosure): @@ -113,19 +118,21 @@ def __init__( self._zero_grad_fn = zero_grad_fn self._profiler = PassThroughProfiler() if profiler is None else profiler - def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]: + def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: with self._profiler.profile("training_step_and_backward"): step_output = self._step_fn() - step_output = ClosureResult(**step_output) if step_output else None + step_output = ClosureResult(**step_output) if step_output else ClosureResult(None, None, None) - if step_output is None: - self.warning_cache.warn("training_step returned None. If this was on purpose, ignore this warning...") + if step_output.closure_loss is None: + self.warning_cache.warn( + "`training_step` returned `None`. If this was on purpose, ignore this warning..." + ) if self._zero_grad_fn is not None: with self._profiler.profile("zero_grad"): self._zero_grad_fn() - if self._backward_fn is not None and step_output is not None and step_output.closure_loss is not None: + if self._backward_fn is not None and step_output.closure_loss is not None: with self._profiler.profile("backward"): step_output.closure_loss = self._backward_fn(step_output.closure_loss) diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 5ec476787aed8..4e233bb2c6bd8 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -83,7 +83,7 @@ def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None: # type: i self._optimizers[self.optim_progress.optimizer_idx], self.optim_progress.optimizer_idx, ) - if result: + if result.result_collection is not None: self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result.result_collection)) self.optim_progress.optimizer_idx += 1 @@ -127,7 +127,7 @@ def _run_optimization( batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int, - ) -> Optional[ClosureResult]: + ) -> ClosureResult: """Runs closure (train step + backward) together with optimization if necessary. Args: @@ -160,14 +160,13 @@ def _run_optimization( result = closure.get_result() - if result: + if result.loss: # if no result, user decided to skip optimization # otherwise update running loss + reset accumulated loss # TODO: find proper way to handle updating running loss assert self.trainer.fit_loop is not None assert self.trainer.fit_loop.epoch_loop is not None assert self.trainer.fit_loop.epoch_loop.batch_loop is not None - assert result.loss is not None self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss) # untoggle model params diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 2348089ed37e5..9d3b4d4dc88fe 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -219,14 +219,13 @@ def training_epoch_end(self, outputs): def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_): assert isinstance(optimizer_closure, Closure) - # not passing the closure to the optimizer because step is mocked # zero_grad is called inside the closure + optimizer_closure() + # not passing the closure to the optimizer because step is mocked if isinstance(optimizer, SGD) and batch_idx % 2 == 0: - optimizer_closure() optimizer.step() if isinstance(optimizer, Adam) and batch_idx % 4 == 0: - optimizer_closure() - optimizer.step() # not passing the closure here because it's a mock + optimizer.step() def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -236,8 +235,13 @@ def configure_optimizers(self): model = TestModel() + limit_train_batches = 8 trainer = Trainer( - default_root_dir=tmpdir, limit_train_batches=8, limit_val_batches=1, max_epochs=1, weights_summary=None + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, ) with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, patch.multiple( @@ -245,11 +249,11 @@ def configure_optimizers(self): ) as adam: trainer.fit(model) - assert sgd["step"].call_count == 4 - assert adam["step"].call_count == 2 + assert sgd["step"].call_count == limit_train_batches // 2 + assert adam["step"].call_count == limit_train_batches // 4 - assert sgd["zero_grad"].call_count == 4 - assert adam["zero_grad"].call_count == 2 + assert sgd["zero_grad"].call_count == limit_train_batches + assert adam["zero_grad"].call_count == limit_train_batches def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir): diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py new file mode 100644 index 0000000000000..ed79c5d5a22e4 --- /dev/null +++ b/tests/loops/test_closure.py @@ -0,0 +1,18 @@ +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel + + +def test_optimizer_step_no_closure_raises(tmpdir): + class TestModel(BoringModel): + def optimizer_step( + self, epoch=None, batch_idx=None, optimizer=None, optimizer_idx=None, optimizer_closure=None, **_ + ): + pass + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + with pytest.raises(MisconfigurationException, match="The closure hasn't been executed"): + trainer.fit(model) diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 56674b0ff8e95..e71fd5dbe04da 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -264,7 +264,7 @@ def validation_epoch_end(self, outputs): Closure.warning_cache.clear() - with pytest.warns(UserWarning, match=r"training_step returned None.*"): + with pytest.warns(UserWarning, match=r"training_step` returned `None"): trainer.fit(model) assert model.training_step_called @@ -276,7 +276,7 @@ def validation_epoch_end(self, outputs): Closure.warning_cache.clear() - with no_warning_call(UserWarning, match=r"training_step returned None.*"): + with no_warning_call(UserWarning, match=r"training_step` returned `None"): trainer.fit(model) @@ -303,7 +303,7 @@ def training_step(self, batch, batch_idx): Closure.warning_cache.clear() - with pytest.warns(UserWarning, match=r".*training_step returned None.*"): + with pytest.warns(UserWarning, match=r".*training_step` returned `None.*"): trainer.fit(model) trainer.state.stage = RunningStage.TRAINING diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index 603adb36d6981..5c29bbbac2d7e 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -177,13 +177,21 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_c if batch_idx % 2 == 0: self.optimizer_step_called[optimizer_idx] += 1 optimizer.step(closure=optimizer_closure) + else: + optimizer_closure() model = TestModel() model.val_dataloader = None + limit_train_batches = 4 trainer = pl.Trainer( - default_root_dir=tmpdir, limit_train_batches=4, max_epochs=1, log_every_n_steps=1, weights_summary=None + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, ) trainer.fit(model) - assert model.training_step_called == [4, 2] - assert model.optimizer_step_called == [4, 2] + assert len(model.training_step_called) == len(model.optimizer_step_called) == len(model.optimizers()) + assert model.training_step_called == [limit_train_batches, limit_train_batches] + assert model.optimizer_step_called == [limit_train_batches, limit_train_batches // 2] From 46d8113a3484742d4ec589957640bc8ad4c27721 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 17:47:47 +0200 Subject: [PATCH 12/47] Add comment --- tests/loops/test_closure.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index ed79c5d5a22e4..7996eb3d8ff29 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -10,6 +10,7 @@ class TestModel(BoringModel): def optimizer_step( self, epoch=None, batch_idx=None, optimizer=None, optimizer_idx=None, optimizer_closure=None, **_ ): + # does not call `optimizer_closure()` pass model = TestModel() From 32492ca3ee7ea926f1571e8fea12218acb625dc6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 18:12:46 +0200 Subject: [PATCH 13/47] Docs and changelog --- CHANGELOG.md | 3 +++ docs/source/common/optimizers.rst | 2 +- pytorch_lightning/core/lightning.py | 9 ++++++--- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/loops/closure.py | 6 +++--- pytorch_lightning/loops/optimizer/optimizer_loop.py | 2 +- 6 files changed, 15 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa3669c40ccf3..e5555fcaf2605 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved `block_ddp_sync_behaviour` out of `TrainingBatchLoop` to loop utilities ([#9192](https://github.com/PyTorchLightning/pytorch-lightning/pull/9192)) +- Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index d6b7eec29a5b3..39a583d9c94d8 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -443,7 +443,7 @@ For example, here step optimizer A every batch and optimizer B every 2 batches. # the closure (which includes the `training_step`) will be executed by `optimizer.step` optimizer.step(closure=optimizer_closure) else: - # optional: call the closure by itself to run `training_step` + `backward` without an optimizer step + # call the closure by itself to run `training_step` + `backward` without an optimizer step optimizer_closure() # ... diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8ca87f7b2bb00..e3c7402242a3b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1495,15 +1495,15 @@ def optimizer_step( Warning: If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter - to ``optimizer.step()`` function as shown in the examples. This ensures that - ``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within the training loop. + to ``optimizer.step()`` function as shown in the examples. Args: epoch: Current epoch batch_idx: Index of current batch optimizer: A PyTorch optimizer optimizer_idx: If you used multiple optimizers, this indexes into that list. - optimizer_closure: Closure for all optimizers + optimizer_closure: Closure for all optimizers. This closure must be executed as it includes the + calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``. on_tpu: ``True`` if TPU backward is required using_native_amp: ``True`` if using native amp using_lbfgs: True if the matching optimizer is :class:`torch.optim.LBFGS` @@ -1526,6 +1526,9 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, if optimizer_idx == 1: if (batch_idx + 1) % 2 == 0 : optimizer.step(closure=optimizer_closure) + else: + # call the closure by itself to run `training_step` + `backward` without an optimizer step + optimizer_closure() # ... # add as many optimizers as you want diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 20a1d93cf33af..b3614bdae8331 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -159,7 +159,7 @@ def _run_optimization( # TODO: replace call through closure by direct call (manual optimization) closure = self._make_closure(split_batch, batch_idx, self._hiddens) closure() - result = closure.get_result() + result = closure.result if result.loss: # if no result, user decided to skip optimization diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 508e4741879fd..43f15f014d4f5 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -54,19 +54,19 @@ def __init__(self) -> None: super().__init__() self._result: Optional[ClosureResult] = None - def get_result(self) -> ClosureResult: + @property + def result(self) -> ClosureResult: """The cached result from the last time the closure was called. Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long as necessary. """ - result = self._result if self._result is None: raise MisconfigurationException( "The closure hasn't been executed." " HINT: did you call `optimizer_closure()` in your `optimizer_step` hook?" ) - self._result = None # free memory + result, self._result = self._result, None # free memory return result @abstractmethod diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 4e233bb2c6bd8..aaed6870ee871 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -158,7 +158,7 @@ def _run_optimization( else: self._optimizer_step(optimizer, opt_idx, batch_idx, closure) - result = closure.get_result() + result = closure.result if result.loss: # if no result, user decided to skip optimization From bd54363c357e4a8934bc7bbe4087f4f4506da438 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 18:37:11 +0200 Subject: [PATCH 14/47] is not None --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/loops/optimizer/optimizer_loop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index b3614bdae8331..622d94912437f 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -161,7 +161,7 @@ def _run_optimization( closure() result = closure.result - if result.loss: + if result.loss is not None: # if no result, user decided to skip optimization # otherwise update running loss + reset accumulated loss self._update_running_loss(result.loss) diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index aaed6870ee871..3a295f7a42bcd 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -160,7 +160,7 @@ def _run_optimization( result = closure.result - if result.loss: + if result.loss is not None: # if no result, user decided to skip optimization # otherwise update running loss + reset accumulated loss # TODO: find proper way to handle updating running loss From 426544936270d864086ce2286d7ba9f16bbf2073 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 18:42:24 +0200 Subject: [PATCH 15/47] Improve error message to cover broken optimizers --- pytorch_lightning/loops/closure.py | 3 ++- tests/loops/test_closure.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 43f15f014d4f5..dbed006333d4a 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -64,7 +64,8 @@ def result(self) -> ClosureResult: if self._result is None: raise MisconfigurationException( "The closure hasn't been executed." - " HINT: did you call `optimizer_closure()` in your `optimizer_step` hook?" + " HINT: did you call `optimizer_closure()` in your `optimizer_step` hook? It could also happen because" + " the `optimizer.step(optimizer_closure)` call did not execute it internally" ) result, self._result = self._result, None # free memory return result diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index 7996eb3d8ff29..c309b8f75fec4 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -1,4 +1,5 @@ import pytest +import torch from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -17,3 +18,17 @@ def optimizer_step( trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) with pytest.raises(MisconfigurationException, match="The closure hasn't been executed"): trainer.fit(model) + + class TestModel(BoringModel): + def configure_optimizers(self): + class BrokenSGD(torch.optim.SGD): + def step(self, closure=None): + # forgot to pass the closure + return super().step() + + return BrokenSGD(self.layer.parameters(), lr=0.1) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + with pytest.raises(MisconfigurationException, match="The closure hasn't been executed"): + trainer.fit(model) From 8716d4cca0a869d759df699711e5b3ddc8acb17a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Sep 2021 18:43:39 +0200 Subject: [PATCH 16/47] Undo getter --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/loops/closure.py | 3 +-- pytorch_lightning/loops/optimizer/optimizer_loop.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 622d94912437f..f6b78033e6e81 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -159,7 +159,7 @@ def _run_optimization( # TODO: replace call through closure by direct call (manual optimization) closure = self._make_closure(split_batch, batch_idx, self._hiddens) closure() - result = closure.result + result = closure.consume_result() if result.loss is not None: # if no result, user decided to skip optimization diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index dbed006333d4a..1686638160aba 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -54,8 +54,7 @@ def __init__(self) -> None: super().__init__() self._result: Optional[ClosureResult] = None - @property - def result(self) -> ClosureResult: + def consume_result(self) -> ClosureResult: """The cached result from the last time the closure was called. Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 3a295f7a42bcd..b7687f4aa5ff3 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -158,7 +158,7 @@ def _run_optimization( else: self._optimizer_step(optimizer, opt_idx, batch_idx, closure) - result = closure.result + result = closure.consume_result() if result.loss is not None: # if no result, user decided to skip optimization From 3f7ea809fefc6a1a180de76bb3e3f098a97576c6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 02:23:15 +0200 Subject: [PATCH 17/47] Add license and period --- pytorch_lightning/loops/closure.py | 2 +- tests/loops/test_closure.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 1686638160aba..001e71f464ec4 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -64,7 +64,7 @@ def consume_result(self) -> ClosureResult: raise MisconfigurationException( "The closure hasn't been executed." " HINT: did you call `optimizer_closure()` in your `optimizer_step` hook? It could also happen because" - " the `optimizer.step(optimizer_closure)` call did not execute it internally" + " the `optimizer.step(optimizer_closure)` call did not execute it internally." ) result, self._result = self._result, None # free memory return result diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index c309b8f75fec4..996d9d83b9948 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import torch From 3140d9be4be162b26a943b89e7b9ce4374ed5e20 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 03:31:20 +0200 Subject: [PATCH 18/47] Clone loss --- pytorch_lightning/loops/closure.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 986a9204b0cf3..14f3cb26c1b2b 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -52,11 +52,11 @@ def __post_init__(self) -> None: ClosureResult._check_extra_detach_deprecation(self.extra) self.extra = recursive_detach(self.extra) - self._set_loss() + self._clone_loss() if self.hiddens is not None and self.loss is None: raise MisconfigurationException("If `hiddens` are returned from `training_step`, the loss cannot be `None`") - def _set_loss(self) -> None: + def _clone_loss(self) -> None: if self.closure_loss is not None: # the loss will get scaled for amp. avoid any modifications to it self.loss = self.closure_loss.detach().clone() @@ -84,7 +84,7 @@ def apply_accumulation(self, value: int) -> None: """ if self.closure_loss is not None: self.closure_loss /= value - self._set_loss() + self._clone_loss() @staticmethod def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None: From 6689f3d190a22e38bfddcfb0a3f78965352a48fd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 03:39:51 +0200 Subject: [PATCH 19/47] Remove apply_accumulation --- pytorch_lightning/loops/closure.py | 17 ++++++----------- .../loops/optimizer/optimizer_loop.py | 5 +++-- tests/loops/test_closure.py | 4 +--- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 14f3cb26c1b2b..96e7c7f18e346 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field, replace from typing import Any, Callable, Dict, Optional @@ -62,7 +61,9 @@ def _clone_loss(self) -> None: self.loss = self.closure_loss.detach().clone() @classmethod - def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ClosureResult": + def from_training_step_output( + cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1 + ) -> "ClosureResult": closure_loss = None hiddens = None extra = {} @@ -75,16 +76,10 @@ def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) elif isinstance(training_step_output, Tensor): closure_loss = training_step_output - return cls(closure_loss, hiddens, extra=extra) - - def apply_accumulation(self, value: int) -> None: - """Accumulate loss. + # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect + closure_loss /= normalize - If ``accumulate_grad_batches == 1``, no effect. - """ - if self.closure_loss is not None: - self.closure_loss /= value - self._clone_loss() + return cls(closure_loss, hiddens, extra=extra) @staticmethod def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None: diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index ed9d481e030b9..c886916a8144b 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -351,7 +351,9 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens _check_training_step_output(self.trainer.lightning_module, training_step_output) - closure_result = ClosureResult.from_training_step_output(training_step_output) + closure_result = ClosureResult.from_training_step_output( + training_step_output, self.trainer.accumulate_grad_batches + ) if self.trainer.terminate_on_nan: check_finite_loss(closure_result.closure_loss) @@ -359,7 +361,6 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens if self.trainer.move_metrics_to_cpu: self.trainer._results.cpu() - closure_result.apply_accumulation(self.trainer.accumulate_grad_batches) return closure_result def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]: diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index bec635c91497a..a362276a4a74d 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -79,7 +79,5 @@ def test_closure_result_raises(): def test_closure_result_apply_accumulation(): closure_loss = torch.tensor(25.0) - result = ClosureResult(closure_loss, None) - assert result.loss == 25 - result.apply_accumulation(5) + result = ClosureResult.from_training_step_output(closure_loss, 5) assert result.loss == 5 From 6f9f5af824ca650497ea20edd4c4c475d4516999 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 03:53:17 +0200 Subject: [PATCH 20/47] Move to cpu --- pytorch_lightning/loops/batch/manual.py | 1 + pytorch_lightning/loops/closure.py | 21 +++++++++++++++---- .../loops/optimizer/optimizer_loop.py | 1 + tests/loops/test_closure.py | 11 ++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index dd60d1098d29a..46335ec6801fd 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -83,6 +83,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] if self.trainer.move_metrics_to_cpu: self.trainer._results.cpu() + result.cpu() self._done = True self._output = result diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 96e7c7f18e346..1979ae5f35940 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass, field, replace +from dataclasses import dataclass, field from typing import Any, Callable, Dict, Optional from torch import Tensor from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -40,7 +40,7 @@ class ClosureResult: closure_loss: Optional[Tensor] hiddens: Optional[Any] - loss: Optional[Tensor] = None + loss: Optional[Tensor] = field(init=False, default=None) extra: Dict[str, Tensor] = field(default_factory=dict) def __post_init__(self) -> None: @@ -94,10 +94,23 @@ def check_fn(v: Tensor) -> Tensor: apply_to_collection(extra, Tensor, check_fn) + def to(self, *args: Any, **kwargs: Any) -> "ClosureResult": + """Move all data to the given device.""" + if self.closure_loss is not None: + self.closure_loss = self.closure_loss.to(*args, **kwargs) + self.loss = self.loss.to(*args, **kwargs) + self.hiddens = apply_to_collection(self.hiddens, Tensor, move_data_to_device, *args, **kwargs) + self.extra = apply_to_collection(self.extra, Tensor, move_data_to_device, *args, **kwargs) + return self + + def cpu(self) -> "ClosureResult": + """Move all data to CPU.""" + return self.to(device="cpu") + def __getstate__(self) -> Dict[str, Any]: # return a copy without the closure loss which could have a `grad_fn` # and without `hiddens` which are not necessary - return asdict(replace(self, closure_loss=None, hiddens=None)) + return {"loss": self.loss, "extra": self.extra, "closure_loss": None, "hiddens": None} class AbstractClosure(ABC): diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index c886916a8144b..8f11c0ebc35dd 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -360,6 +360,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens if self.trainer.move_metrics_to_cpu: self.trainer._results.cpu() + closure_result.cpu() return closure_result diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index a362276a4a74d..f478fffd0a166 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -61,6 +61,9 @@ def test_closure_result_deepcopy(): # the `loss` is cloned so the storage is different assert closure_loss.data_ptr() != result.loss.data_ptr() + # make sure `__getstate__` is not missing any keys + assert vars(result).keys() == result.__getstate__().keys() + copy = deepcopy(result) assert result.loss == copy.loss assert copy.closure_loss is None @@ -81,3 +84,11 @@ def test_closure_result_apply_accumulation(): closure_loss = torch.tensor(25.0) result = ClosureResult.from_training_step_output(closure_loss, 5) assert result.loss == 5 + + +def test_closure_to(): + result = ClosureResult(torch.tensor(1.0), (torch.tensor(2.0), torch.tensor(3.0)), extra={"foo": torch.tensor(4.0)}) + result.to(torch.half) + assert result.loss.dtype == torch.half + assert all(t.dtype == torch.half for t in result.hiddens) + assert result.extra["foo"].dtype == torch.half From 8108c1f53595c812186639dea01fcb300849776c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 04:02:53 +0200 Subject: [PATCH 21/47] mypy --- pytorch_lightning/loops/batch/manual.py | 6 +++--- pytorch_lightning/loops/closure.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index 46335ec6801fd..97c047c42c51e 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Optional from pytorch_lightning.loops import Loop @@ -21,7 +20,6 @@ _check_training_step_output, check_finite_loss, ) -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection class ManualOptimization(Loop): @@ -37,7 +35,7 @@ def __init__(self) -> None: super().__init__() self._done: bool = False self._hiddens: Optional[Any] = None - self._output: Optional[ResultCollection] = None + self._output: Optional[ClosureResult] = None @property def done(self) -> bool: @@ -82,6 +80,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] check_finite_loss(result.closure_loss) if self.trainer.move_metrics_to_cpu: + assert self.trainer._results is not None self.trainer._results.cpu() result.cpu() @@ -91,4 +90,5 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] def on_run_end(self) -> ClosureResult: """Returns the result of this loop, i.e., the post-processed outputs from the training step.""" output, self._output = self._output, None # free memory + assert output is not None, "`advance` should have been called" return output diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 1979ae5f35940..ac9251cc9eae0 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -76,8 +76,9 @@ def from_training_step_output( elif isinstance(training_step_output, Tensor): closure_loss = training_step_output - # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect - closure_loss /= normalize + if closure_loss is not None: + # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect + closure_loss /= normalize return cls(closure_loss, hiddens, extra=extra) From c4f360a5eb45154a28b7e60729f1f727fa5a7046 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 18:38:26 +0200 Subject: [PATCH 22/47] Handle hiddens with an utility --- pytorch_lightning/loops/batch/manual.py | 10 +++++++--- pytorch_lightning/loops/closure.py | 14 ++------------ .../loops/optimizer/optimizer_loop.py | 10 ++++++---- pytorch_lightning/loops/utilities.py | 8 ++++++++ 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index 97c047c42c51e..baae27e916fb8 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -19,6 +19,7 @@ _build_training_step_kwargs, _check_training_step_output, check_finite_loss, + _extract_hiddens, ) @@ -72,9 +73,12 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] _check_training_step_output(ligtning_module, training_step_output) - # TODO: do not use `ClosureResult` - result = ClosureResult.from_training_step_output(training_step_output) - self._hiddens = result.hiddens + self._hiddens = _extract_hiddens(training_step_output) + + # TODO: do not use `ClosureResult + result = ClosureResult.from_training_step_output( + training_step_output, self.trainer.accumulate_grad_batches + ) if self.trainer.terminate_on_nan: check_finite_loss(result.closure_loss) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index ac9251cc9eae0..992915228e26c 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -33,27 +33,20 @@ class ClosureResult: Attributes: closure_loss: The loss with a graph attached. - hiddens: The hidden tensors if available. loss: A detached copy of the closure loss. extra: Any keys other than the loss returned. """ closure_loss: Optional[Tensor] - hiddens: Optional[Any] loss: Optional[Tensor] = field(init=False, default=None) extra: Dict[str, Tensor] = field(default_factory=dict) def __post_init__(self) -> None: - # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` - self.hiddens = recursive_detach(self.hiddens) - # TODO: remove with the deprecation removal in v1.6 ClosureResult._check_extra_detach_deprecation(self.extra) self.extra = recursive_detach(self.extra) self._clone_loss() - if self.hiddens is not None and self.loss is None: - raise MisconfigurationException("If `hiddens` are returned from `training_step`, the loss cannot be `None`") def _clone_loss(self) -> None: if self.closure_loss is not None: @@ -64,14 +57,11 @@ def _clone_loss(self) -> None: def from_training_step_output( cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1 ) -> "ClosureResult": - closure_loss = None - hiddens = None - extra = {} + closure_loss, extra = None, {} if isinstance(training_step_output, dict): # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` closure_loss = training_step_output.get("loss") - hiddens = training_step_output.get("hiddens") extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} elif isinstance(training_step_output, Tensor): closure_loss = training_step_output @@ -80,7 +70,7 @@ def from_training_step_output( # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect closure_loss /= normalize - return cls(closure_loss, hiddens, extra=extra) + return cls(closure_loss, extra=extra) @staticmethod def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None: diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 47868a6808b84..21e1da5eacfe7 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -26,6 +26,7 @@ _build_training_step_kwargs, _check_training_step_output, check_finite_loss, + _extract_hiddens, ) from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm @@ -323,18 +324,19 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos _check_training_step_output(self.trainer.lightning_module, training_step_output) - closure_result = ClosureResult.from_training_step_output( + self._hiddens = _extract_hiddens(training_step_output) + + result = ClosureResult.from_training_step_output( training_step_output, self.trainer.accumulate_grad_batches ) if self.trainer.terminate_on_nan: - check_finite_loss(closure_result.closure_loss) + check_finite_loss(result.closure_loss) if self.trainer.move_metrics_to_cpu: self.trainer._results.cpu() - closure_result.cpu() - return closure_result + return result def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]: """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 364c821cd6d1e..731f7a92bae06 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -22,6 +22,7 @@ from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher +from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -61,6 +62,13 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu ) +def _extract_hiddens(training_step_output: STEP_OUTPUT) -> Optional[Any]: + hiddens = training_step_output.get("hiddens") + # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` + hiddens = recursive_detach(hiddens) + return hiddens + + def _build_training_step_kwargs( lightning_module: "pl.LightningModule", optimizers: Sequence[Optimizer], From bed7624b4c0a7da11d46ea6996072218a90f0679 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 18:54:19 +0200 Subject: [PATCH 23/47] Without closure and tests --- pytorch_lightning/loops/batch/manual.py | 10 +++---- .../loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/loops/closure.py | 9 +++--- .../loops/optimizer/optimizer_loop.py | 10 +++---- tests/loops/test_closure.py | 28 ++++--------------- 5 files changed, 19 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index baae27e916fb8..c643e4b101358 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -18,8 +18,8 @@ from pytorch_lightning.loops.utilities import ( _build_training_step_kwargs, _check_training_step_output, - check_finite_loss, _extract_hiddens, + check_finite_loss, ) @@ -76,17 +76,15 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] self._hiddens = _extract_hiddens(training_step_output) # TODO: do not use `ClosureResult - result = ClosureResult.from_training_step_output( - training_step_output, self.trainer.accumulate_grad_batches - ) + result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) if self.trainer.terminate_on_nan: check_finite_loss(result.closure_loss) if self.trainer.move_metrics_to_cpu: - assert self.trainer._results is not None + # hiddens and the training step output are not moved as they are not considered "metrics" + # the user might need them on the correct device for an operation in `training_epoch_end` self.trainer._results.cpu() - result.cpu() self._done = True self._output = result diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 7d80d4e61f059..08127dc4b8051 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -132,7 +132,7 @@ def advance(self, batch, batch_idx): # in manual optimization, hand over execution to the ManualOptimization loop result = self.manual_loop.run(split_batch, batch_idx) if result.loss is not None: - self.batch_outputs[0].append(result) + self.batch_outputs[0].append(result.without_closure()) def on_run_end(self) -> None: self.optimizer_loop._hiddens = None diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 992915228e26c..c316922565dc5 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -90,7 +90,6 @@ def to(self, *args: Any, **kwargs: Any) -> "ClosureResult": if self.closure_loss is not None: self.closure_loss = self.closure_loss.to(*args, **kwargs) self.loss = self.loss.to(*args, **kwargs) - self.hiddens = apply_to_collection(self.hiddens, Tensor, move_data_to_device, *args, **kwargs) self.extra = apply_to_collection(self.extra, Tensor, move_data_to_device, *args, **kwargs) return self @@ -98,10 +97,10 @@ def cpu(self) -> "ClosureResult": """Move all data to CPU.""" return self.to(device="cpu") - def __getstate__(self) -> Dict[str, Any]: - # return a copy without the closure loss which could have a `grad_fn` - # and without `hiddens` which are not necessary - return {"loss": self.loss, "extra": self.extra, "closure_loss": None, "hiddens": None} + def without_closure(self) -> "ClosureResult": + """Return itself without the closure loss which could have a `grad_fn`""" + self.closure_loss = None + return self class AbstractClosure(ABC): diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 21e1da5eacfe7..21f95733eff06 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -25,8 +25,8 @@ _block_parallel_sync_behavior, _build_training_step_kwargs, _check_training_step_output, - check_finite_loss, _extract_hiddens, + check_finite_loss, ) from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm @@ -78,9 +78,8 @@ def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override self._optimizers[self.optim_progress.optimizer_idx], self.optim_progress.optimizer_idx, ) - self._hiddens = result.hiddens if result.loss is not None: - self.outputs[self.optim_progress.optimizer_idx].append(result) + self.outputs[self.optim_progress.optimizer_idx].append(result.without_closure()) self.optim_progress.optimizer_idx += 1 @@ -326,14 +325,13 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos self._hiddens = _extract_hiddens(training_step_output) - result = ClosureResult.from_training_step_output( - training_step_output, self.trainer.accumulate_grad_batches - ) + result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) if self.trainer.terminate_on_nan: check_finite_loss(result.closure_loss) if self.trainer.move_metrics_to_cpu: + # hiddens and the training step output are not moved as they are not considered "metrics" self.trainer._results.cpu() return result diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index f478fffd0a166..42e3ad5a7346b 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pickle -from copy import deepcopy - import pytest import torch @@ -53,31 +50,19 @@ def step(self, closure=None): def test_closure_result_deepcopy(): closure_loss = torch.tensor(123.45) - hiddens = torch.tensor(321.12, requires_grad=True) - result = ClosureResult(closure_loss, hiddens) - assert not result.hiddens.requires_grad + result = ClosureResult(closure_loss) assert closure_loss.data_ptr() == result.closure_loss.data_ptr() # the `loss` is cloned so the storage is different assert closure_loss.data_ptr() != result.loss.data_ptr() - # make sure `__getstate__` is not missing any keys - assert vars(result).keys() == result.__getstate__().keys() - - copy = deepcopy(result) + copy = result.without_closure() assert result.loss == copy.loss assert copy.closure_loss is None - assert copy.hiddens is None - - assert id(result.loss) != id(copy.loss) - assert result.loss.data_ptr() != copy.loss.data_ptr() - - assert copy == pickle.loads(pickle.dumps(result)) - -def test_closure_result_raises(): - with pytest.raises(MisconfigurationException, match="If `hiddens` are returned .* the loss cannot be `None`"): - ClosureResult(None, "something") + # no copy + assert id(result.loss) == id(copy.loss) + assert result.loss.data_ptr() == copy.loss.data_ptr() def test_closure_result_apply_accumulation(): @@ -87,8 +72,7 @@ def test_closure_result_apply_accumulation(): def test_closure_to(): - result = ClosureResult(torch.tensor(1.0), (torch.tensor(2.0), torch.tensor(3.0)), extra={"foo": torch.tensor(4.0)}) + result = ClosureResult(torch.tensor(1.0), extra={"foo": torch.tensor(4.0)}) result.to(torch.half) assert result.loss.dtype == torch.half - assert all(t.dtype == torch.half for t in result.hiddens) assert result.extra["foo"].dtype == torch.half From 6a81c0cfb82e623a6ccd4627150c51f04e3859a9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 18:54:49 +0200 Subject: [PATCH 24/47] Remove code to move the ClosureResult --- pytorch_lightning/loops/closure.py | 14 +------------- tests/loops/test_closure.py | 7 ------- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index c316922565dc5..5264aef75009b 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -18,7 +18,7 @@ from torch import Tensor from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -85,18 +85,6 @@ def check_fn(v: Tensor) -> Tensor: apply_to_collection(extra, Tensor, check_fn) - def to(self, *args: Any, **kwargs: Any) -> "ClosureResult": - """Move all data to the given device.""" - if self.closure_loss is not None: - self.closure_loss = self.closure_loss.to(*args, **kwargs) - self.loss = self.loss.to(*args, **kwargs) - self.extra = apply_to_collection(self.extra, Tensor, move_data_to_device, *args, **kwargs) - return self - - def cpu(self) -> "ClosureResult": - """Move all data to CPU.""" - return self.to(device="cpu") - def without_closure(self) -> "ClosureResult": """Return itself without the closure loss which could have a `grad_fn`""" self.closure_loss = None diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index 42e3ad5a7346b..d83427c3ccda6 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -69,10 +69,3 @@ def test_closure_result_apply_accumulation(): closure_loss = torch.tensor(25.0) result = ClosureResult.from_training_step_output(closure_loss, 5) assert result.loss == 5 - - -def test_closure_to(): - result = ClosureResult(torch.tensor(1.0), extra={"foo": torch.tensor(4.0)}) - result.to(torch.half) - assert result.loss.dtype == torch.half - assert result.extra["foo"].dtype == torch.half From a09cfc8407d5e5487b0afeedaa9fa446e9a4afe1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 19:00:20 +0200 Subject: [PATCH 25/47] Tests --- pytorch_lightning/loops/batch/manual.py | 1 + tests/loops/test_utilities.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 tests/loops/test_utilities.py diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index c643e4b101358..e6e8d46763e89 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -84,6 +84,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] if self.trainer.move_metrics_to_cpu: # hiddens and the training step output are not moved as they are not considered "metrics" # the user might need them on the correct device for an operation in `training_epoch_end` + assert self.trainer._results is not None self.trainer._results.cpu() self._done = True diff --git a/tests/loops/test_utilities.py b/tests/loops/test_utilities.py new file mode 100644 index 0000000000000..f4275222cefd6 --- /dev/null +++ b/tests/loops/test_utilities.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.loops.utilities import _extract_hiddens + + +def test_extract_hiddens(): + hiddens = torch.tensor(321.12, requires_grad=True) + training_step_output = {"hiddens": hiddens} + hiddens = _extract_hiddens(training_step_output) + assert "hiddens" in training_step_output + assert not hiddens.requires_grad From 78253e4cb344af681448086570412ee7d2e395b0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 19:04:02 +0200 Subject: [PATCH 26/47] Fix TODO --- pytorch_lightning/loops/batch/manual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index e6e8d46763e89..43370bc7d8b75 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -75,7 +75,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] self._hiddens = _extract_hiddens(training_step_output) - # TODO: do not use `ClosureResult + # TODO: do not use `ClosureResult` result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) if self.trainer.terminate_on_nan: From 26e96a15b40d0bf9d3679b18f4a12ddeb600b5fe Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 19:13:35 +0200 Subject: [PATCH 27/47] Undo docs changes --- pytorch_lightning/accelerators/accelerator.py | 37 +++++++++++++++++-- pytorch_lightning/core/lightning.py | 9 +++-- pytorch_lightning/loops/utilities.py | 3 +- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 93915ac946ae9..f40dc9e1576cf 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -173,7 +173,15 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual training step. - See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details + Args: + step_kwargs: the arguments for the models training step. Can consist of the following: + + - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + - batch_idx (int): Integer displaying index of this batch + - optimizer_idx (int): When using multiple optimizers, this argument will also be present. + - hiddens(:class:`~torch.Tensor`): Passed in if + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. """ with self.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) @@ -184,7 +192,14 @@ def post_training_step(self) -> None: def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual validation step. - See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details + Args: + step_kwargs: the arguments for the models validation step. Can consist of the following: + + - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + - batch_idx (int): The index of this batch + - dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple val dataloaders used) """ with self.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) @@ -192,7 +207,14 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual test step. - See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details + Args: + step_kwargs: the arguments for the models test step. Can consist of the following: + + - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + - batch_idx (int): The index of this batch. + - dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple test dataloaders used). """ with self.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) @@ -200,7 +222,14 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual predict step. - See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details + Args: + step_kwargs: the arguments for the models predict step. Can consist of the following: + + - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + - batch_idx (int): The index of this batch. + - dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple predict dataloaders used). """ with self.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7edd9b532096c..e3c7402242a3b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -620,9 +620,9 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: Args: batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (``int``): Integer displaying index of this batch - optimizer_idx (``int``): When using multiple optimizers, this argument will also be present. - hiddens (``Any``): Passed in if + batch_idx (int): Integer displaying index of this batch + optimizer_idx (int): When using multiple optimizers, this argument will also be present. + hiddens(:class:`~torch.Tensor`): Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Return: @@ -667,7 +667,8 @@ def training_step(self, batch, batch_idx, optimizer_idx): # Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step - loss, hiddens = self.lstm(data, hiddens) + ... + out, hiddens = self.lstm(data, hiddens) ... return {"loss": loss, "hiddens": hiddens} diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 731f7a92bae06..72e8f223a355d 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -16,6 +16,7 @@ from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence import torch +from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl @@ -75,7 +76,7 @@ def _build_training_step_kwargs( batch: Any, batch_idx: int, opt_idx: Optional[int], - hiddens: Optional[Any], + hiddens: Optional[Tensor], ) -> Dict[str, Any]: """Builds the keyword arguments for training_step. From 5b50789db9c5904265149a52d9fa2cb57acd90df Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 19:15:44 +0200 Subject: [PATCH 28/47] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4290e4edec64f..56dd27e64fccc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -322,6 +322,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367)) +- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) From bf08c2bc17d26aa6a151b3040e0754f5edcfb460 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 8 Sep 2021 20:12:04 +0200 Subject: [PATCH 29/47] Fix --- pytorch_lightning/loops/utilities.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 72e8f223a355d..58116e207d2f9 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -64,6 +64,8 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu def _extract_hiddens(training_step_output: STEP_OUTPUT) -> Optional[Any]: + if not isinstance(training_step_output, dict): + return hiddens = training_step_output.get("hiddens") # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` hiddens = recursive_detach(hiddens) From f110bc111fcf4e4530dd8b0936e1491c9d6f51ff Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 9 Sep 2021 17:28:19 +0200 Subject: [PATCH 30/47] Stricter hiddens returning --- pytorch_lightning/loops/batch/manual.py | 10 +++++----- .../loops/optimizer/optimizer_loop.py | 6 +++--- pytorch_lightning/loops/utilities.py | 18 ++++++++++++++++-- tests/loops/test_utilities.py | 17 ++++++++++++++++- 4 files changed, 40 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index 43370bc7d8b75..50ea8113f2878 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -53,16 +53,16 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] batch_idx: the index of the current batch """ assert self.trainer is not None - ligtning_module = self.trainer.lightning_module + model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): step_kwargs = _build_training_step_kwargs( - ligtning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens + model_ref, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens ) # manually capture logged metrics - ligtning_module._current_fx_name = "training_step" + model_ref._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() @@ -71,9 +71,9 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - _check_training_step_output(ligtning_module, training_step_output) + _check_training_step_output(model_ref, training_step_output) - self._hiddens = _extract_hiddens(training_step_output) + self._hiddens = _extract_hiddens(training_step_output, model_ref.truncated_bptt_steps) # TODO: do not use `ClosureResult` result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 21f95733eff06..74384bda84e29 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -308,7 +308,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos with self.trainer.profiler.profile("model_forward"): step_kwargs = _build_training_step_kwargs( - self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens + model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens ) # manually capture logged metrics @@ -321,9 +321,9 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - _check_training_step_output(self.trainer.lightning_module, training_step_output) + _check_training_step_output(model_ref, training_step_output) - self._hiddens = _extract_hiddens(training_step_output) + self._hiddens = _extract_hiddens(training_step_output, model_ref.truncated_bptt_steps) result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 58116e207d2f9..a57bff35982eb 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -63,9 +63,23 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu ) -def _extract_hiddens(training_step_output: STEP_OUTPUT) -> Optional[Any]: - if not isinstance(training_step_output, dict): +def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]: + """Get the hidden embeddings if present from the training step output. + + Raises: + MisconfigurationException: If :attr:`~pytorch_lightning.core.Lightning.LightningModule.truncated_bptt_steps` is + not enabled and hiddens are returned or vice versa. + """ + if not truncated_bptt_steps: + if "hiddens" in training_step_output: + raise MisconfigurationException( + 'You returned "hiddens" in your `training_step` but `truncated_bptt_steps` is disabled' + ) return + elif not isinstance(training_step_output, dict) or "hiddens" not in training_step_output: + raise MisconfigurationException( + 'You enabled `truncated_bptt_steps` but did not return "hiddens" in your `training_step`' + ) hiddens = training_step_output.get("hiddens") # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` hiddens = recursive_detach(hiddens) diff --git a/tests/loops/test_utilities.py b/tests/loops/test_utilities.py index f4275222cefd6..4f9e99323df69 100644 --- a/tests/loops/test_utilities.py +++ b/tests/loops/test_utilities.py @@ -11,14 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from pytorch_lightning.loops.utilities import _extract_hiddens +from pytorch_lightning.utilities.exceptions import MisconfigurationException def test_extract_hiddens(): + # tbptt not enabled, no hiddens return + training_step_output = "whatever" + hiddens = _extract_hiddens(training_step_output, 0) + assert hiddens is None + + # tbptt enabled, hiddens return hiddens = torch.tensor(321.12, requires_grad=True) training_step_output = {"hiddens": hiddens} - hiddens = _extract_hiddens(training_step_output) + hiddens = _extract_hiddens(training_step_output, 2) assert "hiddens" in training_step_output assert not hiddens.requires_grad + + # tbptt not enabled, hiddens return + with pytest.raises(MisconfigurationException, match='returned "hiddens" .* but `truncated_bptt_steps` is disabled'): + _extract_hiddens(training_step_output, 0) + # tbptt enabled, no hiddens return + with pytest.raises(MisconfigurationException, match='enabled `truncated_bptt_steps` but did not return "hiddens"'): + _extract_hiddens(None, 1) From 9d8d5877b97307cc642b9be3902e04b115210ec1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 9 Sep 2021 17:29:50 +0200 Subject: [PATCH 31/47] Drop closure loss --- pytorch_lightning/loops/closure.py | 2 +- pytorch_lightning/loops/optimizer/optimizer_loop.py | 2 +- tests/loops/test_closure.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 5264aef75009b..64ef343b151ca 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -85,7 +85,7 @@ def check_fn(v: Tensor) -> Tensor: apply_to_collection(extra, Tensor, check_fn) - def without_closure(self) -> "ClosureResult": + def drop_closure_loss(self) -> "ClosureResult": """Return itself without the closure loss which could have a `grad_fn`""" self.closure_loss = None return self diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 74384bda84e29..08e2f9842e00d 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -79,7 +79,7 @@ def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override self.optim_progress.optimizer_idx, ) if result.loss is not None: - self.outputs[self.optim_progress.optimizer_idx].append(result.without_closure()) + self.outputs[self.optim_progress.optimizer_idx].append(result.drop_closure_loss()) self.optim_progress.optimizer_idx += 1 diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index d83427c3ccda6..ff9bd531183ce 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -56,7 +56,7 @@ def test_closure_result_deepcopy(): # the `loss` is cloned so the storage is different assert closure_loss.data_ptr() != result.loss.data_ptr() - copy = result.without_closure() + copy = result.drop_closure_loss() assert result.loss == copy.loss assert copy.closure_loss is None From 5138edfdc52de4345254000e64ee48e3d21db095 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 9 Sep 2021 17:40:05 +0200 Subject: [PATCH 32/47] Logic fix --- pytorch_lightning/loops/utilities.py | 8 ++++---- tests/loops/test_utilities.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index a57bff35982eb..0698fb7851708 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -70,19 +70,19 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in MisconfigurationException: If :attr:`~pytorch_lightning.core.Lightning.LightningModule.truncated_bptt_steps` is not enabled and hiddens are returned or vice versa. """ + is_dict = isinstance(training_step_output, dict) if not truncated_bptt_steps: - if "hiddens" in training_step_output: + if is_dict and "hiddens" in training_step_output: raise MisconfigurationException( 'You returned "hiddens" in your `training_step` but `truncated_bptt_steps` is disabled' ) return - elif not isinstance(training_step_output, dict) or "hiddens" not in training_step_output: + elif not is_dict or "hiddens" not in training_step_output: raise MisconfigurationException( 'You enabled `truncated_bptt_steps` but did not return "hiddens" in your `training_step`' ) - hiddens = training_step_output.get("hiddens") # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` - hiddens = recursive_detach(hiddens) + hiddens = recursive_detach(training_step_output["hiddens"]) return hiddens diff --git a/tests/loops/test_utilities.py b/tests/loops/test_utilities.py index 4f9e99323df69..e970be0648fd6 100644 --- a/tests/loops/test_utilities.py +++ b/tests/loops/test_utilities.py @@ -20,7 +20,7 @@ def test_extract_hiddens(): # tbptt not enabled, no hiddens return - training_step_output = "whatever" + training_step_output = 1 # anything hiddens = _extract_hiddens(training_step_output, 0) assert hiddens is None From 3aaf9ddee8da1cac3a7516a7bd84a5e8f7549655 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 9 Sep 2021 21:35:17 +0200 Subject: [PATCH 33/47] Bad rename --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 08127dc4b8051..db9f792744ef1 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -132,7 +132,7 @@ def advance(self, batch, batch_idx): # in manual optimization, hand over execution to the ManualOptimization loop result = self.manual_loop.run(split_batch, batch_idx) if result.loss is not None: - self.batch_outputs[0].append(result.without_closure()) + self.batch_outputs[0].append(result.drop_closure_loss()) def on_run_end(self) -> None: self.optimizer_loop._hiddens = None From 1f32d522cadaf833f1ecd082b1c58a4886a2dbd1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 10 Sep 2021 00:28:16 +0200 Subject: [PATCH 34/47] Fix for raise StopIteration --- pytorch_lightning/loops/batch/manual.py | 5 +++-- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index 50ea8113f2878..d9388d068e0fd 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -90,8 +90,9 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] self._done = True self._output = result - def on_run_end(self) -> ClosureResult: + def on_run_end(self) -> Optional[ClosureResult]: """Returns the result of this loop, i.e., the post-processed outputs from the training step.""" output, self._output = self._output, None # free memory - assert output is not None, "`advance` should have been called" + # #9052 added support for raising `StopIteration` in the `training_step` this `advance` doesn't finish + # and the return value is `Optional`. If #9415 happens then this can be avoided. return output diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index db9f792744ef1..4a544b2b9b4db 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -131,7 +131,7 @@ def advance(self, batch, batch_idx): else: # in manual optimization, hand over execution to the ManualOptimization loop result = self.manual_loop.run(split_batch, batch_idx) - if result.loss is not None: + if result is not None and result.loss is not None: self.batch_outputs[0].append(result.drop_closure_loss()) def on_run_end(self) -> None: From 04a54a62bc512faefb3fc74ef9d7f96671cdaf77 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 10 Sep 2021 00:30:35 +0200 Subject: [PATCH 35/47] Fix comment --- pytorch_lightning/loops/batch/manual.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index d9388d068e0fd..c55c3e6f4e549 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -93,6 +93,6 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] def on_run_end(self) -> Optional[ClosureResult]: """Returns the result of this loop, i.e., the post-processed outputs from the training step.""" output, self._output = self._output, None # free memory - # #9052 added support for raising `StopIteration` in the `training_step` this `advance` doesn't finish - # and the return value is `Optional`. If #9415 happens then this can be avoided. + # #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance` + # doesn't finish and `self._output` stays as `None`. If #9415 happens then this would always return a result return output From 6fb8758eccd0aa98a995204e6a38fa901bec3714 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 10 Sep 2021 01:26:23 +0200 Subject: [PATCH 36/47] Initial implementation --- pytorch_lightning/loops/batch/manual.py | 59 ++++++-- pytorch_lightning/loops/closure.py | 131 ++---------------- .../loops/epoch/training_epoch_loop.py | 19 +-- .../loops/optimizer/optimizer_loop.py | 131 +++++++++++++++++- pytorch_lightning/loops/utilities.py | 27 +--- tests/loops/test_closure.py | 2 +- .../loops/test_training_loop_flow_scalar.py | 2 +- 7 files changed, 191 insertions(+), 180 deletions(-) diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py index c55c3e6f4e549..ac4604d68e502 100644 --- a/pytorch_lightning/loops/batch/manual.py +++ b/pytorch_lightning/loops/batch/manual.py @@ -11,16 +11,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from torch import Tensor from pytorch_lightning.loops import Loop -from pytorch_lightning.loops.closure import ClosureResult -from pytorch_lightning.loops.utilities import ( - _build_training_step_kwargs, - _check_training_step_output, - _extract_hiddens, - check_finite_loss, -) +from pytorch_lightning.loops.closure import OutputResult +from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens, check_finite_loss +from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.types import STEP_OUTPUT + + +@dataclass +class ManualResult(OutputResult): + """A container to hold the result returned by the ``ManualLoop``. + + It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`. + + Attributes: + extra: Anything returned by the ``training_step``. + """ + + extra: Dict[str, Tensor] = field(default_factory=dict) + + def __post_init__(self) -> None: + # TODO: remove with the deprecation removal in v1.6 + self._check_extra_detach_deprecation(self.extra) + self.extra = recursive_detach(self.extra) + + @classmethod + def from_training_step_output( + cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1 + ) -> "ManualResult": + extra = {} + if isinstance(training_step_output, dict): + extra = {k: v for k, v in training_step_output.items() if k != "hiddens"} + elif isinstance(training_step_output, Tensor): + extra["loss"] = training_step_output + + if "loss" in extra: + # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect + extra["loss"] = extra["loss"].detach().div(normalize) + + return cls(extra=extra) class ManualOptimization(Loop): @@ -36,7 +70,7 @@ def __init__(self) -> None: super().__init__() self._done: bool = False self._hiddens: Optional[Any] = None - self._output: Optional[ClosureResult] = None + self._output: Optional[ManualResult] = None @property def done(self) -> bool: @@ -71,12 +105,9 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - _check_training_step_output(model_ref, training_step_output) - self._hiddens = _extract_hiddens(training_step_output, model_ref.truncated_bptt_steps) - # TODO: do not use `ClosureResult` - result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) + result = ManualResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) if self.trainer.terminate_on_nan: check_finite_loss(result.closure_loss) @@ -90,7 +121,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] self._done = True self._output = result - def on_run_end(self) -> Optional[ClosureResult]: + def on_run_end(self) -> Optional[ManualResult]: """Returns the result of this loop, i.e., the post-processed outputs from the training step.""" output, self._output = self._output, None # free memory # #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance` diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index 64ef343b151ca..8583f6900541e 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -12,68 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional +from dataclasses import dataclass +from typing import Any, Dict, Optional from torch import Tensor -from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.memory import recursive_detach -from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache +from pytorch_lightning.utilities.warnings import rank_zero_deprecation @dataclass -class ClosureResult: - """A container to hold the result of a :class:`AbstractClosure` call. - - It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`. - - Attributes: - closure_loss: The loss with a graph attached. - loss: A detached copy of the closure loss. - extra: Any keys other than the loss returned. - """ - - closure_loss: Optional[Tensor] - loss: Optional[Tensor] = field(init=False, default=None) - extra: Dict[str, Tensor] = field(default_factory=dict) - - def __post_init__(self) -> None: - # TODO: remove with the deprecation removal in v1.6 - ClosureResult._check_extra_detach_deprecation(self.extra) - self.extra = recursive_detach(self.extra) - - self._clone_loss() - - def _clone_loss(self) -> None: - if self.closure_loss is not None: - # the loss will get scaled for amp. avoid any modifications to it - self.loss = self.closure_loss.detach().clone() - - @classmethod - def from_training_step_output( - cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1 - ) -> "ClosureResult": - closure_loss, extra = None, {} - - if isinstance(training_step_output, dict): - # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` - closure_loss = training_step_output.get("loss") - extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} - elif isinstance(training_step_output, Tensor): - closure_loss = training_step_output - - if closure_loss is not None: - # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect - closure_loss /= normalize - - return cls(closure_loss, extra=extra) - +class OutputResult: @staticmethod def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None: + # TODO: remove with the deprecation removal in v1.6 + # this is only here to avoid duplication def check_fn(v: Tensor) -> Tensor: if v.grad_fn is not None: rank_zero_deprecation( @@ -85,11 +39,6 @@ def check_fn(v: Tensor) -> Tensor: apply_to_collection(extra, Tensor, check_fn) - def drop_closure_loss(self) -> "ClosureResult": - """Return itself without the closure loss which could have a `grad_fn`""" - self.closure_loss = None - return self - class AbstractClosure(ABC): """Abstract base class for optimizer closures in Lightning. @@ -104,9 +53,9 @@ class AbstractClosure(ABC): def __init__(self) -> None: super().__init__() - self._result: Optional[ClosureResult] = None + self._result: Optional[OutputResult] = None - def consume_result(self) -> ClosureResult: + def consume_result(self) -> OutputResult: """The cached result from the last time the closure was called. Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long @@ -122,69 +71,9 @@ def consume_result(self) -> ClosureResult: return result @abstractmethod - def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: + def closure(self, *args: Any, **kwargs: Any) -> OutputResult: """Implements the behavior of the closure once it is getting called.""" pass - def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: + def __call__(self, *args: Any, **kwargs: Any) -> None: self._result = self.closure(*args, **kwargs) - return self._result.loss - - -class Closure(AbstractClosure): - """An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary - closures into one: ``training_step``, ``backward`` and ``zero_grad``. - - The Closure gets created by the training loop(s) and is then passed to the - :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally - do something with the output. - - Args: - step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step - wrapped with processing for its outputs - backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value. - Can be set to ``None`` to skip the backward operation. - zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example - when accumulating gradients. - profiler: A profiler for profiling the actions of the passed in closure functions. - - Example: - - closure = Closure() - optimizer = torch.optim.Adam(...) - optimizer.step(closure) - """ - - warning_cache = WarningCache() - - def __init__( - self, - step_fn: Callable[[], ClosureResult], - backward_fn: Optional[Callable[[Tensor], Tensor]] = None, - zero_grad_fn: Optional[Callable[[], None]] = None, - profiler: Optional[BaseProfiler] = None, - ): - super().__init__() - self._step_fn = step_fn - self._backward_fn = backward_fn - self._zero_grad_fn = zero_grad_fn - self._profiler = PassThroughProfiler() if profiler is None else profiler - - def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: - with self._profiler.profile("training_step_and_backward"): - step_output = self._step_fn() - - if step_output.closure_loss is None: - self.warning_cache.warn( - "`training_step` returned `None`. If this was on purpose, ignore this warning..." - ) - - if self._zero_grad_fn is not None: - with self._profiler.profile("zero_grad"): - self._zero_grad_fn() - - if self._backward_fn is not None and step_output.closure_loss is not None: - with self._profiler.profile("backward"): - step_output.closure_loss = self._backward_fn(step_output.closure_loss) - - return step_output diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index f8040c9686aba..e7c9bac69b86d 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -17,7 +17,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop -from pytorch_lightning.loops.closure import ClosureResult +from pytorch_lightning.loops.closure import OutputResult from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress, SchedulerProgress @@ -282,18 +282,18 @@ def _track_epoch_end_reduce_metrics( @staticmethod def _prepare_outputs( - outputs: List[List[List[ClosureResult]]], batch_mode: bool + outputs: List[List[List[OutputResult]]], batch_mode: bool ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: """Extract required information from batch or epoch end results. Args: - outputs: A 3-dimensional list of ``ClosureResult`` objects with dimensions: + outputs: A 3-dimensional list of ``OutputResult`` objects with dimensions: ``[optimizer outs][batch outs][tbptt steps]``. batch_mode: If True, ignore the batch output dimension. Returns: - The cleaned outputs with ``ClosureResult`` objects converted to dictionaries. + The cleaned outputs with ``OutputResult`` objects converted to dictionaries. All list dimensions of size one will be collapsed. """ processed_outputs = [] @@ -308,17 +308,10 @@ def _prepare_outputs( opt_outputs = [opt_outputs] for batch_outputs in opt_outputs: - processed_tbptt_outputs = [] - - if isinstance(batch_outputs, ClosureResult): + if isinstance(batch_outputs, OutputResult): batch_outputs = [batch_outputs] - for tbptt_output in batch_outputs: - out = {} - if tbptt_output.loss is not None: - out["loss"] = tbptt_output.loss - out.update(tbptt_output.extra) - processed_tbptt_outputs.append(out) + processed_tbptt_outputs = [tbptt_output.extra for tbptt_output in batch_outputs] # if there was only one tbptt step then we can collapse that dimension if len(processed_tbptt_outputs) == 1: diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index 08e2f9842e00d..c965d950d6759 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass, field from functools import partial from typing import Any, Callable, Dict, List, Optional @@ -20,19 +21,143 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops import Loop -from pytorch_lightning.loops.closure import Closure, ClosureResult +from pytorch_lightning.loops.closure import AbstractClosure, OutputResult from pytorch_lightning.loops.utilities import ( _block_parallel_sync_behavior, _build_training_step_kwargs, - _check_training_step_output, _extract_hiddens, check_finite_loss, ) +from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE +from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.warnings import WarningCache + + +@dataclass +class ClosureResult(OutputResult): + """A container to hold the result of a :class:`Closure` call. + + It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`. + + Attributes: + closure_loss: The loss with a graph attached. + loss: A detached copy of the closure loss. + extra: Any keys other than the loss returned. + """ + + closure_loss: Optional[Tensor] + loss: Optional[Tensor] = field(init=False, default=None) + extra: Dict[str, Tensor] = field(default_factory=dict) + + def __post_init__(self) -> None: + # TODO: remove with the deprecation removal in v1.6 + self._check_extra_detach_deprecation(self.extra) + self.extra = recursive_detach(self.extra) + + self._clone_loss() + + if self.loss is not None: + self.extra["loss"] = self.loss + + def _clone_loss(self) -> None: + if self.closure_loss is not None: + # the loss will get scaled for amp. avoid any modifications to it + self.loss = self.closure_loss.detach().clone() + + @classmethod + def from_training_step_output( + cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1 + ) -> "ClosureResult": + closure_loss, extra = None, {} + + if isinstance(training_step_output, dict): + # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` + closure_loss = training_step_output.get("loss") + extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} + elif isinstance(training_step_output, Tensor): + closure_loss = training_step_output + elif training_step_output is not None: + raise MisconfigurationException( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) + + if closure_loss is not None: + # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect + closure_loss /= normalize + + return cls(closure_loss, extra=extra) + + def drop_closure_loss(self) -> "ClosureResult": + """Return itself without the closure loss which could have a `grad_fn`""" + self.closure_loss = None + return self + + +class Closure(AbstractClosure): + """An implementation of a :class:`AbstractClosure` for automatic optimization in Lightning that combines three + elementary closures into one: ``training_step``, ``backward`` and ``zero_grad``. + + The Closure gets created by the training loop(s) and is then passed to the + :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally + do something with the output. + + Args: + step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step + wrapped with processing for its outputs + backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value. + Can be set to ``None`` to skip the backward operation. + zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example + when accumulating gradients. + profiler: A profiler for profiling the actions of the passed in closure functions. + + Example: + + closure = Closure() + optimizer = torch.optim.Adam(...) + optimizer.step(closure) + """ + + warning_cache = WarningCache() + + def __init__( + self, + step_fn: Callable[[], ClosureResult], + backward_fn: Optional[Callable[[Tensor], Tensor]] = None, + zero_grad_fn: Optional[Callable[[], None]] = None, + profiler: Optional[BaseProfiler] = None, + ): + super().__init__() + self._step_fn = step_fn + self._backward_fn = backward_fn + self._zero_grad_fn = zero_grad_fn + self._profiler = PassThroughProfiler() if profiler is None else profiler + + def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: + with self._profiler.profile("training_step_and_backward"): + step_output = self._step_fn() + + if step_output.closure_loss is None: + self.warning_cache.warn( + "`training_step` returned `None`. If this was on purpose, ignore this warning..." + ) + + if self._zero_grad_fn is not None: + with self._profiler.profile("zero_grad"): + self._zero_grad_fn() + + if self._backward_fn is not None and step_output.closure_loss is not None: + with self._profiler.profile("backward"): + step_output.closure_loss = self._backward_fn(step_output.closure_loss) + + return step_output + _OUTPUTS_TYPE = List[List[ClosureResult]] @@ -321,8 +446,6 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - _check_training_step_output(model_ref, training_step_output) - self._hiddens = _extract_hiddens(training_step_output, model_ref.truncated_bptt_steps) result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 0698fb7851708..21587869d5700 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence +from typing import Any, Dict, Generator, Iterator, Optional, Sequence import torch from torch import Tensor @@ -38,31 +38,6 @@ def check_finite_loss(loss: Optional[torch.Tensor]) -> None: raise ValueError(f"The loss returned in `training_step` is {loss}.") -def _check_training_step_output(model: "pl.LightningModule", training_step_output: STEP_OUTPUT) -> None: - """Sanity checks that training produced a valid output. - - Args: - model: a reference to the trainer - training_step_output: the output of the training step (before wrapping in an AttributeDict) - """ - if ( - isinstance(training_step_output, torch.Tensor) - and not model.automatic_optimization - and training_step_output.grad_fn is None - ): - # TODO: in manual optimization, anything returned should be considered an `extra` - raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") - if model.automatic_optimization and not ( - isinstance(training_step_output, torch.Tensor) - or (isinstance(training_step_output, Mapping) and "loss" in training_step_output) - or training_step_output is None - ): - raise MisconfigurationException( - "In automatic optimization, `training_step` must either return a Tensor, " - "a dict with key 'loss' or None (where the step will be skipped)." - ) - - def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]: """Get the hidden embeddings if present from the training step output. diff --git a/tests/loops/test_closure.py b/tests/loops/test_closure.py index ff9bd531183ce..2a15d3a135dc3 100644 --- a/tests/loops/test_closure.py +++ b/tests/loops/test_closure.py @@ -15,7 +15,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.loops.closure import ClosureResult +from pytorch_lightning.loops.optimizer.optimizer_loop import ClosureResult from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index d5fe6ff6c11c5..f427cdb1245fc 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -18,7 +18,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.loops.closure import Closure +from pytorch_lightning.loops.optimizer.optimizer_loop import Closure from pytorch_lightning.trainer.states import RunningStage from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.deterministic_model import DeterministicModel From fa80ef02b02f95eee2d9055cda5968ed81fe0f87 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 10 Sep 2021 02:12:44 +0200 Subject: [PATCH 37/47] Undo change --- pytorch_lightning/loops/optimizer/optimizer_loop.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index c965d950d6759..57c64b17e29ac 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -62,9 +62,6 @@ def __post_init__(self) -> None: self._clone_loss() - if self.loss is not None: - self.extra["loss"] = self.loss - def _clone_loss(self) -> None: if self.closure_loss is not None: # the loss will get scaled for amp. avoid any modifications to it From 93652ace532b74c7551c738a299bce0559eac3dc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 14 Sep 2021 16:49:56 +0200 Subject: [PATCH 38/47] Bad merge --- pytorch_lightning/loops/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 628419b3f9ff1..6fdde18fa6bb2 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -38,7 +38,7 @@ def check_finite_loss(loss: Optional[torch.Tensor]) -> None: def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]: - """Get the hidden embeddings if present from the training step output. + """Get the hidden state if present from the training step output. Raises: MisconfigurationException: If :attr:`~pytorch_lightning.core.Lightning.LightningModule.truncated_bptt_steps` is From 21697a4e920c3384e64df6a6c0349578cb2e11ea Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 14 Sep 2021 16:59:10 +0200 Subject: [PATCH 39/47] Progress --- .../loops/optimization/closure.py | 1 + .../loops/optimization/manual_loop.py | 15 ++++++++---- .../loops/optimization/optimizer_loop.py | 4 +--- tests/loops/optimization/test_closure.py | 24 ------------------- 4 files changed, 12 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/loops/optimization/closure.py b/pytorch_lightning/loops/optimization/closure.py index 1c211b16f890b..74805dedc58cd 100644 --- a/pytorch_lightning/loops/optimization/closure.py +++ b/pytorch_lightning/loops/optimization/closure.py @@ -37,6 +37,7 @@ def check_fn(v: Tensor) -> Tensor: " but this behaviour will change in v1.6. Please detach it manually:" " `return {'loss': ..., 'something': something.detach()}`" ) + return v.detach() return v apply_to_collection(extra, Tensor, check_fn) diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 3047dae14e17e..480c925c669b0 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -19,7 +19,7 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.loops.optimization.closure import OutputResult from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens, check_finite_loss -from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -37,8 +37,7 @@ class ManualResult(OutputResult): def __post_init__(self) -> None: # TODO: remove with the deprecation removal in v1.6 - self._check_extra_detach_deprecation(self.extra) - self.extra = recursive_detach(self.extra) + self.extra = self._check_extra_detach_deprecation(self.extra) @classmethod def from_training_step_output( @@ -48,10 +47,16 @@ def from_training_step_output( if isinstance(training_step_output, dict): extra = {k: v for k, v in training_step_output.items() if k != "hiddens"} elif isinstance(training_step_output, Tensor): - extra["loss"] = training_step_output + extra = {"loss": training_step_output} + elif training_step_output is not None: + raise MisconfigurationException( + "In manual optimization, `training_step` must either return a Tensor, " + "a dict with extras to pass to `training_epoch_end` or have no return." + ) if "loss" in extra: - # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect + # accumulate the loss. If `accumulate_grad_batches == 1`, no effect. + # we detach manually as it's expected that it will have a `grad_fn` extra["loss"] = extra["loss"].detach().div(normalize) return cls(extra=extra) diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 33dc72f5dd80c..41047cb1b812f 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -34,7 +34,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE -from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -57,8 +56,7 @@ class ClosureResult(OutputResult): def __post_init__(self) -> None: # TODO: remove with the deprecation removal in v1.6 - self._check_extra_detach_deprecation(self.extra) - self.extra = recursive_detach(self.extra) + self.extra = self._check_extra_detach_deprecation(self.extra) self._clone_loss() diff --git a/tests/loops/optimization/test_closure.py b/tests/loops/optimization/test_closure.py index 2a15d3a135dc3..996d9d83b9948 100644 --- a/tests/loops/optimization/test_closure.py +++ b/tests/loops/optimization/test_closure.py @@ -15,7 +15,6 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.loops.optimizer.optimizer_loop import ClosureResult from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -46,26 +45,3 @@ def step(self, closure=None): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) with pytest.raises(MisconfigurationException, match="The closure hasn't been executed"): trainer.fit(model) - - -def test_closure_result_deepcopy(): - closure_loss = torch.tensor(123.45) - result = ClosureResult(closure_loss) - - assert closure_loss.data_ptr() == result.closure_loss.data_ptr() - # the `loss` is cloned so the storage is different - assert closure_loss.data_ptr() != result.loss.data_ptr() - - copy = result.drop_closure_loss() - assert result.loss == copy.loss - assert copy.closure_loss is None - - # no copy - assert id(result.loss) == id(copy.loss) - assert result.loss.data_ptr() == copy.loss.data_ptr() - - -def test_closure_result_apply_accumulation(): - closure_loss = torch.tensor(25.0) - result = ClosureResult.from_training_step_output(closure_loss, 5) - assert result.loss == 5 From b4a0204297c11dc92c0ed961222accaed54449fa Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 14 Sep 2021 18:05:54 +0200 Subject: [PATCH 40/47] Progress --- .../loops/batch/training_batch_loop.py | 6 ++- .../loops/epoch/training_epoch_loop.py | 15 ++---- .../loops/optimization/closure.py | 7 ++- .../loops/optimization/manual_loop.py | 10 ++-- .../loops/optimization/optimizer_loop.py | 22 +++++---- .../loops/optimization/test_optimizer_loop.py | 32 +++++++++++-- tests/loops/test_evaluation_loop_flow.py | 8 ++-- tests/loops/test_training_loop.py | 46 ------------------- tests/loops/test_training_loop_flow_scalar.py | 13 +++--- 9 files changed, 69 insertions(+), 90 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 8516e0464b153..b4b027c9e7bbb 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -131,8 +131,10 @@ def advance(self, batch, batch_idx): else: # in manual optimization, hand over execution to the ManualOptimization loop result = self.manual_loop.run(split_batch, batch_idx) - if result is not None and result.loss is not None: - self.batch_outputs[0].append(result.drop_closure_loss()) + if result is not None: + output = result.asdict() + if output: + self.batch_outputs[0].append(output) def on_run_end(self) -> None: self.optimizer_loop._hiddens = None diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 4fdf48c3dafc7..5365fc94b6eba 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -314,16 +314,11 @@ def _prepare_outputs( for batch_outputs in opt_outputs: processed_tbptt_outputs = [] - if isinstance(batch_outputs, OutputResult): - batch_outputs = [batch_outputs] - - for tbptt_output in batch_outputs: - # FIXME - out = {} - if tbptt_output.loss is not None: - out["loss"] = tbptt_output.loss - out.update(tbptt_output.extra) - processed_tbptt_outputs.append(out) + if isinstance(batch_outputs, dict): + processed_tbptt_outputs.append(batch_outputs) + else: + for tbptt_output in batch_outputs: + processed_tbptt_outputs.append(tbptt_output) # if there was only one tbptt step then we can collapse that dimension if len(processed_tbptt_outputs) == 1: diff --git a/pytorch_lightning/loops/optimization/closure.py b/pytorch_lightning/loops/optimization/closure.py index 74805dedc58cd..6cdc319665c80 100644 --- a/pytorch_lightning/loops/optimization/closure.py +++ b/pytorch_lightning/loops/optimization/closure.py @@ -27,7 +27,7 @@ @dataclass class OutputResult: @staticmethod - def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None: + def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> Dict[str, Any]: # TODO: remove with the deprecation removal in v1.6 # this is only here to avoid duplication def check_fn(v: Tensor) -> Tensor: @@ -40,7 +40,10 @@ def check_fn(v: Tensor) -> Tensor: return v.detach() return v - apply_to_collection(extra, Tensor, check_fn) + return apply_to_collection(extra, Tensor, check_fn) + + def asdict(self) -> Dict[str, Any]: + return {} class AbstractClosure(ABC, Generic[T]): diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 480c925c669b0..1a6f45d47fb9b 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -18,7 +18,7 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.loops.optimization.closure import OutputResult -from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens, check_finite_loss +from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -33,7 +33,7 @@ class ManualResult(OutputResult): extra: Anything returned by the ``training_step``. """ - extra: Dict[str, Tensor] = field(default_factory=dict) + extra: Dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: # TODO: remove with the deprecation removal in v1.6 @@ -61,6 +61,9 @@ def from_training_step_output( return cls(extra=extra) + def asdict(self) -> Dict[str, Any]: + return self.extra + class ManualOptimization(Loop): """A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens @@ -114,9 +117,6 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] result = ManualResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) - if self.trainer.terminate_on_nan: - check_finite_loss(result.closure_loss) - if self.trainer.move_metrics_to_cpu: # hiddens and the training step output are not moved as they are not considered "metrics" # the user might need them on the correct device for an operation in `training_epoch_end` diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 41047cb1b812f..79cd1ac636439 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -52,7 +52,7 @@ class ClosureResult(OutputResult): closure_loss: Optional[Tensor] loss: Optional[Tensor] = field(init=False, default=None) - extra: Dict[str, Tensor] = field(default_factory=dict) + extra: Dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: # TODO: remove with the deprecation removal in v1.6 @@ -74,13 +74,17 @@ def from_training_step_output( if isinstance(training_step_output, dict): # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` closure_loss = training_step_output.get("loss") + if closure_loss is None: + raise MisconfigurationException( + "In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present" + ) extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} elif isinstance(training_step_output, Tensor): closure_loss = training_step_output elif training_step_output is not None: raise MisconfigurationException( - "In automatic optimization, `training_step` must either return a Tensor, " - "a dict with key 'loss' or None (where the step will be skipped)." + "In automatic optimization, `training_step` must return a Tensor, " + "a dict, or None (where the step will be skipped)." ) if closure_loss is not None: @@ -89,10 +93,8 @@ def from_training_step_output( return cls(closure_loss, extra=extra) - def drop_closure_loss(self) -> "ClosureResult": - """Return itself without the closure loss which could have a `grad_fn`""" - self.closure_loss = None - return self + def asdict(self) -> Dict[str, Any]: + return {"loss": self.loss, **self.extra} class Closure(AbstractClosure[ClosureResult]): @@ -158,7 +160,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: return self._result.loss -_OUTPUTS_TYPE = List[List[ClosureResult]] +_OUTPUTS_TYPE = List[List[Dict[str, Any]]] class OptimizerLoop(Loop): @@ -203,7 +205,9 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor self.optim_progress.optimizer_idx, ) if result.loss is not None: - self.outputs[self.optim_progress.optimizer_idx].append(result.drop_closure_loss()) + # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch + # would be skipped otherwise + self.outputs[self.optim_progress.optimizer_idx].append(result.asdict()) self.optim_progress.optimizer_idx += 1 diff --git a/tests/loops/optimization/test_optimizer_loop.py b/tests/loops/optimization/test_optimizer_loop.py index e39f9e5367f1c..c0b532217b204 100644 --- a/tests/loops/optimization/test_optimizer_loop.py +++ b/tests/loops/optimization/test_optimizer_loop.py @@ -11,9 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import pytest import torch +from pytorch_lightning import Trainer from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel def test_closure_result_deepcopy(): @@ -24,16 +29,33 @@ def test_closure_result_deepcopy(): # the `loss` is cloned so the storage is different assert closure_loss.data_ptr() != result.loss.data_ptr() - copy = result.drop_closure_loss() - assert result.loss == copy.loss - assert copy.closure_loss is None + copy = result.asdict() + assert result.loss == copy["loss"] + assert copy.keys() == {"loss"} # no copy - assert id(result.loss) == id(copy.loss) - assert result.loss.data_ptr() == copy.loss.data_ptr() + assert id(result.loss) == id(copy["loss"]) + assert result.loss.data_ptr() == copy["loss"].data_ptr() def test_closure_result_apply_accumulation(): closure_loss = torch.tensor(25.0) result = ClosureResult.from_training_step_output(closure_loss, 5) assert result.loss == 5 + + +@pytest.mark.parametrize( + "case", [(5.0, "must return a Tensor, a dict, or None"), ({"a": 5}, "the 'loss' key needs to be present")] +) +def test_warning_invalid_trainstep_output(tmpdir, case): + output, match = case + + class InvalidTrainStepModel(BoringModel): + def training_step(self, batch, batch_idx): + return output + + model = InvalidTrainStepModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + + with pytest.raises(MisconfigurationException, match=match): + trainer.fit(model) diff --git a/tests/loops/test_evaluation_loop_flow.py b/tests/loops/test_evaluation_loop_flow.py index 94f71907cd12f..c95f5197b74b7 100644 --- a/tests/loops/test_evaluation_loop_flow.py +++ b/tests/loops/test_evaluation_loop_flow.py @@ -70,8 +70,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out.loss, torch.Tensor) - assert train_step_out.loss.item() == 171 + assert isinstance(train_step_out["loss"], torch.Tensor) + assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( @@ -135,8 +135,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out.loss, torch.Tensor) - assert train_step_out.loss.item() == 171 + assert isinstance(train_step_out["loss"], torch.Tensor) + assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index d4652fe0d2c7f..ae36e56495f93 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re import pytest import torch from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -129,25 +127,6 @@ def validation_step(self, *args): assert model.validation_called_at == (0, 4) -@pytest.mark.parametrize(["output"], [(5.0,), ({"a": 5},)]) -def test_warning_invalid_trainstep_output(tmpdir, output): - class InvalidTrainStepModel(BoringModel): - def training_step(self, batch, batch_idx): - return output - - model = InvalidTrainStepModel() - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) - with pytest.raises( - MisconfigurationException, - match=re.escape( - "In automatic optimization, `training_step` must either return a Tensor, " - "a dict with key 'loss' or None (where the step will be skipped)." - ), - ): - trainer.fit(model) - - def test_warning_valid_train_step_end(tmpdir): class ValidTrainStepEndModel(BoringModel): def training_step(self, batch, batch_idx): @@ -163,28 +142,3 @@ def training_step_end(self, outputs): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) trainer.fit(model) - - -def test_prepare_outputs(tmpdir): - """Test that the `extra` field of the saved `ResultCollection` objects for `training_epoch_end` doesn't get - accidentally modified by reference.""" - - class TestModel(BoringModel): - on_train_batch_end_called = 0 - - def on_train_batch_end(self, outputs, *args, **kwargs): - epoch_outputs = self.trainer.fit_loop.epoch_loop._epoch_output - epoch_outputs = epoch_outputs[0] # 1 optimizer - assert len(epoch_outputs) == self.on_train_batch_end_called - # `extra` should be empty for all `ResultCollection` objects - assert all(not out.extra for out in epoch_outputs) - self.on_train_batch_end_called += 1 - - def training_epoch_end(self, outputs) -> None: - # override so epoch outputs get stored - pass - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) - trainer.fit(model) - assert model.on_train_batch_end_called == 2 diff --git a/tests/loops/test_training_loop_flow_scalar.py b/tests/loops/test_training_loop_flow_scalar.py index e389bd21c4ac2..fd40985ea2ad9 100644 --- a/tests/loops/test_training_loop_flow_scalar.py +++ b/tests/loops/test_training_loop_flow_scalar.py @@ -153,8 +153,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out.loss, torch.Tensor) - assert train_step_out.loss.item() == 171 + assert isinstance(train_step_out["loss"], torch.Tensor) + assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( @@ -227,8 +227,8 @@ def backward(self, loss, optimizer, optimizer_idx): train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out.loss, torch.Tensor) - assert train_step_out.loss.item() == 171 + assert isinstance(train_step_out["loss"], torch.Tensor) + assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( @@ -249,17 +249,16 @@ def training_step(self, batch, batch_idx): self.log("a", loss, on_step=True, on_epoch=True) def training_epoch_end(self, outputs) -> None: - assert len(outputs) == 0 + assert len(outputs) == 0, outputs def validation_step(self, batch, batch_idx): self.validation_step_called = True def validation_epoch_end(self, outputs): - assert len(outputs) == 0 + assert len(outputs) == 0, outputs model = TestModel() trainer_args = dict(default_root_dir=tmpdir, fast_dev_run=2) - trainer = Trainer(**trainer_args) Closure.warning_cache.clear() From 237992e75338cccfaff6539bb257556404ae255f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 14 Sep 2021 18:15:35 +0200 Subject: [PATCH 41/47] Progress --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 5365fc94b6eba..bf6f594e768ac 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -312,14 +312,7 @@ def _prepare_outputs( opt_outputs = [opt_outputs] for batch_outputs in opt_outputs: - processed_tbptt_outputs = [] - - if isinstance(batch_outputs, dict): - processed_tbptt_outputs.append(batch_outputs) - else: - for tbptt_output in batch_outputs: - processed_tbptt_outputs.append(tbptt_output) - + processed_tbptt_outputs = batch_outputs if isinstance(batch_outputs, list) else [batch_outputs] # if there was only one tbptt step then we can collapse that dimension if len(processed_tbptt_outputs) == 1: processed_tbptt_outputs = processed_tbptt_outputs[0] From 48a42dc0b7ea73195ac04a51b6a2ae4676730b1b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 14 Sep 2021 18:19:11 +0200 Subject: [PATCH 42/47] Add test --- tests/loops/optimization/test_manual_loop.py | 34 ++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/loops/optimization/test_manual_loop.py diff --git a/tests/loops/optimization/test_manual_loop.py b/tests/loops/optimization/test_manual_loop.py new file mode 100644 index 0000000000000..4e9b2f3271e81 --- /dev/null +++ b/tests/loops/optimization/test_manual_loop.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel + + +def test_warning_invalid_trainstep_output(tmpdir): + class InvalidTrainStepModel(BoringModel): + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + return 5 + + model = InvalidTrainStepModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + + with pytest.raises(MisconfigurationException, match="return a Tensor, a dict with extras .* or have no return"): + trainer.fit(model) From e758c290db5b7c32524b825cd1e1b208cbf549c0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 14 Sep 2021 18:33:17 +0200 Subject: [PATCH 43/47] Make conversion in the manual loop --- pytorch_lightning/loops/batch/training_batch_loop.py | 6 ++---- pytorch_lightning/loops/optimization/manual_loop.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index b4b027c9e7bbb..663bdb04b3162 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -131,10 +131,8 @@ def advance(self, batch, batch_idx): else: # in manual optimization, hand over execution to the ManualOptimization loop result = self.manual_loop.run(split_batch, batch_idx) - if result is not None: - output = result.asdict() - if output: - self.batch_outputs[0].append(output) + if result: + self.batch_outputs[0].append(result) def on_run_end(self) -> None: self.optimizer_loop._hiddens = None diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 1a6f45d47fb9b..c55b527441f0d 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -78,7 +78,7 @@ def __init__(self) -> None: super().__init__() self._done: bool = False self._hiddens: Optional[Any] = None - self._output: Optional[ManualResult] = None + self._output: Optional[Dict[str, Any]] = None @property def done(self) -> bool: @@ -124,9 +124,9 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] self.trainer._results.cpu() self._done = True - self._output = result + self._output = result.asdict() - def on_run_end(self) -> Optional[ManualResult]: + def on_run_end(self) -> Optional[Dict[str, Any]]: """Returns the result of this loop, i.e., the post-processed outputs from the training step.""" output, self._output = self._output, None # free memory # #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance` From e9fdb959efe2dacc9c32d039926b404e32308a23 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 14 Sep 2021 18:35:01 +0200 Subject: [PATCH 44/47] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71192a8ee61a1..75f299e1a47ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191)) * Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265)) * Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266)) - * Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437)) + * Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424)) - Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187)) From 9c3d3febe91caddd539aa1d0aafe94d631cc1901 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 14 Sep 2021 18:42:49 +0200 Subject: [PATCH 45/47] Typing --- pytorch_lightning/loops/optimization/manual_loop.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index c55b527441f0d..9fbb6ca60e1d2 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -65,7 +65,10 @@ def asdict(self) -> Dict[str, Any]: return self.extra -class ManualOptimization(Loop): +_OUTPUTS_TYPE = Dict[str, Any] + + +class ManualOptimization(Loop[_OUTPUTS_TYPE]): """A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user is responsible for back-propagating gradients and making calls to the optimizers. @@ -78,7 +81,7 @@ def __init__(self) -> None: super().__init__() self._done: bool = False self._hiddens: Optional[Any] = None - self._output: Optional[Dict[str, Any]] = None + self._output: _OUTPUTS_TYPE = {} @property def done(self) -> bool: @@ -126,9 +129,9 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] self._done = True self._output = result.asdict() - def on_run_end(self) -> Optional[Dict[str, Any]]: + def on_run_end(self) -> _OUTPUTS_TYPE: """Returns the result of this loop, i.e., the post-processed outputs from the training step.""" - output, self._output = self._output, None # free memory + output, self._output = self._output, {} # free memory # #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance` # doesn't finish and `self._output` stays as `None`. If #9415 happens then this would always return a result return output From effc3c919ab9550f9d8ed1cb0fcb5f6822a53527 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Sep 2021 02:48:46 +0200 Subject: [PATCH 46/47] Add test. Raise NotImplementedError --- pytorch_lightning/loops/optimization/closure.py | 2 +- tests/loops/optimization/test_manual_loop.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/optimization/closure.py b/pytorch_lightning/loops/optimization/closure.py index 6cdc319665c80..48115208089e2 100644 --- a/pytorch_lightning/loops/optimization/closure.py +++ b/pytorch_lightning/loops/optimization/closure.py @@ -43,7 +43,7 @@ def check_fn(v: Tensor) -> Tensor: return apply_to_collection(extra, Tensor, check_fn) def asdict(self) -> Dict[str, Any]: - return {} + raise NotImplementedError class AbstractClosure(ABC, Generic[T]): diff --git a/tests/loops/optimization/test_manual_loop.py b/tests/loops/optimization/test_manual_loop.py index 4e9b2f3271e81..0a8a266bbc64c 100644 --- a/tests/loops/optimization/test_manual_loop.py +++ b/tests/loops/optimization/test_manual_loop.py @@ -12,12 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +import torch from pytorch_lightning import Trainer +from pytorch_lightning.loops.optimization.manual_loop import ManualResult from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel +def test_manual_result(): + training_step_output = {"loss": torch.tensor(25.0, requires_grad=True), "something": "jiraffe"} + result = ManualResult.from_training_step_output(training_step_output, normalize=5) + asdict = result.asdict() + assert not asdict["loss"].requires_grad + assert asdict["loss"] == 5 + assert result.extra == asdict + + def test_warning_invalid_trainstep_output(tmpdir): class InvalidTrainStepModel(BoringModel): def __init__(self): From fc91d13c248eabbd86cd3349148e75d930c27e34 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Sep 2021 13:23:18 +0200 Subject: [PATCH 47/47] Remove invalid test --- .../optimization/test_manual_optimization.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 6ba4aa3489c0a..38015c7c1138b 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -302,32 +302,6 @@ def test_manual_optimization_and_return_tensor(tmpdir): trainer.fit(model) -@RunIf(min_gpus=2) -def test_manual_optimization_and_return_detached_tensor(tmpdir): - """This test verify that in `manual_optimization` we don't add gradient when the user return loss in - `training_step` When the tensor is detached, return MisConfiguration Error.""" - - model = ManualOptimizationExtendedModel() - model.detach = True - model.training_step_end = None - model.training_epoch_end = None - - trainer = Trainer( - max_epochs=1, - default_root_dir=tmpdir, - limit_train_batches=10, - limit_test_batches=0, - limit_val_batches=0, - precision=16, - amp_backend="native", - accelerator="ddp_spawn", - gpus=2, - ) - expected_message = "In manual optimization, `training_step` should not return a Tensor" - with pytest.raises(Exception, match=expected_message): - trainer.fit(model) - - @RunIf(min_gpus=1) def test_manual_optimization_and_accumulated_gradient(tmpdir): """This test verify that in `automatic_optimization=False`, step is being called only when we shouldn't