Skip to content

Commit

Permalink
Transcribe for multi-channel signals (NVIDIA#5479)
Browse files Browse the repository at this point in the history
Transcribe for multi-channel signals (NVIDIA#5479)

Signed-off-by: Ante Jukić <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
anteju authored and Hainan Xu committed Nov 29, 2022
1 parent 233e9de commit ed778d8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
5 changes: 5 additions & 0 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class TranscriptionConfig:
pretrained_name: Optional[str] = None # Name of a pretrained model
audio_dir: Optional[str] = None # Path to a directory which contains audio files
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
channel_selector: Optional[int] = None # Used to select a single channel from multi-channel files
audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest

# General configs
output_filename: Optional[str] = None
Expand Down Expand Up @@ -198,6 +200,7 @@ def autocast():
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
)
else:
logging.warning(
Expand All @@ -208,13 +211,15 @@ def autocast():
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
)
else:
transcriptions = asr_model.transcribe(
paths2audio_files=filepaths,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
)

logging.info(f"Finished transcribing {len(filepaths)} files !")
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'shuffle': False,
'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
'pin_memory': True,
'channel_selector': config.get('channel_selector', None),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json
import os
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -220,7 +220,8 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]:
has_two_fields.append(True)
else:
has_two_fields.append(False)
audio_file = Path(item['audio_filepath'])
audio_key = cfg.get('audio_key', 'audio_filepath')
audio_file = Path(item[audio_key])
if not audio_file.is_file() and not audio_file.is_absolute():
audio_file = manifest_dir / audio_file
filepaths.append(str(audio_file.absolute()))
Expand Down Expand Up @@ -290,6 +291,7 @@ def transcribe_partial_audio(
logprobs: bool = False,
return_hypotheses: bool = False,
num_workers: int = 0,
channel_selector: Optional[int] = None,
) -> List[str]:

assert isinstance(asr_model, EncDecCTCModel), "Currently support CTC model only."
Expand Down Expand Up @@ -325,6 +327,7 @@ def transcribe_partial_audio(
'manifest_filepath': path2manifest,
'batch_size': batch_size,
'num_workers': num_workers,
'channel_selector': channel_selector,
}

temporary_datalayer = asr_model._setup_transcribe_dataloader(config)
Expand Down

0 comments on commit ed778d8

Please sign in to comment.