Skip to content

Commit

Permalink
Refactor ContextNetEncoder to subclass JasperEncoder
Browse files Browse the repository at this point in the history
Signed-off-by: smajumdar <[email protected]>
  • Loading branch information
titu1994 committed May 14, 2020
1 parent 66924b6 commit 81330ba
Showing 1 changed file with 15 additions and 83 deletions.
98 changes: 15 additions & 83 deletions nemo/collections/asr/contextnet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) 2019 NVIDIA Corporation
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

import nemo
from .jasper import JasperEncoder
from .parts.jasper import JasperBlock, init_weights, jasper_activations
from nemo.backends.pytorch.nm import TrainableNM
from nemo.core.neural_types import *
Expand All @@ -14,7 +15,7 @@
logging = nemo.logging


class ContextNetEncoder(TrainableNM):
class ContextNetEncoder(JasperEncoder):
"""
ContextNet Encoder creates the pre-processing (prologue), QuartzNet convolution
block, and the additional pre and post processing layers as described in
Expand Down Expand Up @@ -132,25 +133,9 @@ def output_ports(self):
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
}

@property
def _disabled_deployment_input_ports(self):
return set(["length"])

@property
def _disabled_deployment_output_ports(self):
return set(["encoded_lengths"])

def _prepare_for_deployment(self):
m_count = 0
for m in self.modules():
if type(m).__name__ == "MaskedConv1d":
m.use_mask = False
m_count += 1
logging.warning(f"Turned off {m_count} masked convolutions")

def __init__(
self,
jasper: Dict[str, Any],
jasper: List[Dict[str, Any]],
activation: str,
feat_in: int,
normalization_mode: str = "batch",
Expand All @@ -160,70 +145,17 @@ def __init__(
frame_splicing: int = 1,
init_mode: str = 'xavier_uniform',
):
super().__init__()

activation = jasper_activations[activation]()
feat_in = feat_in * frame_splicing

residual_panes = []
encoder_layers = []
self.dense_residual = False
for lcfg in jasper:
dense_res = []
if lcfg.get('residual_dense', False):
residual_panes.append(feat_in)
dense_res = residual_panes
self.dense_residual = True
groups = lcfg.get('groups', 1)
separable = lcfg.get('separable', False)
heads = lcfg.get('heads', -1)
residual_mode = lcfg.get('residual_mode', residual_mode)
se = lcfg.get('se', True)
se_reduction_ratio = lcfg.get('se_reduction_ratio', 8)
se_context_window = lcfg.get('se_context_window', -1)
se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest')
kernel_size_factor = lcfg.get('kernel_size_factor', 1.0)
stride_last = lcfg.get('stride_last', False)
encoder_layers.append(
JasperBlock(
feat_in,
lcfg['filters'],
repeat=lcfg['repeat'],
kernel_size=lcfg['kernel'],
stride=lcfg['stride'],
dilation=lcfg['dilation'],
dropout=lcfg['dropout'],
residual=lcfg['residual'],
groups=groups,
separable=separable,
heads=heads,
residual_mode=residual_mode,
normalization=normalization_mode,
norm_groups=norm_groups,
activation=activation,
residual_panes=dense_res,
conv_mask=conv_mask,
se=se,
se_reduction_ratio=se_reduction_ratio,
se_context_window=se_context_window,
se_interpolation_mode=se_interpolation_mode,
kernel_size_factor=kernel_size_factor,
stride_last=stride_last,
)
)
feat_in = lcfg['filters']

self.encoder = nn.Sequential(*encoder_layers)
self.apply(lambda x: init_weights(x, mode=init_mode))
self.to(self._device)

def forward(self, audio_signal, length=None):
# type: (Tensor, Optional[Tensor]) -> Tensor, Optional[Tensor]

s_input, length = self.encoder(([audio_signal], length))
if length is None:
return s_input[-1]
return s_input[-1], length
super().__init__(
jasper=jasper,
activation=activation,
feat_in=feat_in,
normalization_mode=normalization_mode,
residual_mode=residual_mode,
norm_groups=norm_groups,
conv_mask=conv_mask,
frame_splicing=frame_splicing,
init_mode=init_mode,
)


class ContextNetDecoderForCTC(TrainableNM):
Expand Down

0 comments on commit 81330ba

Please sign in to comment.