diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bb024220a20..0f3be1532564 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,6 +76,7 @@ To release a new version, please update the changelog as followed: - Speed augmentation on CPU, TimeStretch augmentation on CPU+GPU ([PR #594](https://github.com/NVIDIA/NeMo/pull/565)) - @titu1994 - Added TarredAudioToTextDataLayer, which allows for loading ASR datasets with tarred audio. Existing datasets can be converted with the `convert_to_tarred_audio_dataset.py` script. ([PR #602](https://github.com/NVIDIA/NeMo/pull/602)) - Online audio augmentation notebook in ASR examples ([PR #605](https://github.com/NVIDIA/NeMo/pull/605)) - @titu1994 +- ContextNet Encoder + Decoder Initial Support ([PR #630](https://github.com/NVIDIA/NeMo/pull/630)) - @titu1994 ### Changed diff --git a/examples/asr/contextnet.py b/examples/asr/contextnet.py new file mode 100644 index 000000000000..6e6845142d8f --- /dev/null +++ b/examples/asr/contextnet.py @@ -0,0 +1,324 @@ +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + +import argparse +import copy +import os +from functools import partial + +from ruamel.yaml import YAML + +import nemo +import nemo.collections.asr as nemo_asr +import nemo.utils.argparse as nm_argparse +from nemo.collections.asr.helpers import monitor_asr_train_progress, process_evaluation_batch, process_evaluation_epoch +from nemo.utils.lr_policies import CosineAnnealing + +logging = nemo.logging + + +def parse_args(): + parser = argparse.ArgumentParser( + parents=[nm_argparse.NemoArgParser()], description='ContextNet', conflict_handler='resolve', + ) + parser.set_defaults( + checkpoint_dir=None, + optimizer="novograd", + batch_size=32, + eval_batch_size=64, + lr=0.01, + weight_decay=0.001, + amp_opt_level="O0", + create_tb_writer=True, + ) + + # Overwrite default args + parser.add_argument( + "--num_epochs", + type=int, + default=None, + required=True, + help="number of epochs to train. You should specify either num_epochs or max_steps", + ) + parser.add_argument( + "--model_config", type=str, required=True, help="model configuration file: model.yaml", + ) + + # Create new args + parser.add_argument("--exp_name", default="ContextNet", type=str) + parser.add_argument("--project", default=None, type=str) + parser.add_argument("--beta1", default=0.95, type=float) + parser.add_argument("--beta2", default=0.5, type=float) + parser.add_argument("--warmup_steps", default=1000, type=int) + parser.add_argument("--warmup_ratio", default=None, type=float) + parser.add_argument('--min_lr', default=1e-5, type=float) + parser.add_argument("--load_dir", default=None, type=str) + parser.add_argument("--synced_bn", action='store_true', help="Use synchronized batch norm") + parser.add_argument("--synced_bn_groupsize", default=0, type=int) + parser.add_argument("--update_freq", default=50, type=int, help="Metrics update freq") + parser.add_argument("--eval_freq", default=1000, type=int, help="Evaluation frequency") + parser.add_argument('--kernel_size_factor', default=1.0, type=float) + + args = parser.parse_args() + if args.max_steps is not None: + raise ValueError("ContextNet uses num_epochs instead of max_steps") + + return args + + +def construct_name(name, lr, batch_size, num_epochs, wd, optimizer, kernel_size_factor): + return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-kf_{6}".format( + name, lr, batch_size, num_epochs, wd, optimizer, kernel_size_factor + ) + + +def create_all_dags(args, neural_factory): + ''' + creates train and eval dags as well as their callbacks + returns train loss tensor and callbacks''' + + # parse the config files + yaml = YAML(typ="safe") + with open(args.model_config) as f: + contextnet_params = yaml.load(f) + + vocab = contextnet_params['labels'] + sample_rate = contextnet_params['sample_rate'] + + # Calculate num_workers for dataloader + total_cpus = os.cpu_count() + cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) + + # create data layer for training + train_dl_params = copy.deepcopy(contextnet_params["AudioToTextDataLayer"]) + train_dl_params.update(contextnet_params["AudioToTextDataLayer"]["train"]) + del train_dl_params["train"] + del train_dl_params["eval"] + # del train_dl_params["normalize_transcripts"] + + data_layer_train = nemo_asr.AudioToTextDataLayer( + manifest_filepath=args.train_dataset, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.batch_size, + num_workers=cpu_per_traindl, + **train_dl_params, + ) + + N = len(data_layer_train) + steps_per_epoch = int(N / (args.batch_size * args.iter_per_step * args.num_gpus)) + + # create separate data layers for eval + # we need separate eval dags for separate eval datasets + # but all other modules in these dags will be shared + + eval_dl_params = copy.deepcopy(contextnet_params["AudioToTextDataLayer"]) + eval_dl_params.update(contextnet_params["AudioToTextDataLayer"]["eval"]) + del eval_dl_params["train"] + del eval_dl_params["eval"] + + data_layers_eval = [] + if args.eval_datasets: + for eval_dataset in args.eval_datasets: + data_layer_eval = nemo_asr.AudioToTextDataLayer( + manifest_filepath=eval_dataset, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.eval_batch_size, + num_workers=cpu_per_traindl, + **eval_dl_params, + ) + + data_layers_eval.append(data_layer_eval) + else: + logging.warning("There were no val datasets passed") + + # create shared modules + + data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( + sample_rate=sample_rate, **contextnet_params["AudioToMelSpectrogramPreprocessor"], + ) + + # Inject the `kernel_size_factor` kwarg to the ContextNet config + # Skip the last layer as that must be a pointwise kernel + for idx in range(len(contextnet_params["ContextNetEncoder"]["jasper"]) - 1): + contextnet_params["ContextNetEncoder"]["jasper"][idx]["kernel_size_factor"] = args.kernel_size_factor + + # (ContextNet uses the Jasper baseline encoder and decoder) + encoder = nemo_asr.ContextNetEncoder( + feat_in=contextnet_params["AudioToMelSpectrogramPreprocessor"]["features"], + **contextnet_params["ContextNetEncoder"], + ) + + decoder = nemo_asr.JasperDecoderForCTC( + feat_in=contextnet_params["ContextNetEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab), + ) + + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab), zero_infinity=True) + + greedy_decoder = nemo_asr.GreedyCTCDecoder() + + # create augmentation modules (only used for training) if their configs + # are present + + multiply_batch_config = contextnet_params.get('MultiplyBatch', None) + if multiply_batch_config: + multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config) + + spectr_augment_config = contextnet_params.get('SpectrogramAugmentation', None) + if spectr_augment_config: + data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(**spectr_augment_config) + + # assemble train DAG + + (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t,) = data_layer_train() + + processed_signal_t, p_length_t = data_preprocessor(input_signal=audio_signal_t, length=a_sig_length_t) + + if multiply_batch_config: + (processed_signal_t, p_length_t, transcript_t, transcript_len_t,) = multiply_batch( + in_x=processed_signal_t, in_x_len=p_length_t, in_y=transcript_t, in_y_len=transcript_len_t, + ) + + if spectr_augment_config: + processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t) + + encoded_t, encoded_len_t = encoder(audio_signal=processed_signal_t, length=p_length_t) + log_probs_t = decoder(encoder_output=encoded_t) + predictions_t = greedy_decoder(log_probs=log_probs_t) + loss_t = ctc_loss( + log_probs=log_probs_t, targets=transcript_t, input_length=encoded_len_t, target_length=transcript_len_t, + ) + + # create train callbacks + train_callback = nemo.core.SimpleLossLoggerCallback( + tensors=[loss_t, predictions_t, transcript_t, transcript_len_t], + print_func=partial(monitor_asr_train_progress, labels=vocab), + get_tb_values=lambda x: [["loss", x[0]]], + tb_writer=neural_factory.tb_writer, + step_freq=args.update_freq, + ) + + callbacks = [train_callback] + + if args.checkpoint_dir or args.load_dir: + chpt_callback = nemo.core.CheckpointCallback( + folder=args.checkpoint_dir, load_from_folder=args.load_dir, step_freq=args.checkpoint_save_freq, + ) + + callbacks.append(chpt_callback) + + # Log training metrics to wandb + if args.project is not None: + wand_callback = nemo.core.WandbCallback( + train_tensors=[loss_t], + wandb_name=args.exp_name, + wandb_project=args.project, + update_freq=args.update_freq, + args=args, + ) + callbacks.append(wand_callback) + + # assemble eval DAGs + for i, eval_dl in enumerate(data_layers_eval): + (audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e,) = eval_dl() + processed_signal_e, p_length_e = data_preprocessor(input_signal=audio_signal_e, length=a_sig_length_e) + encoded_e, encoded_len_e = encoder(audio_signal=processed_signal_e, length=p_length_e) + log_probs_e = decoder(encoder_output=encoded_e) + predictions_e = greedy_decoder(log_probs=log_probs_e) + loss_e = ctc_loss( + log_probs=log_probs_e, targets=transcript_e, input_length=encoded_len_e, target_length=transcript_len_e, + ) + + # create corresponding eval callback + tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] + + eval_callback = nemo.core.EvaluatorCallback( + eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e,], + user_iter_callback=partial(process_evaluation_batch, labels=vocab), + user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname), + eval_step=args.eval_freq, + tb_writer=neural_factory.tb_writer, + ) + + callbacks.append(eval_callback) + + return loss_t, callbacks, steps_per_epoch + + +def main(): + args = parse_args() + + name = construct_name( + args.exp_name, + args.lr, + args.batch_size, + args.num_epochs, + args.weight_decay, + args.optimizer, + args.kernel_size_factor, + ) + work_dir = name + if args.work_dir: + work_dir = os.path.join(args.work_dir, name) + + # instantiate Neural Factory with supported backend + neural_factory = nemo.core.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, + local_rank=args.local_rank, + optimization_level=args.amp_opt_level, + log_dir=work_dir, + checkpoint_dir=args.checkpoint_dir, + create_tb_writer=args.create_tb_writer, + files_to_copy=[args.model_config, __file__], + cudnn_benchmark=args.cudnn_benchmark, + tensorboard_dir=args.tensorboard_dir, + ) + args.num_gpus = neural_factory.world_size + + args.checkpoint_dir = neural_factory.checkpoint_dir + + if args.local_rank is not None: + logging.info('Doing ALL GPU') + + # build dags + train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory) + + # train model + neural_factory.train( + tensors_to_optimize=[train_loss], + callbacks=callbacks, + lr_policy=CosineAnnealing( + args.num_epochs * steps_per_epoch, + warmup_steps=args.warmup_steps, + warmup_ratio=args.warmup_ratio, + min_lr=args.min_lr, + ), + optimizer=args.optimizer, + optimization_params={ + "num_epochs": args.num_epochs, + "lr": args.lr, + "betas": (args.beta1, args.beta2), + "weight_decay": args.weight_decay, + "grad_norm_clip": None, + "amp_min_loss_scale": 1e-4, + }, + batches_per_step=args.iter_per_step, + synced_batchnorm=args.synced_bn, + synced_batchnorm_groupsize=args.synced_bn_groupsize, + ) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/asr/__init__.py b/nemo/collections/asr/__init__.py index 0a8ec950564b..b1fe49531777 100644 --- a/nemo/collections/asr/__init__.py +++ b/nemo/collections/asr/__init__.py @@ -14,6 +14,7 @@ # ============================================================================= from .audio_preprocessing import * from .beam_search_decoder import BeamSearchDecoderWithLM +from .contextnet import ContextNetDecoderForCTC, ContextNetEncoder from .data_layer import ( AudioToSpeechLabelDataLayer, AudioToTextDataLayer, @@ -50,6 +51,8 @@ 'JasperDecoderForClassification', 'JasperDecoderForSpkrClass', 'JasperRNNConnector', + 'ContextNetEncoder', + 'ContextNetDecoderForCTC', 'CTCLossNM', 'CrossEntropyLossNM', ] diff --git a/nemo/collections/asr/contextnet.py b/nemo/collections/asr/contextnet.py new file mode 100644 index 000000000000..c09be485d67a --- /dev/null +++ b/nemo/collections/asr/contextnet.py @@ -0,0 +1,213 @@ +# Copyright (c) 2019 NVIDIA Corporation +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import nemo +from .jasper import JasperEncoder +from .parts.jasper import init_weights +from nemo.backends.pytorch.nm import TrainableNM +from nemo.core.neural_types import * +from nemo.utils.decorators import add_port_docs + +logging = nemo.logging + + +class ContextNetEncoder(JasperEncoder): + """ + ContextNet Encoder creates the pre-processing (prologue), QuartzNet convolution + block, and the additional pre and post processing layers as described in + ContextNet (https://arxiv.org/abs/2005.03191) + + Args: + jasper (list): A list of dictionaries. Each element in the list + represents the configuration of one Jasper Block. Each element + should contain:: + + { + # Required parameters + 'filters' (int) # Number of output channels, + 'repeat' (int) # Number of sub-blocks, + 'kernel' (int) # Size of conv kernel, + 'stride' (int) # Conv stride + 'dilation' (int) # Conv dilation + 'dropout' (float) # Dropout probability + 'residual' (bool) # Whether to use residual or not. + # Optional parameters + 'residual_dense' (bool) # Whether to use Dense Residuals + # or not. 'residual' must be True for 'residual_dense' + # to be enabled. + # Defaults to False. + 'separable' (bool) # Whether to use separable convolutions. + # Defaults to False + 'groups' (int) # Number of groups in each conv layer. + # Defaults to 1 + 'heads' (int) # Sharing of separable filters + # Defaults to -1 + 'tied' (bool) # Whether to use the same weights for all + # sub-blocks. + # Defaults to False + 'se' (bool) # Whether to add Squeeze and Excitation + # sub-blocks. + # Defaults to False + 'se_reduction_ratio' (int) # The reduction ratio of the Squeeze + # sub-module. + # Must be an integer > 1. + # Defaults to 8. + 'se_context_window' (int) # The size of the temporal context + # provided to SE sub-module. + # Must be an integer. If value <= 0, will perform global + # temporal pooling (global context). + # If value >= 1, will perform stride 1 average pooling to + # compute context window. + 'se_interpolation_mode' (str) # Interpolation mode of timestep dimension. + # Used only if context window is > 1. + # The modes available for resizing are: `nearest`, `linear` (3D-only), + # `bilinear`, `area` + 'kernel_size_factor' (float) # Conv kernel size multiplier + # Can be either an int or float + # Kernel size is recomputed as below: + # new_kernel_size = int(max(1, (kernel_size * kernel_width))) + # to prevent kernel sizes than 1. + # Note: If rescaled kernel size is an even integer, + # adds 1 to the rescaled kernel size to allow "same" + # padding. + 'stride_last' (bool) # Bool flag to determine whether each + # of the the repeated sub-blockss will perform a stride, + # or only the last sub-block will perform a strided convolution. + } + + activation (str): Activation function used for each sub-blocks. Can be + one of ["hardtanh", "relu", "selu", "swish"]. + feat_in (int): Number of channels being input to this module + normalization_mode (str): Normalization to be used in each sub-block. + Can be one of ["batch", "layer", "instance", "group"] + Defaults to "batch". + residual_mode (str): Type of residual connection. + Can be "add", "stride_add" or "max". + "stride_add" mode performs strided convolution prior to residual + addition. + Defaults to "add". + norm_groups (int): Number of groups for "group" normalization type. + If set to -1, number of channels is used. + Defaults to -1. + conv_mask (bool): Controls the use of sequence length masking prior + to convolutions. + Defaults to True. + frame_splicing (int): Defaults to 1. + init_mode (str): Describes how neural network parameters are + initialized. Options are ['xavier_uniform', 'xavier_normal', + 'kaiming_uniform','kaiming_normal']. + Defaults to "xavier_uniform". + """ + + length: Optional[torch.Tensor] + + @property + @add_port_docs() + def input_ports(self): + """Returns definitions of module input ports. + """ + return { + # "audio_signal": NeuralType( + # {0: AxisType(BatchTag), 1: AxisType(SpectrogramSignalTag), 2: AxisType(ProcessedTimeTag),} + # ), + # "length": NeuralType({0: AxisType(BatchTag)}), + "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + """ + return { + # "outputs": NeuralType( + # {0: AxisType(BatchTag), 1: AxisType(EncodedRepresentationTag), 2: AxisType(ProcessedTimeTag),} + # ), + # "encoded_lengths": NeuralType({0: AxisType(BatchTag)}), + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + def __init__( + self, + jasper: List[Dict[str, Any]], + activation: str, + feat_in: int, + normalization_mode: str = "batch", + residual_mode: str = "add", + norm_groups: int = -1, + conv_mask: bool = False, + frame_splicing: int = 1, + init_mode: str = 'xavier_uniform', + ): + super().__init__( + jasper=jasper, + activation=activation, + feat_in=feat_in, + normalization_mode=normalization_mode, + residual_mode=residual_mode, + norm_groups=norm_groups, + conv_mask=conv_mask, + frame_splicing=frame_splicing, + init_mode=init_mode, + ) + + +class ContextNetDecoderForCTC(TrainableNM): + """ + ContextNet Decoder creates the final layer in ContextNet that maps from the outputs + of ContextNet Encoder to the vocabulary of interest. + + Args: + feat_in (int): Number of channels being input to this module + num_classes (int): Number of characters in ASR model's vocab/labels. + This count should not include the CTC blank symbol. + hidden_size (int): Number of units in the hidden state of the LSTM RNN. + init_mode (str): Describes how neural network parameters are + initialized. Options are ['xavier_uniform', 'xavier_normal', + 'kaiming_uniform','kaiming_normal']. + Defaults to "xavier_uniform". + """ + + @property + @add_port_docs() + def input_ports(self): + """Returns definitions of module input ports. + """ + return { + # "encoder_output": NeuralType( + # {0: AxisType(BatchTag), 1: AxisType(EncodedRepresentationTag), 2: AxisType(ProcessedTimeTag),} + # ) + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()) + } + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + """ + # return {"output": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),})} + return {"output": NeuralType(('B', 'T', 'D'), LogprobsType())} + + def __init__(self, feat_in: int, num_classes: int, hidden_size: int = 640, init_mode: str = "xavier_uniform"): + super().__init__() + + self._feat_in = feat_in + # Add 1 for blank char + self._num_classes = num_classes + 1 + + self.rnn = nn.LSTM(feat_in, hidden_size, bias=True, batch_first=True) + self.clf = nn.Linear(hidden_size, self._num_classes) + self.clf.apply(lambda x: init_weights(x, mode=init_mode)) + self.to(self._device) + + def forward(self, encoder_output): + encoder_output = encoder_output.transpose(1, 2) # [B, T, D] + output, states = self.rnn(encoder_output) + logits = self.clf(output) + return F.log_softmax(logits, dim=-1) diff --git a/nemo/collections/asr/jasper.py b/nemo/collections/asr/jasper.py index 6a36cc2cb41a..b5de8f6b7af4 100644 --- a/nemo/collections/asr/jasper.py +++ b/nemo/collections/asr/jasper.py @@ -54,7 +54,17 @@ class JasperEncoder(TrainableNM): 'se_reduction_ratio' (int) # The reduction ratio of the Squeeze # sub-module. # Must be an integer > 1. - # Defaults to 16 + # Defaults to 8. + 'se_context_window' (int) # The size of the temporal context + # provided to SE sub-module. + # Must be an integer. If value <= 0, will perform global + # temporal pooling (global context). + # If value >= 1, will perform stride 1 average pooling to + # compute context window. + 'se_interpolation_mode' (str) # Interpolation mode of timestep dimension. + # Used only if context window is > 1. + # The modes available for resizing are: `nearest`, `linear` (3D-only), + # `bilinear`, `area` 'kernel_size_factor' (float) # Conv kernel size multiplier # Can be either an int or float # Kernel size is recomputed as below: @@ -63,16 +73,21 @@ class JasperEncoder(TrainableNM): # Note: If rescaled kernel size is an even integer, # adds 1 to the rescaled kernel size to allow "same" # padding. + 'stride_last' (bool) # Bool flag to determine whether each + # of the the repeated sub-blockss will perform a stride, + # or only the last sub-block will perform a strided convolution. } activation (str): Activation function used for each sub-blocks. Can be - one of ["hardtanh", "relu", "selu"]. + one of ["hardtanh", "relu", "selu", "swish"]. feat_in (int): Number of channels being input to this module normalization_mode (str): Normalization to be used in each sub-block. Can be one of ["batch", "layer", "instance", "group"] Defaults to "batch". residual_mode (str): Type of residual connection. - Can be "add" or "max". + Can be "add", "stride_add" or "max". + "stride_add" mode performs strided convolution prior to residual + addition. Defaults to "add". norm_groups (int): Number of groups for "group" normalization type. If set to -1, number of channels is used. @@ -162,9 +177,13 @@ def __init__( groups = lcfg.get('groups', 1) separable = lcfg.get('separable', False) heads = lcfg.get('heads', -1) + residual_mode = lcfg.get('residual_mode', residual_mode) se = lcfg.get('se', False) - se_reduction_ratio = lcfg.get('se_reduction_ratio', 16) + se_reduction_ratio = lcfg.get('se_reduction_ratio', 8) + se_context_window = lcfg.get('se_context_window', -1) + se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest') kernel_size_factor = lcfg.get('kernel_size_factor', 1.0) + stride_last = lcfg.get('stride_last', False) encoder_layers.append( JasperBlock( feat_in, @@ -186,7 +205,10 @@ def __init__( conv_mask=conv_mask, se=se, se_reduction_ratio=se_reduction_ratio, + se_context_window=se_context_window, + se_interpolation_mode=se_interpolation_mode, kernel_size_factor=kernel_size_factor, + stride_last=stride_last, ) ) feat_in = lcfg['filters'] diff --git a/nemo/collections/asr/losses.py b/nemo/collections/asr/losses.py index d8714187cd2e..f3ca8f5a4d25 100644 --- a/nemo/collections/asr/losses.py +++ b/nemo/collections/asr/losses.py @@ -13,6 +13,9 @@ class CTCLossNM(LossNM): Args: num_classes (int): Number of characters in ASR model's vocab/labels. This count should not include the CTC blank symbol. + zero_infinity (bool): Whether to zero infinite losses and the associated gradients. + By default, it is False. Infinite losses mainly occur when the inputs are too + short to be aligned to the targets. """ @property @@ -41,11 +44,11 @@ def output_ports(self): # return {"loss": NeuralType(None)} return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, num_classes): + def __init__(self, num_classes, zero_infinity=False): super().__init__() self._blank = num_classes - self._criterion = nn.CTCLoss(blank=self._blank, reduction='none') + self._criterion = nn.CTCLoss(blank=self._blank, reduction='none', zero_infinity=zero_infinity) def _loss(self, log_probs, targets, input_length, target_length): input_length = input_length.long() diff --git a/nemo/collections/asr/parts/jasper.py b/nemo/collections/asr/parts/jasper.py index 428b967f19cd..a07fd1fb3b50 100644 --- a/nemo/collections/asr/parts/jasper.py +++ b/nemo/collections/asr/parts/jasper.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -179,21 +179,66 @@ def forward(self, x): class SqueezeExcite(nn.Module): - def __init__(self, channels, reduction_ratio): + def __init__( + self, + channels: int, + reduction_ratio: int, + context_window: int = -1, + interpolation_mode: str = 'nearest', + activation: Optional[Callable] = None, + ): + """ + Squeeze-and-Excitation sub-module. + + Args: + channels: Input number of channels. + reduction_ratio: Reduction ratio for "squeeze" layer. + context_window: Integer number of timesteps that the context + should be computed over, using stride 1 average pooling. + If value < 1, then global context is computed. + interpolation_mode: Interpolation mode of timestep dimension. + Used only if context window is > 1. + The modes available for resizing are: `nearest`, `linear` (3D-only), + `bilinear`, `area` + activation: Intermediate activation function used. Must be a + callable activation function. + """ super(SqueezeExcite, self).__init__() - self.pool = nn.AdaptiveAvgPool1d(1) + self.context_window = int(context_window) + self.interpolation_mode = interpolation_mode + + if self.context_window <= 0: + self.pool = nn.AdaptiveAvgPool1d(1) # context window = T + else: + self.pool = nn.AvgPool1d(self.context_window, stride=1) + + if activation is None: + activation = nn.ReLU(inplace=True) + self.fc = nn.Sequential( nn.Linear(channels, channels // reduction_ratio, bias=False), - nn.ReLU(inplace=True), + activation, nn.Linear(channels // reduction_ratio, channels, bias=False), - nn.Sigmoid(), ) def forward(self, x): - batch, channels, _ = x.size() - y = self.pool(x).view(batch, channels) - y = self.fc(y).view(batch, channels, 1) - return x * y.expand_as(x) + batch, channels, timesteps = x.size() + y = self.pool(x) # [B, C, T - context_window + 1] + y = y.transpose(1, 2) # [B, T - context_window + 1, C] + y = self.fc(y) # [B, T - context_window + 1, C] + y = y.transpose(1, 2) # [B, C, T - context_window + 1] + + if self.context_window > 0: + y = torch.nn.functional.interpolate(y, size=timesteps, mode=self.interpolation_mode) + + y = torch.sigmoid(y) + + return x * y + + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) class JasperBlock(nn.Module): @@ -222,6 +267,9 @@ def __init__( conv_mask=False, se=False, se_reduction_ratio=16, + se_context_window=None, + se_interpolation_mode='nearest', + stride_last=False, ): super(JasperBlock, self).__init__() @@ -244,12 +292,18 @@ def __init__( conv = nn.ModuleList() for _ in range(repeat - 1): + # Stride last means only the last convolution in block will have stride + if stride_last: + stride_val = [1] + else: + stride_val = stride + conv.extend( self._get_conv_bn_layer( inplanes_loop, planes, kernel_size=kernel_size, - stride=stride, + stride=stride_val, dilation=dilation, padding=padding_val, groups=groups, @@ -262,9 +316,6 @@ def __init__( conv.extend(self._get_act_dropout_layer(drop_prob=dropout, activation=activation)) - if se and not residual: - conv.append(SqueezeExcite(planes, reduction_ratio=se_reduction_ratio)) - inplanes_loop = planes conv.extend( @@ -283,8 +334,16 @@ def __init__( ) ) - if se and not residual: - conv.append(SqueezeExcite(planes, reduction_ratio=se_reduction_ratio)) + if se: + conv.append( + SqueezeExcite( + planes, + reduction_ratio=se_reduction_ratio, + context_window=se_context_window, + interpolation_mode=se_interpolation_mode, + activation=activation, + ) + ) self.mconv = conv @@ -293,19 +352,27 @@ def __init__( if residual: res_list = nn.ModuleList() + + if residual_mode == 'stride_add': + stride_val = stride + else: + stride_val = [1] + if len(residual_panes) == 0: res_panes = [inplanes] self.dense_residual = False for ip in res_panes: res = nn.ModuleList( self._get_conv_bn_layer( - ip, planes, kernel_size=1, normalization=normalization, norm_groups=norm_groups, + ip, + planes, + kernel_size=1, + normalization=normalization, + norm_groups=norm_groups, + stride=stride_val, ) ) - if se: - res.append(SqueezeExcite(planes, reduction_ratio=se_reduction_ratio)) - res_list.append(res) self.res = res_list @@ -462,7 +529,7 @@ def forward(self, input_: Tuple[List[Tensor], Optional[Tensor]]): else: res_out = res_layer(res_out) - if self.residual_mode == 'add': + if self.residual_mode == 'add' or self.residual_mode == 'stride_add': out = out + res_out else: out = torch.max(out, res_out) @@ -473,3 +540,7 @@ def forward(self, input_: Tuple[List[Tensor], Optional[Tensor]]): return xs + [out], lens return [out], lens + + +# Register swish activation function +jasper_activations['swish'] = Swish diff --git a/nemo/collections/asr/parts/spectr_augment.py b/nemo/collections/asr/parts/spectr_augment.py index ff733cc2f352..a2f4bd2f587a 100755 --- a/nemo/collections/asr/parts/spectr_augment.py +++ b/nemo/collections/asr/parts/spectr_augment.py @@ -14,7 +14,12 @@ class SpecAugment(nn.Module): freq_masks - how many frequency segments should be cut time_masks - how many time segments should be cut freq_width - maximum number of frequencies to be cut in one segment - time_width - maximum number of time steps to be cut in one segment + time_width - maximum number of time steps to be cut in one segment. + Can be a positive integer or a float value in the range [0, 1]. + If positive integer value, defines maximum number of time steps + to be cut in one segment. + If a float value, defines maximum percentage of timesteps that + are cut adaptively. """ def __init__( @@ -30,10 +35,23 @@ def __init__( self.freq_width = freq_width self.time_width = time_width + if isinstance(time_width, int): + self.adaptive_temporal_width = False + else: + if time_width > 1.0 or time_width < 0.0: + raise ValueError('If `time_width` is a float value, must be in range [0, 1]') + + self.adaptive_temporal_width = True + @torch.no_grad() def forward(self, x): sh = x.shape + if self.adaptive_temporal_width: + time_width = max(1, int(sh[2] * self.time_width)) + else: + time_width = self.time_width + mask = torch.zeros(x.shape).byte() for idx in range(sh[0]): @@ -45,9 +63,9 @@ def forward(self, x): mask[idx, x_left : x_left + w, :] = 1 for i in range(self.time_masks): - y_left = int(self._rng.uniform(0, sh[2] - self.time_width)) + y_left = int(self._rng.uniform(0, sh[2] - time_width)) - w = int(self._rng.uniform(0, self.time_width)) + w = int(self._rng.uniform(0, time_width)) mask[idx, :, y_left : y_left + w] = 1 diff --git a/tests/data/contextnet_32.yaml b/tests/data/contextnet_32.yaml new file mode 100644 index 000000000000..5e56e0d44048 --- /dev/null +++ b/tests/data/contextnet_32.yaml @@ -0,0 +1,77 @@ +model: "ContextNet" +sample_rate: 16000 +repeat: &repeat 2 +dropout: &dropout 0.0 +stride: &stride 2 + + +AudioToTextDataLayer: + max_duration: 16.7 + trim_silence: true + + train: + shuffle: true + + eval: + shuffle: false + max_duration: null + +AudioToMelSpectrogramPreprocessor: + window_size: 0.025 + window_stride: 0.01 + window: "hann" + normalize: "per_feature" + n_fft: 512 + features: 80 + dither: 0.00001 + pad_to: 16 + stft_conv: false + +SpectrogramAugmentation: + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 + +ContextNetEncoder: + activation: "relu" + conv_mask: true + + jasper: + - filters: 32 + repeat: 1 + kernel: [5] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: false + separable: true + se: true + se_context_size: -1 + + - filters: 32 + repeat: *repeat + kernel: [5] + stride: [1] + dilation: [1] + dropout: *dropout + residual: true + separable: true + se: true + se_context_size: 256 + + - filters: 32 + repeat: *repeat + kernel: [5] + stride: [*stride] + dilation: [1] + dropout: *dropout + residual: true + separable: true + se: true + se_context_size: -1 + stride_last: true + residual_mode: "stride_add" + +labels: [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", + "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"] diff --git a/tests/integration/test_asr_gradient_step_and_eval.py b/tests/integration/test_asr_gradient_step_and_eval.py index 66688794ec72..d68898c076b6 100644 --- a/tests/integration/test_asr_gradient_step_and_eval.py +++ b/tests/integration/test_asr_gradient_step_and_eval.py @@ -212,6 +212,64 @@ def test_quartznet_training(self): # Assert that training loss went down assert loss_list[-1] < loss_list[0] + @pytest.mark.integration + def test_contextnet_ctc_training(self): + """Integtaion test that instantiates a small ContextNet model and tests training with the sample asr data. + Training is run for 3 forward and backward steps and asserts that loss after 3 steps is smaller than the loss + at the first step. + Note: Training is done with batch gradient descent as opposed to stochastic gradient descent due to CTC loss + Checks SE-block with fixed context size and global context, residual_mode='stride_add' and 'stride_last' flags + """ + with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/contextnet_32.yaml"))) as f: + contextnet_model_definition = self.yaml.load(f) + dl = nemo_asr.AudioToTextDataLayer(manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=30) + pre_process_params = { + 'frame_splicing': 1, + 'features': 80, + 'window_size': 0.025, + 'n_fft': 512, + 'dither': 1e-05, + 'window': 'hann', + 'sample_rate': 16000, + 'normalize': 'per_feature', + 'window_stride': 0.01, + } + preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(**pre_process_params) + + spec_aug = nemo_asr.SpectrogramAugmentation(**contextnet_model_definition['SpectrogramAugmentation']) + + contextnet_encoder = nemo_asr.ContextNetEncoder( + feat_in=contextnet_model_definition['AudioToMelSpectrogramPreprocessor']['features'], + **contextnet_model_definition['ContextNetEncoder'], + ) + contextnet_decoder = nemo_asr.ContextNetDecoderForCTC(feat_in=32, hidden_size=16, num_classes=len(self.labels)) + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(self.labels)) + + # DAG + audio_signal, a_sig_length, transcript, transcript_len = dl() + processed_signal, p_length = preprocessing(input_signal=audio_signal, length=a_sig_length) + + processed_signal = spec_aug(input_spec=processed_signal) + + encoded, encoded_len = contextnet_encoder(audio_signal=processed_signal, length=p_length) + log_probs = contextnet_decoder(encoder_output=encoded) + loss = ctc_loss( + log_probs=log_probs, targets=transcript, input_length=encoded_len, target_length=transcript_len, + ) + + loss_list = [] + callback = nemo.core.SimpleLossLoggerCallback( + tensors=[loss], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1 + ) + + self.nf.train( + [loss], callbacks=[callback], optimizer="sgd", optimization_params={"max_steps": 3, "lr": 0.001}, + ) + self.nf.reset_trainer() + + # Assert that training loss went down + assert loss_list[-1] < loss_list[0] + @pytest.mark.integration def test_stft_conv_training(self): """Integtaion test that instantiates a small Jasper model and tests training with the sample asr data.