-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Samplers always return the same number of batches in distributed mode #267
Merged
pzelasko
merged 7 commits into
master
from
feature/fix-distributed-dynamic-batch-size-sampler
Apr 13, 2021
Merged
Changes from 6 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
c5e843c
Add test for equal number of batches of distributed samplers with dyn…
pzelasko d987f85
Ensure equal number of batches across all distributed workers in an e…
pzelasko 02621f9
Ensure equal number of batches across all distributed workers in an e…
pzelasko 7ee8d59
Remove print and add comments
pzelasko c836656
Add BucketingSampler into the test
pzelasko af4c1f5
Documentation lifting
pzelasko 8671fe9
Fix formulas in docs
pzelasko File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
import logging | ||
import random | ||
import warnings | ||
from dataclasses import dataclass | ||
from math import ceil | ||
from typing import Iterable, List, Optional, Tuple, Type, Union | ||
|
||
import torch.distributed as dist | ||
|
@@ -32,16 +30,22 @@ def __len__(self): | |
|
||
class CutSampler(Sampler[List[str]]): | ||
""" | ||
CutSampler is responsible for collecting batches of cuts, given specified criteria. | ||
It implements correct handling of distributed sampling in DataLoader, | ||
``CutSampler`` is responsible for collecting batches of cuts, given specified criteria. | ||
It implements correct handling of distributed sampling in ``DataLoader``, | ||
so that the cuts are not duplicated across workers. | ||
|
||
Sampling in a CutSampler is intended to be very quick - it only uses the metadata in | ||
Sampling in a ``CutSampler`` is intended to be very quick - it only uses the metadata in | ||
``CutSet`` manifest to select the cuts, and is not intended to perform any I/O. | ||
|
||
CutSampler works similarly to PyTorch's DistributedSampler - when :attr:`shuffle=True`, | ||
you should call ``sampler.set_epoch(epoch)`` at each new epoch to have a different | ||
ordering of returned elements. | ||
ordering of returned elements. However, its actual behaviour is different than that of | ||
DistributedSampler -- instead of partitioning the underlying cuts into equally sized chunks, | ||
it will return every N-th batch and skip the other batches (where ``N == world_size``). | ||
The formula used to determine which batches are returned is | ||
``(batch_idx + rank) % world_size == 0``. | ||
This ensures that we can return an equal number of batches in all distributed workers | ||
in spite of using a dynamic batch size, at the cost of skipping at most ``world_size`` batches. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it be
? |
||
|
||
Example usage:: | ||
|
||
|
@@ -55,10 +59,10 @@ class CutSampler(Sampler[List[str]]): | |
.. note:: | ||
|
||
For implementers of new samplers: | ||
Subclasses of CutSampler are expected to implement ``__next__()`` to introduce specific | ||
Subclasses of CutSampler are expected to implement ``self._next_batch()`` to introduce specific | ||
sampling logic (e.g. based on filters such as max number of frames/tokens/etc.). | ||
CutSampler defines ``__iter__()``, which optionally shuffles the cut IDs, and resets | ||
``self.current_idx`` to zero (to be used and incremented inside of ``__next__()``. | ||
``self.cut_idx`` to zero (to be used and incremented inside of ``_next_batch()``. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -81,27 +85,29 @@ def __init__( | |
:param rank: Index of distributed node. We will try to infer it by default. | ||
:param seed: Random seed used to consistently shuffle the dataset across different processes. | ||
""" | ||
data_source = list(cut_ids) | ||
data_source = DataSource(list(cut_ids)) | ||
super().__init__(data_source) | ||
self.full_data_source = data_source | ||
self.data_source = data_source | ||
self.shuffle = shuffle | ||
self.seed = seed | ||
self.epoch = 0 | ||
self.cut_idx = 0 | ||
|
||
self._maybe_init_distributed(world_size=world_size, rank=rank) | ||
self.data_source = DataSource( | ||
partition_cut_ids(self.full_data_source, world_size=self.world_size, rank=self.rank) | ||
) | ||
|
||
self.num_batches = None | ||
|
||
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]): | ||
if world_size is not None: | ||
assert world_size >= 1 | ||
if rank is not None: | ||
assert rank >= 0 | ||
if not dist.is_available() or not dist.is_initialized(): | ||
self.world_size = 1 | ||
self.rank = 0 | ||
self.world_size = 1 if world_size is None else world_size | ||
self.rank = 0 if rank is None else rank | ||
return | ||
self.world_size = dist.get_world_size() if world_size is None else world_size | ||
self.rank = dist.get_rank() if rank is None else rank | ||
assert self.rank < self.world_size | ||
|
||
def set_epoch(self, epoch: int) -> None: | ||
r""" | ||
|
@@ -114,22 +120,34 @@ def set_epoch(self, epoch: int) -> None: | |
self.epoch = epoch | ||
self.num_batches = None | ||
|
||
def _next_batch(self) -> List[str]: | ||
raise NotImplementedError("Sub-classes of CutSampler have to implement self._next_batch()") | ||
|
||
def __iter__(self) -> 'CutSampler': | ||
""" | ||
Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested. | ||
""" | ||
if self.shuffle: | ||
self.data_source.shuffle(self.seed + self.epoch) | ||
self.current_idx = 0 | ||
self.cut_idx = 0 | ||
return self | ||
|
||
def __len__(self) -> int: | ||
if self.num_batches is None: | ||
self.num_batches = sum(1 for item in self) | ||
self.num_batches = sum(1 for _ in self) | ||
return self.num_batches | ||
|
||
def __next__(self) -> List[str]: | ||
raise NotImplemented | ||
# We use the following trick to ensure equal number of batches for each distributed | ||
# worker: | ||
# Every time a next batch is required, we will sample self.world_size batches first, | ||
# and then return the one at position self.rank. | ||
# This way, if any of the batches raises StopIteration, we'll know to stop early | ||
# even a given batch was available for one of the nodes, but not for the others. | ||
batches = [] | ||
for _ in range(self.world_size): | ||
batches.append(self._next_batch()) | ||
return batches[self.rank] | ||
|
||
|
||
@dataclass | ||
|
@@ -202,6 +220,16 @@ class SingleCutSampler(CutSampler): | |
the batch size is dynamic. | ||
Exactly zero or one of those constraints can be specified. | ||
Padding required to collate the batch does not contribute to max frames/samples/duration. | ||
|
||
Example usage:: | ||
|
||
>>> dataset = K2SpeechRecognitionDataset(cuts) | ||
>>> sampler = SingleCutSampler(cuts, shuffle=True) | ||
>>> loader = DataLoader(dataset, sampler=sampler, batch_size=None) | ||
>>> for epoch in range(start_epoch, n_epochs): | ||
... sampler.set_epoch(epoch) | ||
... train(loader) | ||
|
||
""" | ||
|
||
def __init__( | ||
|
@@ -237,7 +265,7 @@ def __init__( | |
# Constraints | ||
assert is_none_or_gt(self.max_cuts, 0) | ||
|
||
def __next__(self) -> List[str]: | ||
def _next_batch(self) -> List[str]: | ||
# Keep iterating the underlying CutSet as long as we hit or exceed the constraints | ||
# provided by user (the max number of frames or max number of cuts). | ||
# Note: no actual data is loaded into memory yet because the manifests contain all the metadata | ||
|
@@ -246,9 +274,9 @@ def __next__(self) -> List[str]: | |
cut_ids = [] | ||
while True: | ||
# Check that we have not reached the end of the dataset. | ||
if self.current_idx < len(self.data_source): | ||
if self.cut_idx < len(self.data_source): | ||
# We didn't - grab the next cut | ||
next_cut_id = self.data_source[self.current_idx] | ||
next_cut_id = self.data_source[self.cut_idx] | ||
else: | ||
if cut_ids: | ||
# We did and we have a partial batch - return it. | ||
|
@@ -263,7 +291,7 @@ def __next__(self) -> List[str]: | |
if not self.time_constraint.exceeded() and (self.max_cuts is None or next_num_cuts <= self.max_cuts): | ||
# No - add the next cut to the batch, and keep trying. | ||
cut_ids.append(next_cut.id) | ||
self.current_idx += 1 | ||
self.cut_idx += 1 | ||
else: | ||
# Yes. Do we have at least one cut in the batch? | ||
if cut_ids: | ||
|
@@ -275,7 +303,7 @@ def __next__(self) -> List[str]: | |
warnings.warn("The first cut drawn in batch collection violates the max_frames or max_cuts " | ||
"constraints - we'll return it anyway. Consider increasing max_frames/max_cuts.") | ||
cut_ids.append(next_cut.id) | ||
self.current_idx += 1 | ||
self.cut_idx += 1 | ||
return cut_ids | ||
|
||
|
||
|
@@ -336,7 +364,7 @@ def __init__( | |
) | ||
self.max_cuts = max_cuts | ||
|
||
def __next__(self) -> List[str]: | ||
def _next_batch(self) -> List[str]: | ||
# Keep iterating the underlying CutSets as long as we hit or exceed the constraints | ||
# provided by user (the max number of source_feats or max number of cuts). | ||
# Note: no actual data is loaded into memory yet because the manifests contain all the metadata | ||
|
@@ -346,9 +374,9 @@ def __next__(self) -> List[str]: | |
cut_ids = [] | ||
while True: | ||
# Check that we have not reached the end of the dataset. | ||
if self.current_idx < len(self.data_source): | ||
if self.cut_idx < len(self.data_source): | ||
# We didn't - grab the next cut | ||
next_cut_id = self.data_source[self.current_idx] | ||
next_cut_id = self.data_source[self.cut_idx] | ||
else: | ||
if cut_ids: | ||
# We did and we have a partial batch - return it. | ||
|
@@ -367,7 +395,7 @@ def __next__(self) -> List[str]: | |
and (self.max_cuts is None or next_num_cuts <= self.max_cuts): | ||
# No - add the next cut to the batch, and keep trying. | ||
cut_ids.append(next_source_cut.id) | ||
self.current_idx += 1 | ||
self.cut_idx += 1 | ||
else: | ||
# Yes. Do we have at least one cut in the batch? | ||
if cut_ids: | ||
|
@@ -379,7 +407,7 @@ def __next__(self) -> List[str]: | |
warnings.warn("The first cut drawn in batch collection violates one of the max_... constraints" | ||
"we'll return it anyway. Consider increasing max_source_frames/max_cuts/etc.") | ||
cut_ids.append(next_source_cut.id) | ||
self.current_idx += 1 | ||
self.cut_idx += 1 | ||
return cut_ids | ||
|
||
|
||
|
@@ -466,7 +494,7 @@ def __iter__(self) -> 'BucketingSampler': | |
self.depleted = [False] * self.num_buckets | ||
return self | ||
|
||
def __next__(self) -> List[str]: | ||
def _next_batch(self) -> List[str]: | ||
while not self.is_depleted: | ||
idx, sampler = self.bucket_rng.choice(self._nondepleted_samplers_with_idxs) | ||
try: | ||
|
@@ -494,35 +522,3 @@ def _nondepleted_samplers_with_idxs(self): | |
enumerate(zip(self.bucket_samplers, self.depleted)) | ||
if not depleted | ||
] | ||
|
||
|
||
def partition_cut_ids( | ||
data_source: List[str], | ||
world_size: int = 1, | ||
rank: int = 0 | ||
) -> List[str]: | ||
""" | ||
Returns a list of cut IDs to be used by a single dataloading process. | ||
For multiple dataloader workers or ``DistributedDataParallel`` training, | ||
that list will be a subset of ``sampler.full_data_source``. | ||
|
||
:param data_source: a list of Cut IDs, representing the full dataset. | ||
:param world_size: Total number of distributed nodes. Set only when using ``DistributedDataParallel``. | ||
:param rank: Index of distributed node. Set only when using ``DistributedDataParallel``. | ||
""" | ||
|
||
# First, split depending on the world_size and rank. | ||
if world_size == 1: | ||
return data_source | ||
else: | ||
# Distributed training is active - split full dataset into a subset. | ||
total = len(data_source) | ||
per_partition = int(ceil(total / float(world_size))) | ||
partition_start = rank * per_partition | ||
partition_end = min(partition_start + per_partition, total) | ||
logging.info(f'Distributed training with world size of {world_size} detected ' | ||
f'(node\'s local rank is {rank}. ' | ||
f'Splitting cuts into {world_size} partitions (' | ||
f'this partition has cut IDs range [{partition_start, partition_end}].') | ||
|
||
return data_source[partition_start: partition_end] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you're right, thanks