From 373c32e34b3402bd1aa322cf22039669c72c685c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 2 Nov 2021 11:40:35 +0100 Subject: [PATCH] Fix yielding from iterator in LiteDataLoader (#10304) * fix yielding form iterator * update description * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused code Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pytorch_lightning/lite/wrappers.py | 6 +++--- tests/lite/test_wrappers.py | 32 ++++++++++++++++++++---------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 8b6f072c57adc..ad01b44ef30f4 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -174,9 +174,9 @@ def device(self) -> Optional[torch.device]: return self._device def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: - dataloader_iter = iter(self._dataloader) + iterator = iter(self._dataloader) if self._device is None: - return dataloader_iter + yield from iterator - for item in dataloader_iter: + for item in iterator: yield move_data_to_device(item, self._device) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 8fc4f7e9c6e53..6741bf59b4dca 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -60,6 +60,27 @@ def check_autocast(forward_input): assert out.dtype == torch.get_default_dtype() +def test_lite_dataloader_iterator(): + """Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic + device placement).""" + dataloader = DataLoader(range(5), batch_size=2) + lite_dataloader = _LiteDataLoader(dataloader) + assert len(lite_dataloader) == len(dataloader) == 3 + + iterator = iter(dataloader) + lite_iterator = iter(lite_dataloader) + + assert torch.equal(next(iterator), next(lite_iterator)) + assert torch.equal(next(iterator), next(lite_iterator)) + assert torch.equal(next(iterator), next(lite_iterator)) + + with pytest.raises(StopIteration): + next(iterator) + + with pytest.raises(StopIteration): + next(lite_iterator) + + @pytest.mark.parametrize( "src_device, dest_device", [ @@ -84,17 +105,6 @@ def test_lite_dataloader_device_placement(src_device, dest_device): batch1 = next(iterator) assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) - with pytest.raises(StopIteration): - batch1 = next(iterator) - - lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device) - iterator = iter(lite_dataloader) - - batch0 = next(iterator) - assert batch0 == 0 - - assert len(lite_dataloader) == 4 - def test_lite_optimizer_wraps(): """Test that the LiteOptimizer fully wraps the optimizer."""