-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor asr_datamodule. #15
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,28 @@ | ||
import argparse | ||
import logging | ||
from functools import lru_cache | ||
from pathlib import Path | ||
from typing import List, Union | ||
|
||
from lhotse import Fbank, FbankConfig, load_manifest | ||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest | ||
from lhotse.dataset import ( | ||
BucketingSampler, | ||
CutConcatenate, | ||
CutMix, | ||
K2SpeechRecognitionDataset, | ||
PrecomputedFeatures, | ||
SingleCutSampler, | ||
SpecAugment, | ||
) | ||
from lhotse.dataset.dataloading import LhotseDataLoader | ||
from lhotse.dataset.input_strategies import OnTheFlyFeatures | ||
from torch.utils.data import DataLoader | ||
|
||
from icefall.dataset.datamodule import DataModule | ||
from icefall.utils import str2bool | ||
|
||
|
||
class AsrDataModule(DataModule): | ||
class LibriSpeechAsrDataModule(DataModule): | ||
""" | ||
DataModule for K2 ASR experiments. | ||
It assumes there is always one train and valid dataloader, | ||
|
@@ -47,6 +50,13 @@ def add_arguments(cls, parser: argparse.ArgumentParser): | |
"effective batch sizes, sampling strategies, applied data " | ||
"augmentations, etc.", | ||
) | ||
group.add_argument( | ||
"--full-libri", | ||
type=str2bool, | ||
default=True, | ||
help="When enabled, use 960h LibriSpeech. " | ||
"Otherwise, use 100h subset.", | ||
) | ||
group.add_argument( | ||
"--feature-dir", | ||
type=Path, | ||
|
@@ -104,6 +114,38 @@ def add_arguments(cls, parser: argparse.ArgumentParser): | |
"extraction. Will drop existing precomputed feature manifests " | ||
"if available.", | ||
) | ||
group.add_argument( | ||
"--shuffle", | ||
type=str2bool, | ||
default=True, | ||
help="When enabled (=default), the examples will be " | ||
"shuffled for each epoch.", | ||
) | ||
group.add_argument( | ||
"--return-cuts", | ||
type=str2bool, | ||
default=True, | ||
help="When enabled, each batch will have the " | ||
"field: batch['supervisions']['cut'] with the cuts that " | ||
"were used to construct it.", | ||
) | ||
|
||
group.add_argument( | ||
"--num-workers", | ||
type=int, | ||
default=2, | ||
help="The number of training dataloader workers that " | ||
"collect the batches.", | ||
) | ||
|
||
group.add_argument( | ||
"--num-workers-inner", | ||
type=int, | ||
default=8, | ||
help="The number of sub-workers (replicated for each of " | ||
"training dataloader workers) that parallelize " | ||
"the I/O to collect each batch.", | ||
) | ||
|
||
def train_dataloaders(self) -> DataLoader: | ||
logging.info("About to get train cuts") | ||
|
@@ -138,9 +180,9 @@ def train_dataloaders(self) -> DataLoader: | |
] | ||
|
||
train = K2SpeechRecognitionDataset( | ||
cuts_train, | ||
cut_transforms=transforms, | ||
input_transforms=input_transforms, | ||
return_cuts=self.args.return_cuts, | ||
) | ||
|
||
if self.args.on_the_fly_feats: | ||
|
@@ -154,70 +196,98 @@ def train_dataloaders(self) -> DataLoader: | |
# to be strict (e.g. could be randomized) | ||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa | ||
# Drop feats to be on the safe side. | ||
cuts_train = cuts_train.drop_features() | ||
train = K2SpeechRecognitionDataset( | ||
cuts=cuts_train, | ||
cut_transforms=transforms, | ||
input_strategy=OnTheFlyFeatures( | ||
Fbank(FbankConfig(num_mel_bins=80)) | ||
Fbank(FbankConfig(num_mel_bins=80)), | ||
num_workers=self.args.num_workers_inner, | ||
), | ||
input_transforms=input_transforms, | ||
return_cuts=self.args.return_cuts, | ||
) | ||
|
||
if self.args.bucketing_sampler: | ||
logging.info("Using BucketingSampler.") | ||
train_sampler = BucketingSampler( | ||
cuts_train, | ||
max_duration=self.args.max_duration, | ||
shuffle=True, | ||
shuffle=self.args.shuffle, | ||
num_buckets=self.args.num_buckets, | ||
bucket_method='equal_duration', | ||
bucket_method="equal_duration", | ||
drop_last=True, | ||
) | ||
else: | ||
logging.info("Using SingleCutSampler.") | ||
train_sampler = SingleCutSampler( | ||
cuts_train, | ||
max_duration=self.args.max_duration, | ||
shuffle=True, | ||
shuffle=self.args.shuffle, | ||
) | ||
logging.info("About to create train dataloader") | ||
train_dl = DataLoader( | ||
|
||
# train_dl = DataLoader( | ||
# train, | ||
# sampler=train_sampler, | ||
# batch_size=None, | ||
# num_workers=2, | ||
# persistent_workers=False, | ||
# ) | ||
|
||
train_dl = LhotseDataLoader( | ||
train, | ||
sampler=train_sampler, | ||
batch_size=None, | ||
num_workers=2, | ||
persistent_workers=False, | ||
num_workers=self.args.num_workers, | ||
prefetch_factor=5, | ||
) | ||
|
||
return train_dl | ||
|
||
def valid_dataloaders(self) -> DataLoader: | ||
logging.info("About to get dev cuts") | ||
cuts_valid = self.valid_cuts() | ||
|
||
transforms = [] | ||
if self.args.concatenate_cuts: | ||
transforms = [ | ||
CutConcatenate( | ||
duration_factor=self.args.duration_factor, gap=self.args.gap | ||
) | ||
] + transforms | ||
|
||
logging.info("About to create dev dataset") | ||
if self.args.on_the_fly_feats: | ||
cuts_valid = cuts_valid.drop_features() | ||
validate = K2SpeechRecognitionDataset( | ||
cuts_valid.drop_features(), | ||
cut_transforms=transforms, | ||
input_strategy=OnTheFlyFeatures( | ||
Fbank(FbankConfig(num_mel_bins=80)) | ||
), | ||
return_cuts=self.args.return_cuts, | ||
) | ||
else: | ||
validate = K2SpeechRecognitionDataset(cuts_valid) | ||
validate = K2SpeechRecognitionDataset( | ||
cut_transforms=transforms, | ||
return_cuts=self.args.return_cuts, | ||
) | ||
valid_sampler = SingleCutSampler( | ||
cuts_valid, | ||
max_duration=self.args.max_duration, | ||
shuffle=False, | ||
) | ||
logging.info("About to create dev dataloader") | ||
valid_dl = DataLoader( | ||
# valid_dl = DataLoader( | ||
# validate, | ||
# sampler=valid_sampler, | ||
# batch_size=None, | ||
# num_workers=2, | ||
# persistent_workers=False, | ||
# ) | ||
|
||
valid_dl = LhotseDataLoader( | ||
validate, | ||
sampler=valid_sampler, | ||
batch_size=None, | ||
num_workers=2, | ||
persistent_workers=False, | ||
) | ||
|
||
return valid_dl | ||
|
||
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: | ||
|
@@ -230,21 +300,63 @@ def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: | |
for cuts_test in cuts: | ||
logging.debug("About to create test dataset") | ||
test = K2SpeechRecognitionDataset( | ||
cuts_test, | ||
input_strategy=OnTheFlyFeatures( | ||
Fbank(FbankConfig(num_mel_bins=80)) | ||
Fbank(FbankConfig(num_mel_bins=80), num_workers=4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For LibriSpeech, remove the |
||
if self.args.on_the_fly_feats | ||
else PrecomputedFeatures() | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this closing parenthesis has to be moved 3 lines up, so the code looks like: input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80), num_workers=4)
) if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=... Currently when args.on_the_fly_feats = False, it tries to use OnTheFlyFeatures(PrecomputedFeatures()) which is an error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! test = K2SpeechRecognitionDataset(
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=8)
if self.args.on_the_fly_feats
else PrecomputedFeatures()
),
return_cuts=self.args.return_cuts,
) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, you’re right |
||
return_cuts=self.args.return_cuts, | ||
) | ||
sampler = SingleCutSampler( | ||
cuts_test, max_duration=self.args.max_duration | ||
) | ||
logging.debug("About to create test dataloader") | ||
test_dl = DataLoader( | ||
test, batch_size=None, sampler=sampler, num_workers=1 | ||
) | ||
# test_dl = DataLoader( | ||
# test, batch_size=None, sampler=sampler, num_workers=1 | ||
# ) | ||
test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2) | ||
test_loaders.append(test_dl) | ||
|
||
if is_list: | ||
return test_loaders | ||
else: | ||
return test_loaders[0] | ||
|
||
@lru_cache() | ||
def train_cuts(self) -> CutSet: | ||
logging.info("About to get train cuts") | ||
cuts_train = load_manifest( | ||
self.args.feature_dir / "cuts_train-clean-100.json.gz" | ||
) | ||
if self.args.full_libri: | ||
cuts_train = ( | ||
cuts_train | ||
+ load_manifest( | ||
self.args.feature_dir / "cuts_train-clean-360.json.gz" | ||
) | ||
+ load_manifest( | ||
self.args.feature_dir / "cuts_train-other-500.json.gz" | ||
) | ||
) | ||
return cuts_train | ||
|
||
@lru_cache() | ||
def valid_cuts(self) -> CutSet: | ||
logging.info("About to get dev cuts") | ||
cuts_valid = load_manifest( | ||
self.args.feature_dir / "cuts_dev-clean.json.gz" | ||
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") | ||
return cuts_valid | ||
|
||
@lru_cache() | ||
def test_cuts(self) -> List[CutSet]: | ||
test_sets = ["test-clean", "test-other"] | ||
cuts = [] | ||
for test_set in test_sets: | ||
logging.debug("About to get test cuts") | ||
cuts.append( | ||
load_manifest( | ||
self.args.feature_dir / f"cuts_{test_set}.json.gz" | ||
) | ||
) | ||
return cuts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say be careful with
LhotseDataLoader
-- it is experimental and I'm hoping to avoid needing to use it in the future. It overcomes some I/O issues with GigaSpeech, but for LibriSpeech you shouldn't see any difference in perf with a regular DataLoader.The downside of LhotseDataLoader is that it doesn't have the elaborate shutdown mechanisms of PyTorch DataLoader and might leave your script running after the training has finished (i.e., everything runs ok, but the script doesn't exit by itself..).