Skip to content

Commit

Permalink
Add typing to data fetching
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 17, 2022
1 parent d50b48e commit b8b5e72
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 74 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ module = [
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.distributed",
"pytorch_lightning.utilities.fetching",
"pytorch_lightning.utilities.memory",
"pytorch_lightning.utilities.meta",
]
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Dict, Iterator, Optional, Union

from deprecate import void
from torch.utils.data import DataLoader

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.trainer.progress import BatchProgress
Expand Down Expand Up @@ -194,7 +195,7 @@ def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> No
"Reloading support hasn't been implemented for `CombinedLoader`. You can request it by opening an issue"
" in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
)
assert dataloader is not None
assert isinstance(dataloader, DataLoader)
_reload_dataloader_state_dict(dataloader, self._dataloader_state_dict)
self._dataloader_state_dict = {}

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def wrapper() -> Any:
def patch_dataloader_iterator(
dataloader: DataLoader,
iterator: Iterator,
data_fetcher: "pl.utilities.fetching.DataFetcher",
data_fetcher: "pl.utilities.fetching.AbstractDataFetcher",
num_batches_fetched: int = 0,
) -> None:
"""Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is
Expand Down
123 changes: 54 additions & 69 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from copy import deepcopy
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple

import torch
from torch.utils.data.dataloader import DataLoader
Expand Down Expand Up @@ -58,28 +58,28 @@ def fetching_function(self) -> Any:
def prefetching(self) -> None:
"""Override with your own pre-fetching logic."""

def on_fetch_start(self) -> Any:
"""Hook to override to handle the logic before fetching a batch."""

def on_fetch_end(self, batch: Any, start_output: Any) -> None:
"""Hook to extend which handles the logic after fetching a batch."""

def wait(self) -> None:
"""Hook to override to indicate the `DataFetcher` to wait for an event."""

def __init__(self, prefetch_batches: int = 0) -> None:
if prefetch_batches < 0:
raise MisconfigurationException("`prefetch_batches` should at least be 0.")
self.prefetch_batches = prefetch_batches

self.dataloader: Optional[Union[DataLoader, CombinedLoader]] = None
self.dataloader: Optional[Iterable] = None
self.dataloader_iter: Optional[Iterator] = None

self.batch_to_device: Optional[Callable]

self.batches: List
self.fetched: int
self.done: bool

self.batch_to_device: Optional[Callable] = None
self.reset()

def setup(self, dataloader: Iterable, batch_to_device: Optional[Callable] = None) -> None:
self._add_capture_metadata_collate(dataloader)

self.dataloader = dataloader
self.batch_to_device = batch_to_device

self._attach_data_fetcher()

@staticmethod
Expand All @@ -92,8 +92,8 @@ def _add_capture_metadata_collate(dataloader: Iterable) -> None:

apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate)

def _apply_patch(self):
def _apply_patch_fn(loader: DataLoader, iterator: Iterator):
def _apply_patch(self) -> None:
def _apply_patch_fn(loader: DataLoader, iterator: Iterator) -> None:
if isinstance(loader, CycleIterator):
loader = loader.loader
# cycle_iterator = iterator
Expand Down Expand Up @@ -158,13 +158,13 @@ def loader_iters(self) -> List[Iterator]:

@property
def state(self) -> Any:
def collect_state(iterator: Iterator):
def collect_state(iterator: Iterator) -> Any:
return iterator.state

return apply_to_collection(self.loader_iters, Iterator, collect_state)

def _attach_data_fetcher(self):
def _attach_data_fetcher_fn(loader: DataLoader):
def _attach_data_fetcher(self) -> None:
def _attach_data_fetcher_fn(loader: DataLoader) -> None:
if isinstance(loader, CycleIterator):
loader = loader.loader

Expand All @@ -173,7 +173,7 @@ def _attach_data_fetcher_fn(loader: DataLoader):

apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn)

def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
def __iter__(self) -> "AbstractDataFetcher":
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
self.reset()
Expand All @@ -184,11 +184,11 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
self.prefetching()
return self

def __next__(self):
def __next__(self) -> Any:
return self.fetching_function()

def reset(self) -> None:
self.batches: List = []
self.batches: List[Any] = []
self.fetched: int = 0
self.done: bool = False

Expand Down Expand Up @@ -217,55 +217,43 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N
super().__init__(prefetch_batches=prefetch_batches)
self.store_on_device = store_on_device

def on_fetch_start(self) -> None:
"""Hook to override to handle the logic before fetching a batch."""

def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> None:
def on_fetch_end(self, batch: Any, start_output: Any) -> None:
"""Hook to extend which handles the logic after fetching a batch."""
self.batches.append(batch)

def wait(self) -> None:
"""Hook to override to indicate the `DataFetcher` to wait for an event."""

def prefetching(self) -> None:
iterator = self.dataloader_iter
assert iterator is not None
for _ in range(self.prefetch_batches):
try:
self._fetch_next_batch()
self._fetch_next_batch(iterator)
except StopIteration:
break

def fetching_function(self) -> Optional[Tuple[Any, bool]]:
if self.done:
while self.batches:
return self._get_queued_batch()
raise StopIteration
else:
def fetching_function(self) -> Tuple[Any, bool]:
if self.done and self.batches:
# no more batches to fetch, return all prefetched
batch = self.batches.pop(0)
self.wait()
return self.move_to_device(batch), len(self.batches) == 0
if self.batches:
batch = self.batches.pop(0)
assert self.dataloader_iter is not None
try:
yield_batch = self.batches.pop(0)
self._fetch_next_batch()
# wait for batch to be available.
self.wait()
# yield last and has next
return self.move_to_device(yield_batch), False
self._fetch_next_batch(self.dataloader_iter)
except StopIteration:
self.batches.insert(0, yield_batch)
self.done = True
return self._get_queued_batch()

except IndexError:
raise StopIteration
finally:
self.wait()
return self.move_to_device(batch), len(self.batches) == 0
# break empty iterators
raise StopIteration

def _fetch_next_batch(self):
data = self.on_fetch_start()
batch = next(self.dataloader_iter)
def _fetch_next_batch(self, iterator: Iterator) -> None:
start_output = self.on_fetch_start()
batch = next(iterator)
self.fetched += 1
self.on_fetch_end(batch, data)

def _get_queued_batch(self) -> Tuple[Any, bool]:
batch = self.batches.pop(0)
is_last = len(self.batches) == 0
self.wait()
return self.move_to_device(batch), is_last
self.on_fetch_end(batch, start_output)

def move_to_device(self, batch: Any) -> Any:
if self.store_on_device and self.batch_to_device is not None:
Expand Down Expand Up @@ -293,23 +281,21 @@ class InterBatchParallelDataFetcher(DataFetcher):
batch 2: [HtoD] [forward][backward]
"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.cuda_stream = torch.cuda.Stream()
self.events: List[torch.cuda.Event] = []

def move_to_device(self, batch):
def move_to_device(self, batch: Any) -> Any:
with torch.cuda.stream(self.cuda_stream):
return super().move_to_device(batch)

def on_fetch_start(self) -> "torch.cuda.Event":
# create a cuda event used to record the async stream of data to device.
return torch.cuda.Event()

def on_fetch_end(self, batch, event: torch.cuda.Event) -> None:
super().on_fetch_end(batch)

# record event and store the event
def on_fetch_end(self, batch: Any, event: torch.cuda.Event) -> None:
self.batches.append(batch)
event.record()
self.events.append(event)

Expand All @@ -319,7 +305,7 @@ def wait(self) -> None:
event.wait()


class StepFuncDataLoaderIter:
class StepFuncDataLoaderIter(Iterator):

"""This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user
control."""
Expand All @@ -328,9 +314,6 @@ def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"):
self.iterator = iterator
self.data_fetcher = data_fetcher

def __iter__(self) -> "StepFuncDataLoaderIter":
return self

def __next__(self) -> Any:
try:
data = next(self.iterator)
Expand Down Expand Up @@ -360,14 +343,16 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
...
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.store_on_device = False

def prefetching(self) -> None:
self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))
iterator = self.dataloader_iter
assert iterator is not None
self.iterator = iter(StepFuncDataLoaderIter(iterator, self))

def fetching_function(self):
while not self.done:
def fetching_function(self) -> Tuple[int, Tuple[Iterator, bool]]:
if not self.done:
return self.fetched, (self.iterator, self.done)
raise StopIteration
4 changes: 2 additions & 2 deletions tests/utilities/test_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def __iter__(self):

def generate():
generated = []
for idx, data in enumerate(iterator, 1):
for idx, data in enumerate(iterator, prefetch_batches + 1):
if iterator.done:
assert iterator.fetched == 3
else:
assert iterator.fetched == (idx + prefetch_batches)
assert iterator.fetched == idx
generated.append(data)
return generated

Expand Down

0 comments on commit b8b5e72

Please sign in to comment.