Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hybridasr bug #5950

Merged
merged 4 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ model:

# The section which would contain the decoder and decoding configs of the auxiliary CTC decoder
aux_ctc:
ctc_loss_weight: 0.5 # the weight used to combine the CTC loss with the RNNT loss
ctc_loss_weight: 0.3 # the weight used to combine the CTC loss with the RNNT loss
use_cer: false
ctc_reduction: 'mean_batch'
decoder:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ model:

# The section which would contain the decoder and decoding configs of the auxiliary CTC decoder
aux_ctc:
ctc_loss_weight: 0.5 # the weight used to combine the CTC loss with the RNNT loss
ctc_loss_weight: 0.3 # the weight used to combine the CTC loss with the RNNT loss
use_cer: false
ctc_reduction: 'mean_batch'
decoder:
Expand Down
84 changes: 83 additions & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

import copy
import os
from typing import Optional, Union
from typing import Dict, Optional, Union

import torch
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.losses.rnnt import RNNTLoss
from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEWER, RNNTBPEDecoding, RNNTBPEDecodingConfig
Expand Down Expand Up @@ -123,6 +126,85 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# setting the RNNT decoder as the default one
self.use_rnnt_decoder = True

def _setup_dataloader_from_config(self, config: Optional[Dict]):
dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config(
config=config,
local_rank=self.local_rank,
global_rank=self.global_rank,
world_size=self.world_size,
tokenizer=self.tokenizer,
preprocessor_cfg=self.cfg.get("preprocessor", None),
)

if dataset is None:
return None

if isinstance(dataset, AudioToBPEDALIDataset):
# DALI Dataset implements dataloader interface
return dataset

shuffle = config['shuffle']
if config.get('is_tarred', False):
shuffle = False

if hasattr(dataset, 'collate_fn'):
collate_fn = dataset.collate_fn
else:
collate_fn = dataset.datasets[0].collate_fn

return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),
pin_memory=config.get('pin_memory', False),
)

def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
"""
Setup function for a temporary data loader which wraps the provided audio file.

Args:
config: A python dictionary which contains the following keys:
paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \
Recommended length per file is between 5 and 25 seconds.
batch_size: (int) batch size to use during inference. \
Bigger will result in better throughput performance but would use more memory.
temp_dir: (str) A temporary directory where the audio manifest is temporarily
stored.
num_workers: (int) number of workers. Depends of the batch_size and machine. \
0 - only the main process will load batches, 1 - one worker (not main process)

Returns:
A pytorch DataLoader for the given audio file(s).
"""

if 'manifest_filepath' in config:
manifest_filepath = config['manifest_filepath']
batch_size = config['batch_size']
else:
manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json')
batch_size = min(config['batch_size'], len(config['paths2audio_files']))

dl_config = {
'manifest_filepath': manifest_filepath,
'sample_rate': self.preprocessor._sample_rate,
'batch_size': batch_size,
'shuffle': False,
'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
'pin_memory': True,
'channel_selector': config.get('channel_selector', None),
'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False),
}

if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

def change_vocabulary(
self,
new_tokenizer_dir: Union[str, DictConfig],
Expand Down