From 4bbd32f5011bd409e71b0d5f345ea81557440f74 Mon Sep 17 00:00:00 2001 From: Nithin Rao Date: Thu, 17 Oct 2024 12:24:22 -0400 Subject: [PATCH] =?UTF-8?q?Add=20lhotse=20fixes=20for=20rnnt=20model=20tra?= =?UTF-8?q?ining=20and=20WER=20hanging=20issue=20with=20f=E2=80=A6=20(#108?= =?UTF-8?q?21)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add lhotse fixes for rnnt model training and WER hanging issue with f… (#10787) * Add lhotse fixes for rnnt model training and WER hanging issue with fuse batching Signed-off-by: Nithin Rao Koluguri * Apply isort and black reformatting Signed-off-by: nithinraok --------- Signed-off-by: Nithin Rao Koluguri Signed-off-by: nithinraok Co-authored-by: Nithin Rao Koluguri Co-authored-by: nithinraok * Apply isort and black reformatting Signed-off-by: nithinraok * Apply isort and black reformatting Signed-off-by: artbataev --------- Signed-off-by: Nithin Rao Koluguri Signed-off-by: nithinraok Signed-off-by: artbataev Co-authored-by: nithinraok Co-authored-by: artbataev --- nemo/collections/asr/data/audio_to_text_lhotse.py | 15 ++++++--------- nemo/collections/asr/metrics/wer.py | 3 ++- nemo/collections/asr/modules/rnnt.py | 7 +++++++ nemo/collections/common/data/lhotse/dataloader.py | 1 - 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse.py b/nemo/collections/asr/data/audio_to_text_lhotse.py index 576ea8234c874..f916ae1de56b6 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse.py @@ -51,15 +51,12 @@ def __init__(self, tokenizer): def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) tokens = [ - torch.as_tensor( - sum( - ( - # Supervisions may come pre-tokenized from the dataloader. - s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language) - for s in c.supervisions - ), - start=[], - ) + torch.cat( + [ + torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language)) + for s in c.supervisions + ], + dim=0, ) for c in cuts ] diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index a135e5c51e84b..7bda3a77b278a 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -254,8 +254,9 @@ def __init__( fold_consecutive=True, batch_dim_index=0, dist_sync_on_step=False, + sync_on_compute=True, ): - super().__init__(dist_sync_on_step=dist_sync_on_step) + super().__init__(dist_sync_on_step=dist_sync_on_step, sync_on_compute=sync_on_compute) self.decoding = decoding self.use_cer = use_cer diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index eaa0445f56cc2..3ab6a432b9479 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1457,6 +1457,10 @@ def forward( sub_transcripts = sub_transcripts.detach() # Update WER on each process without syncing + if self.training: + original_sync = self.wer._to_sync + self.wer._to_sync = False + self.wer.update( predictions=sub_enc, predictions_lengths=sub_enc_lens, @@ -1467,6 +1471,9 @@ def forward( wer, wer_num, wer_denom = self.wer.compute() self.wer.reset() + if self.training: + self.wer._to_sync = original_sync + wers.append(wer) wer_nums.append(wer_num) wer_denoms.append(wer_denom) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 9c22a43d736fa..98b63a07fa9df 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -317,7 +317,6 @@ def get_lhotse_dataloader_from_config( duration_bins=determine_bucket_duration_bins(config), num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate, buffer_size=config.bucket_buffer_size, - concurrent=config.concurrent_bucketing, rank=0 if is_tarred else global_rank, world_size=1 if is_tarred else world_size, )