diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ec87f4448d93..1fa08639b0668 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training: * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) + * Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890)) * Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) diff --git a/benchmarks/test_basic_parity.py b/benchmarks/test_basic_parity.py index 3ff602566d2e4..ab3b6ebfb84f4 100644 --- a/benchmarks/test_basic_parity.py +++ b/benchmarks/test_basic_parity.py @@ -51,7 +51,7 @@ def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: f "cls_model,max_diff_speed,max_diff_memory,num_epochs,num_runs", [ (ParityModuleRNN, 0.05, 0.001, 4, 3), - (ParityModuleMNIST, 0.25, 0.001, 4, 3), # todo: lower this thr + (ParityModuleMNIST, 0.3, 0.001, 4, 3), # todo: lower this thr pytest.param(ParityModuleCIFAR, 4.0, 0.0002, 2, 2, marks=_MARK_SHORT_BM), ], ) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py new file mode 100644 index 0000000000000..3e4822bbe665f --- /dev/null +++ b/pytorch_lightning/utilities/fetching.py @@ -0,0 +1,153 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator +from typing import Any, Generator, List, Optional, Tuple + +from torch.utils.data.dataloader import DataLoader + +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class AbstractDataFetcher(ABC): + + """ + This class is used to control batch fetching flow. + """ + + @abstractmethod + def fetching_function(self) -> Generator: + pass + + def __init__( + self, + prefetch_batches: int = 0, + ) -> None: + if not isinstance(prefetch_batches, int) or (isinstance(prefetch_batches, int) and prefetch_batches < 0): + raise MisconfigurationException("`prefetch_batches` should at least be 0.") + + self.prefetch_batches = prefetch_batches + 1 + + self.dataloader: Optional[Iterable] = None + self.dataloader_iter: Optional[Iterator] = None + + self.batches: List + self.fetched: int + self.done: bool + self.has_raised: bool + + self.reset() + + def setup(self, dataloader: DataLoader, **kwargs) -> None: + if not isinstance(dataloader, (DataLoader, CombinedLoader)): + raise MisconfigurationException( + "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." + ) + self.dataloader = dataloader + + def add_batch(self, batch: Any) -> None: + self.batches.append(batch) + + def fetch_batch(self) -> Any: + return self.batches.pop(0) + + @property + def loaders(self) -> List[DataLoader]: + if not self.dataloader: + raise MisconfigurationException( + "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." + ) + if isinstance(self.dataloader, CombinedLoader): + loaders = self.dataloader.loaders + elif isinstance(self.dataloader, (tuple, list)): + loaders = self.dataloader + else: + loaders = [self.dataloader] + return loaders + + @property + def loader_iters(self) -> List[Iterator]: + if not self.dataloader: + raise MisconfigurationException( + "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." + ) + + if not self.dataloader_iter: + raise MisconfigurationException("The dataloader_iter isn't available outside the __iter__ context.") + + if isinstance(self.dataloader, CombinedLoader): + loader_iters = self.dataloader_iter.loader_iters + else: + loader_iters = [self.dataloader_iter] + return loader_iters + + @property + def state(self) -> Any: + def collect_state(iterator: Iterator): + return iterator.state + + return apply_to_collection(self.loader_iters, Iterator, collect_state) + + def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: + if self.dataloader is None: + raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") + self.reset() + self.dataloader_iter = iter(self.dataloader) + return self.fetching_function() + + def reset(self) -> None: + self.batches: List = [] + self.fetched: int = 0 + self.done: bool = False + + +class LightningDataFetcher(AbstractDataFetcher): + + """ + This class is used to control batch fetching flow. + """ + + def fetching_function(self) -> Generator: + self.done = False + while not self.done: + self._prefetching(self.prefetch_batches) + + for batch in self.dataloader_iter: + yield_batch = self.fetch_batch() + self.add_batch(batch) + self.fetched += 1 + # yield last and has next + yield yield_batch, False + + yield from self._consume_prefetched_batches() + + def _consume_prefetched_batches(self) -> Generator: + self.done = True + while self.batches: + if len(self.batches) == 1: + yield self.batches.pop(0), True + else: + yield self.batches.pop(0), False + + def _prefetching(self, prefetch_batches: int) -> None: + for _ in range(prefetch_batches): + try: + batch = next(self.dataloader_iter) + self.fetched += 1 + self.add_batch(batch) + except StopIteration: + break diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py new file mode 100644 index 0000000000000..323245094cd3a --- /dev/null +++ b/tests/utilities/test_fetching.py @@ -0,0 +1,92 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch import tensor +from torch.utils.data import DataLoader, IterableDataset + +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import LightningDataFetcher + + +@pytest.mark.parametrize("use_combined_loader", [False, True]) +def test_prefetch_iterator(use_combined_loader): + """Test the LightningDataFetcher with PyTorch IterableDataset.""" + + class IterDataset(IterableDataset): + def __iter__(self): + yield 1 + yield 2 + yield 3 + + for prefetch_batches in range(0, 4): + if use_combined_loader: + loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())]) + expected = [ + ([tensor([1]), tensor([1])], False), + ([tensor([2]), tensor([2])], False), + ([tensor([3]), tensor([3])], True), + ] + else: + loader = DataLoader(IterDataset()) + expected = [(1, False), (2, False), (3, True)] + iterator = LightningDataFetcher(prefetch_batches=prefetch_batches) + prefetch_batches += 1 + assert iterator.prefetch_batches == prefetch_batches + iterator.setup(loader) + + def generate(): + generated = [] + for idx, data in enumerate(iterator, 1): + if iterator.done: + assert iterator.fetched == 3 + else: + assert iterator.fetched == (idx + prefetch_batches) + generated.append(data) + return generated + + assert generate() == expected + # validate reset works properly. + assert generate() == expected + assert iterator.fetched == 3 + + class EmptyIterDataset(IterableDataset): + def __iter__(self): + return iter([]) + + dataloader = DataLoader(EmptyIterDataset()) + iterator = LightningDataFetcher() + iterator.setup(dataloader) + assert list(iterator) == [] + + +def test_misconfiguration_error(): + + fetcher = LightningDataFetcher() + with pytest.raises( + MisconfigurationException, match="The `DataFetcher` should be setup with an instance of a PyTorch" + ): + fetcher.setup(range(10)) + + fetcher = LightningDataFetcher() + with pytest.raises( + MisconfigurationException, match="The dataloader_iter isn't available outside the __iter__ context." + ): + loader = DataLoader(range(10)) + fetcher.setup(loader) + assert fetcher.loaders[0] == loader + fetcher.loader_iters + + iter(fetcher) + assert fetcher.loader_iters