Skip to content

Commit

Permalink
Tarred audio support in ASR data layer (#602)
Browse files Browse the repository at this point in the history
* Initial draft of WebDataset integration for reading tarred audio datasets

Signed-off-by: Jocelyn Huang <[email protected]>

* WebDataset integration bugfixes

Signed-off-by: Jocelyn Huang <[email protected]>

* WebDataset integration: add batch_size and num_workers options, fix collate_fn for non-distributed

Signed-off-by: Jocelyn Huang <[email protected]>

* Add wider collate_fn support in actions.py for DataLayers w/ Datasets

Signed-off-by: Jocelyn Huang <[email protected]>

* Don't create distributed sampler if provided dataset is an IterableDataset

Signed-off-by: Jocelyn Huang <[email protected]>

* Adding torch.distributed multiprocessing support to TarredAudioToTextDataLayer (prevent duplicate samples)

Signed-off-by: Jocelyn Huang <[email protected]>

* Add filter (pipe) for when WebDataset tries to retrieve the entry for an already-filtered-out sample.

Signed-off-by: Jocelyn Huang <[email protected]>

* Add script to convert non-tarred ASR datasets to tarred ones compatible with TarredAudioToTextDataLayer.

Signed-off-by: Jocelyn Huang <[email protected]>

* Add leftover files to last shard in dataset conversion script

Signed-off-by: Jocelyn Huang <[email protected]>

* Fix for docstring of TarredAudioToTextDataLayer

Signed-off-by: Jocelyn Huang <[email protected]>

* Added changelog entry and fixed imports

Signed-off-by: Jocelyn Huang <[email protected]>

* Removed unused imports

Signed-off-by: Jocelyn Huang <[email protected]>

* Add unit tests for tarred data loading

Signed-off-by: Jocelyn Huang <[email protected]>

* Add more arguments to docstring, add tarfile requirement for conversion script

Signed-off-by: Jocelyn Huang <[email protected]>

* Remove tarfile from requirements--already in standard library.

Signed-off-by: Jocelyn Huang <[email protected]>
  • Loading branch information
redoctopus authored May 6, 2020
1 parent 8025d3d commit d219483
Show file tree
Hide file tree
Showing 10 changed files with 515 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ To release a new version, please update the changelog as followed:
([PR #538](https://github.com/NVIDIA/NeMo/pull/538)) - @yzhang123
- Online Data Augmentation for ASR Collection. ([PR #565](https://github.com/NVIDIA/NeMo/pull/565)) - @titu1994
- Speed augmentation on CPU, TimeStretch augmentation on CPU+GPU ([PR #594](https://github.com/NVIDIA/NeMo/pull/565)) - @titu1994
- Added TarredAudioToTextDataLayer, which allows for loading ASR datasets with tarred audio. Existing datasets can be converted with the `convert_to_tarred_audio_dataset.py` script. ([PR #602](https://github.com/NVIDIA/NeMo/pull/602))

### Changed

Expand Down
127 changes: 76 additions & 51 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,16 +526,21 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False):
# )

if dl_nm.dataset is not None:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dl_nm.dataset, shuffle=dl_nm.shuffle
)
eval_dataloader = torch.utils.data.DataLoader(
dataset=dl_nm.dataset,
sampler=sampler,
num_workers=dl_nm.num_workers,
batch_size=dl_nm.batch_size,
shuffle=False,
)
sampler = None
if not isinstance(dl_nm.dataset, torch.utils.data.IterableDataset):
sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dl_nm.dataset, shuffle=dl_nm.shuffle
)
dataloader_params = {
'dataset': dl_nm.dataset,
'sampler': sampler,
'num_workers': dl_nm.num_workers,
'batch_size': dl_nm.batch_size,
'shuffle': False,
}
if hasattr(dl_nm, 'collate_fn'):
dataloader_params['collate_fn'] = dl_nm.collate_fn
eval_dataloader = torch.utils.data.DataLoader(**dataloader_params)
else:
eval_dataloader = dl_nm.data_iterator

Expand All @@ -544,13 +549,16 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False):
else: # Not distributed
if dl_nm.dataset is not None:
# Todo: remove local_parameters
eval_dataloader = torch.utils.data.DataLoader(
dataset=dl_nm.dataset,
sampler=None, # not distributed sampler
num_workers=dl_nm.num_workers,
batch_size=dl_nm.batch_size,
shuffle=dl_nm.shuffle,
)
dataloader_params = {
'dataset': dl_nm.dataset,
'sampler': None, # not distributed sampler
'num_workers': dl_nm.num_workers,
'batch_size': dl_nm.batch_size,
'shuffle': dl_nm.shuffle,
}
if hasattr(dl_nm, 'collate_fn'):
dataloader_params['collate_fn'] = dl_nm.collate_fn
eval_dataloader = torch.utils.data.DataLoader(**dataloader_params)
else:
eval_dataloader = dl_nm.data_iterator
# after this eval_dataloader is ready to be used
Expand Down Expand Up @@ -693,16 +701,21 @@ def _infer(
# )
# )
if dl_nm.dataset is not None:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dl_nm.dataset, shuffle=dl_nm.shuffle
)
eval_dataloader = torch.utils.data.DataLoader(
dataset=dl_nm.dataset,
sampler=sampler,
num_workers=dl_nm.num_workers,
batch_size=dl_nm.batch_size,
shuffle=False,
)
sampler = None
if not isinstance(dl_nm.dataset, torch.utils.data.IterableDataset):
sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dl_nm.dataset, shuffle=dl_nm.shuffle
)
dataloader_params = {
'dataset': dl_nm.dataset,
'sampler': sampler,
'num_workers': dl_nm.num_workers,
'batch_size': dl_nm.batch_size,
'shuffle': False,
}
if hasattr(dl_nm, 'collate_fn'):
dataloader_params['collate_fn'] = dl_nm.collate_fn
eval_dataloader = torch.utils.data.DataLoader(**dataloader_params)
else:
eval_dataloader = dl_nm.data_iterator
eval_dataloader.sampler.set_epoch(0)
Expand All @@ -711,13 +724,16 @@ def _infer(
# When caching, the DAG must cache all outputs from dataloader
if dl_nm.dataset is not None:
# Todo: remove local_parameters
eval_dataloader = torch.utils.data.DataLoader(
dataset=dl_nm.dataset,
sampler=None, # not distributed sampler
num_workers=dl_nm.num_workers,
batch_size=dl_nm.batch_size,
shuffle=dl_nm.shuffle,
)
dataloader_params = {
'dataset': dl_nm.dataset,
'sampler': None, # not distributed sampler
'num_workers': dl_nm.num_workers,
'batch_size': dl_nm.batch_size,
'shuffle': dl_nm.shuffle,
}
if hasattr(dl_nm, 'collate_fn'):
dataloader_params['collate_fn'] = dl_nm.collate_fn
eval_dataloader = torch.utils.data.DataLoader(**dataloader_params)
else:
eval_dataloader = dl_nm.data_iterator
# after this eval_dataloader is ready to be used
Expand Down Expand Up @@ -1231,16 +1247,21 @@ def train(
# "optimizers")
logging.info("Doing distributed training")
if t_dataset is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=t_dataset, shuffle=dataNM.shuffle
)
train_dataloader = torch.utils.data.DataLoader(
dataset=t_dataset,
sampler=train_sampler,
num_workers=dataNM.num_workers,
batch_size=dataNM.batch_size,
shuffle=False,
)
train_sampler = None
if not isinstance(t_dataset, torch.utils.data.IterableDataset):
train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=t_dataset, shuffle=dataNM.shuffle
)
dataloader_params = {
'dataset': t_dataset,
'sampler': train_sampler,
'num_workers': dataNM.num_workers,
'batch_size': dataNM.batch_size,
'shuffle': False,
}
if hasattr(dataNM, 'collate_fn'):
dataloader_params['collate_fn'] = dataNM.collate_fn
train_dataloader = torch.utils.data.DataLoader(**dataloader_params)
else:
train_dataloader = dataNM.data_iterator
if hasattr(train_dataloader, 'sampler'):
Expand Down Expand Up @@ -1313,13 +1334,17 @@ def train(
else:
if t_dataset is not None:
train_sampler = None
train_dataloader = torch.utils.data.DataLoader(
dataset=t_dataset,
sampler=None,
num_workers=dataNM.num_workers,
batch_size=dataNM.batch_size,
shuffle=dataNM.shuffle,
)
dataloader_params = {
'dataset': t_dataset,
'sampler': None,
'num_workers': dataNM.num_workers,
'batch_size': dataNM.batch_size,
'shuffle': dataNM.shuffle,
}
if hasattr(dataNM, 'collate_fn'):
dataloader_params['collate_fn'] = dataNM.collate_fn

train_dataloader = torch.utils.data.DataLoader(**dataloader_params)
else:
train_dataloader = dataNM.data_iterator
train_sampler = None
Expand Down
9 changes: 8 additions & 1 deletion nemo/collections/asr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
# =============================================================================
from .audio_preprocessing import *
from .beam_search_decoder import BeamSearchDecoderWithLM
from .data_layer import AudioToSpeechLabelDataLayer, AudioToTextDataLayer, KaldiFeatureDataLayer, TranscriptDataLayer
from .data_layer import (
AudioToSpeechLabelDataLayer,
AudioToTextDataLayer,
KaldiFeatureDataLayer,
TarredAudioToTextDataLayer,
TranscriptDataLayer,
)
from .greedy_ctc_decoder import GreedyCTCDecoder
from .jasper import JasperDecoderForClassification, JasperDecoderForCTC, JasperEncoder
from .las.misc import JasperRNNConnector
Expand All @@ -25,6 +31,7 @@
__all__ = [
'Backend',
'AudioToTextDataLayer',
'TarredAudioToTextDataLayer',
'AudioToSpeechLabelDataLayer',
'AudioPreprocessing',
'AudioPreprocessor',
Expand Down
Loading

0 comments on commit d219483

Please sign in to comment.