Skip to content

Commit

Permalink
Add StreamingFeatureBufferer class for real-life streaming decoding (N…
Browse files Browse the repository at this point in the history
…VIDIA#5534)

* Add StreamingFeatureBufferer class for real-life streaming decoding

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed the outdated commit

Signed-off-by: Taejin Park <[email protected]>

* resolve conflict

Signed-off-by: Taejin Park <[email protected]>

* removed unnecessary files

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reflected comments from PR review

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Updated the type annotations

Signed-off-by: Taejin Park <[email protected]>

* Updated type annotations

Signed-off-by: Taejin Park <[email protected]>

Signed-off-by: Taejin Park <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and titu1994 committed Mar 24, 2023
1 parent 45f78a2 commit 8c1829f
Showing 1 changed file with 126 additions and 1 deletion.
127 changes: 126 additions & 1 deletion nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,128 @@ def inplace_buffer_merge(buffer, data, timesteps, model):
return buffer


class StreamingFeatureBufferer:
"""
Class to append each feature frame to a buffer and return an array of buffers.
This class is designed to perform a real-life streaming decoding where only a single chunk
is provided at each step of a streaming pipeline.
"""

def __init__(self, asr_model, chunk_size, buffer_size):
'''
Args:
asr_model:
Reference to the asr model instance for which the feature needs to be created
chunk_size (float):
Duration of the new chunk of audio
buffer_size (float):
Size of the total audio in seconds maintained in the buffer
'''

self.NORM_CONSTANT = 1e-5
if asr_model.cfg.preprocessor.log:
self.ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
else:
self.ZERO_LEVEL_SPEC_DB_VAL = 0.0
self.asr_model = asr_model
self.sr = asr_model.cfg.sample_rate
self.model_normalize_type = asr_model.cfg.preprocessor.normalize
self.chunk_size = chunk_size
timestep_duration = asr_model.cfg.preprocessor.window_stride

self.n_chunk_look_back = int(timestep_duration * self.sr)
self.n_chunk_samples = int(chunk_size * self.sr)
self.buffer_size = buffer_size
total_buffer_len = int(buffer_size / timestep_duration)
self.n_feat = asr_model.cfg.preprocessor.features
self.sample_buffer = torch.zeros(int(self.buffer_size * self.sr))
self.buffer = torch.ones([self.n_feat, total_buffer_len], dtype=torch.float32) * self.ZERO_LEVEL_SPEC_DB_VAL
self.feature_chunk_len = int(chunk_size / timestep_duration)
self.feature_buffer_len = total_buffer_len

self.reset()
cfg = copy.deepcopy(asr_model.cfg)
OmegaConf.set_struct(cfg.preprocessor, False)

cfg.preprocessor.dither = 0.0
cfg.preprocessor.pad_to = 0
cfg.preprocessor.normalize = "None"
self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)
self.raw_preprocessor.to(asr_model.device)

def reset(self):
'''
Reset frame_history and decoder's state
'''
self.buffer = torch.ones(self.buffer.shape, dtype=torch.float32) * self.ZERO_LEVEL_SPEC_DB_VAL
self.frame_buffers = []
self.sample_buffer = torch.zeros(int(self.buffer_size * self.sr))
self.feature_buffer = (
torch.ones([self.n_feat, self.feature_buffer_len], dtype=torch.float32) * self.ZERO_LEVEL_SPEC_DB_VAL
)

def _add_chunk_to_buffer(self, chunk):
"""
Add time-series audio signal to `sample_buffer`
Args:
chunk (Tensor):
Tensor filled with time-series audio signal
"""
self.sample_buffer[: -self.n_chunk_samples] = self.sample_buffer[self.n_chunk_samples :].clone()
self.sample_buffer[-self.n_chunk_samples :] = chunk.clone()

def _update_feature_buffer(self, feat_chunk):
"""
Add an extracted feature to `feature_buffer`
"""
self.feature_buffer[:, : -self.feature_chunk_len] = self.feature_buffer[:, self.feature_chunk_len :].clone()
self.feature_buffer[:, -self.feature_chunk_len :] = feat_chunk.clone()

def get_raw_feature_buffer(self):
return self.feature_buffer

def get_normalized_feature_buffer(self):
normalized_buffer, _, _ = normalize_batch(
x=self.feature_buffer.unsqueeze(0),
seq_len=torch.tensor([len(self.feature_buffer)]),
normalize_type=self.model_normalize_type,
)
return normalized_buffer.squeeze(0)

def _convert_buffer_to_features(self):
"""
Extract features from the time-series audio buffer `sample_buffer`.
"""
# samples for conversion to features.
# Add look_back to have context for the first feature
samples = self.sample_buffer[: -(self.n_chunk_samples + self.n_chunk_look_back)]
device = self.asr_model.device
audio_signal = samples.unsqueeze_(0).to(device)
audio_signal_len = torch.Tensor([samples.shape[1]]).to(device)
features, features_len = self.raw_preprocessor(input_signal=audio_signal, length=audio_signal_len,)
features = features.squeeze()
self._update_feature_buffer(features[:, -self.feature_chunk_len :])

def update_feature_buffer(self, chunk):
"""
Update time-series signal `chunk` to the buffer then generate features out of the
signal in the audio buffer.
Args:
chunk (Tensor):
Tensor filled with time-series audio signal
"""
if len(chunk) > self.n_chunk_samples:
raise ValueError(f"chunk should be of length {self.n_chunk_samples} or less")
if len(chunk) < self.n_chunk_samples:
temp_chunk = torch.zeros(self.n_chunk_samples, dtype=torch.float32)
temp_chunk[: chunk.shape[0]] = chunk
chunk = temp_chunk
self._add_chunk_to_buffer(chunk)
self._convert_buffer_to_features()


class AudioFeatureIterator(IterableDataset):
def __init__(self, samples, frame_len, preprocessor, device):
self._samples = samples
Expand Down Expand Up @@ -454,7 +576,10 @@ def __init__(self, asr_model, frame_len=1.6, batch_size=4, total_buffer=4.0):
frame_overlap: duration of overlaps before and after current frame, seconds
offset: number of symbols to drop for smooth streaming
'''
self.ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
if asr_model.cfg.preprocessor.log:
self.ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
else:
self.ZERO_LEVEL_SPEC_DB_VAL = 0.0
self.asr_model = asr_model
self.sr = asr_model._cfg.sample_rate
self.frame_len = frame_len
Expand Down

0 comments on commit 8c1829f

Please sign in to comment.