Skip to content

Commit

Permalink
Enable distributed training with CombinedDataLoader and max_size_cycle (
Browse files Browse the repository at this point in the history
#10374)

Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Thomas Chaton <[email protected]>
  • Loading branch information
3 people authored and lexierule committed Nov 16, 2021
1 parent 43a5075 commit ab44b81
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))



## [1.5.1] - 2021-11-09
Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import (
Expand Down Expand Up @@ -136,14 +136,22 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn
if isinstance(dataloader, CombinedLoader):
# apply `prepare_dataloader` on all the collection of loaders
dataloader.loaders = apply_to_collection(
dataloader.loaders, DataLoader, self.prepare_dataloader, shuffle, mode=mode
dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode
)
# the length need to recomputed across all dataloaders in case of special behavior.
dataloader._apply_cycle_iterator_length()
return dataloader

# don't do anything if it's not a dataloader
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, (DataLoader, CycleIterator)):
return dataloader

cycle_iterator: Optional[CycleIterator] = None

if isinstance(dataloader, CycleIterator):
cycle_iterator = dataloader
dataloader = dataloader.loader

if (
_fault_tolerant_training() # injects components to track the state
or self._requires_distributed_sampler(dataloader) # sets the distributed sampler
Expand All @@ -153,6 +161,10 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn
sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
dataloader = self._update_dataloader(dataloader, sampler, mode=mode)

if cycle_iterator is not None:
cycle_iterator.loader = dataloader
return cycle_iterator

return dataloader

def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler:
Expand Down
20 changes: 17 additions & 3 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,19 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
)
state.reset()

def _apply_cycle_iterator_length(self) -> None:
"""When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to
all dataloaders."""
if self.mode != "max_size_cycle":
return

def set_len(cycle_iterator: CycleIterator, length: int) -> None:
cycle_iterator.length = length

all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader))
max_length = _nested_calc_num_data(all_lengths, max)
apply_to_collection(self.loaders, CycleIterator, set_len, length=max_length)

def __iter__(self) -> Any:
"""Create and return an iterator, `CombinedLoaderIterator`, for the combined loader."""

Expand All @@ -473,11 +486,12 @@ def __getstate__patch__(*_):
return iterator

@staticmethod
def _calc_num_batches(loaders: Any) -> Union[int, float]:
def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]:
"""Compute the length (aka the number of batches) of `CombinedLoader`.
Args:
loaders: a collections of loaders.
mode: Mode used by the CombinedDataloader
Returns:
length: the minimum length of loaders
Expand All @@ -486,10 +500,10 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:

if isinstance(all_lengths, (int, float)):
return all_lengths
return _nested_calc_num_data(all_lengths, min)
return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min)

def __len__(self) -> int:
return self._calc_num_batches(self.loaders)
return self._calc_num_batches(self.loaders, mode=self.mode)

@staticmethod
def _shutdown_workers_and_reset_iterator(dataloader) -> None:
Expand Down
55 changes: 55 additions & 0 deletions tests/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7
from tests.helpers.boring_model import RandomDataset


def test_tensor_running_accum_reset():
Expand Down Expand Up @@ -379,3 +381,56 @@ def _assert_dataset(loader):
assert isinstance(d, CustomDataset)

apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset)


@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp, tmpdir):
"""This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
with ddp and `max_size_cycle` mode."""
trainer = Trainer(strategy="ddp", accelerator="auto", devices=2, replace_sampler_ddp=replace_sampler_ddp)

dataloader = CombinedLoader(
{"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
)
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader) == 4 if replace_sampler_ddp else 8

for a_length in [6, 8, 10]:
dataloader = CombinedLoader(
{
"a": DataLoader(range(a_length), batch_size=1),
"b": DataLoader(range(8), batch_size=1),
},
mode="max_size_cycle",
)

length = max(a_length, 8)
assert len(dataloader) == length
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader) == length // 2 if replace_sampler_ddp else length
if replace_sampler_ddp:
last_batch = list(dataloader)[-1]
if a_length == 6:
assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])}
elif a_length == 8:
assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])}
elif a_length == 10:
assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])}

class InfiniteDataset(IterableDataset):
def __iter__(self):
while True:
yield 1

dataloader = CombinedLoader(
{
"a": DataLoader(InfiniteDataset(), batch_size=1),
"b": DataLoader(range(8), batch_size=1),
},
mode="max_size_cycle",
)
assert get_len(dataloader) == float("inf")
assert len(dataloader.loaders["b"].loader) == 8
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8
assert get_len(dataloader) == float("inf")

0 comments on commit ab44b81

Please sign in to comment.