Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add exportable mel spec #5512

Merged
merged 2 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from packaging import version

from nemo.collections.asr.parts.numba.spec_augment import SpecAugmentNumba, spec_augment_launch_heuristics
from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures
from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures, FilterbankFeaturesTA
from nemo.collections.asr.parts.submodules.spectr_augment import SpecAugment, SpecCutout
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.classes import Exportable, NeuralModule, typecheck
from nemo.core.neural_types import (
AudioSignal,
LengthsType,
Expand Down Expand Up @@ -92,11 +92,8 @@ def get_features(self, input_signal, length):
pass


class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
class AudioToMelSpectrogramPreprocessor(AudioPreprocessor, Exportable):
"""Featurizer module that converts wavs to mel spectrograms.
We don't use torchaudio's implementation here because the original
implementation is not the same, so for the sake of backwards-compatibility
this will use the old FilterbankFeatures for now.

Args:
sample_rate (int): Sample rate of the input audio data.
Expand Down Expand Up @@ -158,6 +155,7 @@ class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
Defaults to 0.0
nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
Defaults to 4000
use_torchaudio: Whether to use the `torchaudio` implementation.
stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
"""
Expand Down Expand Up @@ -222,6 +220,7 @@ def __init__(
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
use_torchaudio: bool = False,
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
stft_conv=False, # Deprecated arguments; kept for config compatibility
):
Expand All @@ -239,7 +238,12 @@ def __init__(
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)

self.featurizer = FilterbankFeatures(
# Given the long and similar argument list, point to the class and instantiate it by reference
if not use_torchaudio:
featurizer_class = FilterbankFeatures
else:
featurizer_class = FilterbankFeaturesTA
self.featurizer = featurizer_class(
sample_rate=self._sample_rate,
n_window_size=n_window_size,
n_window_stride=n_window_stride,
Expand All @@ -266,6 +270,14 @@ def __init__(
stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility
)

def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200):
batch_size = torch.randint(low=1, high=max_batch, size=[1]).item()
max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item()
signals = torch.rand(size=[batch_size, max_length]) * 2 - 1
lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size])
lengths[0] = max_length
return signals, lengths

def get_features(self, input_signal, length):
return self.featurizer(input_signal, length)

Expand Down Expand Up @@ -699,6 +711,7 @@ class AudioToMelSpectrogramPreprocessorConfig:
rng: Optional[str] = None
nb_augmentation_prob: float = 0.0
nb_max_freq: int = 4000
use_torchaudio: bool = False
stft_exact_pad: bool = False # Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: bool = False # Deprecated argument, kept for compatibility with older checkpoints.

Expand Down
225 changes: 225 additions & 0 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
# This file contains code artifacts adapted from https://github.com/ryanleary/patter
import math
import random
from typing import Optional, Tuple, Union

import librosa
import numpy as np
Expand All @@ -44,6 +45,14 @@
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.utils import logging

try:
import torchaudio

HAVE_TORCHAUDIO = True
except ModuleNotFoundError:
HAVE_TORCHAUDIO = False


CONSTANT = 1e-5


Expand Down Expand Up @@ -99,6 +108,39 @@ def splice_frames(x, frame_splicing):
return torch.cat(seq, dim=1)


@torch.jit.script_if_tracing
def make_seq_mask_like(
lengths: torch.Tensor, like: torch.Tensor, time_dim: int = -1, valid_ones: bool = True
) -> torch.Tensor:
"""

Args:
lengths: Tensor with shape [B] containing the sequence length of each batch element
like: The mask will contain the same number of dimensions as this Tensor, and will have the same max
length in the time dimension of this Tensor.
time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based.
valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert.

Returns:
A :class:`torch.Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else
vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match
the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and
`time_dim == -1', mask will have shape `[3, 1, 5]`.
"""
# Mask with shape [B, T]
mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.view(-1, 1))
# [B, T] -> [B, *, T] where * is any number of singleton dimensions to expand to like tensor
for _ in range(like.dim() - mask.dim()):
mask = mask.unsqueeze(1)
# If needed, transpose time dim
if time_dim != -1 and time_dim != mask.dim() - 1:
mask = mask.transpose(-1, time_dim)
# Maybe invert the padded vs. valid token values
if not valid_ones:
mask = ~mask
return mask


class WaveformFeaturizer(object):
def __init__(self, sample_rate=16000, int_values=False, augmentor=None):
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
Expand Down Expand Up @@ -401,3 +443,186 @@ def forward(self, x, seq_len):
if pad_amt != 0:
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
return x, seq_len


class FilterbankFeaturesTA(nn.Module):
"""
Exportable, `torchaudio`-based implementation of Mel Spectrogram extraction.

See `AudioToMelSpectrogramPreprocessor` for args.

"""

def __init__(
self,
sample_rate: int = 16000,
n_window_size: int = 320,
n_window_stride: int = 160,
normalize: Optional[str] = "per_feature",
nfilt: int = 64,
n_fft: Optional[int] = None,
preemph: float = 0.97,
lowfreq: float = 0,
highfreq: Optional[float] = None,
log: bool = True,
log_zero_guard_type: str = "add",
log_zero_guard_value: Union[float, str] = 2 ** -24,
dither: float = 1e-5,
window: str = "hann",
pad_to: int = 0,
pad_value: float = 0.0,
# Seems like no one uses these options anymore. Don't convolute the code by supporting thm.
use_grads: bool = False, # Deprecated arguments; kept for config compatibility
max_duration: float = 16.7, # Deprecated arguments; kept for config compatibility
frame_splicing: int = 1, # Deprecated arguments; kept for config compatibility
exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
nb_augmentation_prob: float = 0.0, # Deprecated arguments; kept for config compatibility
nb_max_freq: int = 4000, # Deprecated arguments; kept for config compatibility
mag_power: float = 2.0, # Deprecated arguments; kept for config compatibility
rng: Optional[random.Random] = None, # Deprecated arguments; kept for config compatibility
stft_exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
stft_conv: bool = False, # Deprecated arguments; kept for config compatibility
):
super().__init__()
if not HAVE_TORCHAUDIO:
raise ValueError(f"Need to install torchaudio to instantiate a {self.__class__.__name__}")

# Make sure log zero guard is supported, if given as a string
supported_log_zero_guard_strings = {"eps", "tiny"}
if isinstance(log_zero_guard_value, str) and log_zero_guard_value not in supported_log_zero_guard_strings:
raise ValueError(
f"Log zero guard value must either be a float or a member of {supported_log_zero_guard_strings}"
)

# Copied from `AudioPreprocessor` due to the ad-hoc structuring of the Mel Spec extractor class
self.torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'ones': torch.ones,
None: torch.ones,
}

# Ensure we can look up the window function
if window not in self.torch_windows:
raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}")

self._win_length = n_window_size
self._hop_length = n_window_stride
self._sample_rate = sample_rate
self._normalize_strategy = normalize
self._use_log = log
self._preemphasis_value = preemph
self._log_zero_guard_type = log_zero_guard_type
self._log_zero_guard_value: Union[str, float] = log_zero_guard_value
self._dither_value = dither
self._pad_to = pad_to
self._pad_value = pad_value
self._num_fft = n_fft
self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self._sample_rate,
win_length=self._win_length,
hop_length=self._hop_length,
n_mels=nfilt,
window_fn=self.torch_windows[window],
mel_scale="slaney",
norm="slaney",
n_fft=n_fft,
f_max=highfreq,
f_min=lowfreq,
wkwargs={"periodic": False},
)

@property
def filter_banks(self):
""" Matches the analogous class """
return self._mel_spec_extractor.mel_scale.fb

def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
if isinstance(self._log_zero_guard_value, float):
return self._log_zero_guard_value
return getattr(torch.finfo(dtype), self._log_zero_guard_value)

def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor:
if self.training and self._dither_value > 0.0:
noise = torch.randn_like(signals) * self._dither_value
signals = signals + noise
return signals

def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor:
if self._preemphasis_value is not None:
padded = torch.nn.functional.pad(signals, (1, 0))
signals = signals - self._preemphasis_value * padded[:, :-1]
return signals

def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
out_lengths = input_lengths.div(self._hop_length, rounding_mode="floor").add(1).long()
return out_lengths

def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor:
# Only apply during training; else need to capture dynamic shape for exported models
if not self.training or self._pad_to == 0 or features.shape[-1] % self._pad_to == 0:
return features
pad_length = self._pad_to - (features.shape[-1] % self._pad_to)
return torch.nn.functional.pad(features, pad=(0, pad_length), value=self._pad_value)

def _apply_log(self, features: torch.Tensor) -> torch.Tensor:
if self._use_log:
zero_guard = self._resolve_log_zero_guard_value(features.dtype)
if self._log_zero_guard_type == "add":
features = features + zero_guard
elif self._log_zero_guard_type == "clamp":
features = features.clamp(min=zero_guard)
else:
raise ValueError(f"Unsupported log zero guard type: '{self._log_zero_guard_type}'")
features = features.log()
return features

def _extract_spectrograms(self, signals: torch.Tensor) -> torch.Tensor:
# Complex FFT needs to be done in single precision
with torch.cuda.amp.autocast(enabled=False):
features = self._mel_spec_extractor(waveform=signals)
return features

def _apply_normalization(self, features: torch.Tensor, lengths: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
# For consistency, this function always does a masked fill even if not normalizing.
mask: torch.Tensor = make_seq_mask_like(lengths=lengths, like=features, time_dim=-1, valid_ones=False)
features = features.masked_fill(mask, 0.0)
# Maybe don't normalize
if self._normalize_strategy is None:
return features
# Use the log zero guard for the sqrt zero guard
guard_value = self._resolve_log_zero_guard_value(features.dtype)
if self._normalize_strategy == "per_feature" or self._normalize_strategy == "all_features":
# 'all_features' reduces over each sample; 'per_feature' reduces over each channel
reduce_dim = 2
if self._normalize_strategy == "all_features":
reduce_dim = [1, 2]
# [B, D, T] -> [B, D, 1] or [B, 1, 1]
means = features.sum(dim=reduce_dim, keepdim=True).div(lengths.view(-1, 1, 1))
stds = (
features.sub(means)
.masked_fill(mask, 0.0)
.pow(2.0)
.sum(dim=reduce_dim, keepdim=True) # [B, D, T] -> [B, D, 1] or [B, 1, 1]
.div(lengths.view(-1, 1, 1) - 1) # assume biased estimator
.clamp(min=guard_value) # avoid sqrt(0)
.sqrt()
)
features = (features - means) / (stds + eps)
else:
# Deprecating constant std/mean
raise ValueError(f"Unsupported norm type: '{self._normalize_strategy}")
features = features.masked_fill(mask, 0.0)
return features

def forward(self, input_signal: torch.Tensor, length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
feature_lengths = self._compute_output_lengths(input_lengths=length)
signals = self._apply_dithering(signals=input_signal)
signals = self._apply_preemphasis(signals=signals)
features = self._extract_spectrograms(signals=signals)
features = self._apply_log(features=features)
features = self._apply_normalization(features=features, lengths=feature_lengths)
features = self._apply_pad_to(features=features)
return features, feature_lengths