diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b86fdaab097d..1ae7b07e77b96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950)) * Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401)) * Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537)) + * Added support for restarting within Evaluation Loop ([#9563](https://github.com/PyTorchLightning/pytorch-lightning/pull/9563)) * Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566)) * Support skipping to validation during fitting ([#9681](https://github.com/PyTorchLightning/pytorch-lightning/pull/9681)) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 4f58889b42c4a..b3a8ef764efc6 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Sequence, Union from deprecate.utils import void from torch.utils.data.dataloader import DataLoader @@ -19,6 +20,8 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection +from pytorch_lightning.utilities.auto_restart import reload_dataloader_state_dict +from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -34,6 +37,8 @@ def __init__(self): self._results = ResultCollection(training=False) self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False + self._data_fetcher: Optional[AbstractDataFetcher] = None + self._dataloader_state_dict: Dict[str, Any] = None @property def num_dataloaders(self) -> int: @@ -101,7 +106,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader_idx: int = self.current_dataloader_idx dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) + self._data_fetcher = dataloader = self.trainer.data_connector.get_profiled_dataloader( + dataloader, dataloader_idx=dataloader_idx + ) dl_max_batches = self._max_batches[dataloader_idx] @@ -121,6 +128,9 @@ def on_run_end(self) -> List[_OUT_DICT]: # free memory self.outputs = [] + # drop reference to iterator. + self._data_fetcher = None + # with a single dataloader don't pass a 2D list if len(outputs) > 0 and self.num_dataloaders == 1: outputs = outputs[0] @@ -167,6 +177,10 @@ def _reload_evaluation_dataloaders(self) -> None: elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch: self.trainer.reset_val_dataloader() + if not self.trainer.sanity_checking and self._dataloader_state_dict: + reload_dataloader_state_dict(self.dataloaders[self.current_dataloader_idx], self._dataloader_state_dict) + self._dataloader_state_dict = None + def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_start`` hooks.""" assert self._results is not None @@ -239,3 +253,13 @@ def _on_evaluation_epoch_end(self) -> None: self.trainer.call_hook(hook_name) self.trainer.call_hook("on_epoch_end") self.trainer.logger_connector.on_epoch_end() + + def on_save_checkpoint(self) -> Dict: + state_dict = super().on_save_checkpoint() + if self._data_fetcher is not None and self._data_fetcher.dataloader_iter is not None: + state_dict["dataloader_state_dict"] = asdict(self._data_fetcher.dataloader_iter.previous_state) + return state_dict + + def on_load_checkpoint(self, state_dict: Dict) -> None: + # cache the dataloader state dict until the dataloader objects are available + self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 9a4f7c510f303..333ded17398b6 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -249,7 +249,8 @@ def teardown(self) -> None: def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() # TODO: update has_completed to its proper value - state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False) + if self.trainer.train_dataloader is not None: + state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False) return state_dict def on_load_checkpoint(self, state_dict: Dict) -> None: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0fa9e8c219df0..ffe12a03bd75d 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -24,12 +24,9 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( - _find_fast_forward_samplers, - CaptureIterableDataset, - CaptureMapDataset, - IteratorState, MergedIteratorState, patch_dataloader_iterator, + reload_dataloader_state_dict, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -400,37 +397,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader - dataset = dataloader.dataset - - # We reload the states before creating the workers - # The specific type of dataset will then decide if the state should be applied before or after - # spawning the workers - if isinstance(dataset, CaptureMapDataset): - iterator_state = state_dict["state"][0] - - if not isinstance(iterator_state, IteratorState): - iterator_state = IteratorState.from_state_dict(iterator_state) - - # reload sampler state - ff_sampler = _find_fast_forward_samplers(dataloader) - ff_sampler.load_state_dict(iterator_state.sampler_state) - # reload dataset state - dataset.load_state_dict( - iterator_state.dataset_state, - latest_worker_id=state_dict["latest_worker_id"], - num_workers=iterator_state.num_workers, - ) - - elif isinstance(dataset, CaptureIterableDataset): - dataset_dict = { - sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items() - } - dataset.load_state_dict(dataset_dict) - - else: - raise MisconfigurationException( - "This shouldn't happen. Please, open an issue on PyTorch Lightning Github." - ) + reload_dataloader_state_dict(dataloader, state_dict) # We finally spawned the workers if any. it = iter(dataloader_to_iter_on) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 36db9e986b777..25fca786da7f4 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -27,6 +27,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _fault_tolerant_training class FastForwardSampler(Sampler): @@ -545,3 +546,37 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None: dataloader.collate_fn = partial( _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn ) + + +def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + """Utility to reload state_dict within dataloader for fault tolerance.""" + + if not _fault_tolerant_training(): + return + + dataset = dataloader.dataset + + if isinstance(dataset, CaptureMapDataset): + iterator_state = state_dict["state"][0] + + if not isinstance(iterator_state, IteratorState): + iterator_state = IteratorState.from_state_dict(iterator_state) + + # reload sampler state + ff_sampler = _find_fast_forward_samplers(dataloader) + ff_sampler.load_state_dict(iterator_state.sampler_state) + + # reload dataset state + dataset.load_state_dict( + iterator_state.dataset_state, + latest_worker_id=state_dict["latest_worker_id"], + num_workers=iterator_state.num_workers, + ) + + elif isinstance(dataset, CaptureIterableDataset): + dataset.load_state_dict( + {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} + ) + + else: + raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 5500fe5393f27..249afcbbb6c64 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -15,6 +15,7 @@ import os import random import random as python_random +from collections import defaultdict from collections.abc import Iterable from contextlib import suppress from copy import deepcopy @@ -970,3 +971,92 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult for w0, w1 in zip(weights0, weights1): assert w0 is not w1 assert torch.allclose(w0, w1) + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +@pytest.mark.parametrize( + ["train_datasets", "val_datasets"], + [ + ([RandomGetItemDataset], [RandomGetItemDataset]), + ([RandomGetItemDataset], [RandomGetItemDataset, RandomGetItemDataset]), + ], +) +@pytest.mark.parametrize( + "val_check_interval", + [ + pytest.param( + 0.5, + marks=pytest.mark.xfail( + reason=( + "TODO: the `train_dataloader` random state overrides the validation state when restarting training" + ) + ), + ), + 1.0, + ], +) +def test_auto_restart_within_validation_loop(train_datasets, val_datasets, val_check_interval, tmpdir): + n_val_dataloaders = len(val_datasets) + stop_dataloader = n_val_dataloaders - 1 + stop_batch = 1 + + class ValidationLoopTestModel(LightningModule): + def __init__(self, should_fail): + super().__init__() + self.layer = torch.nn.Linear(1, 2) + self.should_fail = should_fail + self.training_batches = [] + self.validation_batches = defaultdict(list) + + def step(self, batch): + return sum(self.layer(b).sum() for b in batch) + + def training_step(self, batch, batch_idx): + self.training_batches.append(batch) + return self.step(batch) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + if self.should_fail and stop_dataloader == dataloader_idx and batch_idx == stop_batch: + raise CustomException + self.validation_batches[dataloader_idx].append(batch) + return self.step(batch) + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + def train_dataloader(self): + return [DataLoader(cls(4, 1)) for cls in train_datasets] + + def val_dataloader(self): + return [DataLoader(cls(4, 1)) for cls in val_datasets] + + def run(should_fail, resume): + if not resume: + seed_everything(42) + + model = ValidationLoopTestModel(should_fail) + + resume_from_checkpoint = str(tmpdir / ".pl_auto_save.ckpt") if resume else None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + val_check_interval=val_check_interval, + num_sanity_val_steps=0, + resume_from_checkpoint=resume_from_checkpoint, + ) + if should_fail: + with pytest.raises(CustomException): + trainer.fit(model) + else: + trainer.fit(model) + + return model.training_batches, model.validation_batches + + total_train_batches, total_val_batches = run(should_fail=False, resume=False) + pre_fail_train_batches, pre_fail_val_batches = run(should_fail=True, resume=False) + post_fail_train_batches, post_fail_val_batches = run(should_fail=False, resume=True) + + torch.testing.assert_allclose(total_train_batches, pre_fail_train_batches + post_fail_train_batches) + for k in total_val_batches: + torch.testing.assert_allclose(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k])