Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning for few workers #1378

Merged
merged 6 commits into from
Apr 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
- Added a warning when the number of data loader workers is small. ([#1378](https://github.com/PyTorchLightning/pytorch-lightning/pull/1378))

### Changed

Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable

import torch.distributed as torch_distrib
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
Expand Down Expand Up @@ -73,6 +74,12 @@ def _percent_range_check(self, name: str) -> None:
if not 0. <= value <= 1.:
raise ValueError(msg)

def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2:
warnings.warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
' in the `DataLoader` init to improve performance.')

def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

# don't do anything if it's not a dataloader
Expand Down Expand Up @@ -112,11 +119,13 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)

self.num_training_batches = 0

# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self._worker_check(self.train_dataloader, 'train dataloader')
self._percent_range_check('train_percent_check')

if not _has_len(self.train_dataloader):
Expand Down Expand Up @@ -176,10 +185,10 @@ def _reset_eval_dataloader(self, model: LightningModule,
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for dataloader in dataloaders:
for i, dataloader in enumerate(dataloaders):
self._worker_check(dataloader, f'{mode} dataloader {i}')
if not _has_len(dataloader):
num_batches = float('inf')
break

percent_check = getattr(self, f'{mode}_percent_check')

Expand Down
42 changes: 42 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
LightValStepFitMultipleDataloadersMixin,
LightValStepFitSingleDataloaderMixin,
LightTrainDataloader,
LightValidationDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader,
Expand Down Expand Up @@ -485,6 +486,47 @@ class CurrentTestModel(
trainer.fit(model)


def test_warning_with_few_workers(tmpdir):
""" Test that error is raised if dataloader with only a few workers is used """
tutils.reset_seed()

class CurrentTestModel(
LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
LightTestFitSingleTestDataloadersMixin,
LightEmptyTestStep,
TestModelBase,
):
pass

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))

trainer = Trainer(**trainer_options)

# fit model
with pytest.warns(UserWarning, match='train'):
trainer.fit(model, **fit_options)

with pytest.warns(UserWarning, match='val'):
trainer.fit(model, **fit_options)

with pytest.warns(UserWarning, match='test'):
trainer.test()


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_reinit_for_subclass():

Expand Down