Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samplers always return the same number of batches in distributed mode #267

Merged
merged 7 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 58 additions & 62 deletions lhotse/dataset/sampling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
import random
import warnings
from dataclasses import dataclass
from math import ceil
from typing import Iterable, List, Optional, Tuple, Type, Union

import torch.distributed as dist
Expand Down Expand Up @@ -32,16 +30,22 @@ def __len__(self):

class CutSampler(Sampler[List[str]]):
"""
CutSampler is responsible for collecting batches of cuts, given specified criteria.
It implements correct handling of distributed sampling in DataLoader,
``CutSampler`` is responsible for collecting batches of cuts, given specified criteria.
It implements correct handling of distributed sampling in ``DataLoader``,
so that the cuts are not duplicated across workers.

Sampling in a CutSampler is intended to be very quick - it only uses the metadata in
Sampling in a ``CutSampler`` is intended to be very quick - it only uses the metadata in
``CutSet`` manifest to select the cuts, and is not intended to perform any I/O.

CutSampler works similarly to PyTorch's DistributedSampler - when :attr:`shuffle=True`,
you should call ``sampler.set_epoch(epoch)`` at each new epoch to have a different
ordering of returned elements.
ordering of returned elements. However, its actual behaviour is different than that of
DistributedSampler -- instead of partitioning the underlying cuts into equally sized chunks,
it will return every N-th batch and skip the other batches (where ``N == world_size``).
The formula used to determine which batches are returned is
``(batch_idx + rank) % world_size == 0``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be

(batch_idx + (world_size - rank)) % world_size == 0

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you're right, thanks

This ensures that we can return an equal number of batches in all distributed workers
in spite of using a dynamic batch size, at the cost of skipping at most ``world_size`` batches.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be

skipping at most ``world_size - 1`` batches.

?


Example usage::

Expand All @@ -55,10 +59,10 @@ class CutSampler(Sampler[List[str]]):
.. note::

For implementers of new samplers:
Subclasses of CutSampler are expected to implement ``__next__()`` to introduce specific
Subclasses of CutSampler are expected to implement ``self._next_batch()`` to introduce specific
sampling logic (e.g. based on filters such as max number of frames/tokens/etc.).
CutSampler defines ``__iter__()``, which optionally shuffles the cut IDs, and resets
``self.current_idx`` to zero (to be used and incremented inside of ``__next__()``.
``self.cut_idx`` to zero (to be used and incremented inside of ``_next_batch()``.
"""

def __init__(
Expand All @@ -81,27 +85,29 @@ def __init__(
:param rank: Index of distributed node. We will try to infer it by default.
:param seed: Random seed used to consistently shuffle the dataset across different processes.
"""
data_source = list(cut_ids)
data_source = DataSource(list(cut_ids))
super().__init__(data_source)
self.full_data_source = data_source
self.data_source = data_source
self.shuffle = shuffle
self.seed = seed
self.epoch = 0
self.cut_idx = 0

self._maybe_init_distributed(world_size=world_size, rank=rank)
self.data_source = DataSource(
partition_cut_ids(self.full_data_source, world_size=self.world_size, rank=self.rank)
)

self.num_batches = None

def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
if world_size is not None:
assert world_size >= 1
if rank is not None:
assert rank >= 0
if not dist.is_available() or not dist.is_initialized():
self.world_size = 1
self.rank = 0
self.world_size = 1 if world_size is None else world_size
self.rank = 0 if rank is None else rank
return
self.world_size = dist.get_world_size() if world_size is None else world_size
self.rank = dist.get_rank() if rank is None else rank
assert self.rank < self.world_size

def set_epoch(self, epoch: int) -> None:
r"""
Expand All @@ -114,22 +120,34 @@ def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
self.num_batches = None

def _next_batch(self) -> List[str]:
raise NotImplementedError("Sub-classes of CutSampler have to implement self._next_batch()")

def __iter__(self) -> 'CutSampler':
"""
Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested.
"""
if self.shuffle:
self.data_source.shuffle(self.seed + self.epoch)
self.current_idx = 0
self.cut_idx = 0
return self

def __len__(self) -> int:
if self.num_batches is None:
self.num_batches = sum(1 for item in self)
self.num_batches = sum(1 for _ in self)
return self.num_batches

def __next__(self) -> List[str]:
raise NotImplemented
# We use the following trick to ensure equal number of batches for each distributed
# worker:
# Every time a next batch is required, we will sample self.world_size batches first,
# and then return the one at position self.rank.
# This way, if any of the batches raises StopIteration, we'll know to stop early
# even a given batch was available for one of the nodes, but not for the others.
batches = []
for _ in range(self.world_size):
batches.append(self._next_batch())
return batches[self.rank]


@dataclass
Expand Down Expand Up @@ -202,6 +220,16 @@ class SingleCutSampler(CutSampler):
the batch size is dynamic.
Exactly zero or one of those constraints can be specified.
Padding required to collate the batch does not contribute to max frames/samples/duration.

Example usage::

>>> dataset = K2SpeechRecognitionDataset(cuts)
>>> sampler = SingleCutSampler(cuts, shuffle=True)
>>> loader = DataLoader(dataset, sampler=sampler, batch_size=None)
>>> for epoch in range(start_epoch, n_epochs):
... sampler.set_epoch(epoch)
... train(loader)

"""

def __init__(
Expand Down Expand Up @@ -237,7 +265,7 @@ def __init__(
# Constraints
assert is_none_or_gt(self.max_cuts, 0)

def __next__(self) -> List[str]:
def _next_batch(self) -> List[str]:
# Keep iterating the underlying CutSet as long as we hit or exceed the constraints
# provided by user (the max number of frames or max number of cuts).
# Note: no actual data is loaded into memory yet because the manifests contain all the metadata
Expand All @@ -246,9 +274,9 @@ def __next__(self) -> List[str]:
cut_ids = []
while True:
# Check that we have not reached the end of the dataset.
if self.current_idx < len(self.data_source):
if self.cut_idx < len(self.data_source):
# We didn't - grab the next cut
next_cut_id = self.data_source[self.current_idx]
next_cut_id = self.data_source[self.cut_idx]
else:
if cut_ids:
# We did and we have a partial batch - return it.
Expand All @@ -263,7 +291,7 @@ def __next__(self) -> List[str]:
if not self.time_constraint.exceeded() and (self.max_cuts is None or next_num_cuts <= self.max_cuts):
# No - add the next cut to the batch, and keep trying.
cut_ids.append(next_cut.id)
self.current_idx += 1
self.cut_idx += 1
else:
# Yes. Do we have at least one cut in the batch?
if cut_ids:
Expand All @@ -275,7 +303,7 @@ def __next__(self) -> List[str]:
warnings.warn("The first cut drawn in batch collection violates the max_frames or max_cuts "
"constraints - we'll return it anyway. Consider increasing max_frames/max_cuts.")
cut_ids.append(next_cut.id)
self.current_idx += 1
self.cut_idx += 1
return cut_ids


Expand Down Expand Up @@ -336,7 +364,7 @@ def __init__(
)
self.max_cuts = max_cuts

def __next__(self) -> List[str]:
def _next_batch(self) -> List[str]:
# Keep iterating the underlying CutSets as long as we hit or exceed the constraints
# provided by user (the max number of source_feats or max number of cuts).
# Note: no actual data is loaded into memory yet because the manifests contain all the metadata
Expand All @@ -346,9 +374,9 @@ def __next__(self) -> List[str]:
cut_ids = []
while True:
# Check that we have not reached the end of the dataset.
if self.current_idx < len(self.data_source):
if self.cut_idx < len(self.data_source):
# We didn't - grab the next cut
next_cut_id = self.data_source[self.current_idx]
next_cut_id = self.data_source[self.cut_idx]
else:
if cut_ids:
# We did and we have a partial batch - return it.
Expand All @@ -367,7 +395,7 @@ def __next__(self) -> List[str]:
and (self.max_cuts is None or next_num_cuts <= self.max_cuts):
# No - add the next cut to the batch, and keep trying.
cut_ids.append(next_source_cut.id)
self.current_idx += 1
self.cut_idx += 1
else:
# Yes. Do we have at least one cut in the batch?
if cut_ids:
Expand All @@ -379,7 +407,7 @@ def __next__(self) -> List[str]:
warnings.warn("The first cut drawn in batch collection violates one of the max_... constraints"
"we'll return it anyway. Consider increasing max_source_frames/max_cuts/etc.")
cut_ids.append(next_source_cut.id)
self.current_idx += 1
self.cut_idx += 1
return cut_ids


Expand Down Expand Up @@ -466,7 +494,7 @@ def __iter__(self) -> 'BucketingSampler':
self.depleted = [False] * self.num_buckets
return self

def __next__(self) -> List[str]:
def _next_batch(self) -> List[str]:
while not self.is_depleted:
idx, sampler = self.bucket_rng.choice(self._nondepleted_samplers_with_idxs)
try:
Expand Down Expand Up @@ -494,35 +522,3 @@ def _nondepleted_samplers_with_idxs(self):
enumerate(zip(self.bucket_samplers, self.depleted))
if not depleted
]


def partition_cut_ids(
data_source: List[str],
world_size: int = 1,
rank: int = 0
) -> List[str]:
"""
Returns a list of cut IDs to be used by a single dataloading process.
For multiple dataloader workers or ``DistributedDataParallel`` training,
that list will be a subset of ``sampler.full_data_source``.

:param data_source: a list of Cut IDs, representing the full dataset.
:param world_size: Total number of distributed nodes. Set only when using ``DistributedDataParallel``.
:param rank: Index of distributed node. Set only when using ``DistributedDataParallel``.
"""

# First, split depending on the world_size and rank.
if world_size == 1:
return data_source
else:
# Distributed training is active - split full dataset into a subset.
total = len(data_source)
per_partition = int(ceil(total / float(world_size)))
partition_start = rank * per_partition
partition_end = min(partition_start + per_partition, total)
logging.info(f'Distributed training with world size of {world_size} detected '
f'(node\'s local rank is {rank}. '
f'Splitting cuts into {world_size} partitions ('
f'this partition has cut IDs range [{partition_start, partition_end}].')

return data_source[partition_start: partition_end]
20 changes: 20 additions & 0 deletions test/dataset/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from itertools import groupby

import pytest
Expand Down Expand Up @@ -491,3 +492,22 @@ def test_bucketing_sampler_time_constraints(constraint):
for batch in sampler:
cut_ids.extend(batch)
assert set(cut_set.ids) == set(cut_ids)


@pytest.mark.parametrize('world_size', [2, 3, 4])
@pytest.mark.parametrize('n_cuts', [995, 996, 997, 998, 999, 1000, 1001, 1002, 1003])
@pytest.mark.parametrize('sampler_cls', [SingleCutSampler, BucketingSampler])
def test_partitions_are_equal(world_size, n_cuts, sampler_cls):
# Create a dummy CutSet.
cut_set = DummyManifest(CutSet, begin_id=0, end_id=n_cuts)
# Randomize the durations of cuts to increase the chance we run into edge cases.
for c in cut_set:
c.duration += (10 * random.random())
# Create a sampler for each "distributed worker."
samplers = [
sampler_cls(cut_set, max_duration=25.0, rank=i, world_size=world_size)
for i in range(world_size)
]
# Check that it worked.
n_batches = [len(s) for s in samplers]
assert all(nb == n_batches[0] for nb in n_batches)