From 09cee73e462a3ee4da5d4a3921a56f25c7aceca6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 29 Jun 2021 13:32:30 +0100 Subject: [PATCH 01/14] add state_dict, loop_dict on loops --- pytorch_lightning/loops/base.py | 17 +++++++- .../loops/batch/training_batch_loop.py | 8 ++++ .../loops/dataloader/evaluation_loop.py | 10 ++++- .../loops/epoch/training_epoch_loop.py | 13 +++++- pytorch_lightning/loops/fit_loop.py | 24 ++++++++++- tests/loops/__init__.py | 0 tests/loops/test_loops.py | 41 +++++++++++++++++++ 7 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 tests/loops/__init__.py create mode 100644 tests/loops/test_loops.py diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1d976aa3cd079..e75910312ed76 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,11 +13,12 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Dict, Optional from deprecate import void import pytorch_lightning as pl +from pytorch_lightning.utilities.exceptions import MisconfigurationException class Loop(ABC): @@ -46,6 +47,10 @@ def __init__(self) -> None: self.iteration_count: int = 0 self.trainer: Optional['pl.Trainer'] = None + @property + def is_connected(self) -> bool: + return self.trainer is not None + @property @abstractmethod def done(self) -> bool: @@ -59,6 +64,10 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects Loop with all the necessary things like connectors and accelerators.""" # TODO(@justusschock): Make the trainer a weakref/proxy + if not isinstance(trainer, pl.Trainer): + raise MisconfigurationException( + f"Loop {self.__class__.__name__} should be connected to a :class:`~pytorch_lightning.Trainer` instance." + ) self.trainer = trainer def on_skip(self) -> Optional[Any]: @@ -128,3 +137,9 @@ def on_run_end(self) -> Any: def teardown(self) -> None: """The very last method called inside :meth:`run`. Use to release memory etc.""" + + def load_state_dict(self, state_dict: Dict) -> None: + """Restore the loop state from the provided state_dict.""" + + def state_dict(self) -> Dict: + """Return the loop current states.""" diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 76051fc3f1e94..4c4b01823db04 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -40,6 +40,8 @@ class TrainingBatchLoop(Loop): """ Runs over a single batch of data. """ + name = "batch_loop" + def __init__(self) -> None: super().__init__() self.accumulated_loss: Optional[Tensor] = None @@ -674,3 +676,9 @@ def _truncated_bptt_steps(self) -> int: if lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps return self.trainer.truncated_bptt_steps or 0 + + def state_dict(self) -> Dict: + return {} + + def load_state_dict(self, state_dict: Dict) -> None: + raise NotImplementedError diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index bdee54a174891..e76fc99ae6822 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union from deprecate.utils import void from torch.utils.data.dataloader import DataLoader @@ -29,6 +29,8 @@ class EvaluationLoop(DataLoaderLoop): """Loops over all dataloaders for evaluation.""" + name = "val_loop" + def __init__(self): super().__init__() self._max_batches: Optional[Union[int, Sequence[int]]] = None @@ -266,3 +268,9 @@ def on_evaluation_epoch_end(self) -> None: self.trainer.call_hook(hook_name) self.trainer.call_hook("on_epoch_end") self.trainer.logger_connector.on_epoch_end() + + def state_dict(self) -> Dict: + return {} + + def load_state_dict(self, state_dict: Dict) -> None: + raise NotImplementedError diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index cd8b992b09d45..e3cc0bd61c472 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -30,6 +30,8 @@ class TrainingEpochLoop(loops.Loop): """ Runs over all batches in a dataloader (one epoch). """ + name = "epoch_loop" + def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps @@ -47,8 +49,8 @@ def __init__(self, min_steps: int, max_steps: int): self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None - self.batch_loop: Optional[TrainingBatchLoop] = None - self.val_loop: Optional[loops.EvaluationLoop] = None + self.batch_loop: Optional[TrainingBatchLoop] = TrainingBatchLoop() + self.val_loop: Optional[loops.EvaluationLoop] = loops.EvaluationLoop() self._dataloader_idx: Optional[int] = None self._warning_cache: WarningCache = WarningCache() @@ -425,3 +427,10 @@ def _save_loggers_on_train_batch_end(self) -> None: should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() + + def state_dict(self) -> Dict: + return {self.batch_loop.name: self.batch_loop.state_dict(), self.val_loop.name: self.val_loop.state_dict()} + + def load_state_dict(self, state_dict: Dict) -> None: + self.batch_loop.load_state_dict(state_dict[self.batch_loop.name]) + self.val_loop.load_state_dict(state_dict[self.val_loop.name]) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 655e102466931..109a18c5d8d3a 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Any, Optional +from typing import Any, Dict, Optional import pytorch_lightning as pl from pytorch_lightning.loops import Loop @@ -22,6 +22,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) @@ -50,7 +51,9 @@ def __init__( super().__init__() self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) + self.epoch_loop: Optional[TrainingEpochLoop] = None + self._min_steps = min_steps + self._max_steps = max_steps @property def results(self) -> ResultCollection: @@ -97,6 +100,13 @@ def min_steps(self) -> int: """Returns the minimum numnber of steps to run""" return self.epoch_loop.min_steps + @min_steps.setter + def min_steps(self, value: int) -> None: + """Sets the minimum number of steps (forwards to epoch_loop)""" + # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided + self.epoch_loop.min_steps = value + self._min_steps = value + @property def max_steps(self) -> int: """Returns the maximum number of steps to run""" @@ -107,6 +117,7 @@ def max_steps(self, value: int) -> None: """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.epoch_loop.max_steps = value + self._max_steps = value @property def running_loss(self) -> TensorRunningAccum: @@ -159,6 +170,7 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) + self.epoch_loop = TrainingEpochLoop(self._min_steps, self._max_steps) self.epoch_loop.connect(trainer) def reset(self) -> None: @@ -274,3 +286,11 @@ def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False) for cb in callbacks: cb.on_validation_end(self.trainer, model) + + def state_dict(self) -> Dict: + if not self.is_connected: + raise MisconfigurationException("The Trainer should be connected to loop to retrieve the state_dict.") + return {self.epoch_loop.name: self.epoch_loop.state_dict()} + + def load_state_dict(self, state_dict: Dict) -> None: + self.epoch_loop.load_state_dict(state_dict[self.epoch_loop.name]) diff --git a/tests/loops/__init__.py b/tests/loops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py new file mode 100644 index 0000000000000..41479eeaecab9 --- /dev/null +++ b/tests/loops/test_loops.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import ANY + +import pytest + +from pytorch_lightning.loops import FitLoop +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def test_loops_state_dict_structure(): + + fit_loop = FitLoop() + with pytest.raises( + MisconfigurationException, match="The Trainer should be connected to loop to retrieve the state_dict." + ): + state_dict = fit_loop.state_dict() + with pytest.raises( + MisconfigurationException, + match="Loop FitLoop should be connected to a :class:`~pytorch_lightning.Trainer` instance." + ): + fit_loop.connect(object()) + fit_loop.connect(Trainer()) + state_dict = fit_loop.state_dict() + expected = {'epoch_loop': {'batch_loop': ANY, 'val_loop': ANY}} + assert state_dict == expected + + with pytest.raises(NotImplementedError): + fit_loop.load_state_dict(state_dict) From efd2e50561f21795c8360bf392a8ab130e1c927c Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 29 Jun 2021 13:37:54 +0100 Subject: [PATCH 02/14] update --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 1 + pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- pytorch_lightning/loops/fit_loop.py | 8 ++------ 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index e76fc99ae6822..e91feb3f8f453 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -72,6 +72,7 @@ def predictions(self): def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop to everything necessary (like trainer and accelerators)""" super().connect(trainer, *args, **kwargs) + self.epoch_loop = EvaluationEpochLoop() self.epoch_loop.connect(trainer) @property diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index e3cc0bd61c472..71012f2518318 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -49,8 +49,8 @@ def __init__(self, min_steps: int, max_steps: int): self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None - self.batch_loop: Optional[TrainingBatchLoop] = TrainingBatchLoop() - self.val_loop: Optional[loops.EvaluationLoop] = loops.EvaluationLoop() + self.batch_loop = TrainingBatchLoop() + self.val_loop = loops.EvaluationLoop() self._dataloader_idx: Optional[int] = None self._warning_cache: WarningCache = WarningCache() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 109a18c5d8d3a..997c12a896b9b 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -51,9 +51,7 @@ def __init__( super().__init__() self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.epoch_loop: Optional[TrainingEpochLoop] = None - self._min_steps = min_steps - self._max_steps = max_steps + self.epoch_loop: Optional[TrainingEpochLoop] = TrainingEpochLoop(min_steps, max_steps) @property def results(self) -> ResultCollection: @@ -105,7 +103,6 @@ def min_steps(self, value: int) -> None: """Sets the minimum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.epoch_loop.min_steps = value - self._min_steps = value @property def max_steps(self) -> int: @@ -117,7 +114,6 @@ def max_steps(self, value: int) -> None: """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.epoch_loop.max_steps = value - self._max_steps = value @property def running_loss(self) -> TensorRunningAccum: @@ -170,7 +166,7 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - self.epoch_loop = TrainingEpochLoop(self._min_steps, self._max_steps) + self.epoch_loop = TrainingEpochLoop(self.min_steps, self.max_steps) self.epoch_loop.connect(trainer) def reset(self) -> None: From 0430370d67573bd5710a87a3f8463b0d1d32e828 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 29 Jun 2021 13:40:37 +0100 Subject: [PATCH 03/14] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 158f89c128f63..c87de4d5ec956 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,6 +121,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014)) +- Added `state_dict` and `load_state_dict` function to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197)) + ### Changed From 21190e2810e4394ef5fcda55329aff863c174c00 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 30 Jun 2021 13:03:13 +0100 Subject: [PATCH 04/14] resolve on comments --- .../loops/batch/training_batch_loop.py | 4 +--- .../loops/dataloader/evaluation_loop.py | 4 +--- .../loops/epoch/training_epoch_loop.py | 8 +++---- pytorch_lightning/loops/fit_loop.py | 6 ++--- pytorch_lightning/trainer/properties.py | 9 +++++++- tests/loops/test_loops.py | 22 ++++++++++++++++--- 6 files changed, 35 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 4c4b01823db04..6ef4035721d2d 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -40,8 +40,6 @@ class TrainingBatchLoop(Loop): """ Runs over a single batch of data. """ - name = "batch_loop" - def __init__(self) -> None: super().__init__() self.accumulated_loss: Optional[Tensor] = None @@ -681,4 +679,4 @@ def state_dict(self) -> Dict: return {} def load_state_dict(self, state_dict: Dict) -> None: - raise NotImplementedError + pass diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index e91feb3f8f453..c804394e68a95 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -29,8 +29,6 @@ class EvaluationLoop(DataLoaderLoop): """Loops over all dataloaders for evaluation.""" - name = "val_loop" - def __init__(self): super().__init__() self._max_batches: Optional[Union[int, Sequence[int]]] = None @@ -274,4 +272,4 @@ def state_dict(self) -> Dict: return {} def load_state_dict(self, state_dict: Dict) -> None: - raise NotImplementedError + pass diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 71012f2518318..3a660a9bb7004 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -30,8 +30,6 @@ class TrainingEpochLoop(loops.Loop): """ Runs over all batches in a dataloader (one epoch). """ - name = "epoch_loop" - def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps @@ -429,8 +427,8 @@ def _save_loggers_on_train_batch_end(self) -> None: self.trainer.logger.save() def state_dict(self) -> Dict: - return {self.batch_loop.name: self.batch_loop.state_dict(), self.val_loop.name: self.val_loop.state_dict()} + return {"batch_loop": self.batch_loop.state_dict(), "validation_loop": self.val_loop.state_dict()} def load_state_dict(self, state_dict: Dict) -> None: - self.batch_loop.load_state_dict(state_dict[self.batch_loop.name]) - self.val_loop.load_state_dict(state_dict[self.val_loop.name]) + self.batch_loop.load_state_dict(state_dict["batch_loop"]) + self.val_loop.load_state_dict(state_dict["validation_loop"]) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 997c12a896b9b..b19caa268d230 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -51,7 +51,7 @@ def __init__( super().__init__() self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.epoch_loop: Optional[TrainingEpochLoop] = TrainingEpochLoop(min_steps, max_steps) + self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) @property def results(self) -> ResultCollection: @@ -286,7 +286,7 @@ def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False) def state_dict(self) -> Dict: if not self.is_connected: raise MisconfigurationException("The Trainer should be connected to loop to retrieve the state_dict.") - return {self.epoch_loop.name: self.epoch_loop.state_dict()} + return {"epoch_loop": self.epoch_loop.state_dict()} def load_state_dict(self, state_dict: Dict) -> None: - self.epoch_loop.load_state_dict(state_dict[self.epoch_loop.name]) + self.epoch_loop.load_state_dict(state_dict["epoch_loop"]) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ea1164bdee861..ba053be1c06df 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -16,7 +16,7 @@ from abc import ABC from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import cast, List, Optional, Type, TypeVar, Union +from typing import Any, cast, Dict, List, Optional, Type, TypeVar, Union import torch from torch.optim import Optimizer @@ -555,6 +555,13 @@ def _results(self) -> Optional[ResultCollection]: if active_loop is not None: return active_loop.results + def get_loops_state_dict(self) -> Dict[str, Any]: + return { + "fit_loop": self.fit_loop.state_dict(), + "validate_loop": self.validation_loop.state_dict(), + "test_loop": self.test_loop.state_dict(), + } + """ Other """ diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 41479eeaecab9..bf8ca21e0df6d 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -34,8 +34,24 @@ def test_loops_state_dict_structure(): fit_loop.connect(object()) fit_loop.connect(Trainer()) state_dict = fit_loop.state_dict() - expected = {'epoch_loop': {'batch_loop': ANY, 'val_loop': ANY}} + expected = {'epoch_loop': {'batch_loop': ANY, 'validation_loop': ANY}} assert state_dict == expected - with pytest.raises(NotImplementedError): - fit_loop.load_state_dict(state_dict) + fit_loop.load_state_dict(state_dict) + + +def test_loops_state_dict_structure_with_trainer(): + + trainer = Trainer() + state_dict = trainer.get_loops_state_dict() + expected = { + "fit_loop": { + 'epoch_loop': { + 'batch_loop': ANY, + 'validation_loop': ANY + } + }, + "validate_loop": ANY, + "test_loop": ANY + } + assert state_dict == expected From b92c681db5e988366ce534681cf790ebea363776 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 30 Jun 2021 17:40:36 +0100 Subject: [PATCH 05/14] update on comments --- pytorch_lightning/loops/base.py | 7 ++----- .../loops/batch/training_batch_loop.py | 6 ------ .../loops/dataloader/evaluation_loop.py | 9 +-------- .../loops/epoch/training_epoch_loop.py | 6 ++---- pytorch_lightning/loops/fit_loop.py | 4 ---- pytorch_lightning/trainer/properties.py | 16 ++++++++-------- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/loops/test_loops.py | 14 ++++---------- 8 files changed, 19 insertions(+), 47 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index e75910312ed76..1edc997e715ce 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -47,10 +47,6 @@ def __init__(self) -> None: self.iteration_count: int = 0 self.trainer: Optional['pl.Trainer'] = None - @property - def is_connected(self) -> bool: - return self.trainer is not None - @property @abstractmethod def done(self) -> bool: @@ -66,7 +62,7 @@ def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: # TODO(@justusschock): Make the trainer a weakref/proxy if not isinstance(trainer, pl.Trainer): raise MisconfigurationException( - f"Loop {self.__class__.__name__} should be connected to a :class:`~pytorch_lightning.Trainer` instance." + f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." ) self.trainer = trainer @@ -143,3 +139,4 @@ def load_state_dict(self, state_dict: Dict) -> None: def state_dict(self) -> Dict: """Return the loop current states.""" + return {} diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 6ef4035721d2d..76051fc3f1e94 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -674,9 +674,3 @@ def _truncated_bptt_steps(self) -> int: if lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps return self.trainer.truncated_bptt_steps or 0 - - def state_dict(self) -> Dict: - return {} - - def load_state_dict(self, state_dict: Dict) -> None: - pass diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index c804394e68a95..bdee54a174891 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Union from deprecate.utils import void from torch.utils.data.dataloader import DataLoader @@ -70,7 +70,6 @@ def predictions(self): def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop to everything necessary (like trainer and accelerators)""" super().connect(trainer, *args, **kwargs) - self.epoch_loop = EvaluationEpochLoop() self.epoch_loop.connect(trainer) @property @@ -267,9 +266,3 @@ def on_evaluation_epoch_end(self) -> None: self.trainer.call_hook(hook_name) self.trainer.call_hook("on_epoch_end") self.trainer.logger_connector.on_epoch_end() - - def state_dict(self) -> Dict: - return {} - - def load_state_dict(self, state_dict: Dict) -> None: - pass diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3a660a9bb7004..89891c0d6148a 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -80,9 +80,7 @@ def done(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with all necessary parts like trainer and accelerators""" super().connect(trainer, *args, **kwargs) - self.batch_loop = TrainingBatchLoop() self.batch_loop.connect(trainer) - self.val_loop = loops.EvaluationLoop() self.val_loop.connect(trainer) def reset(self) -> None: @@ -427,8 +425,8 @@ def _save_loggers_on_train_batch_end(self) -> None: self.trainer.logger.save() def state_dict(self) -> Dict: - return {"batch_loop": self.batch_loop.state_dict(), "validation_loop": self.val_loop.state_dict()} + return {"batch_loop": self.batch_loop.state_dict(), "val_loop": self.val_loop.state_dict()} def load_state_dict(self, state_dict: Dict) -> None: self.batch_loop.load_state_dict(state_dict["batch_loop"]) - self.val_loop.load_state_dict(state_dict["validation_loop"]) + self.val_loop.load_state_dict(state_dict["val_loop"]) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index b19caa268d230..bf42663fd5c9e 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -22,7 +22,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info -from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) @@ -166,7 +165,6 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - self.epoch_loop = TrainingEpochLoop(self.min_steps, self.max_steps) self.epoch_loop.connect(trainer) def reset(self) -> None: @@ -284,8 +282,6 @@ def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False) cb.on_validation_end(self.trainer, model) def state_dict(self) -> Dict: - if not self.is_connected: - raise MisconfigurationException("The Trainer should be connected to loop to retrieve the state_dict.") return {"epoch_loop": self.epoch_loop.state_dict()} def load_state_dict(self, state_dict: Dict) -> None: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ba053be1c06df..d991bd058f342 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -63,7 +63,7 @@ class TrainerProperties(ABC): logger_connector: LoggerConnector state: TrainerState fit_loop: FitLoop - validation_loop: EvaluationLoop + validate_loop: EvaluationLoop test_loop: EvaluationLoop """ Accelerator properties @@ -555,13 +555,6 @@ def _results(self) -> Optional[ResultCollection]: if active_loop is not None: return active_loop.results - def get_loops_state_dict(self) -> Dict[str, Any]: - return { - "fit_loop": self.fit_loop.state_dict(), - "validate_loop": self.validation_loop.state_dict(), - "test_loop": self.test_loop.state_dict(), - } - """ Other """ @@ -575,6 +568,13 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__ = state + def get_loops_state_dict(self) -> Dict[str, Any]: + return { + "fit_loop": self.fit_loop.state_dict(), + "validate_loop": self.validate_loop.state_dict(), + "test_loop": self.test_loop.state_dict(), + } + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dcf04760c636b..f322a72170a48 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -344,11 +344,11 @@ def __init__( self.tuner = Tuner(self) self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) - self.validation_loop = EvaluationLoop() + self.validate_loop = EvaluationLoop() self.test_loop = EvaluationLoop() self.predict_loop = PredictionLoop() self.fit_loop.connect(self) - self.validation_loop.connect(self) + self.validate_loop.connect(self) self.test_loop.connect(self) self.predict_loop.connect(self) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index bf8ca21e0df6d..1ac3e9b7573f4 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -23,18 +23,12 @@ def test_loops_state_dict_structure(): fit_loop = FitLoop() - with pytest.raises( - MisconfigurationException, match="The Trainer should be connected to loop to retrieve the state_dict." - ): - state_dict = fit_loop.state_dict() - with pytest.raises( - MisconfigurationException, - match="Loop FitLoop should be connected to a :class:`~pytorch_lightning.Trainer` instance." - ): + state_dict = fit_loop.state_dict() + with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): fit_loop.connect(object()) fit_loop.connect(Trainer()) state_dict = fit_loop.state_dict() - expected = {'epoch_loop': {'batch_loop': ANY, 'validation_loop': ANY}} + expected = {'epoch_loop': {'batch_loop': ANY, 'val_loop': ANY}} assert state_dict == expected fit_loop.load_state_dict(state_dict) @@ -48,7 +42,7 @@ def test_loops_state_dict_structure_with_trainer(): "fit_loop": { 'epoch_loop': { 'batch_loop': ANY, - 'validation_loop': ANY + 'val_loop': ANY } }, "validate_loop": ANY, From 3d748cb83e614047db555bf5e2d72db8d8c6874d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 30 Jun 2021 18:52:04 +0200 Subject: [PATCH 06/14] Update tests and CHANGELOG --- CHANGELOG.md | 10 ++++------ tests/loops/test_loops.py | 22 +++++++++------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ffdc9318d6e3e..582d359d4d675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,13 +84,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training - * Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + * Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + * Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197)) -- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) +- Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) -- Add `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) +- Added `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) - Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734)) @@ -123,9 +124,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014)) -- Added `state_dict` and `load_state_dict` function to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197)) - - - Added `should_raise_exception` parameter to `parse_gpu_ids`, `parse_tpu_cores` and `_sanitize_gpu_ids` utility functions ([#8194](https://github.com/PyTorchLightning/pytorch-lightning/pull/8194)) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 1ac3e9b7573f4..c5e8c959e8509 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import ANY import pytest @@ -21,31 +20,28 @@ def test_loops_state_dict_structure(): - fit_loop = FitLoop() - state_dict = fit_loop.state_dict() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): - fit_loop.connect(object()) + fit_loop.connect(object()) # noqa + fit_loop.connect(Trainer()) state_dict = fit_loop.state_dict() - expected = {'epoch_loop': {'batch_loop': ANY, 'val_loop': ANY}} - assert state_dict == expected - - fit_loop.load_state_dict(state_dict) + new_fit_loop = FitLoop() + new_fit_loop.load_state_dict(state_dict) + assert fit_loop.state_dict() == new_fit_loop.state_dict() def test_loops_state_dict_structure_with_trainer(): - trainer = Trainer() state_dict = trainer.get_loops_state_dict() expected = { "fit_loop": { 'epoch_loop': { - 'batch_loop': ANY, - 'val_loop': ANY + 'batch_loop': {}, + 'val_loop': {}, } }, - "validate_loop": ANY, - "test_loop": ANY + "validate_loop": {}, + "test_loop": {}, } assert state_dict == expected From dad9f13e93a4ac1909ffc5d1551cd45724fec825 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 30 Jun 2021 18:55:51 +0200 Subject: [PATCH 07/14] Move code and rename --- pytorch_lightning/trainer/properties.py | 14 +++++++------- tests/loops/test_loops.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index d991bd058f342..e70b14c5a81af 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -484,6 +484,13 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self.state.stage = None + def loops_state_dict(self) -> Dict[str, Any]: + return { + "fit_loop": self.fit_loop.state_dict(), + "validate_loop": self.validate_loop.state_dict(), + "test_loop": self.test_loop.state_dict(), + } + """ Loop properties """ @@ -568,13 +575,6 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__ = state - def get_loops_state_dict(self) -> Dict[str, Any]: - return { - "fit_loop": self.fit_loop.state_dict(), - "validate_loop": self.validate_loop.state_dict(), - "test_loop": self.test_loop.state_dict(), - } - # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index c5e8c959e8509..cb4a9398cf0e1 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -33,7 +33,7 @@ def test_loops_state_dict_structure(): def test_loops_state_dict_structure_with_trainer(): trainer = Trainer() - state_dict = trainer.get_loops_state_dict() + state_dict = trainer.loops_state_dict() expected = { "fit_loop": { 'epoch_loop': { From cf5a260e499daad7a2e9875d24e8866366ca2720 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 30 Jun 2021 19:41:27 +0200 Subject: [PATCH 08/14] Rename --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 88f089224ff2e..5a950d40f8f54 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -316,5 +316,5 @@ def progress_bar_metrics(self) -> Dict[str, float]: def teardown(self): self.trainer.fit_loop.epoch_loop._results.cpu() self.trainer.fit_loop.epoch_loop.val_loop._results.cpu() - self.trainer.validation_loop._results.cpu() + self.trainer.validate_loop._results.cpu() self.trainer.test_loop._results.cpu() From c6d026a1c04f924dfbc7869723839e7960c7dde8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 30 Jun 2021 19:56:47 +0200 Subject: [PATCH 09/14] Rename --- pytorch_lightning/trainer/progress.py | 2 +- pytorch_lightning/trainer/properties.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index caf4ab0bf1599..2d7a1d7e8f53a 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -239,7 +239,7 @@ class TrainingEpochProgress(EpochProgress): current: Tracks the current epoch progress. batch: Tracks batch progress. optim: Tracks optimization progress. - val: Tracks validation_loop progress. + val: Tracks val_loop progress. """ optim: OptimizationProgress = field(default_factory=OptimizationProgress) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index e70b14c5a81af..baed75c9c51ac 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -500,7 +500,7 @@ def evaluation_loop(self) -> EvaluationLoop: if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): return self.fit_loop.epoch_loop.val_loop elif self.state.fn == TrainerFn.VALIDATING: - return self.validation_loop + return self.validate_loop if self.state.fn == TrainerFn.TESTING: return self.test_loop raise RuntimeError("The `Trainer.evaluation_loop` property isn't defined. Accessed outside of scope") From 95a40730cb56f743af3445b91f319e290a81c696 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 30 Jun 2021 19:37:29 +0100 Subject: [PATCH 10/14] change test file name --- tests/loops/{test_loops.py => test_loop_state_dict.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/loops/{test_loops.py => test_loop_state_dict.py} (100%) diff --git a/tests/loops/test_loops.py b/tests/loops/ test_loop_state_dict.py similarity index 100% rename from tests/loops/test_loops.py rename to tests/loops/ test_loop_state_dict.py From d2576fb826b246a421c1c4dc6e76330bd21fcd91 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Jul 2021 10:24:32 +0100 Subject: [PATCH 11/14] rename file --- tests/loops/{ test_loop_state_dict.py => test_loop_state_dict.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/loops/{ test_loop_state_dict.py => test_loop_state_dict.py} (100%) diff --git a/tests/loops/ test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py similarity index 100% rename from tests/loops/ test_loop_state_dict.py rename to tests/loops/test_loop_state_dict.py From 90663d65f3aafb7ac18b673cb8d0e9ca9be070a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 1 Jul 2021 11:48:05 +0200 Subject: [PATCH 12/14] Address comments --- pytorch_lightning/trainer/properties.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index baed75c9c51ac..c3f65d258c8bf 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -484,13 +484,6 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self.state.stage = None - def loops_state_dict(self) -> Dict[str, Any]: - return { - "fit_loop": self.fit_loop.state_dict(), - "validate_loop": self.validate_loop.state_dict(), - "test_loop": self.test_loop.state_dict(), - } - """ Loop properties """ From d207caa1797f93fbdc293fb82c2293ca4c453e80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 1 Jul 2021 11:49:28 +0200 Subject: [PATCH 13/14] Update pytorch_lightning/trainer/properties.py --- pytorch_lightning/trainer/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index c3f65d258c8bf..b59066cb03b17 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -16,7 +16,7 @@ from abc import ABC from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Any, cast, Dict, List, Optional, Type, TypeVar, Union +from typing import cast, List, Optional, Type, TypeVar, Union import torch from torch.optim import Optimizer From b814c191752fbdcbb77dc837da55573d03bb776e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 1 Jul 2021 11:53:32 +0200 Subject: [PATCH 14/14] update test --- tests/loops/test_loop_state_dict.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index cb4a9398cf0e1..1930dc46566fd 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -19,7 +19,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException -def test_loops_state_dict_structure(): +def test_loops_state_dict(): fit_loop = FitLoop() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): fit_loop.connect(object()) # noqa @@ -31,9 +31,15 @@ def test_loops_state_dict_structure(): assert fit_loop.state_dict() == new_fit_loop.state_dict() -def test_loops_state_dict_structure_with_trainer(): +def test_loops_state_dict_structure(): trainer = Trainer() - state_dict = trainer.loops_state_dict() + # structure saved by the checkpoint connector + state_dict = { + "fit_loop": trainer.fit_loop.state_dict(), + "validate_loop": trainer.validate_loop.state_dict(), + "test_loop": trainer.test_loop.state_dict(), + "predict_loop": trainer.predict_loop.state_dict(), + } expected = { "fit_loop": { 'epoch_loop': { @@ -43,5 +49,6 @@ def test_loops_state_dict_structure_with_trainer(): }, "validate_loop": {}, "test_loop": {}, + "predict_loop": {}, } assert state_dict == expected