From a8f33512ae5debec01871740a5ea6c72e906e6b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 28 Aug 2023 15:49:03 +0200 Subject: [PATCH] Avoid expensive `iter()` call to dataloader in dataloader checks (#18415) (cherry picked from commit d77132b20edd4fa61ad5f62e1b4d383b702201d9) --- src/lightning/pytorch/CHANGELOG.md | 3 +++ .../trainer/connectors/data_connector.py | 4 ++++ tests/tests_pytorch/loops/test_loops.py | 8 ------- .../trainer/connectors/test_data_connector.py | 21 ++++++++++++++++++- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index de5c769cacc8b..42ae22dd12d6c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed setting the tracking uri in `MLFlowLogger` for logging artifacts to the MLFlow server ([#18395](https://github.com/Lightning-AI/lightning/pull/18395)) +- Fixed redundant `iter()` call to dataloader when checking dataloading configuration ([#18415](https://github.com/Lightning-AI/lightning/pull/18415)) + + ## [2.0.5] - 2023-07-07 ### Fixed diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index d8aea27831e87..b78203dd5c2b8 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -387,6 +387,10 @@ def _check_dataloader_iterable( source: _DataLoaderSource, trainer_fn: TrainerFn, ) -> None: + if isinstance(dataloader, DataLoader): + # Fast path: `torch.utils.data.DataLoader` is always iterable, calling iter() would be expensive + return + try: iter(dataloader) # type: ignore[call-overload] except TypeError: diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index df263db1b00f8..e8aa8de9f19ff 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -818,8 +818,6 @@ def _get_iterator(self): expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown elif should_fail: expected = [ - # iterable check - 0, # epoch ends 1, # teardown @@ -827,8 +825,6 @@ def _get_iterator(self): ] else: expected = [ - # iterable check - 0, # epoch ends 1, 2, @@ -843,8 +839,6 @@ def _get_iterator(self): expected = [ # sanity check 0, - # iterable check - 0, # epoch ends 0, 1, @@ -853,8 +847,6 @@ def _get_iterator(self): expected = [ # sanity check 0, - # iterable check - 0, # epoch ends 0, 1, diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 91335d176e52d..c596a457bf7e9 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -26,7 +26,12 @@ from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset -from lightning.pytorch.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache +from lightning.pytorch.trainer.connectors.data_connector import ( + _check_dataloader_iterable, + _DataHookSelector, + _DataLoaderSource, + warning_cache, +) from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import _update_dataloader @@ -651,3 +656,17 @@ def test_non_iterables_raise(tmp_path, trainer_fn_name, dataloader_name, stage, setattr(model, dl_method, lambda: dataloader) with pytest.raises(TypeError, match=f"invalid dataloader was returned from `BoringModel.{dl_method}"): trainer_fn(model) + + +def test_iterable_check_on_known_iterators(): + """Test that we only call the `iter()` on the dataloader object if it isn't a known type.""" + iterable = Mock() + iterable.__iter__ = Mock(return_value=iter(range(3))) + _check_dataloader_iterable(iterable, Mock(), Mock()) + iterable.__iter__.assert_called_once() + + # If it's a datalaoder, we don't call the expensive `__iter__` method + dataloader = Mock(spec=DataLoader) + dataloader.__iter__ = Mock() + _check_dataloader_iterable(dataloader, Mock(), Mock()) + dataloader.__iter__.assert_not_called()