Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add Fault Tolerant Training for ValidationLoop. #9563

Merged
merged 66 commits into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
8091491
cleanup
tchaton Sep 10, 2021
9ded53b
wip
tchaton Sep 13, 2021
a15b60e
wip
tchaton Sep 13, 2021
8be96ed
update
tchaton Sep 13, 2021
1d74c45
update
tchaton Sep 13, 2021
63e13e6
update
tchaton Sep 13, 2021
2163185
update
tchaton Sep 13, 2021
831368d
wip
tchaton Sep 13, 2021
ec38605
update
tchaton Sep 13, 2021
516cae7
resolve tests
tchaton Sep 14, 2021
48dfb1e
add support validation
tchaton Sep 14, 2021
138ba8d
tiny cleanup
tchaton Sep 14, 2021
0d24120
update changelog
tchaton Sep 14, 2021
528d9f4
update
tchaton Sep 14, 2021
726614e
Merge branch 'master' into fault_tolerant_validation
tchaton Sep 14, 2021
f8abada
update
tchaton Sep 14, 2021
4f5db67
update
tchaton Sep 14, 2021
9dcc5ef
Merge branch 'master' into fault_tolerant_validation
tchaton Sep 14, 2021
5ef80de
update on comments
tchaton Sep 14, 2021
8582f12
Merge branch 'master' into fault_tolerant_validation
tchaton Sep 14, 2021
0ccbf07
update
tchaton Sep 14, 2021
9fea22b
update
tchaton Sep 14, 2021
63ac346
update
tchaton Sep 14, 2021
11e38aa
update test
tchaton Sep 14, 2021
5c820a9
update
tchaton Sep 15, 2021
f2675a0
update
tchaton Sep 16, 2021
262900d
move reset_on_restart in the loop
tchaton Sep 16, 2021
3668fed
update changelog
tchaton Sep 16, 2021
3defdb1
update
tchaton Sep 16, 2021
e9a46d0
update
tchaton Sep 16, 2021
c0c0c85
update
tchaton Sep 16, 2021
e1ad52e
update
tchaton Sep 16, 2021
c4b86ab
Minor test changes
carmocca Sep 16, 2021
ad334f5
Update CHANGELOG.md
tchaton Sep 16, 2021
c672030
updte
tchaton Sep 16, 2021
4a110b5
Merge branch 'move_restart_with_loops' of https://github.com/PyTorchL…
tchaton Sep 16, 2021
122766c
Merge branch 'move_restart_with_loops' into fault_tolerant_validation_2
tchaton Sep 16, 2021
194e0cb
Merge branch 'master' into fault_tolerant_validation_2
tchaton Sep 17, 2021
a454393
update on comments
tchaton Sep 17, 2021
db01b85
resolve conflicts
tchaton Sep 17, 2021
7e6dfaf
resolve mypy
tchaton Sep 17, 2021
b540770
Bad merge
carmocca Sep 17, 2021
a785789
update
tchaton Sep 17, 2021
e95995e
upate
tchaton Sep 17, 2021
f3db789
update
tchaton Sep 20, 2021
6932f6e
Merge branch 'master' into fault_tolerant_validation_2
tchaton Sep 20, 2021
905d6ae
remove out-dated comment
tchaton Sep 20, 2021
b49583e
Merge branch 'fault_tolerant_validation_2' of https://github.com/PyTo…
tchaton Sep 20, 2021
6c62118
update on comments
tchaton Sep 20, 2021
9555eb9
update
tchaton Sep 20, 2021
700be0d
resolve tests
tchaton Sep 21, 2021
f88f43c
Merge branch 'master' into fault_tolerant_validation_2
carmocca Sep 21, 2021
6b99afa
Merge branch 'master' into fault_tolerant_validation_2
tchaton Sep 22, 2021
b03983e
resolve test
tchaton Sep 22, 2021
d364528
Merge branch 'fault_tolerant_validation_2' of https://github.com/PyTo…
carmocca Sep 22, 2021
c80ae90
Merge branch 'master' into fault_tolerant_validation_2
carmocca Sep 22, 2021
941fd66
Merge branch 'master' into fault_tolerant_validation_2
carmocca Sep 24, 2021
5b58673
Refactor and actually assert in test
carmocca Sep 24, 2021
7eaad63
Fix loops test
carmocca Sep 24, 2021
88a9c14
Passing loops tests
carmocca Sep 24, 2021
302fbe3
Simplify auto restart test
carmocca Sep 24, 2021
c16a5a4
Merge branch 'master' into fault_tolerant_validation_2
carmocca Sep 24, 2021
5f310a6
Remove changes from other PR
carmocca Sep 24, 2021
b1c6da9
Merge branch 'master' into fault_tolerant_validation_2
carmocca Sep 24, 2021
bec607a
Merge branch 'master' into fault_tolerant_validation_2
carmocca Sep 24, 2021
9fd83b8
Allclose
carmocca Sep 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))
* Added support for restarting within Evaluation Loop ([#9563](https://github.com/PyTorchLightning/pytorch-lightning/pull/9563))
* Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566))
* Support skipping to validation during fitting ([#9681](https://github.com/PyTorchLightning/pytorch-lightning/pull/9681))

Expand Down
28 changes: 26 additions & 2 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Sequence, Union

from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.utilities.auto_restart import reload_dataloader_state_dict
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

Expand All @@ -34,6 +37,8 @@ def __init__(self):
self._results = ResultCollection(training=False)
self._max_batches: Optional[Union[int, Sequence[int]]] = None
self._has_run: bool = False
self._data_fetcher: Optional[AbstractDataFetcher] = None
self._dataloader_state_dict: Dict[str, Any] = None

@property
def num_dataloaders(self) -> int:
Expand Down Expand Up @@ -101,7 +106,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
self._data_fetcher = dataloader = self.trainer.data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)

dl_max_batches = self._max_batches[dataloader_idx]

Expand All @@ -121,6 +128,9 @@ def on_run_end(self) -> List[_OUT_DICT]:
# free memory
self.outputs = []

# drop reference to iterator.
self._data_fetcher = None

# with a single dataloader don't pass a 2D list
if len(outputs) > 0 and self.num_dataloaders == 1:
outputs = outputs[0]
Expand Down Expand Up @@ -167,6 +177,10 @@ def _reload_evaluation_dataloaders(self) -> None:
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader()

if not self.trainer.sanity_checking and self._dataloader_state_dict:
reload_dataloader_state_dict(self.dataloaders[self.current_dataloader_idx], self._dataloader_state_dict)
self._dataloader_state_dict = None

def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
assert self._results is not None
Expand Down Expand Up @@ -239,3 +253,13 @@ def _on_evaluation_epoch_end(self) -> None:
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
if self._data_fetcher is not None and self._data_fetcher.dataloader_iter is not None:
state_dict["dataloader_state_dict"] = asdict(self._data_fetcher.dataloader_iter.previous_state)
return state_dict

def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def teardown(self) -> None:
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
# TODO: update has_completed to its proper value
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
if self.trainer.train_dataloader is not None:
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
return state_dict

def on_load_checkpoint(self, state_dict: Dict) -> None:
Expand Down
37 changes: 2 additions & 35 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_find_fast_forward_samplers,
CaptureIterableDataset,
CaptureMapDataset,
IteratorState,
MergedIteratorState,
patch_dataloader_iterator,
reload_dataloader_state_dict,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -400,37 +397,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader

dataset = dataloader.dataset

# We reload the states before creating the workers
# The specific type of dataset will then decide if the state should be applied before or after
# spawning the workers
if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]

if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)

# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)

elif isinstance(dataset, CaptureIterableDataset):
dataset_dict = {
sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()
}
dataset.load_state_dict(dataset_dict)

else:
raise MisconfigurationException(
"This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
)
reload_dataloader_state_dict(dataloader, state_dict)

# We finally spawned the workers if any.
it = iter(dataloader_to_iter_on)
Expand Down
35 changes: 35 additions & 0 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class FastForwardSampler(Sampler):
Expand Down Expand Up @@ -545,3 +546,37 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
dataloader.collate_fn = partial(
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
)


def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
"""Utility to reload state_dict within dataloader for fault tolerance."""

if not _fault_tolerant_training():
return

dataset = dataloader.dataset

if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]

if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)

# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)

# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)

elif isinstance(dataset, CaptureIterableDataset):
dataset.load_state_dict(
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
)

else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
90 changes: 90 additions & 0 deletions tests/utilities/test_auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import random
import random as python_random
from collections import defaultdict
from collections.abc import Iterable
from contextlib import suppress
from copy import deepcopy
Expand Down Expand Up @@ -970,3 +971,92 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult
for w0, w1 in zip(weights0, weights1):
assert w0 is not w1
assert torch.allclose(w0, w1)


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@RunIf(min_torch="1.7.0")
@pytest.mark.parametrize(
["train_datasets", "val_datasets"],
[
([RandomGetItemDataset], [RandomGetItemDataset]),
([RandomGetItemDataset], [RandomGetItemDataset, RandomGetItemDataset]),
],
)
@pytest.mark.parametrize(
"val_check_interval",
[
pytest.param(
0.5,
marks=pytest.mark.xfail(
reason=(
"TODO: the `train_dataloader` random state overrides the validation state when restarting training"
)
),
),
1.0,
],
)
def test_auto_restart_within_validation_loop(train_datasets, val_datasets, val_check_interval, tmpdir):
n_val_dataloaders = len(val_datasets)
stop_dataloader = n_val_dataloaders - 1
stop_batch = 1

class ValidationLoopTestModel(LightningModule):
def __init__(self, should_fail):
super().__init__()
self.layer = torch.nn.Linear(1, 2)
self.should_fail = should_fail
self.training_batches = []
self.validation_batches = defaultdict(list)

def step(self, batch):
return sum(self.layer(b).sum() for b in batch)

def training_step(self, batch, batch_idx):
self.training_batches.append(batch)
return self.step(batch)

def validation_step(self, batch, batch_idx, dataloader_idx=0):
if self.should_fail and stop_dataloader == dataloader_idx and batch_idx == stop_batch:
raise CustomException
self.validation_batches[dataloader_idx].append(batch)
return self.step(batch)

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def train_dataloader(self):
return [DataLoader(cls(4, 1)) for cls in train_datasets]

def val_dataloader(self):
return [DataLoader(cls(4, 1)) for cls in val_datasets]

def run(should_fail, resume):
if not resume:
seed_everything(42)

model = ValidationLoopTestModel(should_fail)

resume_from_checkpoint = str(tmpdir / ".pl_auto_save.ckpt") if resume else None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=val_check_interval,
num_sanity_val_steps=0,
resume_from_checkpoint=resume_from_checkpoint,
)
if should_fail:
with pytest.raises(CustomException):
trainer.fit(model)
else:
trainer.fit(model)

return model.training_batches, model.validation_batches

total_train_batches, total_val_batches = run(should_fail=False, resume=False)
pre_fail_train_batches, pre_fail_val_batches = run(should_fail=True, resume=False)
post_fail_train_batches, post_fail_val_batches = run(should_fail=False, resume=True)

torch.testing.assert_allclose(total_train_batches, pre_fail_train_batches + post_fail_train_batches)
for k in total_val_batches:
torch.testing.assert_allclose(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k])