From 4cb0d07bcd88d6811e3099d7de918e1e8662f549 Mon Sep 17 00:00:00 2001 From: Shane Carroll <50530592+1-800-BAD-CODE@users.noreply.github.com> Date: Fri, 2 Dec 2022 11:06:45 -0500 Subject: [PATCH] Pool stats with padding (#5403) * add padded stats pool Signed-off-by: shane carroll * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix issue with padded inputs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: shane carroll Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nithin Rao --- .../asr/parts/submodules/tdnn_attention.py | 88 ++++++++++++++----- 1 file changed, 67 insertions(+), 21 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/tdnn_attention.py b/nemo/collections/asr/parts/submodules/tdnn_attention.py index 03127300099a3..14f27ef41af7d 100644 --- a/nemo/collections/asr/parts/submodules/tdnn_attention.py +++ b/nemo/collections/asr/parts/submodules/tdnn_attention.py @@ -23,39 +23,85 @@ class StatsPoolLayer(nn.Module): - """ - Statistics and time average pooling (TAP) layer - This computes mean and variance statistics across time dimension (dim=-1) - input: - feat_in: input channel feature length - pool_mode: type of pool mode - supported modes are xvector (mean and variance), - tap (mean) - output: - pooled: statistics of feature input + """Statistics and time average pooling (TAP) layer + + This computes mean and, optionally, standard deviation statistics across the time dimension. + + Args: + feat_in: Input features with shape [B, D, T] + pool_mode: Type of pool mode. Supported modes are 'xvector' (mean and standard deviation) and 'tap' (time + average pooling, i.e., mean) + eps: Epsilon, minimum value before taking the square root, when using 'xvector' mode. + biased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default + for torch.Tensor.std() is True. + + Returns: + Pooled statistics with shape [B, D]. + + Raises: + ValueError if an unsupported pooling mode is specified. """ - def __init__(self, feat_in: int, pool_mode: str = 'xvector'): + def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, biased: bool = True): super().__init__() + supported_modes = {"xvector", "tap"} + if pool_mode not in supported_modes: + raise ValueError(f"Pool mode must be one of {supported_modes}; got '{pool_mode}'") self.pool_mode = pool_mode self.feat_in = feat_in + self.eps = eps + self.biased = biased if self.pool_mode == 'xvector': - self.feat_in += feat_in - elif self.pool_mode == 'tap': - self.feat_in = feat_in - else: - raise ValueError("pool mode for stats must be either tap or xvector based") + # Mean + std + self.feat_in *= 2 def forward(self, encoder_output, length=None): - mean = encoder_output.mean(dim=-1) # Time Axis - if self.pool_mode == 'xvector': - std = encoder_output.std(dim=-1) - pooled = torch.cat([mean, std], dim=-1) + if length is None: + mean = encoder_output.mean(dim=-1) # Time Axis + if self.pool_mode == 'xvector': + std = encoder_output.std(dim=-1) + pooled = torch.cat([mean, std], dim=-1) + else: + pooled = mean else: - pooled = mean + mask = make_seq_mask_like(like=encoder_output, lengths=length, valid_ones=False) + encoder_output = encoder_output.masked_fill(mask, 0.0) + # [B, D, T] -> [B, D] + means = encoder_output.mean(dim=-1) + # Re-scale to get padded means + means = means * (encoder_output.shape[-1] / length).unsqueeze(-1) + if self.pool_mode == "xvector": + stds = ( + encoder_output.sub(means.unsqueeze(-1)) + .masked_fill(mask, 0.0) + .pow(2.0) + .sum(-1) # [B, D, T] -> [B, D] + .div(length.view(-1, 1).sub(1 if self.biased else 0)) + .clamp(min=self.eps) + .sqrt() + ) + pooled = torch.cat((means, stds), dim=-1) + else: + pooled = means return pooled +@torch.jit.script_if_tracing +def make_seq_mask_like( + like: torch.Tensor, lengths: torch.Tensor, valid_ones: bool = True, time_dim: int = -1 +) -> torch.Tensor: + mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.unsqueeze(-1)) + # Match number of dims in `like` tensor + for _ in range(like.dim() - mask.dim()): + mask = mask.unsqueeze(1) + # If time dim != -1, transpose to proper dim. + if time_dim != -1: + mask = mask.transpose(time_dim, -1) + if not valid_ones: + mask = ~mask + return mask + + def lens_to_mask(lens: List[int], max_len: int, device: str = None): """ outputs masking labels for list of lengths of audio features, with max length of any