Skip to content

Commit

Permalink
Pool stats with padding (NVIDIA#5403)
Browse files Browse the repository at this point in the history
* add padded stats pool

Signed-off-by: shane carroll <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix issue with padded inputs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: shane carroll <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nithin Rao <[email protected]>
  • Loading branch information
3 people authored and titu1994 committed Mar 24, 2023
1 parent 8c1829f commit 4cb0d07
Showing 1 changed file with 67 additions and 21 deletions.
88 changes: 67 additions & 21 deletions nemo/collections/asr/parts/submodules/tdnn_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,85 @@


class StatsPoolLayer(nn.Module):
"""
Statistics and time average pooling (TAP) layer
This computes mean and variance statistics across time dimension (dim=-1)
input:
feat_in: input channel feature length
pool_mode: type of pool mode
supported modes are xvector (mean and variance),
tap (mean)
output:
pooled: statistics of feature input
"""Statistics and time average pooling (TAP) layer
This computes mean and, optionally, standard deviation statistics across the time dimension.
Args:
feat_in: Input features with shape [B, D, T]
pool_mode: Type of pool mode. Supported modes are 'xvector' (mean and standard deviation) and 'tap' (time
average pooling, i.e., mean)
eps: Epsilon, minimum value before taking the square root, when using 'xvector' mode.
biased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default
for torch.Tensor.std() is True.
Returns:
Pooled statistics with shape [B, D].
Raises:
ValueError if an unsupported pooling mode is specified.
"""

def __init__(self, feat_in: int, pool_mode: str = 'xvector'):
def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, biased: bool = True):
super().__init__()
supported_modes = {"xvector", "tap"}
if pool_mode not in supported_modes:
raise ValueError(f"Pool mode must be one of {supported_modes}; got '{pool_mode}'")
self.pool_mode = pool_mode
self.feat_in = feat_in
self.eps = eps
self.biased = biased
if self.pool_mode == 'xvector':
self.feat_in += feat_in
elif self.pool_mode == 'tap':
self.feat_in = feat_in
else:
raise ValueError("pool mode for stats must be either tap or xvector based")
# Mean + std
self.feat_in *= 2

def forward(self, encoder_output, length=None):
mean = encoder_output.mean(dim=-1) # Time Axis
if self.pool_mode == 'xvector':
std = encoder_output.std(dim=-1)
pooled = torch.cat([mean, std], dim=-1)
if length is None:
mean = encoder_output.mean(dim=-1) # Time Axis
if self.pool_mode == 'xvector':
std = encoder_output.std(dim=-1)
pooled = torch.cat([mean, std], dim=-1)
else:
pooled = mean
else:
pooled = mean
mask = make_seq_mask_like(like=encoder_output, lengths=length, valid_ones=False)
encoder_output = encoder_output.masked_fill(mask, 0.0)
# [B, D, T] -> [B, D]
means = encoder_output.mean(dim=-1)
# Re-scale to get padded means
means = means * (encoder_output.shape[-1] / length).unsqueeze(-1)
if self.pool_mode == "xvector":
stds = (
encoder_output.sub(means.unsqueeze(-1))
.masked_fill(mask, 0.0)
.pow(2.0)
.sum(-1) # [B, D, T] -> [B, D]
.div(length.view(-1, 1).sub(1 if self.biased else 0))
.clamp(min=self.eps)
.sqrt()
)
pooled = torch.cat((means, stds), dim=-1)
else:
pooled = means
return pooled


@torch.jit.script_if_tracing
def make_seq_mask_like(
like: torch.Tensor, lengths: torch.Tensor, valid_ones: bool = True, time_dim: int = -1
) -> torch.Tensor:
mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.unsqueeze(-1))
# Match number of dims in `like` tensor
for _ in range(like.dim() - mask.dim()):
mask = mask.unsqueeze(1)
# If time dim != -1, transpose to proper dim.
if time_dim != -1:
mask = mask.transpose(time_dim, -1)
if not valid_ones:
mask = ~mask
return mask


def lens_to_mask(lens: List[int], max_len: int, device: str = None):
"""
outputs masking labels for list of lengths of audio features, with max length of any
Expand Down

0 comments on commit 4cb0d07

Please sign in to comment.