Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for CombinedLoader while checking for warning raised with eval dataloaders #10994

Merged
merged 2 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed support for `CombinedLoader` while checking for warning raised with eval dataloaders ([#10994](https://github.com/PyTorchLightning/pytorch-lightning/pull/10994))


-


Expand Down
32 changes: 22 additions & 10 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Callable, Collection, List, Optional, Tuple, Union

from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataset import IterableDataset
from torch.utils.data.distributed import DistributedSampler

import pytorch_lightning as pl
Expand Down Expand Up @@ -297,19 +298,17 @@ def _reset_eval_dataloader(
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]

if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",
category=PossibleUserWarning,
)

if any(dl is None for dl in dataloaders):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")

for loader in dataloaders:
apply_to_collection(
loader.loaders if isinstance(loader, CombinedLoader) else loader,
DataLoader,
self._check_eval_shuffling,
mode=mode,
)

# add samplers
dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None]

Expand Down Expand Up @@ -459,3 +458,16 @@ def replace_sampler(dataloader: DataLoader) -> DataLoader:
dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler)

return dataloader

@staticmethod
def _check_eval_shuffling(dataloader, mode):
if (
hasattr(dataloader, "sampler")
and not isinstance(dataloader.sampler, SequentialSampler)
and not isinstance(dataloader.dataset, IterableDataset)
):
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
category=PossibleUserWarning,
)
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ def __len__(self) -> int:


class CombinedLoader:
"""Combines different dataloaders and allows sampling in parallel. Supported modes are 'min_size', which raises
StopIteration after the shortest loader (the one with the lowest number of batches) is done, and
'max_size_cycle` which raises StopIteration after the longest loader (the one with most batches) is done, while
cycling through the shorter loaders.
"""Combines different dataloaders and allows sampling in parallel. Supported modes are ``"min_size"``, which
raises StopIteration after the shortest loader (the one with the lowest number of batches) is done, and
``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is done,
while cycling through the shorter loaders.

Examples:
>>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4),
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.data import _update_dataloader
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -334,3 +337,27 @@ def test_pre_made_batches():
loader = DataLoader(RandomDataset(32, 10), batch_size=None)
trainer = Trainer(fast_dev_run=1)
trainer.predict(LoaderTestModel(), loader)


@pytest.mark.parametrize(
"val_dl",
[
DataLoader(dataset=RandomDataset(32, 64), shuffle=True),
CombinedLoader(DataLoader(dataset=RandomDataset(32, 64), shuffle=True)),
CombinedLoader(
[DataLoader(dataset=RandomDataset(32, 64)), DataLoader(dataset=RandomDataset(32, 64), shuffle=True)]
),
CombinedLoader(
{
"dl1": DataLoader(dataset=RandomDataset(32, 64)),
"dl2": DataLoader(dataset=RandomDataset(32, 64), shuffle=True),
}
),
],
)
def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
trainer = Trainer()
model = BoringModel()
trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
with pytest.warns(PossibleUserWarning, match="recommended .* turn this off for val/test/predict"):
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)