diff --git a/nemo/collections/asr/contextnet.py b/nemo/collections/asr/contextnet.py index 0c939ea72830..35fcdce94c88 100644 --- a/nemo/collections/asr/contextnet.py +++ b/nemo/collections/asr/contextnet.py @@ -1,5 +1,5 @@ # Copyright (c) 2019 NVIDIA Corporation -from typing import Optional +from typing import Any, Dict, Optional import torch import torch.nn as nn @@ -61,6 +61,10 @@ class ContextNetEncoder(TrainableNM): # temporal pooling (global context). # If value >= 1, will perform stride 1 average pooling to # compute context window. + 'se_interpolation_mode' (str) # Interpolation mode of timestep dimension. + # Used only if context window is > 1. + # The modes available for resizing are: `nearest`, `linear` (3D-only), + # `bilinear`, `area` 'kernel_size_factor' (float) # Conv kernel size multiplier # Can be either an int or float # Kernel size is recomputed as below: @@ -146,15 +150,15 @@ def _prepare_for_deployment(self): def __init__( self, - jasper, - activation, - feat_in, - normalization_mode="batch", - residual_mode="add", - norm_groups=-1, - conv_mask=False, - frame_splicing=1, - init_mode='xavier_uniform', + jasper: Dict[str, Any], + activation: str, + feat_in: int, + normalization_mode: str = "batch", + residual_mode: str = "add", + norm_groups: int = -1, + conv_mask: bool = False, + frame_splicing: int = 1, + init_mode: str = 'xavier_uniform', ): super().__init__() @@ -177,6 +181,7 @@ def __init__( se = lcfg.get('se', True) se_reduction_ratio = lcfg.get('se_reduction_ratio', 8) se_context_window = lcfg.get('se_context_window', -1) + se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest') kernel_size_factor = lcfg.get('kernel_size_factor', 1.0) stride_last = lcfg.get('stride_last', False) encoder_layers.append( @@ -201,6 +206,7 @@ def __init__( se=se, se_reduction_ratio=se_reduction_ratio, se_context_window=se_context_window, + se_interpolation_mode=se_interpolation_mode, kernel_size_factor=kernel_size_factor, stride_last=stride_last, ) @@ -256,7 +262,7 @@ def output_ports(self): # return {"output": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),})} return {"output": NeuralType(('B', 'T', 'D'), LogprobsType())} - def __init__(self, feat_in, num_classes, hidden_size=640, init_mode="xavier_uniform"): + def __init__(self, feat_in: int, num_classes: int, hidden_size: int = 640, init_mode: str = "xavier_uniform"): super().__init__() self._feat_in = feat_in diff --git a/nemo/collections/asr/jasper.py b/nemo/collections/asr/jasper.py index 02b82409fb29..b5de8f6b7af4 100644 --- a/nemo/collections/asr/jasper.py +++ b/nemo/collections/asr/jasper.py @@ -61,6 +61,10 @@ class JasperEncoder(TrainableNM): # temporal pooling (global context). # If value >= 1, will perform stride 1 average pooling to # compute context window. + 'se_interpolation_mode' (str) # Interpolation mode of timestep dimension. + # Used only if context window is > 1. + # The modes available for resizing are: `nearest`, `linear` (3D-only), + # `bilinear`, `area` 'kernel_size_factor' (float) # Conv kernel size multiplier # Can be either an int or float # Kernel size is recomputed as below: @@ -177,6 +181,7 @@ def __init__( se = lcfg.get('se', False) se_reduction_ratio = lcfg.get('se_reduction_ratio', 8) se_context_window = lcfg.get('se_context_window', -1) + se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest') kernel_size_factor = lcfg.get('kernel_size_factor', 1.0) stride_last = lcfg.get('stride_last', False) encoder_layers.append( @@ -201,6 +206,7 @@ def __init__( se=se, se_reduction_ratio=se_reduction_ratio, se_context_window=se_context_window, + se_interpolation_mode=se_interpolation_mode, kernel_size_factor=kernel_size_factor, stride_last=stride_last, ) diff --git a/nemo/collections/asr/losses.py b/nemo/collections/asr/losses.py index c34b7cc357c3..f3ca8f5a4d25 100644 --- a/nemo/collections/asr/losses.py +++ b/nemo/collections/asr/losses.py @@ -13,6 +13,9 @@ class CTCLossNM(LossNM): Args: num_classes (int): Number of characters in ASR model's vocab/labels. This count should not include the CTC blank symbol. + zero_infinity (bool): Whether to zero infinite losses and the associated gradients. + By default, it is False. Infinite losses mainly occur when the inputs are too + short to be aligned to the targets. """ @property diff --git a/nemo/collections/asr/parts/jasper.py b/nemo/collections/asr/parts/jasper.py index 5ac0ca8bacd2..a07fd1fb3b50 100644 --- a/nemo/collections/asr/parts/jasper.py +++ b/nemo/collections/asr/parts/jasper.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -180,7 +180,12 @@ def forward(self, x): class SqueezeExcite(nn.Module): def __init__( - self, channels, reduction_ratio, context_window: int = -1, interpolation_mode='nearest', activation=None + self, + channels: int, + reduction_ratio: int, + context_window: int = -1, + interpolation_mode: str = 'nearest', + activation: Optional[Callable] = None, ): """ Squeeze-and-Excitation sub-module. @@ -193,7 +198,10 @@ def __init__( If value < 1, then global context is computed. interpolation_mode: Interpolation mode of timestep dimension. Used only if context window is > 1. - activation: Intermediate activation function used. + The modes available for resizing are: `nearest`, `linear` (3D-only), + `bilinear`, `area` + activation: Intermediate activation function used. Must be a + callable activation function. """ super(SqueezeExcite, self).__init__() self.context_window = int(context_window) @@ -260,6 +268,7 @@ def __init__( se=False, se_reduction_ratio=16, se_context_window=None, + se_interpolation_mode='nearest', stride_last=False, ): super(JasperBlock, self).__init__() @@ -328,7 +337,11 @@ def __init__( if se: conv.append( SqueezeExcite( - planes, reduction_ratio=se_reduction_ratio, context_window=se_context_window, activation=activation + planes, + reduction_ratio=se_reduction_ratio, + context_window=se_context_window, + interpolation_mode=se_interpolation_mode, + activation=activation, ) )