diff --git a/lhotse/dataset/sampling.py b/lhotse/dataset/sampling.py index f44cc746a..16306cba2 100644 --- a/lhotse/dataset/sampling.py +++ b/lhotse/dataset/sampling.py @@ -1,8 +1,6 @@ -import logging import random import warnings from dataclasses import dataclass -from math import ceil from typing import Iterable, List, Optional, Tuple, Type, Union import torch.distributed as dist @@ -32,16 +30,22 @@ def __len__(self): class CutSampler(Sampler[List[str]]): """ - CutSampler is responsible for collecting batches of cuts, given specified criteria. - It implements correct handling of distributed sampling in DataLoader, + ``CutSampler`` is responsible for collecting batches of cuts, given specified criteria. + It implements correct handling of distributed sampling in ``DataLoader``, so that the cuts are not duplicated across workers. - Sampling in a CutSampler is intended to be very quick - it only uses the metadata in + Sampling in a ``CutSampler`` is intended to be very quick - it only uses the metadata in ``CutSet`` manifest to select the cuts, and is not intended to perform any I/O. CutSampler works similarly to PyTorch's DistributedSampler - when :attr:`shuffle=True`, you should call ``sampler.set_epoch(epoch)`` at each new epoch to have a different - ordering of returned elements. + ordering of returned elements. However, its actual behaviour is different than that of + DistributedSampler -- instead of partitioning the underlying cuts into equally sized chunks, + it will return every N-th batch and skip the other batches (where ``N == world_size``). + The formula used to determine which batches are returned is: + ``(batch_idx + (world_size - rank)) % world_size == 0``. + This ensures that we can return an equal number of batches in all distributed workers + in spite of using a dynamic batch size, at the cost of skipping at most ``world_size - 1`` batches. Example usage:: @@ -55,10 +59,10 @@ class CutSampler(Sampler[List[str]]): .. note:: For implementers of new samplers: - Subclasses of CutSampler are expected to implement ``__next__()`` to introduce specific + Subclasses of CutSampler are expected to implement ``self._next_batch()`` to introduce specific sampling logic (e.g. based on filters such as max number of frames/tokens/etc.). CutSampler defines ``__iter__()``, which optionally shuffles the cut IDs, and resets - ``self.current_idx`` to zero (to be used and incremented inside of ``__next__()``. + ``self.cut_idx`` to zero (to be used and incremented inside of ``_next_batch()``. """ def __init__( @@ -81,27 +85,29 @@ def __init__( :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ - data_source = list(cut_ids) + data_source = DataSource(list(cut_ids)) super().__init__(data_source) - self.full_data_source = data_source + self.data_source = data_source self.shuffle = shuffle self.seed = seed self.epoch = 0 + self.cut_idx = 0 self._maybe_init_distributed(world_size=world_size, rank=rank) - self.data_source = DataSource( - partition_cut_ids(self.full_data_source, world_size=self.world_size, rank=self.rank) - ) - self.num_batches = None def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]): + if world_size is not None: + assert world_size >= 1 + if rank is not None: + assert rank >= 0 if not dist.is_available() or not dist.is_initialized(): - self.world_size = 1 - self.rank = 0 + self.world_size = 1 if world_size is None else world_size + self.rank = 0 if rank is None else rank return self.world_size = dist.get_world_size() if world_size is None else world_size self.rank = dist.get_rank() if rank is None else rank + assert self.rank < self.world_size def set_epoch(self, epoch: int) -> None: r""" @@ -114,22 +120,34 @@ def set_epoch(self, epoch: int) -> None: self.epoch = epoch self.num_batches = None + def _next_batch(self) -> List[str]: + raise NotImplementedError("Sub-classes of CutSampler have to implement self._next_batch()") + def __iter__(self) -> 'CutSampler': """ Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested. """ if self.shuffle: self.data_source.shuffle(self.seed + self.epoch) - self.current_idx = 0 + self.cut_idx = 0 return self def __len__(self) -> int: if self.num_batches is None: - self.num_batches = sum(1 for item in self) + self.num_batches = sum(1 for _ in self) return self.num_batches def __next__(self) -> List[str]: - raise NotImplemented + # We use the following trick to ensure equal number of batches for each distributed + # worker: + # Every time a next batch is required, we will sample self.world_size batches first, + # and then return the one at position self.rank. + # This way, if any of the batches raises StopIteration, we'll know to stop early + # when a given batch was available for one of the nodes, but not for the others. + batches = [] + for _ in range(self.world_size): + batches.append(self._next_batch()) + return batches[self.rank] @dataclass @@ -202,6 +220,16 @@ class SingleCutSampler(CutSampler): the batch size is dynamic. Exactly zero or one of those constraints can be specified. Padding required to collate the batch does not contribute to max frames/samples/duration. + + Example usage:: + + >>> dataset = K2SpeechRecognitionDataset(cuts) + >>> sampler = SingleCutSampler(cuts, shuffle=True) + >>> loader = DataLoader(dataset, sampler=sampler, batch_size=None) + >>> for epoch in range(start_epoch, n_epochs): + ... sampler.set_epoch(epoch) + ... train(loader) + """ def __init__( @@ -237,7 +265,7 @@ def __init__( # Constraints assert is_none_or_gt(self.max_cuts, 0) - def __next__(self) -> List[str]: + def _next_batch(self) -> List[str]: # Keep iterating the underlying CutSet as long as we hit or exceed the constraints # provided by user (the max number of frames or max number of cuts). # Note: no actual data is loaded into memory yet because the manifests contain all the metadata @@ -246,9 +274,9 @@ def __next__(self) -> List[str]: cut_ids = [] while True: # Check that we have not reached the end of the dataset. - if self.current_idx < len(self.data_source): + if self.cut_idx < len(self.data_source): # We didn't - grab the next cut - next_cut_id = self.data_source[self.current_idx] + next_cut_id = self.data_source[self.cut_idx] else: if cut_ids: # We did and we have a partial batch - return it. @@ -263,7 +291,7 @@ def __next__(self) -> List[str]: if not self.time_constraint.exceeded() and (self.max_cuts is None or next_num_cuts <= self.max_cuts): # No - add the next cut to the batch, and keep trying. cut_ids.append(next_cut.id) - self.current_idx += 1 + self.cut_idx += 1 else: # Yes. Do we have at least one cut in the batch? if cut_ids: @@ -275,7 +303,7 @@ def __next__(self) -> List[str]: warnings.warn("The first cut drawn in batch collection violates the max_frames or max_cuts " "constraints - we'll return it anyway. Consider increasing max_frames/max_cuts.") cut_ids.append(next_cut.id) - self.current_idx += 1 + self.cut_idx += 1 return cut_ids @@ -336,7 +364,7 @@ def __init__( ) self.max_cuts = max_cuts - def __next__(self) -> List[str]: + def _next_batch(self) -> List[str]: # Keep iterating the underlying CutSets as long as we hit or exceed the constraints # provided by user (the max number of source_feats or max number of cuts). # Note: no actual data is loaded into memory yet because the manifests contain all the metadata @@ -346,9 +374,9 @@ def __next__(self) -> List[str]: cut_ids = [] while True: # Check that we have not reached the end of the dataset. - if self.current_idx < len(self.data_source): + if self.cut_idx < len(self.data_source): # We didn't - grab the next cut - next_cut_id = self.data_source[self.current_idx] + next_cut_id = self.data_source[self.cut_idx] else: if cut_ids: # We did and we have a partial batch - return it. @@ -367,7 +395,7 @@ def __next__(self) -> List[str]: and (self.max_cuts is None or next_num_cuts <= self.max_cuts): # No - add the next cut to the batch, and keep trying. cut_ids.append(next_source_cut.id) - self.current_idx += 1 + self.cut_idx += 1 else: # Yes. Do we have at least one cut in the batch? if cut_ids: @@ -379,7 +407,7 @@ def __next__(self) -> List[str]: warnings.warn("The first cut drawn in batch collection violates one of the max_... constraints" "we'll return it anyway. Consider increasing max_source_frames/max_cuts/etc.") cut_ids.append(next_source_cut.id) - self.current_idx += 1 + self.cut_idx += 1 return cut_ids @@ -466,7 +494,7 @@ def __iter__(self) -> 'BucketingSampler': self.depleted = [False] * self.num_buckets return self - def __next__(self) -> List[str]: + def _next_batch(self) -> List[str]: while not self.is_depleted: idx, sampler = self.bucket_rng.choice(self._nondepleted_samplers_with_idxs) try: @@ -494,35 +522,3 @@ def _nondepleted_samplers_with_idxs(self): enumerate(zip(self.bucket_samplers, self.depleted)) if not depleted ] - - -def partition_cut_ids( - data_source: List[str], - world_size: int = 1, - rank: int = 0 -) -> List[str]: - """ - Returns a list of cut IDs to be used by a single dataloading process. - For multiple dataloader workers or ``DistributedDataParallel`` training, - that list will be a subset of ``sampler.full_data_source``. - - :param data_source: a list of Cut IDs, representing the full dataset. - :param world_size: Total number of distributed nodes. Set only when using ``DistributedDataParallel``. - :param rank: Index of distributed node. Set only when using ``DistributedDataParallel``. - """ - - # First, split depending on the world_size and rank. - if world_size == 1: - return data_source - else: - # Distributed training is active - split full dataset into a subset. - total = len(data_source) - per_partition = int(ceil(total / float(world_size))) - partition_start = rank * per_partition - partition_end = min(partition_start + per_partition, total) - logging.info(f'Distributed training with world size of {world_size} detected ' - f'(node\'s local rank is {rank}. ' - f'Splitting cuts into {world_size} partitions (' - f'this partition has cut IDs range [{partition_start, partition_end}].') - - return data_source[partition_start: partition_end] diff --git a/test/dataset/test_sampling.py b/test/dataset/test_sampling.py index ec8857f96..4a62c1a31 100644 --- a/test/dataset/test_sampling.py +++ b/test/dataset/test_sampling.py @@ -1,3 +1,4 @@ +import random from itertools import groupby import pytest @@ -491,3 +492,22 @@ def test_bucketing_sampler_time_constraints(constraint): for batch in sampler: cut_ids.extend(batch) assert set(cut_set.ids) == set(cut_ids) + + +@pytest.mark.parametrize('world_size', [2, 3, 4]) +@pytest.mark.parametrize('n_cuts', [995, 996, 997, 998, 999, 1000, 1001, 1002, 1003]) +@pytest.mark.parametrize('sampler_cls', [SingleCutSampler, BucketingSampler]) +def test_partitions_are_equal(world_size, n_cuts, sampler_cls): + # Create a dummy CutSet. + cut_set = DummyManifest(CutSet, begin_id=0, end_id=n_cuts) + # Randomize the durations of cuts to increase the chance we run into edge cases. + for c in cut_set: + c.duration += (10 * random.random()) + # Create a sampler for each "distributed worker." + samplers = [ + sampler_cls(cut_set, max_duration=25.0, rank=i, world_size=world_size) + for i in range(world_size) + ] + # Check that it worked. + n_batches = [len(s) for s in samplers] + assert all(nb == n_batches[0] for nb in n_batches)