Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor asr_datamodule. #15

Merged
merged 3 commits into from
Aug 21, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union

from lhotse import Fbank, FbankConfig, load_manifest
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.dataloading import LhotseDataLoader
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader

from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool


class AsrDataModule(DataModule):
class LibriSpeechAsrDataModule(DataModule):
"""
DataModule for K2 ASR experiments.
It assumes there is always one train and valid dataloader,
Expand Down Expand Up @@ -47,6 +50,13 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.",
)
group.add_argument(
"--feature-dir",
type=Path,
Expand Down Expand Up @@ -104,6 +114,38 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)

group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)

group.add_argument(
"--num-workers-inner",
type=int,
default=8,
help="The number of sub-workers (replicated for each of "
"training dataloader workers) that parallelize "
"the I/O to collect each batch.",
)

def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
Expand Down Expand Up @@ -138,9 +180,9 @@ def train_dataloaders(self) -> DataLoader:
]

train = K2SpeechRecognitionDataset(
cuts_train,
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)

if self.args.on_the_fly_feats:
Expand All @@ -154,70 +196,98 @@ def train_dataloaders(self) -> DataLoader:
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
cuts_train = cuts_train.drop_features()
train = K2SpeechRecognitionDataset(
cuts=cuts_train,
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
Fbank(FbankConfig(num_mel_bins=80)),
num_workers=self.args.num_workers_inner,
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)

if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=True,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method='equal_duration',
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=True,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(

# train_dl = DataLoader(
# train,
# sampler=train_sampler,
# batch_size=None,
# num_workers=2,
# persistent_workers=False,
# )

train_dl = LhotseDataLoader(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say be careful with LhotseDataLoader -- it is experimental and I'm hoping to avoid needing to use it in the future. It overcomes some I/O issues with GigaSpeech, but for LibriSpeech you shouldn't see any difference in perf with a regular DataLoader.

The downside of LhotseDataLoader is that it doesn't have the elaborate shutdown mechanisms of PyTorch DataLoader and might leave your script running after the training has finished (i.e., everything runs ok, but the script doesn't exit by itself..).

train,
sampler=train_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
num_workers=self.args.num_workers,
prefetch_factor=5,
)

return train_dl

def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()

transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms

logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
cuts_valid = cuts_valid.drop_features()
validate = K2SpeechRecognitionDataset(
cuts_valid.drop_features(),
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(cuts_valid)
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
# valid_dl = DataLoader(
# validate,
# sampler=valid_sampler,
# batch_size=None,
# num_workers=2,
# persistent_workers=False,
# )

valid_dl = LhotseDataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)

return valid_dl

def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
Expand All @@ -230,21 +300,63 @@ def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
cuts_test,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
Fbank(FbankConfig(num_mel_bins=80), num_workers=4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For LibriSpeech, remove the num_workers argument from OnTheFlyFeatures -- it will attempt to spawn extra processes that are not needed for LibriSpeech (they help with GigaSpeech which has long OPUS recordings)

if self.args.on_the_fly_feats
else PrecomputedFeatures()
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this closing parenthesis has to be moved 3 lines up, so the code looks like:

input_strategy=OnTheFlyFeatures(
   Fbank(FbankConfig(num_mel_bins=80), num_workers=4)
)  if self.args.on_the_fly_feats
   else PrecomputedFeatures(),
return_cuts=...

Currently when args.on_the_fly_feats = False, it tries to use OnTheFlyFeatures(PrecomputedFeatures()) which is an error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
In that case, I think we should also change snowfall to fix that as this block of code is from snowfall.
See
https://github.com/k2-fsa/snowfall/blob/1f79957e9716c3f980c151df5b1d77bc4bb7ce78/egs/gigaspeech/asr/simple_v1/asr_datamodule.py#L337-L344

            test = K2SpeechRecognitionDataset(
                input_strategy=(
                    OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=8)
                    if self.args.on_the_fly_feats
                    else PrecomputedFeatures()
                ),
                return_cuts=self.args.return_cuts,
            )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you’re right

return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
)
# test_dl = DataLoader(
# test, batch_size=None, sampler=sampler, num_workers=1
# )
test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2)
test_loaders.append(test_dl)

if is_list:
return test_loaders
else:
return test_loaders[0]

@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train

@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid

@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts
19 changes: 13 additions & 6 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer

from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
Expand Down Expand Up @@ -222,7 +222,7 @@ def decode_one_batch(
use_double_scores=params.use_double_scores,
scale=params.lattice_score_scale,
)
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}"
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa

hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
Expand Down Expand Up @@ -317,7 +317,11 @@ def decode_dataset(
results = []

num_cuts = 0
tot_num_batches = len(dl)

try:
num_batches = len(dl)
except TypeError:
num_batches = None

results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
Expand Down Expand Up @@ -346,10 +350,13 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])

if batch_idx % 100 == 0:
if num_batches is not None:
batch_str = f"{batch_idx}/{num_batches}"
else:
batch_str = f"{batch_idx}"

logging.info(
f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
f"{num_cuts}"
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results

Expand Down
8 changes: 2 additions & 6 deletions egs/librispeech/ASR/conformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam

from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.utils import (
Expand Down Expand Up @@ -61,9 +60,6 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)

# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser


Expand Down Expand Up @@ -463,7 +459,7 @@ def train_one_epoch(

optimizer.zero_grad()
loss.backward()
clip_grad_value_(model.parameters(), 5.0)
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()

loss_cpu = loss.detach().cpu().item()
Expand Down
Loading