Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3/n integrate new LightningDataFetcher into loop #8953

Merged
merged 54 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
1488879
add LightningFetcher
tchaton Aug 13, 2021
9a5037a
add lightning fetcher
tchaton Aug 13, 2021
6e6e93c
update changelog
tchaton Aug 13, 2021
f4c99a8
typying
tchaton Aug 13, 2021
4412855
add fault tolerant
tchaton Aug 13, 2021
be899aa
Merge branch 'master' into add_lightning_prefetcher_2_n
tchaton Aug 16, 2021
5c54e95
bad merge
tchaton Aug 16, 2021
29c7938
remove prints
tchaton Aug 16, 2021
d1789c8
update
tchaton Aug 16, 2021
3d81454
remove random code
tchaton Aug 16, 2021
64ad33d
fix docstrings and typing
awaelchli Aug 16, 2021
9e1c8a6
resolve todo, rename metadata collate function
awaelchli Aug 16, 2021
91bd840
general cleanup
awaelchli Aug 16, 2021
3ae2a43
fix typo in comment
awaelchli Aug 17, 2021
3ad3afc
update changelog
awaelchli Aug 17, 2021
dd7fc13
remove unused code in apply_to_collection
awaelchli Aug 17, 2021
9579432
Merge branch 'master' into thomas/add_lightning_prefetcher_2_n
awaelchli Aug 17, 2021
4e8697e
random state
awaelchli Aug 17, 2021
e5bb75f
clean up
awaelchli Aug 17, 2021
e65f523
clean out non-global random state (will come in future PR)
awaelchli Aug 17, 2021
909f8ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2021
c23e740
clean out debug statements
awaelchli Aug 17, 2021
dc9525c
fix import
awaelchli Aug 17, 2021
163c486
update data fetcher
awaelchli Aug 17, 2021
6267d01
update cycle iterator code
awaelchli Aug 17, 2021
ce41f56
remove unused function
awaelchli Aug 17, 2021
84aed29
remove unused test
awaelchli Aug 17, 2021
44424c9
remove redundant fix
awaelchli Aug 17, 2021
302e39d
remove state_dict set to None
awaelchli Aug 17, 2021
7af4273
revert type hint Any -> int
awaelchli Aug 17, 2021
7e76efc
rename lastest -> latest
awaelchli Aug 17, 2021
2041809
reword exception message
awaelchli Aug 17, 2021
09787bd
Merge branch 'thomas/add_lightning_prefetcher_2_n' into thomas/fault-…
awaelchli Aug 17, 2021
a3488ab
update type hint
awaelchli Aug 17, 2021
6414ced
remove my own todo
awaelchli Aug 17, 2021
ff6d5ca
remove my own todo
awaelchli Aug 17, 2021
f9b326d
Merge branch 'thomas/add_lightning_prefetcher_2_n' into thomas/fault-…
awaelchli Aug 17, 2021
4254558
fix latest worker id
awaelchli Aug 17, 2021
962ec6d
revert
awaelchli Aug 17, 2021
de4012b
update changelog
awaelchli Aug 17, 2021
15e9091
update init
awaelchli Aug 17, 2021
1b3e67d
fix import
awaelchli Aug 17, 2021
17de7a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2021
d996766
Merge branch 'master' into thomas/integrate-fetcher
awaelchli Aug 17, 2021
47b67b0
update changelog
awaelchli Aug 17, 2021
fd85bee
resolve tests
tchaton Aug 17, 2021
4fb8922
resolve failing test
tchaton Aug 17, 2021
39b0b52
update on comments
tchaton Aug 17, 2021
5ef1080
update
tchaton Aug 17, 2021
e893807
simplificition
tchaton Aug 17, 2021
3a70fb8
update
tchaton Aug 17, 2021
9483967
resolve typing
tchaton Aug 17, 2021
35e02f3
resolve typing
tchaton Aug 17, 2021
708df73
revert backa and resolve bug
tchaton Aug 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))
* Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `DataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))

- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

Expand Down Expand Up @@ -98,7 +99,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
data_fetcher = DataFetcher()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
data_fetcher.setup(dataloader)
dataloader_iter = enumerate(data_fetcher)
dl_max_batches = self._max_batches[self.current_dataloader_idx]

dl_outputs = self.epoch_loop.run(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def advance(
"""
void(dl_max_batches, num_dataloaders)

batch_idx, batch = next(dataloader_iter)
batch_idx, (batch, _) = next(dataloader_iter)

if batch is None:
raise StopIteration
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/profiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TextIO, Union
from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union

from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -96,7 +96,7 @@ def profile(self, action_name: str) -> None:
finally:
self.stop(action_name)

def profile_iterable(self, iterable, action_name: str) -> None:
def profile_iterable(self, iterable, action_name: str) -> Generator:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
iterator = iter(iterable)
while True:
try:
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from typing import Callable, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

Expand All @@ -26,6 +26,7 @@ class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self.prefetcher: Optional[DataFetcher] = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def on_trainer_init(
self,
Expand Down Expand Up @@ -60,9 +61,10 @@ def on_trainer_init(
self.trainer._is_data_prepared = False

def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
)
self.prefetcher = DataFetcher()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.prefetcher.setup(train_dataloader)
prefetcher_iter = iter(self.prefetcher)
profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch")
return profiled_dl

def prepare_data(self) -> None:
Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.utilities.auto_restart import (
_capture_metadata_collate,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
)
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
Expand Down Expand Up @@ -246,14 +247,17 @@ def _get_dataloader_init_kwargs(
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
dl_kwargs["sampler"] = None

if isinstance(dl_kwargs["dataset"], IterableDataset):
del dl_kwargs["sampler"]
del dl_kwargs["batch_sampler"]
if _fault_tolerant_enabled():
if isinstance(dl_kwargs["dataset"], IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
dl_kwargs["sampler"] = None
elif len(dl_kwargs["dataset"]):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
else:
raise MisconfigurationException(
"This shouldn't happen, please open an issue on Lightning Github repository."
)

return dl_kwargs

Expand Down
22 changes: 0 additions & 22 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,25 +592,3 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable
new_data.append(x)

return compute_func(new_data)


def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]:
"""
Returns an iterator that pre-fetches and caches the next item.
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
"""
it = iter(iterable)

try:
# the iterator may be empty from the beginning
last = next(it)
except StopIteration:
return

for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True
4 changes: 2 additions & 2 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled


class AbstractFetcher(ABC):
class AbstractDataFetcher(ABC):

"""
This class is used to control batch fetching flow.
Expand Down Expand Up @@ -161,7 +161,7 @@ def reset(self) -> None:
self.done: bool = False


class LightningDataFetcher(AbstractFetcher):
class DataFetcher(AbstractDataFetcher):

"""
This class is used to control batch fetching flow.
Expand Down
23 changes: 0 additions & 23 deletions tests/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
CombinedLoader,
CombinedLoaderIterator,
CycleIterator,
prefetch_iterator,
TensorRunningAccum,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -80,28 +79,6 @@ def test_none_length_cycle_iterator():
assert item == 0


def test_prefetch_iterator():
"""Test the prefetch_iterator with PyTorch IterableDataset."""

class IterDataset(IterableDataset):
def __iter__(self):
yield 1
yield 2
yield 3

dataset = IterDataset()
iterator = prefetch_iterator(dataset)
assert list(iterator) == [(1, False), (2, False), (3, True)]

class EmptyIterDataset(IterableDataset):
def __iter__(self):
return iter([])

dataset = EmptyIterDataset()
iterator = prefetch_iterator(dataset)
assert list(iterator) == []


@pytest.mark.parametrize(
["dataset_1", "dataset_2"],
[
Expand Down
12 changes: 6 additions & 6 deletions tests/utilities/test_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import LightningDataFetcher
from pytorch_lightning.utilities.fetching import DataFetcher


@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_prefetch_iterator(use_combined_loader):
"""Test the LightningDataFetcher with PyTorch IterableDataset."""
"""Test the DataFetcher with PyTorch IterableDataset."""

class IterDataset(IterableDataset):
def __iter__(self):
Expand All @@ -41,7 +41,7 @@ def __iter__(self):
else:
loader = DataLoader(IterDataset())
expected = [(1, False), (2, False), (3, True)]
iterator = LightningDataFetcher(prefetch_batches=prefetch_batches)
iterator = DataFetcher(prefetch_batches=prefetch_batches)
prefetch_batches += 1
assert iterator.prefetch_batches == prefetch_batches
iterator.setup(loader)
Expand All @@ -66,21 +66,21 @@ def __iter__(self):
return iter([])

dataloader = DataLoader(EmptyIterDataset())
iterator = LightningDataFetcher()
iterator = DataFetcher()
iterator.setup(dataloader)
assert list(iterator) == []


def test_misconfiguration_error():

fetcher = LightningDataFetcher()
fetcher = DataFetcher()
with pytest.raises(
MisconfigurationException,
match="The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``.",
):
fetcher.setup(range(10))

fetcher = LightningDataFetcher()
fetcher = DataFetcher()
with pytest.raises(
MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context."
):
Expand Down