Skip to content

Commit

Permalink
Fix yielding from iterator in LiteDataLoader (#10304)
Browse files Browse the repository at this point in the history
* fix yielding form iterator

* update description

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused code

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored Nov 2, 2021
1 parent f44b8a7 commit 373c32e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
6 changes: 3 additions & 3 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def device(self) -> Optional[torch.device]:
return self._device

def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
dataloader_iter = iter(self._dataloader)
iterator = iter(self._dataloader)
if self._device is None:
return dataloader_iter
yield from iterator

for item in dataloader_iter:
for item in iterator:
yield move_data_to_device(item, self._device)
32 changes: 21 additions & 11 deletions tests/lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,27 @@ def check_autocast(forward_input):
assert out.dtype == torch.get_default_dtype()


def test_lite_dataloader_iterator():
"""Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic
device placement)."""
dataloader = DataLoader(range(5), batch_size=2)
lite_dataloader = _LiteDataLoader(dataloader)
assert len(lite_dataloader) == len(dataloader) == 3

iterator = iter(dataloader)
lite_iterator = iter(lite_dataloader)

assert torch.equal(next(iterator), next(lite_iterator))
assert torch.equal(next(iterator), next(lite_iterator))
assert torch.equal(next(iterator), next(lite_iterator))

with pytest.raises(StopIteration):
next(iterator)

with pytest.raises(StopIteration):
next(lite_iterator)


@pytest.mark.parametrize(
"src_device, dest_device",
[
Expand All @@ -84,17 +105,6 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
batch1 = next(iterator)
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))

with pytest.raises(StopIteration):
batch1 = next(iterator)

lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device)
iterator = iter(lite_dataloader)

batch0 = next(iterator)
assert batch0 == 0

assert len(lite_dataloader) == 4


def test_lite_optimizer_wraps():
"""Test that the LiteOptimizer fully wraps the optimizer."""
Expand Down

0 comments on commit 373c32e

Please sign in to comment.