Skip to content

Commit

Permalink
[1/n] Add LightningFetcher (#8890)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Aug 16, 2021
1 parent d0efb55 commit 89156b7
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))
* Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))


Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: f
"cls_model,max_diff_speed,max_diff_memory,num_epochs,num_runs",
[
(ParityModuleRNN, 0.05, 0.001, 4, 3),
(ParityModuleMNIST, 0.25, 0.001, 4, 3), # todo: lower this thr
(ParityModuleMNIST, 0.3, 0.001, 4, 3), # todo: lower this thr
pytest.param(ParityModuleCIFAR, 4.0, 0.0002, 2, 2, marks=_MARK_SHORT_BM),
],
)
Expand Down
153 changes: 153 additions & 0 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from typing import Any, Generator, List, Optional, Tuple

from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class AbstractDataFetcher(ABC):

"""
This class is used to control batch fetching flow.
"""

@abstractmethod
def fetching_function(self) -> Generator:
pass

def __init__(
self,
prefetch_batches: int = 0,
) -> None:
if not isinstance(prefetch_batches, int) or (isinstance(prefetch_batches, int) and prefetch_batches < 0):
raise MisconfigurationException("`prefetch_batches` should at least be 0.")

self.prefetch_batches = prefetch_batches + 1

self.dataloader: Optional[Iterable] = None
self.dataloader_iter: Optional[Iterator] = None

self.batches: List
self.fetched: int
self.done: bool
self.has_raised: bool

self.reset()

def setup(self, dataloader: DataLoader, **kwargs) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
raise MisconfigurationException(
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
)
self.dataloader = dataloader

def add_batch(self, batch: Any) -> None:
self.batches.append(batch)

def fetch_batch(self) -> Any:
return self.batches.pop(0)

@property
def loaders(self) -> List[DataLoader]:
if not self.dataloader:
raise MisconfigurationException(
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
)
if isinstance(self.dataloader, CombinedLoader):
loaders = self.dataloader.loaders
elif isinstance(self.dataloader, (tuple, list)):
loaders = self.dataloader
else:
loaders = [self.dataloader]
return loaders

@property
def loader_iters(self) -> List[Iterator]:
if not self.dataloader:
raise MisconfigurationException(
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
)

if not self.dataloader_iter:
raise MisconfigurationException("The dataloader_iter isn't available outside the __iter__ context.")

if isinstance(self.dataloader, CombinedLoader):
loader_iters = self.dataloader_iter.loader_iters
else:
loader_iters = [self.dataloader_iter]
return loader_iters

@property
def state(self) -> Any:
def collect_state(iterator: Iterator):
return iterator.state

return apply_to_collection(self.loader_iters, Iterator, collect_state)

def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
self.reset()
self.dataloader_iter = iter(self.dataloader)
return self.fetching_function()

def reset(self) -> None:
self.batches: List = []
self.fetched: int = 0
self.done: bool = False


class LightningDataFetcher(AbstractDataFetcher):

"""
This class is used to control batch fetching flow.
"""

def fetching_function(self) -> Generator:
self.done = False
while not self.done:
self._prefetching(self.prefetch_batches)

for batch in self.dataloader_iter:
yield_batch = self.fetch_batch()
self.add_batch(batch)
self.fetched += 1
# yield last and has next
yield yield_batch, False

yield from self._consume_prefetched_batches()

def _consume_prefetched_batches(self) -> Generator:
self.done = True
while self.batches:
if len(self.batches) == 1:
yield self.batches.pop(0), True
else:
yield self.batches.pop(0), False

def _prefetching(self, prefetch_batches: int) -> None:
for _ in range(prefetch_batches):
try:
batch = next(self.dataloader_iter)
self.fetched += 1
self.add_batch(batch)
except StopIteration:
break
92 changes: 92 additions & 0 deletions tests/utilities/test_fetching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch import tensor
from torch.utils.data import DataLoader, IterableDataset

from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import LightningDataFetcher


@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_prefetch_iterator(use_combined_loader):
"""Test the LightningDataFetcher with PyTorch IterableDataset."""

class IterDataset(IterableDataset):
def __iter__(self):
yield 1
yield 2
yield 3

for prefetch_batches in range(0, 4):
if use_combined_loader:
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
expected = [
([tensor([1]), tensor([1])], False),
([tensor([2]), tensor([2])], False),
([tensor([3]), tensor([3])], True),
]
else:
loader = DataLoader(IterDataset())
expected = [(1, False), (2, False), (3, True)]
iterator = LightningDataFetcher(prefetch_batches=prefetch_batches)
prefetch_batches += 1
assert iterator.prefetch_batches == prefetch_batches
iterator.setup(loader)

def generate():
generated = []
for idx, data in enumerate(iterator, 1):
if iterator.done:
assert iterator.fetched == 3
else:
assert iterator.fetched == (idx + prefetch_batches)
generated.append(data)
return generated

assert generate() == expected
# validate reset works properly.
assert generate() == expected
assert iterator.fetched == 3

class EmptyIterDataset(IterableDataset):
def __iter__(self):
return iter([])

dataloader = DataLoader(EmptyIterDataset())
iterator = LightningDataFetcher()
iterator.setup(dataloader)
assert list(iterator) == []


def test_misconfiguration_error():

fetcher = LightningDataFetcher()
with pytest.raises(
MisconfigurationException, match="The `DataFetcher` should be setup with an instance of a PyTorch"
):
fetcher.setup(range(10))

fetcher = LightningDataFetcher()
with pytest.raises(
MisconfigurationException, match="The dataloader_iter isn't available outside the __iter__ context."
):
loader = DataLoader(range(10))
fetcher.setup(loader)
assert fetcher.loaders[0] == loader
fetcher.loader_iters

iter(fetcher)
assert fetcher.loader_iters

0 comments on commit 89156b7

Please sign in to comment.