Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable distributed training with CombinedDataLoader and max_size_cycle #10374

Merged
merged 16 commits into from
Nov 9, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed failure when `DataLoader(batch_size=None)` is passed ([#10345](https://github.com/PyTorchLightning/pytorch-lightning/issues/10345))


- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))


- Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388))


Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import (
Expand Down Expand Up @@ -136,14 +136,22 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn
if isinstance(dataloader, CombinedLoader):
# apply `prepare_dataloader` on all the collection of loaders
dataloader.loaders = apply_to_collection(
dataloader.loaders, DataLoader, self.prepare_dataloader, shuffle, mode=mode
dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode
)
# the length need to recomputed across all dataloaders in case of special behavior.
dataloader._apply_cycle_iterator_length()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return dataloader

# don't do anything if it's not a dataloader
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, (DataLoader, CycleIterator)):
return dataloader

cycle_iterator: Optional[CycleIterator] = None

if isinstance(dataloader, CycleIterator):
cycle_iterator = dataloader
dataloader = dataloader.loader

if (
_fault_tolerant_training() # injects components to track the state
or self._requires_distributed_sampler(dataloader) # sets the distributed sampler
Expand All @@ -153,6 +161,10 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn
sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
dataloader = self._update_dataloader(dataloader, sampler, mode=mode)

if cycle_iterator is not None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
cycle_iterator.loader = dataloader
return cycle_iterator

return dataloader

def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler:
Expand Down
24 changes: 21 additions & 3 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,23 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
)
state.reset()

def _apply_cycle_iterator_length(self) -> None:
"""When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to
all dataloaders."""
if self.mode != "max_size_cycle":
return

all_lengths = apply_to_collection(
self.loaders, CycleIterator, lambda c: get_len(c.loader), wrong_dtype=(Sequence, Mapping)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
length = _nested_calc_num_data(all_lengths, max)

def _apply_fn(cycle_iterator: CycleIterator) -> None:
nonlocal length
tchaton marked this conversation as resolved.
Show resolved Hide resolved
cycle_iterator.length = length

apply_to_collection(self.loaders, CycleIterator, _apply_fn)

def __iter__(self) -> Any:
"""Create and return an iterator, `CombinedLoaderIterator`, for the combined loader."""

Expand All @@ -473,11 +490,12 @@ def __getstate__patch__(*_):
return iterator

@staticmethod
def _calc_num_batches(loaders: Any) -> Union[int, float]:
def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the length (aka the number of batches) of `CombinedLoader`.

Args:
loaders: a collections of loaders.
mode: Mode used by the CombinedDataloader

Returns:
length: the minimum length of loaders
Expand All @@ -486,10 +504,10 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:

if isinstance(all_lengths, (int, float)):
return all_lengths
return _nested_calc_num_data(all_lengths, min)
return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min)

def __len__(self) -> int:
return self._calc_num_batches(self.loaders)
return self._calc_num_batches(self.loaders, mode=self.mode)

@staticmethod
def _shutdown_workers_and_reset_iterator(dataloader) -> None:
Expand Down
64 changes: 64 additions & 0 deletions tests/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7
from tests.helpers.boring_model import RandomDataset


def test_tensor_running_accum_reset():
Expand Down Expand Up @@ -379,3 +381,65 @@ def _assert_dataset(loader):
assert isinstance(d, CustomDataset)

apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"})
@mock.patch("torch.cuda.device_count", return_value=2)
@mock.patch("torch.cuda.is_available", return_value=True)
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
def test_combined_data_loader_with_max_size_cycle_and_ddp(
cuda_available_mock, device_count_mock, replace_sampler_ddp, tmpdir
):
"""This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
with ddp and `max_size_cycle` mode."""

dataloader = CombinedLoader(
{"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
)

trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader) == 4 if replace_sampler_ddp else 8

for a_length in [6, 8, 10]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
dataloader = CombinedLoader(
{
"a": DataLoader(range(a_length), batch_size=1),
"b": DataLoader(range(8), batch_size=1),
},
mode="max_size_cycle",
)

length = max(a_length, 8)
assert len(dataloader) == length
trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp)
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader) == length // 2 if replace_sampler_ddp else length
if replace_sampler_ddp:
batches = [batch for batch in dataloader]
if a_length == 6:
assert batches[-1] == {"a": torch.tensor([0]), "b": torch.tensor([6])}
elif a_length == 8:
assert batches[-1] == {"a": torch.tensor([6]), "b": torch.tensor([6])}
elif a_length == 10:
assert batches[-1] == {"a": torch.tensor([8]), "b": torch.tensor([0])}
carmocca marked this conversation as resolved.
Show resolved Hide resolved

class InfiniteDataset(IterableDataset):
def __iter__(self):
while True:
yield 1

dataloader = CombinedLoader(
{
"a": DataLoader(InfiniteDataset(), batch_size=1),
"b": DataLoader(range(8), batch_size=1),
},
mode="max_size_cycle",
)
assert get_len(dataloader) == float("inf")
assert len(dataloader.loaders["b"].loader) == 8
trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp)
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8
assert get_len(dataloader) == float("inf")