Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3/n integrate new LightningDataFetcher into loop #8953

Merged
merged 54 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
1488879
add LightningFetcher
tchaton Aug 13, 2021
9a5037a
add lightning fetcher
tchaton Aug 13, 2021
6e6e93c
update changelog
tchaton Aug 13, 2021
f4c99a8
typying
tchaton Aug 13, 2021
4412855
add fault tolerant
tchaton Aug 13, 2021
be899aa
Merge branch 'master' into add_lightning_prefetcher_2_n
tchaton Aug 16, 2021
5c54e95
bad merge
tchaton Aug 16, 2021
29c7938
remove prints
tchaton Aug 16, 2021
d1789c8
update
tchaton Aug 16, 2021
3d81454
remove random code
tchaton Aug 16, 2021
64ad33d
fix docstrings and typing
awaelchli Aug 16, 2021
9e1c8a6
resolve todo, rename metadata collate function
awaelchli Aug 16, 2021
91bd840
general cleanup
awaelchli Aug 16, 2021
3ae2a43
fix typo in comment
awaelchli Aug 17, 2021
3ad3afc
update changelog
awaelchli Aug 17, 2021
dd7fc13
remove unused code in apply_to_collection
awaelchli Aug 17, 2021
9579432
Merge branch 'master' into thomas/add_lightning_prefetcher_2_n
awaelchli Aug 17, 2021
4e8697e
random state
awaelchli Aug 17, 2021
e5bb75f
clean up
awaelchli Aug 17, 2021
e65f523
clean out non-global random state (will come in future PR)
awaelchli Aug 17, 2021
909f8ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2021
c23e740
clean out debug statements
awaelchli Aug 17, 2021
dc9525c
fix import
awaelchli Aug 17, 2021
163c486
update data fetcher
awaelchli Aug 17, 2021
6267d01
update cycle iterator code
awaelchli Aug 17, 2021
ce41f56
remove unused function
awaelchli Aug 17, 2021
84aed29
remove unused test
awaelchli Aug 17, 2021
44424c9
remove redundant fix
awaelchli Aug 17, 2021
302e39d
remove state_dict set to None
awaelchli Aug 17, 2021
7af4273
revert type hint Any -> int
awaelchli Aug 17, 2021
7e76efc
rename lastest -> latest
awaelchli Aug 17, 2021
2041809
reword exception message
awaelchli Aug 17, 2021
09787bd
Merge branch 'thomas/add_lightning_prefetcher_2_n' into thomas/fault-…
awaelchli Aug 17, 2021
a3488ab
update type hint
awaelchli Aug 17, 2021
6414ced
remove my own todo
awaelchli Aug 17, 2021
ff6d5ca
remove my own todo
awaelchli Aug 17, 2021
f9b326d
Merge branch 'thomas/add_lightning_prefetcher_2_n' into thomas/fault-…
awaelchli Aug 17, 2021
4254558
fix latest worker id
awaelchli Aug 17, 2021
962ec6d
revert
awaelchli Aug 17, 2021
de4012b
update changelog
awaelchli Aug 17, 2021
15e9091
update init
awaelchli Aug 17, 2021
1b3e67d
fix import
awaelchli Aug 17, 2021
17de7a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2021
d996766
Merge branch 'master' into thomas/integrate-fetcher
awaelchli Aug 17, 2021
47b67b0
update changelog
awaelchli Aug 17, 2021
fd85bee
resolve tests
tchaton Aug 17, 2021
4fb8922
resolve failing test
tchaton Aug 17, 2021
39b0b52
update on comments
tchaton Aug 17, 2021
5ef1080
update
tchaton Aug 17, 2021
e893807
simplificition
tchaton Aug 17, 2021
3a70fb8
update
tchaton Aug 17, 2021
9483967
resolve typing
tchaton Aug 17, 2021
35e02f3
resolve typing
tchaton Aug 17, 2021
708df73
revert backa and resolve bug
tchaton Aug 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))
* Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `DataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))

- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

Expand Down Expand Up @@ -98,7 +99,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
data_fetcher = DataFetcher()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
data_fetcher.setup(dataloader)
dataloader_iter = enumerate(data_fetcher)
dl_max_batches = self._max_batches[self.current_dataloader_idx]

dl_outputs = self.epoch_loop.run(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def advance(
"""
void(dl_max_batches, num_dataloaders)

batch_idx, batch = next(dataloader_iter)
batch_idx, (batch, _) = next(dataloader_iter)

if batch is None:
raise StopIteration
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/profiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TextIO, Union
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union

from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self._stage: Optional[str] = None

@contextmanager
def profile(self, action_name: str) -> None:
def profile(self, action_name: str) -> Generator:
"""
Yields a context manager to encapsulate the scope of a profiled action.

Expand All @@ -96,7 +96,7 @@ def profile(self, action_name: str) -> None:
finally:
self.stop(action_name)

def profile_iterable(self, iterable, action_name: str) -> None:
def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
iterator = iter(iterable)
while True:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _OMEGACONF_AVAILABLE:
Expand Down Expand Up @@ -348,7 +348,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
}
if _fault_tolerant_enabled():
if _fault_tolerant_training():
checkpoint["loops"] = self._get_loops_state_dict()

if not weights_only:
Expand Down Expand Up @@ -451,7 +451,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metrics = (
[m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)]
if _fault_tolerant_enabled()
if _fault_tolerant_training()
else []
)

Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Union
from typing import Callable, Iterable, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

Expand All @@ -26,6 +26,7 @@ class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self.data_fetcher: Optional[DataFetcher] = None

def on_trainer_init(
self,
Expand Down Expand Up @@ -59,10 +60,11 @@ def on_trainer_init(
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
)
def get_profiled_train_dataloader(self, train_dataloader) -> Iterable:
self.data_fetcher = DataFetcher()
self.data_fetcher.setup(train_dataloader)
prefetcher_iter = iter(self.data_fetcher)
profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch")
return profiled_dl

def prepare_data(self) -> None:
Expand Down
27 changes: 17 additions & 10 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
from pytorch_lightning.utilities.auto_restart import (
_capture_metadata_collate,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
)
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function

Expand Down Expand Up @@ -168,7 +169,7 @@ def _resolve_batch_sampler(
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

if _fault_tolerant_enabled():
if _fault_tolerant_training():
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)

Expand All @@ -180,7 +181,7 @@ def _resolve_batch_sampler(
"drop_last": False,
}

if _fault_tolerant_enabled():
if _fault_tolerant_training():
fast_forward_sampler = sampler = FastForwardSampler(sampler)
fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

Expand Down Expand Up @@ -246,14 +247,20 @@ def _get_dataloader_init_kwargs(
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
if isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None

if isinstance(dl_kwargs["dataset"], IterableDataset):
del dl_kwargs["sampler"]
del dl_kwargs["batch_sampler"]
if _fault_tolerant_training():
if isinstance(dl_kwargs["dataset"], IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
elif len(dl_kwargs["dataset"]):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
else:
raise MisconfigurationException(
"This shouldn't happen, please open an issue on Lightning Github repository."
)

return dl_kwargs

Expand Down Expand Up @@ -308,7 +315,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_enabled():
if _fault_tolerant_training():
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
Expand Down
30 changes: 4 additions & 26 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch.utils.data import Dataset
Expand All @@ -30,7 +30,7 @@
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class TensorRunningAccum:
Expand Down Expand Up @@ -375,7 +375,7 @@ def state_dict(self, num_batches_processed: int) -> Dict:
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
may have already prefetched more batches by the time a state dict is requested.
"""
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return DataLoaderDict()

state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
Expand Down Expand Up @@ -541,7 +541,7 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any:

def next_fn(iterator: Iterator):
batch = next(iterator)
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return batch
# when fault tolerant is enabled, the iterator will return
# `FastForwardSampler` state_dict metadata
Expand Down Expand Up @@ -592,25 +592,3 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable
new_data.append(x)

return compute_func(new_data)


def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]:
"""
Returns an iterator that pre-fetches and caches the next item.
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
"""
it = iter(iterable)

try:
# the iterator may be empty from the beginning
last = next(it)
except StopIteration:
return

for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.seed import reset_seed
Expand Down Expand Up @@ -1344,7 +1344,7 @@ def _log_device_info(self) -> None:
)

def _on_exception(self):
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset

import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -515,7 +516,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:


def patch_dataloader_iterator(
dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0
dataloader: DataLoader,
iterator: Iterator,
data_fecher: "pl.utilities.fetching.DataFetcher",
num_batches_fetched: int = 0,
) -> None:
assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))

Expand Down Expand Up @@ -554,7 +558,7 @@ def wrapper():
num_batches_fetched=num_batches_fetched,
)
]
prefetcher._store_dataloader_iter_state(it, state)
data_fecher._store_dataloader_iter_state(it, state)
return batch

return wrapper
Expand Down
29 changes: 19 additions & 10 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class AbstractFetcher(ABC):
class AbstractDataFetcher(ABC):

"""
This class is used to control batch fetching flow.
Expand Down Expand Up @@ -61,13 +61,22 @@ def __init__(
self.reset()

def setup(self, dataloader: DataLoader, **kwargs) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
raise MisconfigurationException(
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
)
self._add_capture_metadata_collate(dataloader)
self.dataloader = dataloader
if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial):
_add_capture_metadata_collate(dataloader)

@staticmethod
def _add_capture_metadata_collate(dataloader: Iterable) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
return

if isinstance(dataloader, CombinedLoader):
dataloader = dataloader.loaders

def add_capture_metadata_collate(dataloader: DataLoader):
if not isinstance(dataloader.collate_fn, partial):
_add_capture_metadata_collate(dataloader)

apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate)

def add_batch(self, batch) -> None:
self.batches.append(batch)
Expand All @@ -82,7 +91,7 @@ def _apply_patch_fn(loader: DataLoader, iterator: Iterator):
# cycle_iterator = iterator
iterator = iterator._loader_iter

if isinstance(loader, DataLoader) and _fault_tolerant_enabled():
if isinstance(loader, DataLoader) and _fault_tolerant_training():
loader._lightning_fetcher = self
patch_dataloader_iterator(loader, iterator, self)

Expand Down Expand Up @@ -161,7 +170,7 @@ def reset(self) -> None:
self.done: bool = False


class LightningDataFetcher(AbstractFetcher):
class DataFetcher(AbstractDataFetcher):

"""
This class is used to control batch fetching flow.
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def _compare_version(package: str, op, version) -> bool:
_IPU_AVAILABLE = False


def _fault_tolerant_enabled() -> bool:
"""
EXPERIMENTAL
the `reset` function from `_MultiProcessingDataLoaderIter` was introduced in PyTorch 1.7 but we need to mock it.
"""
# experimental feature within PyTorch Lightning.
def _fault_tolerant_training() -> bool:
return _TORCH_GREATER_EQUAL_1_7 and int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0))
Loading