From c55bc433ce9b24ac774b9fddf702fbafa3e0677f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 11:36:10 +0100 Subject: [PATCH] Fix retrieval of batch indices when dataloader num_workers > 0 (#10870) Co-authored-by: Rohit Gupta --- CHANGELOG.md | 7 +- .../loops/epoch/prediction_epoch_loop.py | 36 ++++--- pytorch_lightning/overrides/distributed.py | 26 ++++- tests/callbacks/test_prediction_writer.py | 102 +++++++++++++----- tests/deprecated_api/test_remove_1-7.py | 11 ++ tests/deprecated_api/test_remove_1-8.py | 1 + tests/overrides/test_distributed.py | 4 +- 7 files changed, 133 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e965454486a2a..9ed0408bfbb15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -103,7 +103,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the property `Trainer.slurm_job_id` in favor of the new `SLURMEnvironment.job_id()` method ([#10622](https://github.com/PyTorchLightning/pytorch-lightning/pull/10622)) -- +- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) + ### Removed @@ -227,12 +228,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837)) -- +- Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) + - + ## [1.5.4] - 2021-11-30 ### Fixed diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index a5e885efc4b29..985779a17c54d 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -26,7 +26,7 @@ def __init__(self) -> None: self._dl_max_batches = 0 self._num_dataloaders = 0 self._warning_cache = WarningCache() - self._all_batch_indices: List[int] = [] + self._seen_batch_indices: List[List[int]] = [] @property def done(self) -> bool: @@ -44,7 +44,7 @@ def connect(self, **kwargs: "Loop") -> None: def reset(self) -> None: """Resets the loops internal state.""" - self._all_batch_indices = [] + self._seen_batch_indices = [] self.predictions = [] self.batch_progress.reset_on_run() @@ -68,6 +68,7 @@ def on_run_start( # type: ignore[override] void(dataloader_iter, dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders + self._seen_batch_indices = self._get_batch_indices(dataloader_idx) self.return_predictions = return_predictions def advance( # type: ignore[override] @@ -88,6 +89,10 @@ def advance( # type: ignore[override] return_predictions: whether to return the obtained predictions """ batch_idx, batch = next(dataloader_iter) + self._seen_batch_indices = self._get_batch_indices(dataloader_idx) + # we need to truncate the list of batch indicies due to prefetching in the dataloader and Lightning + self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)] + if batch is None: raise StopIteration @@ -99,13 +104,10 @@ def advance( # type: ignore[override] with self.trainer.profiler.profile("predict_step"): self._predict_step(batch, batch_idx, dataloader_idx) - def on_run_end(self) -> Tuple[List[Any], List[int]]: + def on_run_end(self) -> Tuple[List[Any], List[List[int]]]: """Returns the predictions and the corresponding batch indices.""" - predictions = self.predictions - all_batch_indices = self._all_batch_indices - # free memory - self.predictions = [] - self._all_batch_indices = [] + predictions, all_batch_indices = self.predictions, self._seen_batch_indices + self.predictions, self._seen_batch_indices = [], [] # free memory return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: @@ -121,7 +123,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) # extract batch_indices and store them - self._store_batch_indices(dataloader_idx) + self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else [] model_ref = self.trainer.lightning_module @@ -160,12 +162,12 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs - def _store_batch_indices(self, dataloader_idx: int) -> None: - """Stores the batch indices if the predictions should be stored.""" + def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: + """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our + :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler - if isinstance(batch_sampler, IndexBatchSamplerWrapper): - self.current_batch_indices = batch_sampler.batch_indices - if self.should_store_predictions: - self._all_batch_indices.append(batch_sampler.batch_indices) - else: - warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") + if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions: + return batch_sampler.seen_batch_indices + + warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") + return [] diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index f7c2a71b4978d..66644a91d5eea 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, cast, Iterator, List, Optional, Sized, Union +from typing import Any, cast, Iterator, List, Sized, Union import torch from torch import Tensor @@ -21,6 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.utilities import rank_zero_deprecation class LightningDistributedModule(_LightningModuleWrapperBase): @@ -123,12 +124,31 @@ class IndexBatchSamplerWrapper: """This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices.""" def __init__(self, sampler: BatchSampler) -> None: + self.seen_batch_indices: List[List[int]] = [] self._sampler = sampler - self.batch_indices: Optional[List[int]] = None + self._batch_indices: List[int] = [] + + @property + def batch_indices(self) -> List[int]: + rank_zero_deprecation( + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in" + " v1.7. Access the full list `seen_batch_indices` instead." + ) + return self._batch_indices + + @batch_indices.setter + def batch_indices(self, indices: List[int]) -> None: + rank_zero_deprecation( + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in" + " v1.7. Access the full list `seen_batch_indices` instead." + ) + self._batch_indices = indices def __iter__(self) -> Iterator[List[int]]: + self.seen_batch_indices = [] for batch in self._sampler: - self.batch_indices = batch + self._batch_indices = batch + self.seen_batch_indices.append(batch) yield batch def __len__(self) -> int: diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 75e0dbd31ec79..2cd3738ca875f 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -11,54 +11,98 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import ANY, call, Mock import pytest +from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.callbacks import BasePredictionWriter from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf -def test_prediction_writer(tmpdir): - class CustomPredictionWriter(BasePredictionWriter): - def __init__(self, writer_interval: str): - super().__init__(writer_interval) +class DummyPredictionWriter(BasePredictionWriter): + def write_on_batch_end(self, *args, **kwargs): + pass - self.write_on_batch_end_called = False - self.write_on_epoch_end_called = False + def write_on_epoch_end(self, *args, **kwargs): + pass - def write_on_batch_end(self, *args, **kwargs): - self.write_on_batch_end_called = True - - def write_on_epoch_end(self, *args, **kwargs): - self.write_on_epoch_end_called = True +def test_prediction_writer_invalid_write_interval(): + """Test that configuring an unknown interval name raises an error.""" with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"): - CustomPredictionWriter("something") + DummyPredictionWriter("something") + + +def test_prediction_writer_hook_call_intervals(tmpdir): + """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined + interval.""" + DummyPredictionWriter.write_on_batch_end = Mock() + DummyPredictionWriter.write_on_epoch_end = Mock() + + dataloader = DataLoader(RandomDataset(32, 64)) model = BoringModel() - cb = CustomPredictionWriter("batch_and_epoch") + cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader()) + results = trainer.predict(model, dataloaders=dataloader) assert len(results) == 4 - assert cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 1 - cb = CustomPredictionWriter("batch_and_epoch") + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() + + cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 1 + + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() - cb = CustomPredictionWriter("batch") + cb = DummyPredictionWriter("batch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert cb.write_on_batch_end_called - assert not cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 0 + + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() - cb = CustomPredictionWriter("epoch") + cb = DummyPredictionWriter("epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert not cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 0 + assert cb.write_on_epoch_end.call_count == 1 + + +@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))]) +def test_prediction_writer_batch_indices(tmpdir, num_workers): + DummyPredictionWriter.write_on_batch_end = Mock() + DummyPredictionWriter.write_on_epoch_end = Mock() + + dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers) + model = BoringModel() + writer = DummyPredictionWriter("batch_and_epoch") + trainer = Trainer(limit_predict_batches=4, callbacks=writer) + trainer.predict(model, dataloaders=dataloader) + + writer.write_on_batch_end.assert_has_calls( + [ + call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), + call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), + call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), + call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), + ] + ) + + writer.write_on_epoch_end.assert_has_calls( + [ + call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), + ] + ) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 0065d5947dc26..6f7e1199ab438 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -14,6 +14,7 @@ """Test deprecated functionality which will be removed in v1.7.0.""" import os from unittest import mock +from unittest.mock import Mock import pytest @@ -23,6 +24,7 @@ from pytorch_lightning.callbacks.progress import ProgressBar from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger +from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, @@ -528,3 +530,12 @@ def is_using_torchelastic(): match=f"MyClusterEnvironment.{method_name}` has been deprecated in v1.6 and will be removed in v1.7" ): MyClusterEnvironment() + + +def test_v1_7_0_index_batch_sampler_wrapper_batch_indices(): + sampler = IndexBatchSamplerWrapper(Mock()) + with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): + _ = sampler.batch_indices + + with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): + sampler.batch_indices = [] diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index d109b5dbfdcaa..7ef0fe2a15e4f 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.8.0.""" + import pytest import torch diff --git a/tests/overrides/test_distributed.py b/tests/overrides/test_distributed.py index c8d982bd733fe..e425859fe34df 100644 --- a/tests/overrides/test_distributed.py +++ b/tests/overrides/test_distributed.py @@ -54,9 +54,7 @@ def test_index_batch_sampler(tmpdir): assert batch_sampler.batch_size == index_batch_sampler.batch_size assert batch_sampler.drop_last == index_batch_sampler.drop_last assert batch_sampler.sampler is sampler - - for batch in index_batch_sampler: - assert index_batch_sampler.batch_indices == batch + assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices def test_index_batch_sampler_methods():