Skip to content

Commit

Permalink
Correctios to docstrings
Browse files Browse the repository at this point in the history
Signed-off-by: smajumdar <[email protected]>
  • Loading branch information
titu1994 committed May 13, 2020
1 parent 46cc5c7 commit a8d7f4c
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
28 changes: 17 additions & 11 deletions nemo/collections/asr/contextnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2019 NVIDIA Corporation
from typing import Optional
from typing import Any, Dict, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -61,6 +61,10 @@ class ContextNetEncoder(TrainableNM):
# temporal pooling (global context).
# If value >= 1, will perform stride 1 average pooling to
# compute context window.
'se_interpolation_mode' (str) # Interpolation mode of timestep dimension.
# Used only if context window is > 1.
# The modes available for resizing are: `nearest`, `linear` (3D-only),
# `bilinear`, `area`
'kernel_size_factor' (float) # Conv kernel size multiplier
# Can be either an int or float
# Kernel size is recomputed as below:
Expand Down Expand Up @@ -146,15 +150,15 @@ def _prepare_for_deployment(self):

def __init__(
self,
jasper,
activation,
feat_in,
normalization_mode="batch",
residual_mode="add",
norm_groups=-1,
conv_mask=False,
frame_splicing=1,
init_mode='xavier_uniform',
jasper: Dict[str, Any],
activation: str,
feat_in: int,
normalization_mode: str = "batch",
residual_mode: str = "add",
norm_groups: int = -1,
conv_mask: bool = False,
frame_splicing: int = 1,
init_mode: str = 'xavier_uniform',
):
super().__init__()

Expand All @@ -177,6 +181,7 @@ def __init__(
se = lcfg.get('se', True)
se_reduction_ratio = lcfg.get('se_reduction_ratio', 8)
se_context_window = lcfg.get('se_context_window', -1)
se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest')
kernel_size_factor = lcfg.get('kernel_size_factor', 1.0)
stride_last = lcfg.get('stride_last', False)
encoder_layers.append(
Expand All @@ -201,6 +206,7 @@ def __init__(
se=se,
se_reduction_ratio=se_reduction_ratio,
se_context_window=se_context_window,
se_interpolation_mode=se_interpolation_mode,
kernel_size_factor=kernel_size_factor,
stride_last=stride_last,
)
Expand Down Expand Up @@ -256,7 +262,7 @@ def output_ports(self):
# return {"output": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),})}
return {"output": NeuralType(('B', 'T', 'D'), LogprobsType())}

def __init__(self, feat_in, num_classes, hidden_size=640, init_mode="xavier_uniform"):
def __init__(self, feat_in: int, num_classes: int, hidden_size: int = 640, init_mode: str = "xavier_uniform"):
super().__init__()

self._feat_in = feat_in
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/asr/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class JasperEncoder(TrainableNM):
# temporal pooling (global context).
# If value >= 1, will perform stride 1 average pooling to
# compute context window.
'se_interpolation_mode' (str) # Interpolation mode of timestep dimension.
# Used only if context window is > 1.
# The modes available for resizing are: `nearest`, `linear` (3D-only),
# `bilinear`, `area`
'kernel_size_factor' (float) # Conv kernel size multiplier
# Can be either an int or float
# Kernel size is recomputed as below:
Expand Down Expand Up @@ -177,6 +181,7 @@ def __init__(
se = lcfg.get('se', False)
se_reduction_ratio = lcfg.get('se_reduction_ratio', 8)
se_context_window = lcfg.get('se_context_window', -1)
se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest')
kernel_size_factor = lcfg.get('kernel_size_factor', 1.0)
stride_last = lcfg.get('stride_last', False)
encoder_layers.append(
Expand All @@ -201,6 +206,7 @@ def __init__(
se=se,
se_reduction_ratio=se_reduction_ratio,
se_context_window=se_context_window,
se_interpolation_mode=se_interpolation_mode,
kernel_size_factor=kernel_size_factor,
stride_last=stride_last,
)
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class CTCLossNM(LossNM):
Args:
num_classes (int): Number of characters in ASR model's vocab/labels.
This count should not include the CTC blank symbol.
zero_infinity (bool): Whether to zero infinite losses and the associated gradients.
By default, it is False. Infinite losses mainly occur when the inputs are too
short to be aligned to the targets.
"""

@property
Expand Down
21 changes: 17 additions & 4 deletions nemo/collections/asr/parts/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -180,7 +180,12 @@ def forward(self, x):

class SqueezeExcite(nn.Module):
def __init__(
self, channels, reduction_ratio, context_window: int = -1, interpolation_mode='nearest', activation=None
self,
channels: int,
reduction_ratio: int,
context_window: int = -1,
interpolation_mode: str = 'nearest',
activation: Optional[Callable] = None,
):
"""
Squeeze-and-Excitation sub-module.
Expand All @@ -193,7 +198,10 @@ def __init__(
If value < 1, then global context is computed.
interpolation_mode: Interpolation mode of timestep dimension.
Used only if context window is > 1.
activation: Intermediate activation function used.
The modes available for resizing are: `nearest`, `linear` (3D-only),
`bilinear`, `area`
activation: Intermediate activation function used. Must be a
callable activation function.
"""
super(SqueezeExcite, self).__init__()
self.context_window = int(context_window)
Expand Down Expand Up @@ -260,6 +268,7 @@ def __init__(
se=False,
se_reduction_ratio=16,
se_context_window=None,
se_interpolation_mode='nearest',
stride_last=False,
):
super(JasperBlock, self).__init__()
Expand Down Expand Up @@ -328,7 +337,11 @@ def __init__(
if se:
conv.append(
SqueezeExcite(
planes, reduction_ratio=se_reduction_ratio, context_window=se_context_window, activation=activation
planes,
reduction_ratio=se_reduction_ratio,
context_window=se_context_window,
interpolation_mode=se_interpolation_mode,
activation=activation,
)
)

Expand Down

0 comments on commit a8d7f4c

Please sign in to comment.