From 29e65e406c03e7f9b99e9b05759b74052467e4c9 Mon Sep 17 00:00:00 2001 From: Vahid Date: Tue, 7 Feb 2023 09:43:13 -0800 Subject: [PATCH 1/4] fixd the bug. Signed-off-by: Vahid --- .../conformer_hybrid_transducer_ctc_bpe.yaml | 2 +- .../conformer_hybrid_transducer_ctc_char.yaml | 2 +- .../asr/models/hybrid_rnnt_ctc_bpe_models.py | 79 +++++++++++++++++++ 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml b/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml index 3e03d3495174..18f877701568 100644 --- a/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml +++ b/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml @@ -192,7 +192,7 @@ model: # The section which would contain the decoder and decoding configs of the auxiliary CTC decoder aux_ctc: - ctc_loss_weight: 0.5 # the weight used to combine the CTC loss with the RNNT loss + ctc_loss_weight: 0.3 # the weight used to combine the CTC loss with the RNNT loss use_cer: false ctc_reduction: 'mean_batch' decoder: diff --git a/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_char.yaml b/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_char.yaml index dbbde6875383..ea5a31bba0bc 100644 --- a/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_char.yaml +++ b/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_char.yaml @@ -186,7 +186,7 @@ model: # The section which would contain the decoder and decoding configs of the auxiliary CTC decoder aux_ctc: - ctc_loss_weight: 0.5 # the weight used to combine the CTC loss with the RNNT loss + ctc_loss_weight: 0.3 # the weight used to combine the CTC loss with the RNNT loss use_cer: false ctc_reduction: 'mean_batch' decoder: diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 3b94084e0d8b..c3dc25755e73 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -123,6 +123,85 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # setting the RNNT decoder as the default one self.use_rnnt_decoder = True + def _setup_dataloader_from_config(self, config: Optional[Dict]): + dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( + config=config, + local_rank=self.local_rank, + global_rank=self.global_rank, + world_size=self.world_size, + tokenizer=self.tokenizer, + preprocessor_cfg=self.cfg.get("preprocessor", None), + ) + + if dataset is None: + return None + + if isinstance(dataset, AudioToBPEDALIDataset): + # DALI Dataset implements dataloader interface + return dataset + + shuffle = config['shuffle'] + if config.get('is_tarred', False): + shuffle = False + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + else: + collate_fn = dataset.datasets[0].collate_fn + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + num_workers: (int) number of workers. Depends of the batch_size and machine. \ + 0 - only the main process will load batches, 1 - one worker (not main process) + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + + if 'manifest_filepath' in config: + manifest_filepath = config['manifest_filepath'] + batch_size = config['batch_size'] + else: + manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': batch_size, + 'shuffle': False, + 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), + 'pin_memory': True, + 'channel_selector': config.get('channel_selector', None), + 'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False), + } + + if config.get("augmentor"): + dl_config['augmentor'] = config.get("augmentor") + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer + def change_vocabulary( self, new_tokenizer_dir: Union[str, DictConfig], From 54b0732b40b68670fa57246d6badd2656d21225f Mon Sep 17 00:00:00 2001 From: Vahid Date: Tue, 7 Feb 2023 09:50:21 -0800 Subject: [PATCH 2/4] fixd the bug. Signed-off-by: Vahid --- nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index c3dc25755e73..858f1d83be63 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -14,7 +14,8 @@ import copy import os -from typing import Optional, Union +from typing import Optional, Union, Dict +import torch from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from pytorch_lightning import Trainer @@ -27,6 +28,8 @@ from nemo.collections.asr.parts.mixins import ASRBPEMixin from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging, model_utils +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset class EncDecHybridRNNTCTCBPEModel(EncDecHybridRNNTCTCModel, ASRBPEMixin): From 82bf04cd4cb8fb4be86615032abe79578196f8be Mon Sep 17 00:00:00 2001 From: Vahid Date: Tue, 7 Feb 2023 09:51:54 -0800 Subject: [PATCH 3/4] fixd the bug. Signed-off-by: Vahid --- nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 858f1d83be63..ef74775e4669 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -28,8 +28,8 @@ from nemo.collections.asr.parts.mixins import ASRBPEMixin from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging, model_utils -from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset +from nemo.collections.asr.data import audio_to_text_dataset class EncDecHybridRNNTCTCBPEModel(EncDecHybridRNNTCTCModel, ASRBPEMixin): From dceeac9aa65a9440d52a1a021d5e9a2eb64b0104 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:55:09 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index ef74775e4669..25bda96fc5a2 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -14,12 +14,14 @@ import copy import os -from typing import Optional, Union, Dict -import torch +from typing import Dict, Optional, Union +import torch from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.rnnt import RNNTLoss from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEWER, RNNTBPEDecoding, RNNTBPEDecodingConfig @@ -28,8 +30,6 @@ from nemo.collections.asr.parts.mixins import ASRBPEMixin from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging, model_utils -from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset -from nemo.collections.asr.data import audio_to_text_dataset class EncDecHybridRNNTCTCBPEModel(EncDecHybridRNNTCTCModel, ASRBPEMixin):