Skip to content

Commit

Permalink
More clean up and simplified the call to transcribe
Browse files Browse the repository at this point in the history
Signed-off-by: jbalam-nv <[email protected]>
  • Loading branch information
jbalam-nv committed Jun 18, 2021
1 parent bf63949 commit c5826a2
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 106 deletions.
74 changes: 16 additions & 58 deletions examples/asr/speech_to_text_buffered_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

import torch

from nemo.collections.asr.metrics.wer import WER, word_error_rate
from nemo.collections.asr.models import EncDecCTCModel
from nemo.utils import logging

try:
Expand All @@ -39,81 +37,46 @@ def autocast(enabled=None):

can_gpu = torch.cuda.is_available()
import json
import time
import os
from omegaconf import OmegaConf
import copy
import collections
import nemo.collections.asr as nemo_asr
import torch
from nemo.collections.asr.metrics.wer import word_error_rate

import numpy as np
import math
from nemo.collections.asr.parts.utils.streaming_utils import FeatureFrameBufferer, FrameBatchASR, get_samples, AudioFeatureIterator

def clean_label(_str):
"""
Remove unauthorized characters in a string, lower it and remove unneeded spaces
Parameters
----------
_str : the original string
Returns
-------
string
"""
if _str is None:
return
_str = _str.strip()
_str = _str.lower()
_str = _str.replace(".", "")
_str = _str.replace(",", "")
_str = _str.replace("?", "")
_str = _str.replace("!", "")
_str = _str.replace(":", "")
_str = _str.replace("-", " ")
_str = _str.replace("_", " ")
_str = _str.replace(" ", " ")
return _str


def get_wer_feat(mfst, frame_bufferer, asr, frame_len, tokens_per_chunk, delay, preprocessor_cfg, model_stride_in_secs, device):
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR


def get_wer_feat(mfst, asr, frame_len, tokens_per_chunk, delay, preprocessor_cfg, model_stride_in_secs, device):
# Create a preprocessor to convert audio samples into raw features,
# Normalization will be done per buffer in frame_bufferer
# Do not normalize whatever the model's preprocessor setting is
preprocessor_cfg.normalize = "None"
preprocessor = nemo_asr.models.EncDecCTCModelBPE.from_config_dict(preprocessor_cfg)
preprocessor.to(device)
hyps = collections.defaultdict(list)
hyps = []
refs = []
wer_dict = {}

first = True
with open(mfst, "r") as mfst_f:
for l in mfst_f:
frame_bufferer.reset()
asr.reset()
row = json.loads(l.strip())
samples = get_samples(row['audio_filepath'])
samples = np.pad(samples, (0, int(delay * model_stride_in_secs * preprocessor_cfg.sample_rate)))
frame_reader = AudioFeatureIterator(samples, frame_len, preprocessor, device)
frame_bufferer.set_frame_reader(frame_reader)
asr.infer_logits()
asr.read_audio_file(row['audio_filepath'], delay, model_stride_in_secs)
hyp = asr.transcribe(tokens_per_chunk, delay)
hyps[(tokens_per_chunk, delay)].append(hyp)
hyps.append(hyp)
refs.append(row['text'])

for key in hyps.keys():
wer_dict[key] = word_error_rate(hypotheses=hyps[key], references=refs)
return hyps[(tokens_per_chunk, delay)], refs, wer_dict
wer = word_error_rate(hypotheses=hyps, references=refs)
return hyps, refs, wer

def main():
parser = ArgumentParser()
parser.add_argument(
"--asr_model", type=str, required=True, help="Path to asr model .nemo file",
)
parser.add_argument("--test_manifest", type=str, required=True, help="path to evaluation data")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--total_buffer_in_secs", type=float, default=4.0, help="Length of buffer (chunk + left and right padding) in seconds ")
parser.add_argument("--chunk_len_in_ms", type=int, default=1600, help="Chunk length in milliseconds")
parser.add_argument("--output_path", type=str, help="path to output file", default=None)
Expand Down Expand Up @@ -154,20 +117,15 @@ def main():
mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
print(tokens_per_chunk, mid_delay)

frame_bufferer = FeatureFrameBufferer(
asr_model=asr_model,
frame_len=chunk_len,
batch_size=64,
total_buffer=args.total_buffer_in_secs)

frame_asr = FrameBatchASR(frame_bufferer,
frame_asr = FrameBatchASR(
asr_model=asr_model,
)
frame_len=chunk_len,
total_buffer=args.total_buffer_in_secs,
batch_size=args.batch_size,)

hyps, refs, wer_dict = get_wer_feat(args.test_manifest, frame_bufferer, frame_asr, chunk_len, tokens_per_chunk, mid_delay,
hyps, refs, wer = get_wer_feat(args.test_manifest, frame_asr, chunk_len, tokens_per_chunk, mid_delay,
cfg.preprocessor, model_stride_in_secs, asr_model.device)
for key in wer_dict.keys():
logging.info(f"WER is {wer_dict[key]} when decoded with a delay of {mid_delay*model_stride_in_secs}s")
logging.info(f"WER is {round(wer, 2)} when decoded with a delay of {round(mid_delay*model_stride_in_secs, 2)}s")

if args.output_path is not None:

Expand Down
91 changes: 43 additions & 48 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from torch.utils.data import DataLoader
import numpy as np
import soundfile as sf
import math
from omegaconf import OmegaConf
import copy
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE

class AudioFeatureIterator(IterableDataset):

Expand Down Expand Up @@ -64,9 +66,7 @@ def speech_collate_fn(batch):
encoded tokens, and encoded tokens length. This collate func
assumes the signals are 1d torch tensors (i.e. mono audio).
"""
# print(batch[0][1])
_, audio_lengths = zip(*batch)
# print(audio_lengths)
max_audio_len = 0
has_audio = audio_lengths[0] is not None
if has_audio:
Expand Down Expand Up @@ -151,22 +151,19 @@ def __init__(self,asr_model,
self.sr = asr_model._cfg.sample_rate
self.frame_len = frame_len
timestep_duration = asr_model._cfg.preprocessor.window_stride
# ['AudioToMelSpectrogramPreprocessor']['window_stride']
self.n_frame_len = int(frame_len / timestep_duration)

total_buffer_len = int(total_buffer / timestep_duration)
self.n_feat = asr_model._cfg.preprocessor.features
# model_definition['AudioToMelSpectrogramPreprocessor']['features']
self.buffer = np.ones([self.n_feat, total_buffer_len],
dtype=np.float32) * self.ZERO_LEVEL_SPEC_DB_VAL

self.batch_size = batch_size

self.data_layer = AudioBuffersDataLayer()
self.data_loader = DataLoader(self.data_layer, batch_size=self.batch_size, collate_fn=speech_collate_fn)
self.signal_end = False
self.frame_reader = None
self.feature_buffer_len = total_buffer_len

self.feature_buffer = np.ones([self.n_feat, self.feature_buffer_len], dtype=np.float32) * self.ZERO_LEVEL_SPEC_DB_VAL
self.frame_buffers = []
self.buffered_features_size = 0
Expand All @@ -180,8 +177,6 @@ def reset(self):
self.buffer = np.ones(shape=self.buffer.shape, dtype=np.float32) * self.ZERO_LEVEL_SPEC_DB_VAL
self.prev_char = ''
self.unmerged = []
self.data_layer = AudioBuffersDataLayer()
self.data_loader = DataLoader(self.data_layer, batch_size=self.batch_size, collate_fn=speech_collate_fn)
self.frame_buffers = []
self.buffered_len = 0
self.feature_buffer = np.ones([self.n_feat, self.feature_buffer_len], dtype=np.float32) * self.ZERO_LEVEL_SPEC_DB_VAL
Expand All @@ -206,7 +201,6 @@ def get_frame_buffers(self, frames):
self.buffer[:, -self.n_frame_len:] = frame
self.buffered_len += frame.shape[1]
self.frame_buffers.append(np.copy(self.buffer))
# print(frame_buffers[0])
return self.frame_buffers

def set_frame_reader(self, frame_reader):
Expand All @@ -224,8 +218,6 @@ def get_norm_consts_per_frame(self, batch_frames):
self._update_feature_buffer(frame)
mean_from_buffer = np.mean(self.feature_buffer, axis=1)
stdev_from_buffer = np.std(self.feature_buffer, axis=1)
# mean_from_buffer = np.mean(self.frame_buffers[i], axis=1)
# stdev_from_buffer = np.std(self.frame_buffers[i], axis=1)
norm_consts.append((mean_from_buffer.reshape(self.n_feat, 1), stdev_from_buffer.reshape(self.n_feat, 1)))
return norm_consts

Expand Down Expand Up @@ -254,15 +246,24 @@ def get_buffers_batch(self):
# 2) call transcribe(frame) to do ASR on
# contiguous signal's frames
class FrameBatchASR:
"""
class for streaming frame-based ASR use reset() method to reset FrameASR's
state call transcribe(frame) to do ASR on contiguous signal's frames
"""

def __init__(self, frame_bufferer, asr_model, batch_size=4,):
def __init__(self, asr_model, frame_len=1.6, total_buffer=4.0, batch_size=4,):
'''
Args:
frame_len: frame's duration, seconds
frame_overlap: duration of overlaps before and after current frame, seconds
offset: number of symbols to drop for smooth streaming
'''
self.frame_bufferer = frame_bufferer
self.frame_bufferer = FeatureFrameBufferer(
asr_model=asr_model,
frame_len=frame_len,
batch_size=batch_size,
total_buffer=total_buffer)

self.asr_model = asr_model

self.batch_size = batch_size
Expand All @@ -271,19 +272,26 @@ def __init__(self, frame_bufferer, asr_model, batch_size=4,):

self.unmerged = []

self.data_layer = AudioBuffersDataLayer()
self.data_loader = DataLoader(self.data_layer, batch_size=self.batch_size, collate_fn=speech_collate_fn)

self.blank_id = len(asr_model.decoder.vocabulary)
self.tokenizer = asr_model.tokenizer
self.toks_unmerged = []
self.frame_buffers = []
self.reset()
cfg = copy.deepcopy(asr_model._cfg)
self.frame_len = frame_len
OmegaConf.set_struct(cfg.preprocessor, False)

# some changes for streaming scenario
cfg.preprocessor.dither = 0.0
cfg.preprocessor.pad_to = 0
cfg.preprocessor.normalize = "None"
self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)
self.raw_preprocessor.to(asr_model.device)

def reset(self):
'''
"""
Reset frame_history and decoder's state
'''
"""
self.prev_char = ''
self.unmerged = []
self.data_layer = AudioBuffersDataLayer()
Expand All @@ -292,7 +300,17 @@ def reset(self):
self.all_preds = []
self.toks_unmerged = []
self.frame_buffers = []
self.frame_bufferer.reset()

def read_audio_file(self, audio_filepath:str, delay, model_stride_in_secs):
samples = get_samples(audio_filepath)
samples = np.pad(samples, (0, int(delay * model_stride_in_secs * self.asr_model._cfg.sample_rate)))
frame_reader = AudioFeatureIterator(samples, self.frame_len, self.raw_preprocessor, self.asr_model.device)
self.set_frame_reader(frame_reader)


def set_frame_reader(self, frame_reader):
self.frame_bufferer.set_frame_reader(frame_reader)

@torch.no_grad()
def infer_logits(self):
Expand All @@ -303,61 +321,38 @@ def infer_logits(self):
self.data_layer.set_signal(frame_buffers[:])
self._get_batch_preds()
frame_buffers = self.frame_bufferer.get_buffers_batch()
# print(self.frame_buffers)


@torch.no_grad()
def _get_batch_preds(self):

device = self.asr_model.device
for batch in iter(self.data_loader):

feat_signal, feat_signal_len = batch
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)

log_probs, encoded_len, predictions = self.asr_model(processed_signal=feat_signal,
processed_signal_length=feat_signal_len)
preds = torch.unbind(predictions)
for pred in preds:
self.all_preds.append(pred.cpu().numpy())
del log_probs
del encoded_len
del predictions

def transcribe(self, tokens_per_chunk: int, delay: int, ):
self.infer_logits()
self.unmerged = []
self.toks_unmerged = []

decoded_frames = []
all_toks = []
for pred in self.all_preds:
ids, toks = self._greedy_decoder(pred, self.tokenizer)
decoded_frames.append(ids)
all_toks.append(toks)

for decoded in decoded_frames:
decoded = pred.tolist()
self.unmerged += decoded[len(decoded) - 1 - delay:len(decoded) - 1 - delay + tokens_per_chunk]

for i, tok in enumerate(all_toks):
self.toks_unmerged += tok[len(tok) // 2:len(tok) // 2 + 1 + tokens_per_chunk]

return self.greedy_merge(self.unmerged)

def _greedy_decoder(self, preds, tokenizer):
s = []
ids = []
for i in range(preds.shape[0]):
if preds[i] == self.blank_id:
s.append("_")
else:
pred = preds[i]
s.append(tokenizer.ids_to_tokens([pred.item()])[0])
ids.append(preds[i])
return ids, s

def greedy_merge(self, preds):
decoded_prediction = []
previous = self.blank_id
for p in preds:
if (p != previous or previous == self.blank_id) and p != self.blank_id:
decoded_prediction.append(p.item())
decoded_prediction.append(p)
previous = p
hypothesis = self.tokenizer.ids_to_text(decoded_prediction)
return hypothesis

0 comments on commit c5826a2

Please sign in to comment.