From 40109c0d93882bae09f3421908527ad33554660d Mon Sep 17 00:00:00 2001
From: Fangjun Kuang <csukuangfj@gmail.com>
Date: Sun, 22 Aug 2021 14:45:39 +0800
Subject: [PATCH 1/5] Add embedding scale to nn.Embedding.

---
 .../conformer_ctc_embedding_scale/README.md   |  31 +
 .../conformer_ctc_embedding_scale/__init__.py |   0
 .../asr_datamodule.py                         |   1 +
 .../conformer.py                              | 920 ++++++++++++++++
 .../conformer_ctc_embedding_scale/decode.py   | 548 ++++++++++
 .../embedding.py                              | 221 ++++
 .../pretrained.py                             | 350 +++++++
 .../subsampling.py                            | 144 +++
 .../test_subsampling.py                       |  33 +
 .../test_transformer.py                       |  89 ++
 .../conformer_ctc_embedding_scale/train.py    | 708 +++++++++++++
 .../transformer.py                            | 990 ++++++++++++++++++
 12 files changed, 4035 insertions(+)
 create mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/README.md
 create mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/__init__.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc_embedding_scale/asr_datamodule.py
 create mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py
 create mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/embedding.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
 create mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py
 create mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
 create mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py

diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/README.md b/egs/librispeech/ASR/conformer_ctc_embedding_scale/README.md
new file mode 100644
index 0000000000..5106f28cf2
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/README.md
@@ -0,0 +1,31 @@
+## Differences between `conformer_ctc` and `conformer_ctc_embedding_scale`
+
+`conformer_ctc_embedding_scale` replaces `nn.Embedding` with modified
+`Embedding`. Modified embedding contains two changes:
+
+  - (1) The weight matrix is initialized to the range `(-std, std)` where
+    `std = 1 / sqrt(embedding_dim)`
+
+  - (2) The output of the embedding is scaled by `sqrt(embedding_dim)`
+
+Also, `conformer_ctc_embedding_scale` modifies the `PositionalEncoding`
+in `transformer.py`. It replaces
+
+```python
+self.xscale = math.sqrt(self.d_model)
+x = x * self.xscale + self.pe[:, : x.size(1), :]
+```
+with
+
+```python
+self.pos_scale = 1. / math.sqrt(self.d_model)
+x = x + self.pe[:, : x.size(1), :] * self.pos_scale
+```
+
+You can use
+
+```bash
+diff conformer_ctc/transformer.py conformer_ctc_embedding_scale/transformer.py
+```
+
+to find the exact differences.
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/__init__.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/asr_datamodule.py
new file mode 120000
index 0000000000..fa1b8cca3c
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/asr_datamodule.py
@@ -0,0 +1 @@
+../tdnn_lstm_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py
new file mode 100644
index 0000000000..a00664a992
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py
@@ -0,0 +1,920 @@
+#!/usr/bin/env python3
+
+# Copyright (c)  2021  University of Chinese Academy of Sciences (author: Han Zhu)
+# Apache 2.0
+
+import math
+import warnings
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, nn
+from transformer import Supervisions, Transformer, encoder_padding_mask
+
+
+class Conformer(Transformer):
+    """
+    Args:
+        num_features (int): Number of input features
+        num_classes (int): Number of output classes
+        subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
+        d_model (int): attention dimension
+        nhead (int): number of head
+        dim_feedforward (int): feedforward dimention
+        num_encoder_layers (int): number of encoder layers
+        num_decoder_layers (int): number of decoder layers
+        dropout (float): dropout rate
+        cnn_module_kernel (int): Kernel size of convolution module
+        normalize_before (bool): whether to use layer_norm before the first block.
+        vgg_frontend (bool): whether to use vgg frontend.
+    """
+
+    def __init__(
+        self,
+        num_features: int,
+        num_classes: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        num_decoder_layers: int = 6,
+        dropout: float = 0.1,
+        cnn_module_kernel: int = 31,
+        normalize_before: bool = True,
+        vgg_frontend: bool = False,
+        is_espnet_structure: bool = False,
+        mmi_loss: bool = True,
+        use_feat_batchnorm: bool = False,
+    ) -> None:
+        super(Conformer, self).__init__(
+            num_features=num_features,
+            num_classes=num_classes,
+            subsampling_factor=subsampling_factor,
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            num_encoder_layers=num_encoder_layers,
+            num_decoder_layers=num_decoder_layers,
+            dropout=dropout,
+            normalize_before=normalize_before,
+            vgg_frontend=vgg_frontend,
+            mmi_loss=mmi_loss,
+            use_feat_batchnorm=use_feat_batchnorm,
+        )
+
+        self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+        encoder_layer = ConformerEncoderLayer(
+            d_model,
+            nhead,
+            dim_feedforward,
+            dropout,
+            cnn_module_kernel,
+            normalize_before,
+            is_espnet_structure,
+        )
+        self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
+        self.normalize_before = normalize_before
+        self.is_espnet_structure = is_espnet_structure
+        if self.normalize_before and self.is_espnet_structure:
+            self.after_norm = nn.LayerNorm(d_model)
+        else:
+            # Note: TorchScript detects that self.after_norm could be used inside forward()
+            #       and throws an error without this change.
+            self.after_norm = identity
+
+    def run_encoder(
+        self, x: Tensor, supervisions: Optional[Supervisions] = None
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        """
+        Args:
+          x:
+            The model input. Its shape is [N, T, C].
+          supervisions:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            CAUTION: It contains length information, i.e., start and number of
+            frames, before subsampling
+            It is read directly from the batch, without any sorting. It is used
+            to compute encoder padding mask, which is used as memory key padding
+            mask for the decoder.
+
+        Returns:
+            Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
+            Tensor: Mask tensor of dimension (batch_size, input_length)
+        """
+        x = self.encoder_embed(x)
+        x, pos_emb = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
+        mask = encoder_padding_mask(x.size(0), supervisions)
+        if mask is not None:
+            mask = mask.to(x.device)
+        x = self.encoder(x, pos_emb, src_key_padding_mask=mask)  # (T, B, F)
+
+        if self.normalize_before and self.is_espnet_structure:
+            x = self.after_norm(x)
+
+        return x, mask
+
+
+class ConformerEncoderLayer(nn.Module):
+    """
+    ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
+    See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
+
+    Args:
+        d_model: the number of expected features in the input (required).
+        nhead: the number of heads in the multiheadattention models (required).
+        dim_feedforward: the dimension of the feedforward network model (default=2048).
+        dropout: the dropout value (default=0.1).
+        cnn_module_kernel (int): Kernel size of convolution module.
+        normalize_before: whether to use layer_norm before the first block.
+
+    Examples::
+        >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = encoder_layer(src, pos_emb)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        cnn_module_kernel: int = 31,
+        normalize_before: bool = True,
+        is_espnet_structure: bool = False,
+    ) -> None:
+        super(ConformerEncoderLayer, self).__init__()
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
+        )
+
+        self.feed_forward = nn.Sequential(
+            nn.Linear(d_model, dim_feedforward),
+            Swish(),
+            nn.Dropout(dropout),
+            nn.Linear(dim_feedforward, d_model),
+        )
+
+        self.feed_forward_macaron = nn.Sequential(
+            nn.Linear(d_model, dim_feedforward),
+            Swish(),
+            nn.Dropout(dropout),
+            nn.Linear(dim_feedforward, d_model),
+        )
+
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
+
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
+        self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
+        self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
+
+        self.ff_scale = 0.5
+
+        self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
+
+        self.dropout = nn.Dropout(dropout)
+
+        self.normalize_before = normalize_before
+
+    def forward(
+        self,
+        src: Tensor,
+        pos_emb: Tensor,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+            src: the sequence to the encoder layer (required).
+            pos_emb: Positional embedding tensor (required).
+            src_mask: the mask for the src sequence (optional).
+            src_key_padding_mask: the mask for the src keys per batch (optional).
+
+        Shape:
+            src: (S, N, E).
+            pos_emb: (N, 2*S-1, E)
+            src_mask: (S, S).
+            src_key_padding_mask: (N, S).
+            S is the source sequence length, N is the batch size, E is the feature number
+        """
+
+        # macaron style feed forward module
+        residual = src
+        if self.normalize_before:
+            src = self.norm_ff_macaron(src)
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
+        if not self.normalize_before:
+            src = self.norm_ff_macaron(src)
+
+        # multi-headed self-attention module
+        residual = src
+        if self.normalize_before:
+            src = self.norm_mha(src)
+        src_att = self.self_attn(
+            src,
+            src,
+            src,
+            pos_emb=pos_emb,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+        src = residual + self.dropout(src_att)
+        if not self.normalize_before:
+            src = self.norm_mha(src)
+
+        # convolution module
+        residual = src
+        if self.normalize_before:
+            src = self.norm_conv(src)
+        src = residual + self.dropout(self.conv_module(src))
+        if not self.normalize_before:
+            src = self.norm_conv(src)
+
+        # feed forward module
+        residual = src
+        if self.normalize_before:
+            src = self.norm_ff(src)
+        src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
+        if not self.normalize_before:
+            src = self.norm_ff(src)
+
+        if self.normalize_before:
+            src = self.norm_final(src)
+
+        return src
+
+
+class ConformerEncoder(nn.TransformerEncoder):
+    r"""ConformerEncoder is a stack of N encoder layers
+
+    Args:
+        encoder_layer: an instance of the ConformerEncoderLayer() class (required).
+        num_layers: the number of sub-encoder-layers in the encoder (required).
+        norm: the layer normalization component (optional).
+
+    Examples::
+        >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+        >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = conformer_encoder(src, pos_emb)
+    """
+
+    def __init__(
+        self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
+    ) -> None:
+        super(ConformerEncoder, self).__init__(
+            encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
+        )
+
+    def forward(
+        self,
+        src: Tensor,
+        pos_emb: Tensor,
+        mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        r"""Pass the input through the encoder layers in turn.
+
+        Args:
+            src: the sequence to the encoder (required).
+            pos_emb: Positional embedding tensor (required).
+            mask: the mask for the src sequence (optional).
+            src_key_padding_mask: the mask for the src keys per batch (optional).
+
+        Shape:
+            src: (S, N, E).
+            pos_emb: (N, 2*S-1, E)
+            mask: (S, S).
+            src_key_padding_mask: (N, S).
+            S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
+
+        """
+        output = src
+
+        for mod in self.layers:
+            output = mod(
+                output,
+                pos_emb,
+                src_mask=mask,
+                src_key_padding_mask=src_key_padding_mask,
+            )
+
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+    """Relative positional encoding module.
+
+    See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+    Args:
+        d_model: Embedding dimension.
+        dropout_rate: Dropout rate.
+        max_len: Maximum input length.
+
+    """
+
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
+        """Construct an PositionalEncoding object."""
+        super(RelPositionalEncoding, self).__init__()
+        self.d_model = d_model
+        self.xscale = math.sqrt(self.d_model)
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.pe = None
+        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+    def extend_pe(self, x: Tensor) -> None:
+        """Reset the positional encodings."""
+        if self.pe is not None:
+            # self.pe contains both positive and negative parts
+            # the length of self.pe is 2 * input_len - 1
+            if self.pe.size(1) >= x.size(1) * 2 - 1:
+                # Note: TorchScript doesn't implement operator== for torch.Device
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
+                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        # Suppose `i` means to the position of query vecotr and `j` means the
+        # position of key vector. We use position relative positions when keys
+        # are to the left (i>j) and negative relative positions otherwise (i<j).
+        pe_positive = torch.zeros(x.size(1), self.d_model)
+        pe_negative = torch.zeros(x.size(1), self.d_model)
+        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.d_model, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.d_model)
+        )
+        pe_positive[:, 0::2] = torch.sin(position * div_term)
+        pe_positive[:, 1::2] = torch.cos(position * div_term)
+        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+
+        # Reserve the order of positive indices and concat both positive and
+        # negative indices. This is used to support the shifting trick
+        # as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+        pe_negative = pe_negative[1:].unsqueeze(0)
+        pe = torch.cat([pe_positive, pe_negative], dim=1)
+        self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+    def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
+        """Add positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+            torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
+
+        """
+        self.extend_pe(x)
+        x = x * self.xscale
+        pos_emb = self.pe[
+            :,
+            self.pe.size(1) // 2
+            - x.size(1)
+            + 1 : self.pe.size(1) // 2
+            + x.size(1),
+        ]
+        return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+    r"""Multi-Head Attention layer with relative position encoding
+
+    See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+
+    Args:
+        embed_dim: total dimension of the model.
+        num_heads: parallel attention heads.
+        dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+
+    Examples::
+
+        >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
+        >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        is_espnet_structure: bool = False,
+    ) -> None:
+        super(RelPositionMultiheadAttention, self).__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        assert (
+            self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+
+        # linear transformation for positional encoding.
+        self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+
+        self._reset_parameters()
+
+        self.is_espnet_structure = is_espnet_structure
+
+    def _reset_parameters(self) -> None:
+        nn.init.xavier_uniform_(self.in_proj.weight)
+        nn.init.constant_(self.in_proj.bias, 0.0)
+        nn.init.constant_(self.out_proj.bias, 0.0)
+
+        nn.init.xavier_uniform_(self.pos_bias_u)
+        nn.init.xavier_uniform_(self.pos_bias_v)
+
+    def forward(
+        self,
+        query: Tensor,
+        key: Tensor,
+        value: Tensor,
+        pos_emb: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
+        need_weights: bool = True,
+        attn_mask: Optional[Tensor] = None,
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        r"""
+        Args:
+            query, key, value: map a query and a set of key-value pairs to an output.
+            pos_emb: Positional embedding tensor
+            key_padding_mask: if provided, specified padding elements in the key will
+                be ignored by the attention. When given a binary mask and a value is True,
+                the corresponding value on the attention layer will be ignored. When given
+                a byte mask and a value is non-zero, the corresponding value on the attention
+                layer will be ignored
+            need_weights: output attn_output_weights.
+            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+                the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+        Shape:
+            - Inputs:
+            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+            - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+            - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+            - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+            - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the position
+            with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+            - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+            - Outputs:
+            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+            - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+        return self.multi_head_attention_forward(
+            query,
+            key,
+            value,
+            pos_emb,
+            self.embed_dim,
+            self.num_heads,
+            self.in_proj.weight,
+            self.in_proj.bias,
+            self.dropout,
+            self.out_proj.weight,
+            self.out_proj.bias,
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+        )
+
+    def rel_shift(self, x: Tensor) -> Tensor:
+        """Compute relative positional encoding.
+
+        Args:
+            x: Input tensor (batch, head, time1, 2*time1-1).
+                time1 means the length of query vector.
+
+        Returns:
+            Tensor: tensor of shape (batch, head, time1, time2)
+          (note: time2 has the same value as time1, but it is for
+          the key, while time1 is for the query).
+        """
+        (batch_size, num_heads, time1, n) = x.shape
+        assert n == 2 * time1 - 1
+        # Note: TorchScript requires explicit arg for stride()
+        batch_stride = x.stride(0)
+        head_stride = x.stride(1)
+        time1_stride = x.stride(2)
+        n_stride = x.stride(3)
+        return x.as_strided(
+            (batch_size, num_heads, time1, time1),
+            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+            storage_offset=n_stride * (time1 - 1),
+        )
+
+    def multi_head_attention_forward(
+        self,
+        query: Tensor,
+        key: Tensor,
+        value: Tensor,
+        pos_emb: Tensor,
+        embed_dim_to_check: int,
+        num_heads: int,
+        in_proj_weight: Tensor,
+        in_proj_bias: Tensor,
+        dropout_p: float,
+        out_proj_weight: Tensor,
+        out_proj_bias: Tensor,
+        training: bool = True,
+        key_padding_mask: Optional[Tensor] = None,
+        need_weights: bool = True,
+        attn_mask: Optional[Tensor] = None,
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        r"""
+        Args:
+            query, key, value: map a query and a set of key-value pairs to an output.
+            pos_emb: Positional embedding tensor
+            embed_dim_to_check: total dimension of the model.
+            num_heads: parallel attention heads.
+            in_proj_weight, in_proj_bias: input projection weight and bias.
+            dropout_p: probability of an element to be zeroed.
+            out_proj_weight, out_proj_bias: the output projection weight and bias.
+            training: apply dropout if is ``True``.
+            key_padding_mask: if provided, specified padding elements in the key will
+                be ignored by the attention. This is an binary mask. When the value is True,
+                the corresponding value on the attention layer will be filled with -inf.
+            need_weights: output attn_output_weights.
+            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+                the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+        Shape:
+            Inputs:
+            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+            - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+            - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+            - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
+            length, N is the batch size, E is the embedding dimension.
+            - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+            will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+            - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+            Outputs:
+            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+            - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+
+        tgt_len, bsz, embed_dim = query.size()
+        assert embed_dim == embed_dim_to_check
+        assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+        head_dim = embed_dim // num_heads
+        assert (
+            head_dim * num_heads == embed_dim
+        ), "embed_dim must be divisible by num_heads"
+        scaling = float(head_dim) ** -0.5
+
+        if torch.equal(query, key) and torch.equal(key, value):
+            # self-attention
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
+
+        elif torch.equal(key, value):
+            # encoder-decoder attention
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
+
+        else:
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = embed_dim * 2
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            k = nn.functional.linear(key, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim * 2
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            v = nn.functional.linear(value, _w, _b)
+
+        if not self.is_espnet_structure:
+            q = q * scaling
+
+        if attn_mask is not None:
+            assert (
+                attn_mask.dtype == torch.float32
+                or attn_mask.dtype == torch.float64
+                or attn_mask.dtype == torch.float16
+                or attn_mask.dtype == torch.uint8
+                or attn_mask.dtype == torch.bool
+            ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
+                attn_mask.dtype
+            )
+            if attn_mask.dtype == torch.uint8:
+                warnings.warn(
+                    "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
+                )
+                attn_mask = attn_mask.to(torch.bool)
+
+            if attn_mask.dim() == 2:
+                attn_mask = attn_mask.unsqueeze(0)
+                if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
+            elif attn_mask.dim() == 3:
+                if list(attn_mask.size()) != [
+                    bsz * num_heads,
+                    query.size(0),
+                    key.size(0),
+                ]:
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
+            else:
+                raise RuntimeError(
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
+                )
+            # attn_mask's dim is 3 now.
+
+        # convert ByteTensor key_padding_mask to bool
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
+            warnings.warn(
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+            )
+            key_padding_mask = key_padding_mask.to(torch.bool)
+
+        q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
+        k = k.contiguous().view(-1, bsz, num_heads, head_dim)
+        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+        src_len = k.size(0)
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size(0) == bsz, "{} == {}".format(
+                key_padding_mask.size(0), bsz
+            )
+            assert key_padding_mask.size(1) == src_len, "{} == {}".format(
+                key_padding_mask.size(1), src_len
+            )
+
+        q = q.transpose(0, 1)  # (batch, time1, head, d_k)
+
+        pos_emb_bsz = pos_emb.size(0)
+        assert pos_emb_bsz in (1, bsz)  # actually it is 1
+        p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
+        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
+
+        q_with_bias_u = (q + self.pos_bias_u).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        q_with_bias_v = (q + self.pos_bias_v).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
+
+        # compute matrix b and matrix d
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p.transpose(-2, -1)
+        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        if not self.is_espnet_structure:
+            attn_output_weights = (
+                matrix_ac + matrix_bd
+            )  # (batch, head, time1, time2)
+        else:
+            attn_output_weights = (
+                matrix_ac + matrix_bd
+            ) * scaling  # (batch, head, time1, time2)
+
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
+
+        assert list(attn_output_weights.size()) == [
+            bsz * num_heads,
+            tgt_len,
+            src_len,
+        ]
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+            else:
+                attn_output_weights += attn_mask
+
+        if key_padding_mask is not None:
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            attn_output_weights = attn_output_weights.masked_fill(
+                key_padding_mask.unsqueeze(1).unsqueeze(2),
+                float("-inf"),
+            )
+            attn_output_weights = attn_output_weights.view(
+                bsz * num_heads, tgt_len, src_len
+            )
+
+        attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+        attn_output_weights = nn.functional.dropout(
+            attn_output_weights, p=dropout_p, training=training
+        )
+
+        attn_output = torch.bmm(attn_output_weights, v)
+        assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+        attn_output = (
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
+        )
+
+        if need_weights:
+            # average attention weights over heads
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            return attn_output, attn_output_weights.sum(dim=1) / num_heads
+        else:
+            return attn_output, None
+
+
+class ConvolutionModule(nn.Module):
+    """ConvolutionModule in Conformer model.
+    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+
+    Args:
+        channels (int): The number of channels of conv layers.
+        kernel_size (int): Kernerl size of conv layers.
+        bias (bool): Whether to use bias in conv layers (default=True).
+
+    """
+
+    def __init__(
+        self, channels: int, kernel_size: int, bias: bool = True
+    ) -> None:
+        """Construct an ConvolutionModule object."""
+        super(ConvolutionModule, self).__init__()
+        # kernerl_size should be a odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0
+
+        self.pointwise_conv1 = nn.Conv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+        self.depthwise_conv = nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+            groups=channels,
+            bias=bias,
+        )
+        self.norm = nn.BatchNorm1d(channels)
+        self.pointwise_conv2 = nn.Conv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+        self.activation = Swish()
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Compute convolution module.
+
+        Args:
+            x: Input tensor (#time, batch, channels).
+
+        Returns:
+            Tensor: Output tensor (#time, batch, channels).
+
+        """
+        # exchange the temporal dimension and the feature dimension
+        x = x.permute(1, 2, 0)  # (#batch, channels, time).
+
+        # GLU mechanism
+        x = self.pointwise_conv1(x)  # (batch, 2*channels, time)
+        x = nn.functional.glu(x, dim=1)  # (batch, channels, time)
+
+        # 1D Depthwise Conv
+        x = self.depthwise_conv(x)
+        x = self.activation(self.norm(x))
+
+        x = self.pointwise_conv2(x)  # (batch, channel, time)
+
+        return x.permute(2, 0, 1)
+
+
+class Swish(torch.nn.Module):
+    """Construct an Swish object."""
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Return Swich activation function."""
+        return x * torch.sigmoid(x)
+
+
+def identity(x):
+    return x
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py
new file mode 100755
index 0000000000..c3354c0a3e
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py
@@ -0,0 +1,548 @@
+#!/usr/bin/env python3
+
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
+
+# (still working in progress)
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=9,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=1,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--lattice-score-scale",
+        type=float,
+        default=1.0,
+        help="The scale to be applied to `lattice.scores`."
+        "It's needed if you use any kinds of n-best based rescoring. "
+        "Currently, it is used when the decoding method is: nbest, "
+        "nbest-rescoring, attention-decoder, and nbest-oracle. "
+        "A smaller value results in more unique paths.",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            "exp_dir": Path("conformer_ctc_embedding_scale/exp"),
+            "lang_dir": Path("data/lang_bpe"),
+            "lm_dir": Path("data/lm"),
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "subsampling_factor": 4,
+            "num_decoder_layers": 6,
+            "vgg_frontend": False,
+            "is_espnet_structure": True,
+            "mmi_loss": False,
+            "use_feat_batchnorm": True,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+            # Possible values for method:
+            #  - 1best
+            #  - nbest
+            #  - nbest-rescoring
+            #  - whole-lattice-rescoring
+            #  - attention-decoder
+            #  - nbest-oracle
+            #  "method": "nbest",
+            #  "method": "nbest-rescoring",
+            #  "method": "whole-lattice-rescoring",
+            "method": "attention-decoder",
+            #  "method": "nbest-oracle",
+            # num_paths is used when method is "nbest", "nbest-rescoring",
+            # attention-decoder, and nbest-oracle
+            "num_paths": 100,
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: k2.Fsa,
+    batch: dict,
+    lexicon: Lexicon,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[int]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if no rescoring is used, the key is the string `no_rescore`.
+               If LM rescoring is used, the key is the string `lm_scale_xxx`,
+               where `xxx` is the value of `lm_scale`. An example key is
+               `lm_scale_0.7`
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.method is "1best", it uses 1best decoding without LM rescoring.
+        - params.method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      lexicon:
+        It contains word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = HLG.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is [N, T, C]
+
+    supervisions = batch["supervisions"]
+
+    nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+    # nnet_output is [N, T, C]
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        HLG=HLG,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is slightly worse than that of rescored lattices.
+        return nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            lexicon=lexicon,
+            scale=params.lattice_score_scale,
+        )
+
+    if params.method in ["1best", "nbest"]:
+        if params.method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                scale=params.lattice_score_scale,
+            )
+            key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
+        return {key: hyps}
+
+    assert params.method in [
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]
+
+    lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            scale=params.lattice_score_scale,
+        )
+    elif params.method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
+        )
+    elif params.method == "attention-decoder":
+        # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
+        rescored_lattice = rescore_with_whole_lattice(
+            lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
+        )
+
+        best_path_dict = rescore_with_attention_decoder(
+            lattice=rescored_lattice,
+            num_paths=params.num_paths,
+            model=model,
+            memory=memory,
+            memory_key_padding_mask=memory_key_padding_mask,
+            sos_id=sos_id,
+            eos_id=eos_id,
+            scale=params.lattice_score_scale,
+        )
+    else:
+        assert False, f"Unsupported decoding method: {params.method}"
+
+    ans = dict()
+    for lm_scale_str, best_path in best_path_dict.items():
+        hyps = get_texts(best_path)
+        hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
+        ans[lm_scale_str] = hyps
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: k2.Fsa,
+    lexicon: Lexicon,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[int], List[int]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph.
+      lexicon:
+        It contains word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    results = []
+
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            batch=batch,
+            lexicon=lexicon,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        for lm_scale, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for hyp_words, ref_text in zip(hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            results[lm_scale].extend(this_batch)
+
+        num_cuts += len(batch["supervisions"]["text"])
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    if params.method == "attention-decoder":
+        # Set it to False since there are too many logs.
+        enable_log = False
+    else:
+        enable_log = True
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
+        store_transcripts(filename=recog_path, texts=results)
+        if enable_log:
+            logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=enable_log
+            )
+            test_set_wers[key] = wer
+
+        if enable_log:
+            logging.info(
+                "Wrote detailed error stats to {}".format(errs_filename)
+            )
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+    logging.info("Decoding started")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="<sos/eos>",
+        eos_token="<sos/eos>",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
+    HLG = HLG.to(device)
+    assert HLG.requires_grad is False
+
+    if not hasattr(HLG, "lm_scores"):
+        HLG.lm_scores = HLG.scores.clone()
+
+    if params.method in (
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ):
+        if not (params.lm_dir / "G_4_gram.pt").is_file():
+            logging.info("Loading G_4_gram.fst.txt")
+            logging.warning("It may take 8 minutes.")
+            with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+                first_word_disambig_id = lexicon.word_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_word_disambig_id] = 0
+                G = k2.Fsa.from_fsas([G]).to(device)
+                G = k2.arc_sort(G)
+                torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+        else:
+            logging.info("Loading pre-compiled G_4_gram.pt")
+            d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
+            G = k2.Fsa.from_dict(d).to(device)
+
+        if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        is_espnet_structure=params.is_espnet_structure,
+        mmi_loss=params.mmi_loss,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    if params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if start >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.load_state_dict(average_checkpoints(filenames))
+
+    model.to(device)
+    model.eval()
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    librispeech = LibriSpeechAsrDataModule(args)
+    # CAUTION: `test_sets` is for displaying only.
+    # If you want to skip test-clean, you have to skip
+    # it inside the for loop. That is, use
+    #
+    #   if test_set == 'test-clean': continue
+    #
+    test_sets = ["test-clean", "test-other"]
+    for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            lexicon=lexicon,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
+
+    logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/embedding.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/embedding.py
new file mode 100644
index 0000000000..72f9ed53f6
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/embedding.py
@@ -0,0 +1,221 @@
+
+# This file is copied & modified from pytorch/torch/nn/modules/sparse.py
+# It modifies nn.Embedding
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import Parameter
+
+
+class Embedding(nn.Module):
+    r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
+
+    This module is often used to store word embeddings and retrieve them using indices.
+    The input to the module is a list of indices, and the output is the corresponding
+    word embeddings.
+
+    Args:
+        num_embeddings (int): size of the dictionary of embeddings
+        embedding_dim (int): the size of each embedding vector
+        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
+                                     therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
+                                     i.e. it remains as a fixed "pad". For a newly constructed Embedding,
+                                     the embedding vector at :attr:`padding_idx` will default to all zeros,
+                                     but can be updated to another value to be used as the padding vector.
+        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
+                                    is renormalized to have norm :attr:`max_norm`.
+        norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
+        scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
+                                                the words in the mini-batch. Default ``False``.
+        sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
+                                 See Notes for more details regarding sparse gradients.
+
+    Attributes:
+        weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
+                         initialized from :math:`\mathcal{N}(0, 1)`
+
+    Shape:
+        - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
+        - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
+
+    .. note::
+        Keep in mind that only a limited number of optimizers support
+        sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
+        :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
+
+    .. note::
+        When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
+        :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
+        modified in-place, performing a differentiable operation on ``Embedding.weight`` before
+        calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
+        :attr:`max_norm` is not ``None``. For example::
+
+            n, d, m = 3, 5, 7
+            embedding = nn.Embedding(n, d, max_norm=True)
+            W = torch.randn((m, d), requires_grad=True)
+            idx = torch.tensor([1, 2])
+            a = embedding.weight.clone() @ W.t()  # weight must be cloned for this to be differentiable
+            b = embedding(idx) @ W.t()  # modifies weight in-place
+            out = (a.unsqueeze(0) + b.unsqueeze(1))
+            loss = out.sigmoid().prod()
+            loss.backward()
+
+    Examples::
+
+        >>> # an Embedding module containing 10 tensors of size 3
+        >>> embedding = nn.Embedding(10, 3)
+        >>> # a batch of 2 samples of 4 indices each
+        >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
+        >>> embedding(input)
+        tensor([[[-0.0251, -1.6902,  0.7172],
+                 [-0.6431,  0.0748,  0.6969],
+                 [ 1.4970,  1.3448, -0.9685],
+                 [-0.3677, -2.7265, -0.1685]],
+
+                [[ 1.4970,  1.3448, -0.9685],
+                 [ 0.4362, -0.4004,  0.9400],
+                 [-0.6431,  0.0748,  0.6969],
+                 [ 0.9124, -2.3616,  1.1151]]])
+
+
+        >>> # example with padding_idx
+        >>> embedding = nn.Embedding(10, 3, padding_idx=0)
+        >>> input = torch.LongTensor([[0,2,0,5]])
+        >>> embedding(input)
+        tensor([[[ 0.0000,  0.0000,  0.0000],
+                 [ 0.1535, -2.0309,  0.9315],
+                 [ 0.0000,  0.0000,  0.0000],
+                 [-0.1655,  0.9897,  0.0635]]])
+
+        >>> # example of changing `pad` vector
+        >>> padding_idx = 0
+        >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
+        >>> embedding.weight
+        Parameter containing:
+        tensor([[ 0.0000,  0.0000,  0.0000],
+                [-0.7895, -0.7089, -0.0364],
+                [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
+        >>> with torch.no_grad():
+        ...     embedding.weight[padding_idx] = torch.ones(3)
+        >>> embedding.weight
+        Parameter containing:
+        tensor([[ 1.0000,  1.0000,  1.0000],
+                [-0.7895, -0.7089, -0.0364],
+                [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
+    """
+    __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
+                     'norm_type', 'scale_grad_by_freq', 'sparse']
+
+    num_embeddings: int
+    embedding_dim: int
+    padding_idx: Optional[int]
+    max_norm: Optional[float]
+    norm_type: float
+    scale_grad_by_freq: bool
+    weight: Tensor
+    sparse: bool
+
+    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
+                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+                 sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
+        super(Embedding, self).__init__()
+        self.num_embeddings = num_embeddings
+        self.embedding_dim = embedding_dim
+        if padding_idx is not None:
+            if padding_idx > 0:
+                assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
+            elif padding_idx < 0:
+                assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
+                padding_idx = self.num_embeddings + padding_idx
+        self.padding_idx = padding_idx
+        self.max_norm = max_norm
+        self.norm_type = norm_type
+        self.scale_grad_by_freq = scale_grad_by_freq
+        self.embedding_scale = math.sqrt(self.embedding_dim)
+        if _weight is None:
+            self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
+            self.reset_parameters()
+        else:
+            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
+                'Shape of weight does not match num_embeddings and embedding_dim'
+            self.weight = Parameter(_weight)
+
+        self.sparse = sparse
+
+    def reset_parameters(self) -> None:
+        std = 1 / self.embedding_scale
+        nn.init.normal_(self.weight, std=std)
+        self._fill_padding_idx_with_zero()
+
+    def _fill_padding_idx_with_zero(self) -> None:
+        if self.padding_idx is not None:
+            with torch.no_grad():
+                self.weight[self.padding_idx].fill_(0)
+
+    def forward(self, input: Tensor) -> Tensor:
+        return F.embedding(
+            input, self.weight, self.padding_idx, self.max_norm,
+            self.norm_type, self.scale_grad_by_freq, self.sparse) * self.embedding_scale
+
+    def extra_repr(self) -> str:
+        s = '{num_embeddings}, {embedding_dim}'
+        if self.padding_idx is not None:
+            s += ', padding_idx={padding_idx}'
+        if self.max_norm is not None:
+            s += ', max_norm={max_norm}'
+        if self.norm_type != 2:
+            s += ', norm_type={norm_type}'
+        if self.scale_grad_by_freq is not False:
+            s += ', scale_grad_by_freq={scale_grad_by_freq}'
+        if self.sparse is not False:
+            s += ', sparse=True'
+        return s.format(**self.__dict__)
+
+    @classmethod
+    def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
+                        max_norm=None, norm_type=2., scale_grad_by_freq=False,
+                        sparse=False):
+        r"""Creates Embedding instance from given 2-dimensional FloatTensor.
+
+        Args:
+            embeddings (Tensor): FloatTensor containing weights for the Embedding.
+                First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
+            freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
+                Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
+            padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
+                                         therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
+                                         i.e. it remains as a fixed "pad".
+            max_norm (float, optional): See module initialization documentation.
+            norm_type (float, optional): See module initialization documentation. Default ``2``.
+            scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
+            sparse (bool, optional): See module initialization documentation.
+
+        Examples::
+
+            >>> # FloatTensor containing pretrained weights
+            >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
+            >>> embedding = nn.Embedding.from_pretrained(weight)
+            >>> # Get embeddings for index 1
+            >>> input = torch.LongTensor([1])
+            >>> embedding(input)
+            tensor([[ 4.0000,  5.1000,  6.3000]])
+        """
+        assert embeddings.dim() == 2, \
+            'Embeddings parameter is expected to be 2-dimensional'
+        rows, cols = embeddings.shape
+        embedding = cls(
+            num_embeddings=rows,
+            embedding_dim=cols,
+            _weight=embeddings,
+            padding_idx=padding_idx,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            scale_grad_by_freq=scale_grad_by_freq,
+            sparse=sparse)
+        embedding.weight.requires_grad = not freeze
+        return embedding
+
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
new file mode 100755
index 0000000000..c63616d28b
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
@@ -0,0 +1,350 @@
+#!/usr/bin/env python3
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from conformer import Conformer
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import AttributeDict, get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        required=True,
+        help="Path to words.txt",
+    )
+
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) attention-decoder - Extract n paths from he rescored
+            lattice and use the transformer attention decoder for
+            rescoring.
+            We call it HLG decoding + n-gram LM rescoring + attention
+            decoder rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or attention-decoder.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and attention-decoder.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--attention-decoder-scale",
+        type=float,
+        default=1.2,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the scale for attention decoder scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--lattice-score-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--sos-id",
+        type=float,
+        default=1,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies ID for the SOS token.
+        """,
+    )
+
+    parser.add_argument(
+        "--eos-id",
+        type=float,
+        default=1,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies ID for the EOS token.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            "feature_dim": 80,
+            "nhead": 8,
+            "num_classes": 5000,
+            "sample_rate": 16000,
+            "attention_dim": 512,
+            "subsampling_factor": 4,
+            "num_decoder_layers": 6,
+            "vgg_frontend": False,
+            "is_espnet_structure": True,
+            "mmi_loss": False,
+            "use_feat_batchnorm": True,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    params.update(vars(args))
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=params.num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        is_espnet_structure=params.is_espnet_structure,
+        mmi_loss=params.mmi_loss,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"])
+    model.to(device)
+    model.eval()
+
+    logging.info(f"Loading HLG from {params.HLG}")
+    HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+    HLG = HLG.to(device)
+    if not hasattr(HLG, "lm_scores"):
+        # For whole-lattice-rescoring and attention-decoder
+        HLG.lm_scores = HLG.scores.clone()
+
+    if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
+        logging.info(f"Loading G from {params.G}")
+        G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+        G = G.to(device)
+        # Add epsilon self-loops to G as we will compose
+        # it with the whole lattice later
+        G = k2.add_epsilon_self_loops(G)
+        G = k2.arc_sort(G)
+        G.lm_scores = G.scores.clone()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info(f"Decoding started")
+    features = fbank(waves)
+
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
+
+    # Note: We don't use key padding mask for attention during decoding
+    with torch.no_grad():
+        nnet_output, memory, memory_key_padding_mask = model(features)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        HLG=HLG,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.method == "1best":
+        logging.info("Use HLG decoding")
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+    elif params.method == "whole-lattice-rescoring":
+        logging.info("Use HLG decoding + LM rescoring")
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=[params.ngram_lm_scale],
+        )
+        best_path = next(iter(best_path_dict.values()))
+    elif params.method == "attention-decoder":
+        logging.info("Use HLG + LM rescoring + attention decoder rescoring")
+        rescored_lattice = rescore_with_whole_lattice(
+            lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
+        )
+        best_path_dict = rescore_with_attention_decoder(
+            lattice=rescored_lattice,
+            num_paths=params.num_paths,
+            model=model,
+            memory=memory,
+            memory_key_padding_mask=memory_key_padding_mask,
+            sos_id=params.sos_id,
+            eos_id=params.eos_id,
+            scale=params.lattice_score_scale,
+            ngram_lm_scale=params.ngram_lm_scale,
+            attention_scale=params.attention_decoder_scale,
+        )
+        best_path = next(iter(best_path_dict.values()))
+
+    hyps = get_texts(best_path)
+    word_sym_table = k2.SymbolTable.from_file(params.words_file)
+    hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info(f"Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py
new file mode 100644
index 0000000000..5c3e1222ef
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py
@@ -0,0 +1,144 @@
+import torch
+import torch.nn as nn
+
+
+class Conv2dSubsampling(nn.Module):
+    """Convolutional 2D subsampling (to 1/4 length).
+
+    Convert an input of shape [N, T, idim] to an output
+    with shape [N, T', odim], where
+    T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
+
+    It is based on
+    https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py  # noqa
+    """
+
+    def __init__(self, idim: int, odim: int) -> None:
+        """
+        Args:
+          idim:
+            Input dim. The input shape is [N, T, idim].
+            Caution: It requires: T >=7, idim >=7
+          odim:
+            Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
+        """
+        assert idim >= 7
+        super().__init__()
+        self.conv = nn.Sequential(
+            nn.Conv2d(
+                in_channels=1, out_channels=odim, kernel_size=3, stride=2
+            ),
+            nn.ReLU(),
+            nn.Conv2d(
+                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
+            ),
+            nn.ReLU(),
+        )
+        self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Subsample x.
+
+        Args:
+          x:
+            Its shape is [N, T, idim].
+
+        Returns:
+          Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
+        """
+        # On entry, x is [N, T, idim]
+        x = x.unsqueeze(1)  # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
+        x = self.conv(x)
+        # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
+        b, c, t, f = x.size()
+        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+        # Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
+        return x
+
+
+class VggSubsampling(nn.Module):
+    """Trying to follow the setup described in the following paper:
+    https://arxiv.org/pdf/1910.09799.pdf
+
+    This paper is not 100% explicit so I am guessing to some extent,
+    and trying to compare with other VGG implementations.
+
+    Convert an input of shape [N, T, idim] to an output
+    with shape [N, T', odim], where
+    T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
+    """
+
+    def __init__(self, idim: int, odim: int) -> None:
+        """Construct a VggSubsampling object.
+
+        This uses 2 VGG blocks with 2 Conv2d layers each,
+        subsampling its input by a factor of 4 in the time dimensions.
+
+        Args:
+          idim:
+            Input dim. The input shape is [N, T, idim].
+            Caution: It requires: T >=7, idim >=7
+          odim:
+            Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
+        """
+        super().__init__()
+
+        cur_channels = 1
+        layers = []
+        block_dims = [32, 64]
+
+        # The decision to use padding=1 for the 1st convolution, then padding=0
+        # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
+        # a back-compatibility concern so that the number of frames at the
+        # output would be equal to:
+        #  (((T-1)//2)-1)//2.
+        # We can consider changing this by using padding=1 on the
+        # 2nd convolution, so the num-frames at the output would be T//4.
+        for block_dim in block_dims:
+            layers.append(
+                torch.nn.Conv2d(
+                    in_channels=cur_channels,
+                    out_channels=block_dim,
+                    kernel_size=3,
+                    padding=1,
+                    stride=1,
+                )
+            )
+            layers.append(torch.nn.ReLU())
+            layers.append(
+                torch.nn.Conv2d(
+                    in_channels=block_dim,
+                    out_channels=block_dim,
+                    kernel_size=3,
+                    padding=0,
+                    stride=1,
+                )
+            )
+            layers.append(
+                torch.nn.MaxPool2d(
+                    kernel_size=2, stride=2, padding=0, ceil_mode=True
+                )
+            )
+            cur_channels = block_dim
+
+        self.layers = nn.Sequential(*layers)
+
+        self.out = nn.Linear(
+            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Subsample x.
+
+        Args:
+          x:
+            Its shape is [N, T, idim].
+
+        Returns:
+          Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
+        """
+        x = x.unsqueeze(1)
+        x = self.layers(x)
+        b, c, t, f = x.size()
+        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+        return x
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py
new file mode 100755
index 0000000000..937845d779
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python3
+
+from subsampling import Conv2dSubsampling
+from subsampling import VggSubsampling
+import torch
+
+
+def test_conv2d_subsampling():
+    N = 3
+    odim = 2
+
+    for T in range(7, 19):
+        for idim in range(7, 20):
+            model = Conv2dSubsampling(idim=idim, odim=odim)
+            x = torch.empty(N, T, idim)
+            y = model(x)
+            assert y.shape[0] == N
+            assert y.shape[1] == ((T - 1) // 2 - 1) // 2
+            assert y.shape[2] == odim
+
+
+def test_vgg_subsampling():
+    N = 3
+    odim = 2
+
+    for T in range(7, 19):
+        for idim in range(7, 20):
+            model = VggSubsampling(idim=idim, odim=odim)
+            x = torch.empty(N, T, idim)
+            y = model(x)
+            assert y.shape[0] == N
+            assert y.shape[1] == ((T - 1) // 2 - 1) // 2
+            assert y.shape[2] == odim
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py
new file mode 100644
index 0000000000..08e6806074
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python3
+
+import torch
+from transformer import (
+    Transformer,
+    encoder_padding_mask,
+    generate_square_subsequent_mask,
+    decoder_padding_mask,
+    add_sos,
+    add_eos,
+)
+
+from torch.nn.utils.rnn import pad_sequence
+
+
+def test_encoder_padding_mask():
+    supervisions = {
+        "sequence_idx": torch.tensor([0, 1, 2]),
+        "start_frame": torch.tensor([0, 0, 0]),
+        "num_frames": torch.tensor([18, 7, 13]),
+    }
+
+    max_len = ((18 - 1) // 2 - 1) // 2
+    mask = encoder_padding_mask(max_len, supervisions)
+    expected_mask = torch.tensor(
+        [
+            [False, False, False],  # ((18 - 1)//2 - 1)//2 = 3,
+            [False, True, True],  # ((7 - 1)//2 - 1)//2 = 1,
+            [False, False, True],  # ((13 - 1)//2 - 1)//2 = 2,
+        ]
+    )
+    assert torch.all(torch.eq(mask, expected_mask))
+
+
+def test_transformer():
+    num_features = 40
+    num_classes = 87
+    model = Transformer(num_features=num_features, num_classes=num_classes)
+
+    N = 31
+
+    for T in range(7, 30):
+        x = torch.rand(N, T, num_features)
+        y, _, _ = model(x)
+        assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
+
+
+def test_generate_square_subsequent_mask():
+    s = 5
+    mask = generate_square_subsequent_mask(s)
+    inf = float("inf")
+    expected_mask = torch.tensor(
+        [
+            [0.0, -inf, -inf, -inf, -inf],
+            [0.0, 0.0, -inf, -inf, -inf],
+            [0.0, 0.0, 0.0, -inf, -inf],
+            [0.0, 0.0, 0.0, 0.0, -inf],
+            [0.0, 0.0, 0.0, 0.0, 0.0],
+        ]
+    )
+    assert torch.all(torch.eq(mask, expected_mask))
+
+
+def test_decoder_padding_mask():
+    x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
+    y = pad_sequence(x, batch_first=True, padding_value=-1)
+    mask = decoder_padding_mask(y, ignore_id=-1)
+    expected_mask = torch.tensor(
+        [
+            [False, False, True],
+            [False, True, True],
+            [False, False, False],
+        ]
+    )
+    assert torch.all(torch.eq(mask, expected_mask))
+
+
+def test_add_sos():
+    x = [[1, 2], [3], [2, 5, 8]]
+    y = add_sos(x, sos_id=0)
+    expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
+    assert y == expected_y
+
+
+def test_add_eos():
+    x = [[1, 2], [3], [2, 5, 8]]
+    y = add_eos(x, eos_id=0)
+    expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
+    assert y == expected_y
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
new file mode 100755
index 0000000000..795a2ab571
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
@@ -0,0 +1,708 @@
+#!/usr/bin/env python3
+
+# This is just at the very beginning ...
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional
+
+import k2
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from lhotse.utils import fix_random_seed
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+from transformer import Noam
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    is saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - exp_dir: It specifies the directory where all training related
+                   files, e.g., checkpoints, log, etc, are saved
+
+        - lang_dir: It contains language related input files such as
+                    "lexicon.txt"
+
+        - lr: It specifies the initial learning rate
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - weight_decay:  The weight_decay for the optimizer.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - start_epoch:  If it is not zero, load checkpoint `start_epoch-1`
+                        and continue training from that checkpoint.
+
+        - num_epochs:  Number of epochs to train.
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - beam_size: It is used in k2.ctc_loss
+
+        - reduction: It is used in k2.ctc_loss
+
+        - use_double_scores: It is used in k2.ctc_loss
+    """
+    params = AttributeDict(
+        {
+            "exp_dir": Path("conformer_ctc_embedding_scale/exp"),
+            "lang_dir": Path("data/lang_bpe"),
+            "feature_dim": 80,
+            "weight_decay": 1e-6,
+            "subsampling_factor": 4,
+            "start_epoch": 0,
+            "num_epochs": 20,
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 10,
+            "reset_interval": 200,
+            "valid_interval": 3000,
+            "beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            "accum_grad": 1,
+            "att_rate": 0.7,
+            "attention_dim": 512,
+            "nhead": 8,
+            "num_decoder_layers": 6,
+            "is_espnet_structure": True,
+            "mmi_loss": False,
+            "use_feat_batchnorm": True,
+            "lr_factor": 5.0,
+            "warm_step": 80000,
+        }
+    )
+
+    return params
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+    """Load checkpoint from file.
+
+    If params.start_epoch is positive, it will load the checkpoint from
+    `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+    Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+    it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The learning rate scheduler we are using.
+    Returns:
+      Return None.
+    """
+    if params.start_epoch <= 0:
+        return
+
+    filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    batch: dict,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    is_training: bool,
+):
+    """
+    Compute CTC loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+    """
+    device = graph_compiler.device
+    feature = batch["inputs"]
+    # at entry, feature is [N, T, C]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
+        # nnet_output is [N, T, C]
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `k2.ctc_loss`
+    supervision_segments, texts = encode_supervisions(
+        supervisions, subsampling_factor=params.subsampling_factor
+    )
+
+    token_ids = graph_compiler.texts_to_ids(texts)
+
+    decoding_graph = graph_compiler.compile(token_ids)
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    ctc_loss = k2.ctc_loss(
+        decoding_graph=decoding_graph,
+        dense_fsa_vec=dense_fsa_vec,
+        output_beam=params.beam_size,
+        reduction=params.reduction,
+        use_double_scores=params.use_double_scores,
+    )
+
+    if params.att_rate != 0.0:
+        with torch.set_grad_enabled(is_training):
+            if hasattr(model, "module"):
+                att_loss = model.module.decoder_forward(
+                    encoder_memory,
+                    memory_mask,
+                    token_ids=token_ids,
+                    sos_id=graph_compiler.sos_id,
+                    eos_id=graph_compiler.eos_id,
+                )
+            else:
+                att_loss = model.decoder_forward(
+                    encoder_memory,
+                    memory_mask,
+                    token_ids=token_ids,
+                    sos_id=graph_compiler.sos_id,
+                    eos_id=graph_compiler.eos_id,
+                )
+        loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
+    else:
+        loss = ctc_loss
+        att_loss = torch.tensor([0])
+
+    # train_frames and valid_frames are used for printing.
+    if is_training:
+        params.train_frames = supervision_segments[:, 2].sum().item()
+    else:
+        params.valid_frames = supervision_segments[:, 2].sum().item()
+
+    assert loss.requires_grad == is_training
+
+    return loss, ctc_loss.detach(), att_loss.detach()
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> None:
+    """Run the validation process. The validation loss
+    is saved in `params.valid_loss`.
+    """
+    model.eval()
+
+    tot_loss = 0.0
+    tot_ctc_loss = 0.0
+    tot_att_loss = 0.0
+    tot_frames = 0.0
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, ctc_loss, att_loss = compute_loss(
+            params=params,
+            model=model,
+            batch=batch,
+            graph_compiler=graph_compiler,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        assert ctc_loss.requires_grad is False
+        assert att_loss.requires_grad is False
+
+        loss_cpu = loss.detach().cpu().item()
+        tot_loss += loss_cpu
+
+        tot_ctc_loss += ctc_loss.detach().cpu().item()
+        tot_att_loss += att_loss.detach().cpu().item()
+
+        tot_frames += params.valid_frames
+
+    if world_size > 1:
+        s = torch.tensor(
+            [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
+            device=loss.device,
+        )
+        dist.all_reduce(s, op=dist.ReduceOp.SUM)
+        s = s.cpu().tolist()
+        tot_loss = s[0]
+        tot_ctc_loss = s[1]
+        tot_att_loss = s[2]
+        tot_frames = s[3]
+
+    params.valid_loss = tot_loss / tot_frames
+    params.valid_ctc_loss = tot_ctc_loss / tot_frames
+    params.valid_att_loss = tot_att_loss / tot_frames
+
+    if params.valid_loss < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = params.valid_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+    """
+    model.train()
+
+    tot_loss = 0.0  # sum of losses over all batches
+    tot_ctc_loss = 0.0
+    tot_att_loss = 0.0
+
+    tot_frames = 0.0  # sum of frames over all batches
+    params.tot_loss = 0.0
+    params.tot_frames = 0.0
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        loss, ctc_loss, att_loss = compute_loss(
+            params=params,
+            model=model,
+            batch=batch,
+            graph_compiler=graph_compiler,
+            is_training=True,
+        )
+
+        # NOTE: We use reduction==sum and loss is computed over utterances
+        # in the batch and there is no normalization to it so far.
+
+        optimizer.zero_grad()
+        loss.backward()
+        clip_grad_norm_(model.parameters(), 5.0, 2.0)
+        optimizer.step()
+
+        loss_cpu = loss.detach().cpu().item()
+        ctc_loss_cpu = ctc_loss.detach().cpu().item()
+        att_loss_cpu = att_loss.detach().cpu().item()
+
+        tot_frames += params.train_frames
+        tot_loss += loss_cpu
+        tot_ctc_loss += ctc_loss_cpu
+        tot_att_loss += att_loss_cpu
+
+        params.tot_frames += params.train_frames
+        params.tot_loss += loss_cpu
+
+        tot_avg_loss = tot_loss / tot_frames
+        tot_avg_ctc_loss = tot_ctc_loss / tot_frames
+        tot_avg_att_loss = tot_att_loss / tot_frames
+
+        if batch_idx % params.log_interval == 0:
+            logging.info(
+                f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+                f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
+                f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
+                f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
+                f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
+                f"total avg att loss: {tot_avg_att_loss:.4f}, "
+                f"total avg loss: {tot_avg_loss:.4f}, "
+                f"batch size: {batch_size}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/current_ctc_loss",
+                    ctc_loss_cpu / params.train_frames,
+                    params.batch_idx_train,
+                )
+                tb_writer.add_scalar(
+                    "train/current_att_loss",
+                    att_loss_cpu / params.train_frames,
+                    params.batch_idx_train,
+                )
+                tb_writer.add_scalar(
+                    "train/current_loss",
+                    loss_cpu / params.train_frames,
+                    params.batch_idx_train,
+                )
+                tb_writer.add_scalar(
+                    "train/tot_avg_ctc_loss",
+                    tot_avg_ctc_loss,
+                    params.batch_idx_train,
+                )
+
+                tb_writer.add_scalar(
+                    "train/tot_avg_att_loss",
+                    tot_avg_att_loss,
+                    params.batch_idx_train,
+                )
+                tb_writer.add_scalar(
+                    "train/tot_avg_loss",
+                    tot_avg_loss,
+                    params.batch_idx_train,
+                )
+        if batch_idx > 0 and batch_idx % params.reset_interval == 0:
+            tot_loss = 0.0  # sum of losses over all batches
+            tot_ctc_loss = 0.0
+            tot_att_loss = 0.0
+
+            tot_frames = 0.0  # sum of frames over all batches
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"valid ctc loss {params.valid_ctc_loss:.4f},"
+                f"valid att loss {params.valid_att_loss:.4f},"
+                f"valid loss {params.valid_loss:.4f},"
+                f" best valid loss: {params.best_valid_loss:.4f} "
+                f"best valid epoch: {params.best_valid_epoch}"
+            )
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/valid_ctc_loss",
+                    params.valid_ctc_loss,
+                    params.batch_idx_train,
+                )
+                tb_writer.add_scalar(
+                    "train/valid_att_loss",
+                    params.valid_att_loss,
+                    params.batch_idx_train,
+                )
+                tb_writer.add_scalar(
+                    "train/valid_loss",
+                    params.valid_loss,
+                    params.batch_idx_train,
+                )
+
+    params.train_loss = params.tot_loss / params.tot_frames
+
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(42)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+    logging.info(params)
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="<sos/eos>",
+        eos_token="<sos/eos>",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=False,
+        is_espnet_structure=params.is_espnet_structure,
+        mmi_loss=params.mmi_loss,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+    model.to(device)
+    if world_size > 1:
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Noam(
+        model.parameters(),
+        model_size=params.attention_dim,
+        factor=params.lr_factor,
+        warm_step=params.warm_step,
+        weight_decay=params.weight_decay,
+    )
+
+    if checkpoints:
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    librispeech = LibriSpeechAsrDataModule(args)
+    train_dl = librispeech.train_dataloaders()
+    valid_dl = librispeech.valid_dataloaders()
+
+    for epoch in range(params.start_epoch, params.num_epochs):
+        train_dl.sampler.set_epoch(epoch)
+
+        cur_lr = optimizer._rate
+        if tb_writer is not None:
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        if rank == 0:
+            logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            tb_writer=tb_writer,
+            world_size=world_size,
+        )
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py
new file mode 100644
index 0000000000..f237ff8e3c
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py
@@ -0,0 +1,990 @@
+# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
+# Apache 2.0
+
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from embedding import Embedding
+from subsampling import Conv2dSubsampling, VggSubsampling
+from torch.nn.utils.rnn import pad_sequence
+
+# Note: TorchScript requires Dict/List/etc. to be fully typed.
+Supervisions = Dict[str, torch.Tensor]
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self,
+        num_features: int,
+        num_classes: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        num_decoder_layers: int = 6,
+        dropout: float = 0.1,
+        normalize_before: bool = True,
+        vgg_frontend: bool = False,
+        mmi_loss: bool = True,
+        use_feat_batchnorm: bool = False,
+    ) -> None:
+        """
+        Args:
+          num_features:
+            The input dimension of the model.
+          num_classes:
+            The output dimension of the model.
+          subsampling_factor:
+            Number of output frames is num_in_frames // subsampling_factor.
+            Currently, subsampling_factor MUST be 4.
+          d_model:
+            Attention dimension.
+          nhead:
+            Number of heads in multi-head attention.
+            Must satisfy d_model // nhead == 0.
+          dim_feedforward:
+            The output dimension of the feedforward layers in encoder/decoder.
+          num_encoder_layers:
+            Number of encoder layers.
+          num_decoder_layers:
+            Number of decoder layers.
+          dropout:
+            Dropout in encoder/decoder.
+          normalize_before:
+            If True, use pre-layer norm; False to use post-layer norm.
+          vgg_frontend:
+            True to use vgg style frontend for subsampling.
+          mmi_loss:
+          use_feat_batchnorm:
+            True to use batchnorm for the input layer.
+        """
+        super().__init__()
+        self.use_feat_batchnorm = use_feat_batchnorm
+        if use_feat_batchnorm:
+            self.feat_batchnorm = nn.BatchNorm1d(num_features)
+
+        self.num_features = num_features
+        self.num_classes = num_classes
+        self.subsampling_factor = subsampling_factor
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+
+        # self.encoder_embed converts the input of shape [N, T, num_classes]
+        # to the shape [N, T//subsampling_factor, d_model].
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_classes -> d_model
+        if vgg_frontend:
+            self.encoder_embed = VggSubsampling(num_features, d_model)
+        else:
+            self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+        self.encoder_pos = PositionalEncoding(d_model, dropout)
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            dropout=dropout,
+            normalize_before=normalize_before,
+        )
+
+        if normalize_before:
+            encoder_norm = nn.LayerNorm(d_model)
+        else:
+            encoder_norm = None
+
+        self.encoder = nn.TransformerEncoder(
+            encoder_layer=encoder_layer,
+            num_layers=num_encoder_layers,
+            norm=encoder_norm,
+        )
+
+        # TODO(fangjun): remove dropout
+        self.encoder_output_layer = nn.Sequential(
+            nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
+        )
+
+        if num_decoder_layers > 0:
+            if mmi_loss:
+                self.decoder_num_class = (
+                    self.num_classes + 1
+                )  # +1 for the sos/eos symbol
+            else:
+                self.decoder_num_class = (
+                    self.num_classes
+                )  # bpe model already has sos/eos symbol
+
+            self.decoder_embed = Embedding(
+                num_embeddings=self.decoder_num_class, embedding_dim=d_model
+            )
+            self.decoder_pos = PositionalEncoding(d_model, dropout)
+
+            decoder_layer = TransformerDecoderLayer(
+                d_model=d_model,
+                nhead=nhead,
+                dim_feedforward=dim_feedforward,
+                dropout=dropout,
+                normalize_before=normalize_before,
+            )
+
+            if normalize_before:
+                decoder_norm = nn.LayerNorm(d_model)
+            else:
+                decoder_norm = None
+
+            self.decoder = nn.TransformerDecoder(
+                decoder_layer=decoder_layer,
+                num_layers=num_decoder_layers,
+                norm=decoder_norm,
+            )
+
+            self.decoder_output_layer = torch.nn.Linear(
+                d_model, self.decoder_num_class
+            )
+
+            self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
+        else:
+            self.decoder_criterion = None
+
+    def forward(
+        self, x: torch.Tensor, supervision: Optional[Supervisions] = None
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          x:
+            The input tensor. Its shape is [N, T, C].
+          supervision:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            (CAUTION: It contains length information, i.e., start and number of
+             frames, before subsampling)
+
+        Returns:
+          Return a tuple containing 3 tensors:
+            - CTC output for ctc decoding. Its shape is [N, T, C]
+            - Encoder output with shape [T, N, C]. It can be used as key and
+              value for the decoder.
+            - Encoder output padding mask. It can be used as
+              memory_key_padding_mask for the decoder. Its shape is [N, T].
+              It is None if `supervision` is None.
+        """
+        if self.use_feat_batchnorm:
+            x = x.permute(0, 2, 1)  # [N, T, C] -> [N, C, T]
+            x = self.feat_batchnorm(x)
+            x = x.permute(0, 2, 1)  # [N, C, T] -> [N, T, C]
+        encoder_memory, memory_key_padding_mask = self.run_encoder(
+            x, supervision
+        )
+        x = self.ctc_output(encoder_memory)
+        return x, encoder_memory, memory_key_padding_mask
+
+    def run_encoder(
+        self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Run the transformer encoder.
+
+        Args:
+          x:
+            The model input. Its shape is [N, T, C].
+          supervisions:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            CAUTION: It contains length information, i.e., start and number of
+            frames, before subsampling
+            It is read directly from the batch, without any sorting. It is used
+            to compute the encoder padding mask, which is used as memory key
+            padding mask for the decoder.
+        Returns:
+          Return a tuple with two tensors:
+            - The encoder output, with shape [T, N, C]
+            - encoder padding mask, with shape [N, T].
+              The mask is None if `supervisions` is None.
+              It is used as memory key padding mask in the decoder.
+        """
+        x = self.encoder_embed(x)
+        x = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+        mask = encoder_padding_mask(x.size(0), supervisions)
+        mask = mask.to(x.device) if mask is not None else None
+        x = self.encoder(x, src_key_padding_mask=mask)  # (T, N, C)
+
+        return x, mask
+
+    def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+          x:
+            The output tensor from the transformer encoder.
+            Its shape is [T, N, C]
+
+        Returns:
+          Return a tensor that can be used for CTC decoding.
+          Its shape is [N, T, C]
+        """
+        x = self.encoder_output_layer(x)
+        x = x.permute(1, 0, 2)  # (T, N, C) ->(N, T, C)
+        x = nn.functional.log_softmax(x, dim=-1)  # (N, T, C)
+        return x
+
+    def decoder_forward(
+        self,
+        memory: torch.Tensor,
+        memory_key_padding_mask: torch.Tensor,
+        token_ids: List[List[int]],
+        sos_id: int,
+        eos_id: int,
+    ) -> torch.Tensor:
+        """
+        Args:
+          memory:
+            It's the output of the encoder with shape [T, N, C]
+          memory_key_padding_mask:
+            The padding mask from the encoder.
+          token_ids:
+            A list-of-list IDs. Each sublist contains IDs for an utterance.
+            The IDs can be either phone IDs or word piece IDs.
+          sos_id:
+            sos token id
+          eos_id:
+            eos token id
+
+        Returns:
+            A scalar, the **sum** of label smoothing loss over utterances
+            in the batch without any normalization.
+        """
+        ys_in = add_sos(token_ids, sos_id=sos_id)
+        ys_in = [torch.tensor(y) for y in ys_in]
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
+
+        ys_out = add_eos(token_ids, eos_id=eos_id)
+        ys_out = [torch.tensor(y) for y in ys_out]
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
+
+        device = memory.device
+        ys_in_pad = ys_in_pad.to(device)
+        ys_out_pad = ys_out_pad.to(device)
+
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
+
+        tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+        # TODO: Use length information to create the decoder padding mask
+        # We set the first column to False since the first column in ys_in_pad
+        # contains sos_id, which is the same as eos_id in our current setting.
+        tgt_key_padding_mask[:, 0] = False
+
+        tgt = self.decoder_embed(ys_in_pad)  # (N, T) -> (N, T, C)
+        tgt = self.decoder_pos(tgt)
+        tgt = tgt.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+        pred_pad = self.decoder(
+            tgt=tgt,
+            memory=memory,
+            tgt_mask=tgt_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+        )  # (T, N, C)
+        pred_pad = pred_pad.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+        pred_pad = self.decoder_output_layer(pred_pad)  # (N, T, C)
+
+        decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
+
+        return decoder_loss
+
+    def decoder_nll(
+        self,
+        memory: torch.Tensor,
+        memory_key_padding_mask: torch.Tensor,
+        token_ids: List[List[int]],
+        sos_id: int,
+        eos_id: int,
+    ) -> torch.Tensor:
+        """
+        Args:
+          memory:
+            It's the output of the encoder with shape [T, N, C]
+          memory_key_padding_mask:
+            The padding mask from the encoder.
+          token_ids:
+            A list-of-list IDs (e.g., word piece IDs).
+            Each sublist represents an utterance.
+          sos_id:
+            The token ID for SOS.
+          eos_id:
+            The token ID for EOS.
+        Returns:
+            A 2-D tensor of shape (len(token_ids), max_token_length)
+            representing the cross entropy loss (i.e., negative log-likelihood).
+        """
+        # The common part between this function and decoder_forward could be
+        # extracted as a separate function.
+
+        ys_in = add_sos(token_ids, sos_id=sos_id)
+        ys_in = [torch.tensor(y) for y in ys_in]
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
+
+        ys_out = add_eos(token_ids, eos_id=eos_id)
+        ys_out = [torch.tensor(y) for y in ys_out]
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
+
+        device = memory.device
+        ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
+        ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
+
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
+
+        tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+        # TODO: Use length information to create the decoder padding mask
+        # We set the first column to False since the first column in ys_in_pad
+        # contains sos_id, which is the same as eos_id in our current setting.
+        tgt_key_padding_mask[:, 0] = False
+
+        tgt = self.decoder_embed(ys_in_pad)  # (B, T) -> (B, T, F)
+        tgt = self.decoder_pos(tgt)
+        tgt = tgt.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
+        pred_pad = self.decoder(
+            tgt=tgt,
+            memory=memory,
+            tgt_mask=tgt_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+        )  # (T, B, F)
+        pred_pad = pred_pad.permute(1, 0, 2)  # (T, B, F) -> (B, T, F)
+        pred_pad = self.decoder_output_layer(pred_pad)  # (B, T, F)
+        # nll: negative log-likelihood
+        nll = torch.nn.functional.cross_entropy(
+            pred_pad.view(-1, self.decoder_num_class),
+            ys_out_pad.view(-1),
+            ignore_index=-1,
+            reduction="none",
+        )
+
+        nll = nll.view(pred_pad.shape[0], -1)
+
+        return nll
+
+
+class TransformerEncoderLayer(nn.Module):
+    """
+    Modified from torch.nn.TransformerEncoderLayer.
+    Add support of normalize_before,
+    i.e., use layer_norm before the first block.
+
+    Args:
+      d_model:
+        the number of expected features in the input (required).
+      nhead:
+        the number of heads in the multiheadattention models (required).
+      dim_feedforward:
+        the dimension of the feedforward network model (default=2048).
+      dropout:
+        the dropout value (default=0.1).
+      activation:
+        the activation function of intermediate layer, relu or
+        gelu (default=relu).
+      normalize_before:
+        whether to use layer_norm before the first block.
+
+    Examples::
+        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> out = encoder_layer(src)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        activation: str = "relu",
+        normalize_before: bool = True,
+    ) -> None:
+        super(TransformerEncoderLayer, self).__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+
+        self.normalize_before = normalize_before
+
+    def __setstate__(self, state):
+        if "activation" not in state:
+            state["activation"] = nn.functional.relu
+        super(TransformerEncoderLayer, self).__setstate__(state)
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        src_mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+            src: the sequence to the encoder layer (required).
+            src_mask: the mask for the src sequence (optional).
+            src_key_padding_mask: the mask for the src keys per batch (optional)
+
+        Shape:
+            src: (S, N, E).
+            src_mask: (S, S).
+            src_key_padding_mask: (N, S).
+            S is the source sequence length, T is the target sequence length,
+            N is the batch size, E is the feature number
+        """
+        residual = src
+        if self.normalize_before:
+            src = self.norm1(src)
+        src2 = self.self_attn(
+            src,
+            src,
+            src,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+        src = residual + self.dropout1(src2)
+        if not self.normalize_before:
+            src = self.norm1(src)
+
+        residual = src
+        if self.normalize_before:
+            src = self.norm2(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = residual + self.dropout2(src2)
+        if not self.normalize_before:
+            src = self.norm2(src)
+        return src
+
+
+class TransformerDecoderLayer(nn.Module):
+    """
+    Modified from torch.nn.TransformerDecoderLayer.
+    Add support of normalize_before,
+    i.e., use layer_norm before the first block.
+
+    Args:
+      d_model:
+        the number of expected features in the input (required).
+      nhead:
+        the number of heads in the multiheadattention models (required).
+      dim_feedforward:
+        the dimension of the feedforward network model (default=2048).
+      dropout:
+        the dropout value (default=0.1).
+      activation:
+        the activation function of intermediate layer, relu or
+        gelu (default=relu).
+
+    Examples::
+        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+        >>> memory = torch.rand(10, 32, 512)
+        >>> tgt = torch.rand(20, 32, 512)
+        >>> out = decoder_layer(tgt, memory)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        activation: str = "relu",
+        normalize_before: bool = True,
+    ) -> None:
+        super(TransformerDecoderLayer, self).__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
+        self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.norm3 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+
+        self.normalize_before = normalize_before
+
+    def __setstate__(self, state):
+        if "activation" not in state:
+            state["activation"] = nn.functional.relu
+        super(TransformerDecoderLayer, self).__setstate__(state)
+
+    def forward(
+        self,
+        tgt: torch.Tensor,
+        memory: torch.Tensor,
+        tgt_mask: Optional[torch.Tensor] = None,
+        memory_mask: Optional[torch.Tensor] = None,
+        tgt_key_padding_mask: Optional[torch.Tensor] = None,
+        memory_key_padding_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Pass the inputs (and mask) through the decoder layer.
+
+        Args:
+          tgt:
+            the sequence to the decoder layer (required).
+          memory:
+            the sequence from the last layer of the encoder (required).
+          tgt_mask:
+            the mask for the tgt sequence (optional).
+          memory_mask:
+            the mask for the memory sequence (optional).
+          tgt_key_padding_mask:
+            the mask for the tgt keys per batch (optional).
+          memory_key_padding_mask:
+            the mask for the memory keys per batch (optional).
+
+        Shape:
+            tgt: (T, N, E).
+            memory: (S, N, E).
+            tgt_mask: (T, T).
+            memory_mask: (T, S).
+            tgt_key_padding_mask: (N, T).
+            memory_key_padding_mask: (N, S).
+            S is the source sequence length, T is the target sequence length,
+            N is the batch size, E is the feature number
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt2 = self.self_attn(
+            tgt,
+            tgt,
+            tgt,
+            attn_mask=tgt_mask,
+            key_padding_mask=tgt_key_padding_mask,
+        )[0]
+        tgt = residual + self.dropout1(tgt2)
+        if not self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm2(tgt)
+        tgt2 = self.src_attn(
+            tgt,
+            memory,
+            memory,
+            attn_mask=memory_mask,
+            key_padding_mask=memory_key_padding_mask,
+        )[0]
+        tgt = residual + self.dropout2(tgt2)
+        if not self.normalize_before:
+            tgt = self.norm2(tgt)
+
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm3(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = residual + self.dropout3(tgt2)
+        if not self.normalize_before:
+            tgt = self.norm3(tgt)
+        return tgt
+
+
+def _get_activation_fn(activation: str):
+    if activation == "relu":
+        return nn.functional.relu
+    elif activation == "gelu":
+        return nn.functional.gelu
+
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
+
+
+class PositionalEncoding(nn.Module):
+    """This class implements the positional encoding
+    proposed in the following paper:
+
+    - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
+
+        PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
+        PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
+
+    Note::
+
+      1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
+                               = exp(-1* 2i / d_model * log(100000))
+                               = exp(2i * -(log(10000) / d_model))
+    """
+
+    def __init__(self, d_model: int, dropout: float = 0.1) -> None:
+        """
+        Args:
+          d_model:
+            Embedding dimension.
+          dropout:
+            Dropout probability to be applied to the output of this module.
+        """
+        super().__init__()
+        self.d_model = d_model
+        self.pos_scale = 1. / math.sqrt(self.d_model)
+        self.dropout = nn.Dropout(p=dropout)
+        self.pe = None
+
+    def extend_pe(self, x: torch.Tensor) -> None:
+        """Extend the time t in the positional encoding if required.
+
+        The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
+        is [N, T, d_model]. If T > T1, then we change the shape of self.pe
+        to [N, T, d_model]. Otherwise, nothing is done.
+
+        Args:
+          x:
+            It is a tensor of shape [N, T, C].
+        Returns:
+          Return None.
+        """
+        if self.pe is not None:
+            if self.pe.size(1) >= x.size(1):
+                if self.pe.dtype != x.dtype or self.pe.device != x.device:
+                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
+        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.d_model, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.d_model)
+        )
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+        # Now pe is of shape [1, T, d_model], where T is x.size(1)
+        self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Add positional encoding.
+
+        Args:
+          x:
+            Its shape is [N, T, C]
+
+        Returns:
+          Return a tensor of shape [N, T, C]
+        """
+        self.extend_pe(x)
+        x = x + self.pe[:, : x.size(1), :] * self.pos_scale
+        return self.dropout(x)
+
+
+class Noam(object):
+    """
+    Implements Noam optimizer.
+
+    Proposed in
+    "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
+
+    Modified from
+    https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py  # noqa
+
+    Args:
+      params:
+        iterable of parameters to optimize or dicts defining parameter groups
+      model_size:
+        attention dimension of the transformer model
+      factor:
+        learning rate factor
+      warm_step:
+        warmup steps
+    """
+
+    def __init__(
+        self,
+        params,
+        model_size: int = 256,
+        factor: float = 10.0,
+        warm_step: int = 25000,
+        weight_decay=0,
+    ) -> None:
+        """Construct an Noam object."""
+        self.optimizer = torch.optim.Adam(
+            params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
+        )
+        self._step = 0
+        self.warmup = warm_step
+        self.factor = factor
+        self.model_size = model_size
+        self._rate = 0
+
+    @property
+    def param_groups(self):
+        """Return param_groups."""
+        return self.optimizer.param_groups
+
+    def step(self):
+        """Update parameters and rate."""
+        self._step += 1
+        rate = self.rate()
+        for p in self.optimizer.param_groups:
+            p["lr"] = rate
+        self._rate = rate
+        self.optimizer.step()
+
+    def rate(self, step=None):
+        """Implement `lrate` above."""
+        if step is None:
+            step = self._step
+        return (
+            self.factor
+            * self.model_size ** (-0.5)
+            * min(step ** (-0.5), step * self.warmup ** (-1.5))
+        )
+
+    def zero_grad(self):
+        """Reset gradient."""
+        self.optimizer.zero_grad()
+
+    def state_dict(self):
+        """Return state_dict."""
+        return {
+            "_step": self._step,
+            "warmup": self.warmup,
+            "factor": self.factor,
+            "model_size": self.model_size,
+            "_rate": self._rate,
+            "optimizer": self.optimizer.state_dict(),
+        }
+
+    def load_state_dict(self, state_dict):
+        """Load state_dict."""
+        for key, value in state_dict.items():
+            if key == "optimizer":
+                self.optimizer.load_state_dict(state_dict["optimizer"])
+            else:
+                setattr(self, key, value)
+
+
+class LabelSmoothingLoss(nn.Module):
+    """
+    Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
+    and p_{prob. computed by model}(w) is minimized.
+    Modified from
+    https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py  # noqa
+
+    Args:
+        size: the number of class
+        padding_idx: padding_idx: ignored class id
+        smoothing: smoothing rate (0.0 means the conventional CE)
+        normalize_length: normalize loss by sequence length if True
+        criterion: loss function to be smoothed
+    """
+
+    def __init__(
+        self,
+        size: int,
+        padding_idx: int = -1,
+        smoothing: float = 0.1,
+        normalize_length: bool = False,
+        criterion: nn.Module = nn.KLDivLoss(reduction="none"),
+    ) -> None:
+        """Construct an LabelSmoothingLoss object."""
+        super(LabelSmoothingLoss, self).__init__()
+        self.criterion = criterion
+        self.padding_idx = padding_idx
+        assert 0.0 < smoothing <= 1.0
+        self.confidence = 1.0 - smoothing
+        self.smoothing = smoothing
+        self.size = size
+        self.true_dist = None
+        self.normalize_length = normalize_length
+
+    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+        """
+        Compute loss between x and target.
+
+        Args:
+          x:
+            prediction of dimension
+            (batch_size, input_length, number_of_classes).
+          target:
+            target masked with self.padding_id of
+            dimension (batch_size, input_length).
+
+        Returns:
+          A scalar tensor containing the loss without normalization.
+        """
+        assert x.size(2) == self.size
+        #  batch_size = x.size(0)
+        x = x.view(-1, self.size)
+        target = target.view(-1)
+        with torch.no_grad():
+            true_dist = x.clone()
+            true_dist.fill_(self.smoothing / (self.size - 1))
+            ignore = target == self.padding_idx  # (B,)
+            total = len(target) - ignore.sum().item()
+            target = target.masked_fill(ignore, 0)  # avoid -1 index
+            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
+        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
+        #  denom = total if self.normalize_length else batch_size
+        denom = total if self.normalize_length else 1
+        return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
+
+
+def encoder_padding_mask(
+    max_len: int, supervisions: Optional[Supervisions] = None
+) -> Optional[torch.Tensor]:
+    """Make mask tensor containing indexes of padded part.
+
+    TODO::
+      This function **assumes** that the model uses
+      a subsampling factor of 4. We should remove that
+      assumption later.
+
+    Args:
+      max_len:
+        Maximum length of input features.
+        CAUTION: It is the length after subsampling.
+      supervisions:
+        Supervision in lhotse format.
+        See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+        (CAUTION: It contains length information, i.e., start and number of
+         frames, before subsampling)
+
+    Returns:
+        Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
+    """
+    if supervisions is None:
+        return None
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"],
+            supervisions["num_frames"],
+        ),
+        1,
+    ).to(torch.int32)
+
+    lengths = [
+        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+    ]
+    for idx in range(supervision_segments.size(0)):
+        # Note: TorchScript doesn't allow to unpack tensors as tuples
+        sequence_idx = supervision_segments[idx, 0].item()
+        start_frame = supervision_segments[idx, 1].item()
+        num_frames = supervision_segments[idx, 2].item()
+        lengths[sequence_idx] = start_frame + num_frames
+
+    lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
+    bs = int(len(lengths))
+    seq_range = torch.arange(0, max_len, dtype=torch.int64)
+    seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
+    # Note: TorchScript doesn't implement Tensor.new()
+    seq_length_expand = torch.tensor(
+        lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
+    ).unsqueeze(-1)
+    mask = seq_range_expand >= seq_length_expand
+
+    return mask
+
+
+def decoder_padding_mask(
+    ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
+    """Generate a length mask for input.
+
+    The masked position are filled with True,
+    Unmasked positions are filled with False.
+
+    Args:
+      ys_pad:
+        padded tensor of dimension (batch_size, input_length).
+      ignore_id:
+        the ignored number (the padding number) in ys_pad
+
+    Returns:
+      Tensor:
+        a bool tensor of the same shape as the input tensor.
+    """
+    ys_mask = ys_pad == ignore_id
+    return ys_mask
+
+
+def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
+    """Generate a square mask for the sequence. The masked positions are
+    filled with float('-inf'). Unmasked positions are filled with float(0.0).
+    The mask can be used for masked self-attention.
+
+    For instance, if sz is 3, it returns::
+
+        tensor([[0., -inf, -inf],
+                [0., 0., -inf],
+                [0., 0., 0]])
+
+    Args:
+      sz: mask size
+
+    Returns:
+      A square mask of dimension (sz, sz)
+    """
+    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
+    mask = (
+        mask.float()
+        .masked_fill(mask == 0, float("-inf"))
+        .masked_fill(mask == 1, float(0.0))
+    )
+    return mask
+
+
+def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
+    """Prepend sos_id to each utterance.
+
+    Args:
+      token_ids:
+        A list-of-list of token IDs. Each sublist contains
+        token IDs (e.g., word piece IDs) of an utterance.
+      sos_id:
+        The ID of the SOS token.
+
+    Return:
+      Return a new list-of-list, where each sublist starts
+      with SOS ID.
+    """
+    ans = []
+    for utt in token_ids:
+        ans.append([sos_id] + utt)
+    return ans
+
+
+def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
+    """Append eos_id to each utterance.
+
+    Args:
+      token_ids:
+        A list-of-list of token IDs. Each sublist contains
+        token IDs (e.g., word piece IDs) of an utterance.
+      eos_id:
+        The ID of the EOS token.
+
+    Return:
+      Return a new list-of-list, where each sublist ends
+      with EOS ID.
+    """
+    ans = []
+    for utt in token_ids:
+        ans.append(utt + [eos_id])
+    return ans

From 69a2bd5179f1b8b6186fdf81c0b2da5cf6cacdfc Mon Sep 17 00:00:00 2001
From: Fangjun Kuang <csukuangfj@gmail.com>
Date: Thu, 26 Aug 2021 14:52:00 +0800
Subject: [PATCH 2/5] Merge master.

---
 .../conformer.py                              |  19 +-
 .../conformer_ctc_embedding_scale/decode.py   | 102 +++--
 .../pretrained.py                             | 351 +-----------------
 .../subsampling.py                            |  17 +
 .../test_subsampling.py                       |  33 --
 .../test_transformer.py                       |  89 -----
 .../conformer_ctc_embedding_scale/train.py    |  42 ++-
 .../transformer.py                            |  26 +-
 8 files changed, 155 insertions(+), 524 deletions(-)
 mode change 100755 => 120000 egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
 delete mode 100755 egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py
 delete mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py

diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py
index a00664a992..08287d686d 100644
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py
@@ -1,7 +1,20 @@
 #!/usr/bin/env python3
-
 # Copyright (c)  2021  University of Chinese Academy of Sciences (author: Han Zhu)
-# Apache 2.0
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 
 import math
 import warnings
@@ -396,7 +409,7 @@ def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
             :,
             self.pe.size(1) // 2
             - x.size(1)
-            + 1 : self.pe.size(1) // 2
+            + 1 : self.pe.size(1) // 2  # noqa E203
             + x.size(1),
         ]
         return self.dropout(x), self.dropout(pos_emb)
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py
index c3354c0a3e..676e4bf6a2 100755
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py
@@ -1,8 +1,20 @@
 #!/usr/bin/env python3
-
 # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
-# (still working in progress)
 
 import argparse
 import logging
@@ -45,28 +57,63 @@ def get_parser():
     parser.add_argument(
         "--epoch",
         type=int,
-        default=9,
+        default=34,
         help="It specifies the checkpoint to use for decoding."
         "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
-        default=1,
+        default=20,
         help="Number of checkpoints to average. Automatically select "
         "consecutive checkpoints before the checkpoint specified by "
         "'--epoch'. ",
     )
 
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="attention-decoder",
+        help="""Decoding method.
+        Supported values are:
+            - (1) 1best. Extract the best path from the decoding lattice as the
+              decoding result.
+            - (2) nbest. Extract n paths from the decoding lattice; the path
+              with the highest score is the decoding result.
+            - (3) nbest-rescoring. Extract n paths from the decoding lattice,
+              rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+              the highest score is the decoding result.
+            - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
+              n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+              is the decoding result.
+            - (5) attention-decoder. Extract n paths from the LM rescored
+              lattice, the path with the highest score is the decoding result.
+            - (6) nbest-oracle. Its WER is the lower bound of any n-best
+              rescoring method can achieve. Useful for debugging n-best
+              rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        """,
+    )
+
     parser.add_argument(
         "--lattice-score-scale",
         type=float,
         default=1.0,
-        help="The scale to be applied to `lattice.scores`."
-        "It's needed if you use any kinds of n-best based rescoring. "
-        "Currently, it is used when the decoding method is: nbest, "
-        "nbest-rescoring, attention-decoder, and nbest-oracle. "
-        "A smaller value results in more unique paths.",
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
     )
 
     return parser
@@ -92,21 +139,6 @@ def get_params() -> AttributeDict:
             "min_active_states": 30,
             "max_active_states": 10000,
             "use_double_scores": True,
-            # Possible values for method:
-            #  - 1best
-            #  - nbest
-            #  - nbest-rescoring
-            #  - whole-lattice-rescoring
-            #  - attention-decoder
-            #  - nbest-oracle
-            #  "method": "nbest",
-            #  "method": "nbest-rescoring",
-            #  "method": "whole-lattice-rescoring",
-            "method": "attention-decoder",
-            #  "method": "nbest-oracle",
-            # num_paths is used when method is "nbest", "nbest-rescoring",
-            # attention-decoder, and nbest-oracle
-            "num_paths": 100,
         }
     )
     return params
@@ -117,7 +149,7 @@ def decode_one_batch(
     model: nn.Module,
     HLG: k2.Fsa,
     batch: dict,
-    lexicon: Lexicon,
+    word_table: k2.SymbolTable,
     sos_id: int,
     eos_id: int,
     G: Optional[k2.Fsa] = None,
@@ -151,8 +183,8 @@ def decode_one_batch(
         It is the return value from iterating
         `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
         for the format of the `batch`.
-      lexicon:
-        It contains word symbol table.
+      word_table:
+        The word symbol table.
       sos_id:
         The token ID of the SOS.
       eos_id:
@@ -205,7 +237,7 @@ def decode_one_batch(
             lattice=lattice,
             num_paths=params.num_paths,
             ref_texts=supervisions["text"],
-            lexicon=lexicon,
+            word_table=word_table,
             scale=params.lattice_score_scale,
         )
 
@@ -225,7 +257,7 @@ def decode_one_batch(
             key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}"  # noqa
 
         hyps = get_texts(best_path)
-        hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
         return {key: hyps}
 
     assert params.method in [
@@ -271,7 +303,7 @@ def decode_one_batch(
     ans = dict()
     for lm_scale_str, best_path in best_path_dict.items():
         hyps = get_texts(best_path)
-        hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
         ans[lm_scale_str] = hyps
     return ans
 
@@ -281,7 +313,7 @@ def decode_dataset(
     params: AttributeDict,
     model: nn.Module,
     HLG: k2.Fsa,
-    lexicon: Lexicon,
+    word_table: k2.SymbolTable,
     sos_id: int,
     eos_id: int,
     G: Optional[k2.Fsa] = None,
@@ -297,8 +329,8 @@ def decode_dataset(
         The neural model.
       HLG:
         The decoding graph.
-      lexicon:
-        It contains word symbol table.
+      word_table:
+        It is the word symbol table.
       sos_id:
         The token ID for SOS.
       eos_id:
@@ -332,7 +364,7 @@ def decode_dataset(
             model=model,
             HLG=HLG,
             batch=batch,
-            lexicon=lexicon,
+            word_table=word_table,
             G=G,
             sos_id=sos_id,
             eos_id=eos_id,
@@ -528,7 +560,7 @@ def main():
             params=params,
             model=model,
             HLG=HLG,
-            lexicon=lexicon,
+            word_table=lexicon.word_table,
             G=G,
             sos_id=sos_id,
             eos_id=eos_id,
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
deleted file mode 100755
index c63616d28b..0000000000
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
+++ /dev/null
@@ -1,350 +0,0 @@
-#!/usr/bin/env python3
-
-import argparse
-import logging
-import math
-from typing import List
-
-import k2
-import kaldifeat
-import torch
-import torchaudio
-from conformer import Conformer
-from torch.nn.utils.rnn import pad_sequence
-
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_attention_decoder,
-    rescore_with_whole_lattice,
-)
-from icefall.utils import AttributeDict, get_texts
-
-
-def get_parser():
-    parser = argparse.ArgumentParser(
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter
-    )
-
-    parser.add_argument(
-        "--checkpoint",
-        type=str,
-        required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
-    )
-
-    parser.add_argument(
-        "--words-file",
-        type=str,
-        required=True,
-        help="Path to words.txt",
-    )
-
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
-
-    parser.add_argument(
-        "--method",
-        type=str,
-        default="1best",
-        help="""Decoding method.
-        Possible values are:
-        (1) 1best - Use the best path as decoding output. Only
-            the transformer encoder output is used for decoding.
-            We call it HLG decoding.
-        (2) whole-lattice-rescoring - Use an LM to rescore the
-            decoding lattice and then use 1best to decode the
-            rescored lattice.
-            We call it HLG decoding + n-gram LM rescoring.
-        (3) attention-decoder - Extract n paths from he rescored
-            lattice and use the transformer attention decoder for
-            rescoring.
-            We call it HLG decoding + n-gram LM rescoring + attention
-            decoder rescoring.
-        """,
-    )
-
-    parser.add_argument(
-        "--G",
-        type=str,
-        help="""An LM for rescoring.
-        Used only when method is
-        whole-lattice-rescoring or attention-decoder.
-        It's usually a 4-gram LM.
-        """,
-    )
-
-    parser.add_argument(
-        "--num-paths",
-        type=int,
-        default=100,
-        help="""
-        Used only when method is attention-decoder.
-        It specifies the size of n-best list.""",
-    )
-
-    parser.add_argument(
-        "--ngram-lm-scale",
-        type=float,
-        default=1.3,
-        help="""
-        Used only when method is whole-lattice-rescoring and attention-decoder.
-        It specifies the scale for n-gram LM scores.
-        (Note: You need to tune it on a dataset.)
-        """,
-    )
-
-    parser.add_argument(
-        "--attention-decoder-scale",
-        type=float,
-        default=1.2,
-        help="""
-        Used only when method is attention-decoder.
-        It specifies the scale for attention decoder scores.
-        (Note: You need to tune it on a dataset.)
-        """,
-    )
-
-    parser.add_argument(
-        "--lattice-score-scale",
-        type=float,
-        default=0.5,
-        help="""
-        Used only when method is attention-decoder.
-        It specifies the scale for lattice.scores when
-        extracting n-best lists. A smaller value results in
-        more unique number of paths with the risk of missing
-        the best path.
-        """,
-    )
-
-    parser.add_argument(
-        "--sos-id",
-        type=float,
-        default=1,
-        help="""
-        Used only when method is attention-decoder.
-        It specifies ID for the SOS token.
-        """,
-    )
-
-    parser.add_argument(
-        "--eos-id",
-        type=float,
-        default=1,
-        help="""
-        Used only when method is attention-decoder.
-        It specifies ID for the EOS token.
-        """,
-    )
-
-    parser.add_argument(
-        "sound_files",
-        type=str,
-        nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
-    )
-
-    return parser
-
-
-def get_params() -> AttributeDict:
-    params = AttributeDict(
-        {
-            "feature_dim": 80,
-            "nhead": 8,
-            "num_classes": 5000,
-            "sample_rate": 16000,
-            "attention_dim": 512,
-            "subsampling_factor": 4,
-            "num_decoder_layers": 6,
-            "vgg_frontend": False,
-            "is_espnet_structure": True,
-            "mmi_loss": False,
-            "use_feat_batchnorm": True,
-            "search_beam": 20,
-            "output_beam": 8,
-            "min_active_states": 30,
-            "max_active_states": 10000,
-            "use_double_scores": True,
-        }
-    )
-    return params
-
-
-def read_sound_files(
-    filenames: List[str], expected_sample_rate: float
-) -> List[torch.Tensor]:
-    """Read a list of sound files into a list 1-D float32 torch tensors.
-    Args:
-      filenames:
-        A list of sound filenames.
-      expected_sample_rate:
-        The expected sample rate of the sound files.
-    Returns:
-      Return a list of 1-D float32 torch tensors.
-    """
-    ans = []
-    for f in filenames:
-        wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
-        # We use only the first channel
-        ans.append(wave[0])
-    return ans
-
-
-def main():
-    parser = get_parser()
-    args = parser.parse_args()
-
-    params = get_params()
-    params.update(vars(args))
-    logging.info(f"{params}")
-
-    device = torch.device("cpu")
-    if torch.cuda.is_available():
-        device = torch.device("cuda", 0)
-
-    logging.info(f"device: {device}")
-
-    logging.info("Creating model")
-    model = Conformer(
-        num_features=params.feature_dim,
-        nhead=params.nhead,
-        d_model=params.attention_dim,
-        num_classes=params.num_classes,
-        subsampling_factor=params.subsampling_factor,
-        num_decoder_layers=params.num_decoder_layers,
-        vgg_frontend=params.vgg_frontend,
-        is_espnet_structure=params.is_espnet_structure,
-        mmi_loss=params.mmi_loss,
-        use_feat_batchnorm=params.use_feat_batchnorm,
-    )
-
-    checkpoint = torch.load(args.checkpoint, map_location="cpu")
-    model.load_state_dict(checkpoint["model"])
-    model.to(device)
-    model.eval()
-
-    logging.info(f"Loading HLG from {params.HLG}")
-    HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
-    HLG = HLG.to(device)
-    if not hasattr(HLG, "lm_scores"):
-        # For whole-lattice-rescoring and attention-decoder
-        HLG.lm_scores = HLG.scores.clone()
-
-    if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
-        logging.info(f"Loading G from {params.G}")
-        G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
-        G = G.to(device)
-        # Add epsilon self-loops to G as we will compose
-        # it with the whole lattice later
-        G = k2.add_epsilon_self_loops(G)
-        G = k2.arc_sort(G)
-        G.lm_scores = G.scores.clone()
-
-    logging.info("Constructing Fbank computer")
-    opts = kaldifeat.FbankOptions()
-    opts.device = device
-    opts.frame_opts.dither = 0
-    opts.frame_opts.snip_edges = False
-    opts.frame_opts.samp_freq = params.sample_rate
-    opts.mel_opts.num_bins = params.feature_dim
-
-    fbank = kaldifeat.Fbank(opts)
-
-    logging.info(f"Reading sound files: {params.sound_files}")
-    waves = read_sound_files(
-        filenames=params.sound_files, expected_sample_rate=params.sample_rate
-    )
-    waves = [w.to(device) for w in waves]
-
-    logging.info(f"Decoding started")
-    features = fbank(waves)
-
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
-
-    # Note: We don't use key padding mask for attention during decoding
-    with torch.no_grad():
-        nnet_output, memory, memory_key_padding_mask = model(features)
-
-    batch_size = nnet_output.shape[0]
-    supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
-        dtype=torch.int32,
-    )
-
-    lattice = get_lattice(
-        nnet_output=nnet_output,
-        HLG=HLG,
-        supervision_segments=supervision_segments,
-        search_beam=params.search_beam,
-        output_beam=params.output_beam,
-        min_active_states=params.min_active_states,
-        max_active_states=params.max_active_states,
-        subsampling_factor=params.subsampling_factor,
-    )
-
-    if params.method == "1best":
-        logging.info("Use HLG decoding")
-        best_path = one_best_decoding(
-            lattice=lattice, use_double_scores=params.use_double_scores
-        )
-    elif params.method == "whole-lattice-rescoring":
-        logging.info("Use HLG decoding + LM rescoring")
-        best_path_dict = rescore_with_whole_lattice(
-            lattice=lattice,
-            G_with_epsilon_loops=G,
-            lm_scale_list=[params.ngram_lm_scale],
-        )
-        best_path = next(iter(best_path_dict.values()))
-    elif params.method == "attention-decoder":
-        logging.info("Use HLG + LM rescoring + attention decoder rescoring")
-        rescored_lattice = rescore_with_whole_lattice(
-            lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
-        )
-        best_path_dict = rescore_with_attention_decoder(
-            lattice=rescored_lattice,
-            num_paths=params.num_paths,
-            model=model,
-            memory=memory,
-            memory_key_padding_mask=memory_key_padding_mask,
-            sos_id=params.sos_id,
-            eos_id=params.eos_id,
-            scale=params.lattice_score_scale,
-            ngram_lm_scale=params.ngram_lm_scale,
-            attention_scale=params.attention_decoder_scale,
-        )
-        best_path = next(iter(best_path_dict.values()))
-
-    hyps = get_texts(best_path)
-    word_sym_table = k2.SymbolTable.from_file(params.words_file)
-    hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
-
-    s = "\n"
-    for filename, hyp in zip(params.sound_files, hyps):
-        words = " ".join(hyp)
-        s += f"{filename}:\n{words}\n\n"
-    logging.info(s)
-
-    logging.info(f"Decoding Done")
-
-
-if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
-
-    logging.basicConfig(format=formatter, level=logging.INFO)
-    main()
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
new file mode 120000
index 0000000000..cd27e4304b
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
@@ -0,0 +1 @@
+../conformer_ctc/pretrained.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py
index 5c3e1222ef..720ed6c228 100644
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py
@@ -1,3 +1,20 @@
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
 import torch
 import torch.nn as nn
 
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py
deleted file mode 100755
index 937845d779..0000000000
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/usr/bin/env python3
-
-from subsampling import Conv2dSubsampling
-from subsampling import VggSubsampling
-import torch
-
-
-def test_conv2d_subsampling():
-    N = 3
-    odim = 2
-
-    for T in range(7, 19):
-        for idim in range(7, 20):
-            model = Conv2dSubsampling(idim=idim, odim=odim)
-            x = torch.empty(N, T, idim)
-            y = model(x)
-            assert y.shape[0] == N
-            assert y.shape[1] == ((T - 1) // 2 - 1) // 2
-            assert y.shape[2] == odim
-
-
-def test_vgg_subsampling():
-    N = 3
-    odim = 2
-
-    for T in range(7, 19):
-        for idim in range(7, 20):
-            model = VggSubsampling(idim=idim, odim=odim)
-            x = torch.empty(N, T, idim)
-            y = model(x)
-            assert y.shape[0] == N
-            assert y.shape[1] == ((T - 1) // 2 - 1) // 2
-            assert y.shape[2] == odim
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py
deleted file mode 100644
index 08e6806074..0000000000
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py
+++ /dev/null
@@ -1,89 +0,0 @@
-#!/usr/bin/env python3
-
-import torch
-from transformer import (
-    Transformer,
-    encoder_padding_mask,
-    generate_square_subsequent_mask,
-    decoder_padding_mask,
-    add_sos,
-    add_eos,
-)
-
-from torch.nn.utils.rnn import pad_sequence
-
-
-def test_encoder_padding_mask():
-    supervisions = {
-        "sequence_idx": torch.tensor([0, 1, 2]),
-        "start_frame": torch.tensor([0, 0, 0]),
-        "num_frames": torch.tensor([18, 7, 13]),
-    }
-
-    max_len = ((18 - 1) // 2 - 1) // 2
-    mask = encoder_padding_mask(max_len, supervisions)
-    expected_mask = torch.tensor(
-        [
-            [False, False, False],  # ((18 - 1)//2 - 1)//2 = 3,
-            [False, True, True],  # ((7 - 1)//2 - 1)//2 = 1,
-            [False, False, True],  # ((13 - 1)//2 - 1)//2 = 2,
-        ]
-    )
-    assert torch.all(torch.eq(mask, expected_mask))
-
-
-def test_transformer():
-    num_features = 40
-    num_classes = 87
-    model = Transformer(num_features=num_features, num_classes=num_classes)
-
-    N = 31
-
-    for T in range(7, 30):
-        x = torch.rand(N, T, num_features)
-        y, _, _ = model(x)
-        assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
-
-
-def test_generate_square_subsequent_mask():
-    s = 5
-    mask = generate_square_subsequent_mask(s)
-    inf = float("inf")
-    expected_mask = torch.tensor(
-        [
-            [0.0, -inf, -inf, -inf, -inf],
-            [0.0, 0.0, -inf, -inf, -inf],
-            [0.0, 0.0, 0.0, -inf, -inf],
-            [0.0, 0.0, 0.0, 0.0, -inf],
-            [0.0, 0.0, 0.0, 0.0, 0.0],
-        ]
-    )
-    assert torch.all(torch.eq(mask, expected_mask))
-
-
-def test_decoder_padding_mask():
-    x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
-    y = pad_sequence(x, batch_first=True, padding_value=-1)
-    mask = decoder_padding_mask(y, ignore_id=-1)
-    expected_mask = torch.tensor(
-        [
-            [False, False, True],
-            [False, True, True],
-            [False, False, False],
-        ]
-    )
-    assert torch.all(torch.eq(mask, expected_mask))
-
-
-def test_add_sos():
-    x = [[1, 2], [3], [2, 5, 8]]
-    y = add_sos(x, sos_id=0)
-    expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
-    assert y == expected_y
-
-
-def test_add_eos():
-    x = [[1, 2], [3], [2, 5, 8]]
-    y = add_eos(x, eos_id=0)
-    expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
-    assert y == expected_y
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
index 795a2ab571..b0dbe72adb 100755
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
@@ -1,6 +1,20 @@
 #!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
-# This is just at the very beginning ...
 
 import argparse
 import logging
@@ -60,6 +74,23 @@ def get_parser():
         help="Should various information be logged in tensorboard.",
     )
 
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=35,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=0,
+        help="""Resume training from from this epoch.
+        If it is positive, it will load checkpoint from
+        conformer_ctc/exp/epoch-{start_epoch-1}.pt
+        """,
+    )
+
     return parser
 
 
@@ -89,11 +120,6 @@ def get_params() -> AttributeDict:
 
         - subsampling_factor:  The subsampling factor for the model.
 
-        - start_epoch:  If it is not zero, load checkpoint `start_epoch-1`
-                        and continue training from that checkpoint.
-
-        - num_epochs:  Number of epochs to train.
-
         - best_train_loss: Best training loss so far. It is used to select
                            the model that has the lowest training loss. It is
                            updated during the training.
@@ -124,13 +150,11 @@ def get_params() -> AttributeDict:
     """
     params = AttributeDict(
         {
-            "exp_dir": Path("conformer_ctc_embedding_scale/exp"),
+            "exp_dir": Path("conformer_ctc/exp"),
             "lang_dir": Path("data/lang_bpe"),
             "feature_dim": 80,
             "weight_decay": 1e-6,
             "subsampling_factor": 4,
-            "start_epoch": 0,
-            "num_epochs": 20,
             "best_train_loss": float("inf"),
             "best_valid_loss": float("inf"),
             "best_train_epoch": -1,
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py
index f237ff8e3c..74e61b645c 100644
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py
@@ -1,5 +1,19 @@
-# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
-# Apache 2.0
+# Copyright    2021 University of Chinese Academy of Sciences (author: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 
 import math
 from typing import Dict, List, Optional, Tuple
@@ -641,7 +655,7 @@ def __init__(self, d_model: int, dropout: float = 0.1) -> None:
         """
         super().__init__()
         self.d_model = d_model
-        self.pos_scale = 1. / math.sqrt(self.d_model)
+        self.pos_scale = 1.0 / math.sqrt(self.d_model)
         self.dropout = nn.Dropout(p=dropout)
         self.pe = None
 
@@ -780,7 +794,8 @@ def load_state_dict(self, state_dict):
 
 class LabelSmoothingLoss(nn.Module):
     """
-    Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
+    Label-smoothing loss. KL-divergence between
+    q_{smoothed ground truth prob.}(w)
     and p_{prob. computed by model}(w) is minimized.
     Modified from
     https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py  # noqa
@@ -865,7 +880,8 @@ def encoder_padding_mask(
          frames, before subsampling)
 
     Returns:
-        Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
+        Tensor: Mask tensor of dimension (batch_size, input_length),
+        True denote the masked indices.
     """
     if supervisions is None:
         return None

From d09784fb8b104586535cbf381876ea6165000cec Mon Sep 17 00:00:00 2001
From: Fangjun Kuang <csukuangfj@gmail.com>
Date: Thu, 26 Aug 2021 15:11:34 +0800
Subject: [PATCH 3/5] Add madam optimizer.

---
 .flake8                                       |    6 +-
 .../conformer_ctc_embedding_scale/madam.py    | 1136 +++++++++++++++++
 .../conformer_ctc_embedding_scale/train.py    |   18 +-
 3 files changed, 1146 insertions(+), 14 deletions(-)
 create mode 100644 egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py

diff --git a/.flake8 b/.flake8
index 3f1227b9b9..d1aa1205d2 100644
--- a/.flake8
+++ b/.flake8
@@ -4,8 +4,10 @@ statistics=true
 max-line-length = 80
 per-file-ignores =
     # line too long
-    egs/librispeech/ASR/conformer_ctc/conformer.py: E501,
+    egs/librispeech/ASR/conformer_ctc*/conformer.py: E501,
 
 exclude =
   .git,
-  **/data/**
+  **/data/**,
+  egs/librispeech/ASR/conformer_ctc_embedding_scale/embedding.py,
+  egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
new file mode 100644
index 0000000000..bba17c375d
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
@@ -0,0 +1,1136 @@
+# fmt: off
+import logging
+import math
+import random
+from typing import List, Tuple
+
+import torch
+from torch import Tensor, nn
+from torch.optim.optimizer import Optimizer
+
+# After this many warnings about infinite gradients we'll die.
+inf_grad_count = 0
+inf_grad_max_count = 20
+
+class Madam(Optimizer):
+    r"""Madam is a modification of the Adam algorithm, with various changes
+    intended to support certain "common-sense" ideas and solve common
+    pathologies that can happen particularly in transformer-type models that
+    have multiplication of parameters (particularly, key and query matrices)--
+    these can be vulnerable to "subspace loss" where, if you have any l2
+    regularization, certain subspaces in the key/query space might get
+    regularized towards zero.  We solve this with a special formula that
+    changes how the l2/weight-decay is done (see compute_l2_grad()).
+    I'll try to write the math down at some point.  This formula only
+    applies to tensors that have at least two dimensions; for one-dimensional
+    tensors we simply won't do l2 regularization.
+
+    One more thing-- there is a special pathology that can sometimes afflict
+    models like LSTMs, where a particular element of a minibatch experiences
+    gradient blowup in the backward pass.  We'd like to identify such cases and
+    fix it somehow, e.g. by removing or scaling down the gradient for that
+    particular minibatch.  We can identify and somewhat fix this by seeing that the
+    gradient norm (computed over all the parameters in a parameter group) is
+    much more than on previous minibatches, and limiting it to (the preceding
+    average step size times some constant).
+
+    Like most optimization algorithms, for this to work well you need to
+    have an appropriate learning rate schedule, either decreasing with
+    time, or increasing (warm-up) and then decreasing.  The LR schedule may
+    possibly need to decrease a little more aggressively than you would with
+    Adam, or at least have smaller values overall than Adam, because
+    the smaller parameters will mean the effective (relative) learning
+    rate is higher.
+
+    This is modified from PyTorch's optim/adam.py
+
+
+    Args:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        betas (Tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        grad_norm_buffer_size (int, optional):  Buffer size used in detecting
+            minibatches with unusually large gradients and scaling them down.
+        limit_grad_factor (float): factor by which we don't allow the
+            gradient to be greater than the average of previous gradients
+            (we'll scale the gradient down, over the whole param-group,
+            to enforce this).  Must be greater than 1.  Set to float('inf')
+            to disable norm clipping.
+        min_target_rms:  A floor on the "target rms" of each Tensor, so
+            that Tensors that, when initialized, have less than this
+            rms value will have their target rms value floored to this
+        l2:  True to enable l2 regularization
+        l2_period:  You may set this to a value greater than one to save
+            computation by only periodically doing the l2 update.
+            We include a scaling factor in the formula so that, as far
+            as possible (for small learning rates) this shouldn't affect
+            the results.  (Note: this probably isn't necessary to set,
+            since it turns out the update is quite fast, at least on GPU,
+            and the gradient clipping is actually more of a problem)
+
+
+    .. _Adam\: A Method for Stochastic Optimization:
+        https://arxiv.org/abs/1412.6980
+    .. _Decoupled Weight Decay Regularization:
+        https://arxiv.org/abs/1711.05101
+    .. _On the Convergence of Adam and Beyond:
+        https://openreview.net/forum?id=ryQu7f-RZ
+
+    """
+
+    def __init__(self, params,
+                 lr: float = 1e-3,
+                 betas: Tuple[float, float] = (0.9, 0.999),
+                 eps: float = 1e-8,
+                 grad_norm_buffer_size: int = 8,
+                 limit_grad_factor: float = 2.0,
+                 min_target_rms: float = 0.05,
+                 l2: bool = True,
+                 l2_period: int = 1):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        if not (isinstance(grad_norm_buffer_size, int) and grad_norm_buffer_size > 1):
+            raise ValueError("Invalid grad_norm_buffer_size value: {}".format(grad_norm_buffer_size))
+        if not limit_grad_factor > 1.0:
+            raise ValueError("Invalid limit_grad_factor: {}".format(limit_grad_factor))
+        if not isinstance(l2, bool):
+            raise ValueError("Invalid l2 value: {}".format(l2))
+        if not l2_period >= 1:
+            raise ValueError("Invalid l2_period value: {}".format(l2_period))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        grad_norm_buffer_size=grad_norm_buffer_size,
+                        limit_grad_factor=limit_grad_factor,
+                        l2=l2, l2_period=l2_period,
+                        min_target_rms=min_target_rms)
+        super(Madam, self).__init__(params, defaults)
+
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        for group in self.param_groups:
+
+            beta1, beta2 = group['betas']
+            grad_norm_buffer_size = group['grad_norm_buffer_size']
+            limit_grad_factor = group['limit_grad_factor']
+            min_target_rms = group['min_target_rms']
+
+            # The next 5 lists are part of the original Adam optimizer
+            params_with_grad = []
+            grads = []
+            exp_avgs = []
+            exp_avg_sqs = []
+            state_steps = []
+
+            # The next 3 lists are not part of the original Adam optimizer.
+            target_rms_values = [] # relates to weight decay.  Target root-mean-square
+                                  # values of the elements of each parameter
+                                  # we are optimizing
+            prev_norm_stats = []  # contains Tensor with 2 elements each, the sum
+                                  # of the [sum_squared, count] of
+                                  # this parameter on previous minibatches (up to
+                                  # grad_norm_buffer_size minibatches)
+            cur_grad_norms = []   # and `cur_grad_norms` contains the squared l2
+                                  # norm norm of this step's gradient for this
+                                  # parameter, as a Tensor.
+
+
+            for p in group['params']:
+                if p.grad is not None:
+                    params_with_grad.append(p)
+                    if p.grad.is_sparse:
+                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+                    grads.append(p.grad)
+
+                    state = self.state[p]
+                    # Lazy state initialization
+                    if len(state) == 0:
+                        state['step'] = 0
+                        # Exponential moving average of gradient values
+                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                        # Exponential moving average of squared gradient values
+                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+                        # The things below are not part of original Adam, they are the Madam extension..
+                        state['target_rms'] = _get_target_rms(p, min_target_rms)
+                        # grad_norm_buf is a rotating buffer containing (grad_norm**2, count), where
+                        # count is 1 for grad_norms that are set and 0 for those that are not set because
+                        # we're near step 0 or because they were infinite.
+                        state['grad_norm_buf'] = torch.zeros(grad_norm_buffer_size, 2, device=p.device)
+
+                    exp_avgs.append(state['exp_avg'])
+                    exp_avg_sqs.append(state['exp_avg_sq'])
+
+                    target_rms_values.append(state['target_rms'])
+
+                    cur_step = state['step']
+                    if limit_grad_factor != float('inf'):
+                        grad_norm_buf = state['grad_norm_buf']
+                        cur_grad_norm = (p.grad ** 2).sum()  # actually squared nom
+                        prev_mean_norm = grad_norm_buf.sum(0)  # prev_mean_norm is a Tensor [ tot_norm_squared, count ]
+                        grad_norm_buf[cur_step % grad_norm_buffer_size][0] = cur_grad_norm
+                        grad_norm_buf[cur_step % grad_norm_buffer_size][1].fill_(1.0)
+                        prev_norm_stats.append(prev_mean_norm)
+                        cur_grad_norms.append(cur_grad_norm)
+
+                    # update the steps for each param group update
+                    cur_step += 1
+                    state['step'] = cur_step
+                    # record the step after step update
+                    state_steps.append(cur_step)
+
+            if limit_grad_factor != float('inf'):
+                self._apply_grad_norm_clipping(group['params'],
+                                               prev_norm_stats, cur_grad_norms, grads,
+                                               limit_grad_factor, grad_norm_buffer_size)
+
+            _madam(params_with_grad,
+                   grads,
+                   exp_avgs,
+                   exp_avg_sqs,
+                   state_steps,
+                   target_rms_values,
+                   beta1=beta1,
+                   beta2=beta2,
+                   lr=group['lr'],
+                   eps=group['eps'],
+                   l2=group['l2'],
+                   l2_period=group['l2_period'])
+
+        return loss
+
+
+    def _apply_grad_norm_clipping(self,
+                                  params_list,
+                                  prev_norm_stats: List[Tensor],
+                                  cur_grad_norms: List[Tensor],
+                                  grads: List[Tensor],
+                                  limit_grad_factor: float,
+                                  grad_norm_buffer_size: int) -> None:
+        """
+         This function applies gradient norm clipping for this parameter group if this
+         minibatch has substantially larger gradients in this param group than
+         recent minibatches.  The idea is to catch cases like where an LSTM
+         happens to blow up in the backward pass, or some code bug causes very
+         large or infinite gradients on a particular minibatch; so we scale
+         down any very large gradients and zero infinite ones.
+
+     Args:
+        params_list:   some kind of iterable or list of params in this group
+        prev_norm_stats:  a list which, for each parameter in this group
+                          with a grad, contains a Tensor with 2 elements each, containing
+                                  # the [sum, count]  of up to `grad_norm_buffer_size`
+                                  # norms of this parameter on previous minibatches;
+        cur_grad_norms:  a list of Tensor containing, for each parameter in this group,
+                         the norm of this step's gradient for this parameter.
+        grads:     List of gradients with the same order as prev_norm_stats and
+                        cur_grad_norms
+        limit_grad_factor: a float >1.0 (e.g. 4.0)  that dictates
+                      how-much-larger-than-average gradients we allow before clipping.
+        grad_norm_buffer_size: an int that determines the rolling buffer size over which
+                      we store gradient norms
+        """
+        num_params = len(prev_norm_stats)
+        assert len(grads) == num_params
+
+        all_prev_norm_stats, all_cur_grad_norms = _to_device('cpu',
+                                                             torch.stack(prev_norm_stats),
+                                                             torch.stack(cur_grad_norms))
+        assert all_prev_norm_stats.shape == (num_params, 2)
+        assert all_cur_grad_norms.shape == (num_params,)
+
+        # divide totals by counts (i.e. counts of iterations were we stored
+        # a finite grad)
+        all_prev_grad_norms = all_prev_norm_stats[:,0] / all_prev_norm_stats[:,1]
+        # prev_norm and cur_norm are floats, they are actually squared norms.
+        prev_norm = all_prev_grad_norms.sum().item()
+        cur_norm = all_cur_grad_norms.sum().item()
+
+        if prev_norm - prev_norm != 0.0:
+            # There were zero counts; fix this by using the current grad norm
+            # for affected parameters, and recompute all_prev_grad_norms and
+            # prev_norm.
+            for i in range(num_params):
+                if all_prev_norm_stats[i][1] == 0.0:
+                    # if count is 0 and cur norm is finite, use cur norm as our estimate
+                    # of previous norms.  This would only be useful if some but not
+                    # all params were in this situation of having no previous estimates.
+                    cur = all_cur_grad_norms[i]
+                    if cur - cur == 0.0:  # finite..
+                        all_prev_norm_stats[i][0] = cur
+                        all_prev_norm_stats[i][1] = 1.0
+                    else:
+                        # 0.0 is a default; likely won't matter, as if we
+                        # get infinite `cur`, we'll abandon this minibatch.
+                        all_prev_norm_stats[i][0] = 0.0
+            all_prev_grad_norms = all_prev_norm_stats[:,0] / all_prev_norm_stats[:,1]
+            prev_norm = all_prev_grad_norms.sum().item()
+
+        # Deal with infinite gradients.
+        if cur_norm - cur_norm != 0:  # cur_norm is infinite or NaN
+            global inf_grad_count
+            logging.warning(f'Infinite gradient-norm detected (cur/prev: {cur_norm}/{prev_norm}): will '
+                            f'zero grad ({inf_grad_count}/{inf_grad_max_count} times until dying)')
+            inf_grad_count += 1
+            if inf_grad_count >= inf_grad_max_count:
+                assert 0, "Reached max count of infinite gradient-norm stats"
+            # Zero all gradients in this group
+            for g in grads:
+                g[:] = 0.
+            # .. and zero the stored gradient norms in grad_norm_buf (so
+            # that infinities don't ruin our stats of previous batches)
+            for p in params_list:
+                if p.grad is not None:
+                    state = self.state[p]
+                    grad_norm_buf = state['grad_norm_buf']
+                    # cur_step is the location where we would have written the grad_norm.
+                    # We didn't check if it was infinity before, because we didn't want to
+                    # incur lots of GPU->CPU transfers.
+                    cur_step = state['step'] - 1
+                    # Remove this 'bad' step from the buffer.
+                    grad_norm_buf[cur_step % grad_norm_buffer_size][:] = 0.0
+        else:
+            # cur_norm is finite.  Check whether we have to clip this iteration's grad.
+            # we always remove infinities/NaNs from the buffer, so prev_norm should not
+            # be infinite or NaN.
+            assert prev_norm - prev_norm == 0.0
+            # cur_norm and prev_norm are actually squared norms, so we need to
+            # square limit_grad_factor..
+            limit_grad_factor2 = limit_grad_factor ** 2
+            if cur_norm > prev_norm * limit_grad_factor2:
+                grad_factor2 = (prev_norm * limit_grad_factor2) / cur_norm
+                grad_factor = grad_factor2 ** 0.5
+                cur_norm_f, prev_norm_f, grad_factor_f = ('%.2g' % cur_norm, '%.2g' % prev_norm,
+                                                          '%.2g' % grad_factor)
+                logging.warning(f'Gradient norm exceeds average of last {grad_norm_buffer_size} '
+                                f'gradients times {limit_grad_factor}: cur/prev {cur_norm_f}/{prev_norm_f}: '
+                                f'scaling it by {grad_factor_f}.')
+                for g in grads:
+                    g[:] *= grad_factor
+                # .. and scale down the stored gradient norms in grad_norm_buf, to
+                # avoid the bound getting too loose too quickly.
+                for p in params_list:
+                    if p.grad is not None:
+                        state = self.state[p]
+                        grad_norm_buf = state['grad_norm_buf']
+                        cur_step = state['step'] - 1
+                        # the buffer contains squared norms, so multiply by grad_factor2
+                        grad_norm_buf[cur_step % grad_norm_buffer_size][0] *= grad_factor2
+
+
+def _to_device(device, *args):
+    """
+    Transfers a tuple of Tensors from one device to another, using a single transfer.  Must have
+    same dtype but may have different shapes.
+    E.g.
+      (cpu_tensor_a, cpu_tensor_b) = _to_device('cpu', gpu_tensor_a, gpu_tensor_b)
+    """
+    if device == args[0].device:
+        return args
+    else:
+        arg0 = args[0]
+        combined_src = torch.cat([ x.reshape(-1) for x in args ])
+        combined_dest = combined_src.to(device)
+        dests = []
+        offset = 0
+        for src in args:
+            numels = src.numel()
+            dests.append(combined_dest[offset:offset+numels].reshape(src.shape))
+            offset += numels
+        return tuple(dests)
+
+
+
+def _get_target_rms(x: Tensor, min_target_rms: float) -> Tensor:
+    """
+    Returns Tensor with one element, representing a target root-mean-square
+    value of elements of x, that we consider "reasonable", and will use a
+    as a "target rms" in our modified weight-decay formula.  It returns
+    the maximum of the current RMS of the values of x, and `min_target_rms`,
+    as a Tensor on the same device as x.
+    """
+    with torch.no_grad():
+        # `norm` is the 2-norm of x currently (and this function should be
+        # called right after parameter initialization)
+        rms = ((x ** 2).sum() / x.numel()).sqrt()
+        largest_dim = max(list(x.shape))
+        numel = x.numel()
+        if min_target_rms > 0.0:
+            rms = rms.clamp(min=min_target_rms)
+    if x.ndim > 1 and __name__ == '__main__':  # will only be used for x.ndim > 1.
+        print("Target rms = ", rms)   # Print this in testing only.
+    return rms
+
+
+def _madam(params: List[Tensor],
+           grads: List[Tensor],
+           exp_avgs: List[Tensor],
+           exp_avg_sqs: List[Tensor],
+           state_steps: List[int],
+           target_rms_values: List[Tensor],
+           *,
+           beta1: float,
+           beta2: float,
+           lr: float,
+           eps: float,
+           l2: bool,
+           l2_period: int):
+    r"""This is a modification of adam() from torch's optim/_functional.py.
+
+    It has been modified to:
+      (i) remove the amsgrad option; this shouldn't be as necessary due to
+          the adaptive gradient norm clipping we have added
+      (ii) add our special formula for l2 regularization.  This doesn't have
+           any tunable parameters, other than the target standard deviation
+           of the elements of the tensor (which is passed in as target_rms).
+    Args:
+        params: list of Tensor, containing the parameters to be optimized
+        grads: list of Tensor, containing the gradients corresponding to
+             each of the params (grads[i] should correspond to params[i].grad,
+             although it may have undergone gradient clipping).
+        exp_avgs: list of Tensor, containing tensors with the same dimensions
+             as params and grads, that contain the moving-averages of
+             `grads`.
+        exp_avg_sqs: list of Tensor, containing tensors with the same dimensions
+             as params and grads, that contain the moving-averages of
+             `grads ** 2`.
+        state_steps: list of int, containing the step for each parameter (step >= 1)
+        target_rms_values: list of Tensor with one element each, containing the
+             target root-mean-square values of each parameter tensor in `params`
+        l2: a bool, where if true we will activate the l2 regularization
+             formula.
+        l2_period:  an integer that determines how often (i.e. every how many
+             minibatches) we apply the l2 update.  We include a scaling factor
+             so that as far as possible the result will not be too sensitive
+             to the value of this.
+
+        beta1: decay factor for gradients, e.g. 0.9
+        beta2: decay factor for gradients squared, e.g. 0.999
+        lr:  learning rate, e.g. 0.0001
+        eps: a small constant used to prevent division by zero, e.g. 1.0e-8
+
+    See :class:`~torch.optim.Adam` for details.
+    """
+    assert len(params) == len(grads) == len(state_steps) == len(exp_avgs) == len(exp_avg_sqs) ==  len(target_rms_values)
+
+    for i, param in enumerate(params):
+
+        grad = grads[i]
+
+        exp_avg = exp_avgs[i]
+        exp_avg_sq = exp_avg_sqs[i]
+        step = state_steps[i]
+        target_rms = target_rms_values[i]
+
+        bias_correction1 = 1 - beta1 ** step
+        bias_correction2 = 1 - beta2 ** step
+
+        do_l2 = param.ndim > 1 and l2 and step % l2_period == 0
+
+        if do_l2:
+            # This represents just the "noise term" of the gradient, i.e. the grad minus the
+            # running mean.  We'll later divide by denom.
+            cur_grad_noise = (grad - exp_avg)
+
+        # Decay the first and second moment running average coefficient
+        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+        denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
+
+        step_size = lr / bias_correction1
+
+        if not do_l2:
+            param.addcdiv_(exp_avg, denom, value=-step_size)
+        else:
+            # We can treat "pseudo_grad" as if it were a gradient (even though it's
+            # actually a gradient times a per-element learning rate).  The analysis
+            # that we used to figure out what the l2 should be did not use the fact
+            # that the gradients were actually gradients, it simply analyzed it as a
+            # quantity that can be treated as close to zero-mean and with a certain
+            # structure of variance, and added to the param with the formula:
+            #
+            #  param -= step_size * grad
+            #
+            # The original analysis assumed the gradients were independent from frame
+            # to frame; in fact these are not, but the important difference can be captured
+            # in a scalar `grad_scale` that expresses the scale of pseudo_grad relative
+            # to the independent gradients that we are effectively adding on each frame
+            # (but with a delay).
+
+            pseudo_grad = exp_avg / denom
+            cur_pseudo_grad = cur_grad_noise / denom
+
+            # grad_scale expresses the expected size of cur_pseudo_grad relative to the
+            # original grads if we had not done the moving-average; it is the sqrt of
+            # the sum of the squares of coefficients of previous gradients:
+            # c_n = (1-beta1) beta1^n, for
+            # n = 0, 1, ..
+            # .. plus one which is the sumsq of the coefficient of 'grad' itself in
+            # (grad - exp_avg).
+            # It is relevant that the sum of the coefficients (i.e. not squared) is 1;
+            # if this were not so we'd have to incorporate that into the formula for l2.
+            grad_scale = (((1 - beta1)**2) / (1 - beta1**2) + 1) ** 0.5
+
+            with torch.no_grad():
+                l2_grad = _compute_l2_grad(param, cur_pseudo_grad, target_rms,
+                                           rho=step_size, grad_scale=grad_scale,
+                                           period_scale=l2_period,
+                                           eps=eps, safe=True)
+
+            # TODO: could alternate computing l2 on only, say, odd frames, and scale it
+            # up by 2, to save time.
+            param.add_(pseudo_grad + l2_grad, alpha=-step_size)
+
+
+
+def _view_as_matrix(x: Tensor, dim: int) -> Tensor:
+    """
+    Returns a Tensor of shape (n, x.shape[dim]), where n is the product
+    of the sizes of the other dimensions of x.  This may involve a copy,
+    if x cannot be reshaped in this way.
+    """
+    ndim = x.ndim
+    assert ndim > 1 and dim >= 0 and dim < ndim
+    # Move the dim to the last position in x..
+    if dim != ndim - 1:
+        x = x.transpose(dim, ndim - 1)
+    return x.reshape(-1, x.shape[-1])
+
+
+def _outer_product(x: Tensor, dim: int) -> Tensor:
+    """
+    Returns a Tensor of shape (x.shape[dim], x.shape[dim]) formed by
+    summing the outer products of all the vectors in x of size
+    `x.shape[dim]`, that we get by indexing x with all tuples of dimensions
+    on other axes.  E.g. if x is a matrix and dim == 0, this would
+    be torch.matmul(x, x.transpose(0, 1)).
+
+    Note: x must have at least 2 dimensions, x.ndim >= 2.
+    """
+    x = _view_as_matrix(x, dim)
+    return torch.matmul(x.transpose(0, 1), x)
+
+def _multiply_on_dim(x: Tensor, m: Tensor, dim: int) -> Tensor:
+    """
+    Multiplies x by the matrix m which must be of shape:
+    (x.shape[dim], n)), with `dim` as the dimension/axis on
+    x to be multiplied.
+
+    Caution: result may not have the same layout/strides as x,
+    although it will have the same shape.
+
+    Args:
+         x: Tensor to be multiplied; must have ndim >= 2
+         m: Symmetric matrix to multiply x by; must have
+            m.shape == (x.shape[dim], x.shape[dim])
+       dim: Dimension of x to multiply on, with 0 <= dim < x.ndim
+    Return:
+         The matrix product, of the same shape as
+         x, except with the size on dimension `dim` being n.
+    """
+    ndim = x.ndim
+    if dim != ndim - 1:
+        x = x.transpose(dim, ndim - 1)
+    ans = torch.matmul(x, m)
+    if dim != ndim - 1:
+        # Swap the dimensions back to what they were originally.
+        ans = ans.transpose(dim, ndim - 1)
+    return ans
+
+
+def _multiply_product_combined(l2: Tensor, grad: Tensor, dim: int,
+                               need_grad_sumsq: bool):
+    """
+    This function is an optimized version of the following code:
+        outer_prod = _outer_product(grad, dim)
+        l2 = _multiply_on_dim(l2, outer_prod, dim)
+        if dim == 0:  # could choose any dim for this
+            grad_sumsq = torch.trace(outer_prod)
+    Args:
+         l2: The l2 matrix which starts out as the parameter tensor x, must have >= 2 diims
+         grad: The gradient tensor (or a gradient-like quantity); must
+             have same shape as l2.
+         dim: The dimension of l2 and grad that we want this to
+             act on, with 0 <= dim < l2.ndim.  We multiply l2, on
+             this dim, by a symmetric quantity of shape
+             (l2.shape[dim], l2.shape[dim]), that is formed
+             by a product and sum on grad (this is a matrix
+             product, if there are 2 axes).
+    Returns:
+        Returns (l2, grad_sumsq), where l2 is the result of
+        multiplying l2 by the product mentioned above, and
+        grad_sumsq is either None, or a Tensor representing
+        the sum-of-squares of `grad`; for at least one
+        dim with 0 <= dim < l2.ndim, we guarantee to
+        return such a Tensor.
+    """
+    grad = _view_as_matrix(grad, dim)
+    if grad.shape[1] <= grad.shape[0]:
+        # Minimize the size of the intermediate product, which will probably well reflect
+        # the compute time since memory access can be limiting on CUDA.a
+        grad_product = torch.matmul(grad.transpose(0, 1), grad)
+        l2 = _multiply_on_dim(l2, grad_product, dim)
+        if need_grad_sumsq:
+            grad_sumsq = torch.trace(grad_product)
+        else:
+            grad_sumsq = None
+        return (l2, grad_sumsq)
+    else:
+        l2 = _multiply_on_dim(l2, grad.transpose(0, 1), dim)
+        l2 = _multiply_on_dim(l2, grad, dim)
+        # This branch does not compute grad_sumsq, but we're bound to
+        # take the other branch on at least one occasion.
+        return (l2, None)
+
+
+
+def _compute_l2_grad(x: Tensor, grad: Tensor, target_stddev: float, rho: float,
+                     grad_scale: float = 1.0, period_scale: int = 1,
+                     eps: float = 1.0e-08,
+                     safe: bool = True) -> Tensor:
+    """
+    Returns the l2 gradient of x, which will be added to 'grad'.
+    This is a more principled replacement for the typical l2 regularization
+    formula where we do:
+          grad += weight_decay * x.
+    (Note: this must only be called if x.ndim >= 2).
+
+    For x with 2 axes, we instead do this:
+
+        grad += (rho / (2*target_stddev**2)) * (grad grad^T) x (grad^T grad) / trace(grad^T grad),
+
+    where the implicit multiplication above refers to matrix multiplication; note, x means
+    the variable x.  We'll have to write the justification of this, which is a little
+    complicated, separately; it has to do with using exactly the amount of l2 in each
+    subspace of each dimension of x, to to cancel out the gradient noise.
+
+    Args:
+          x: parameter to be updated.  MUST HAVE x.ndim >= 2.
+       grad: Gradient for x on this iteration (or at least, something that
+           is treated like a gradient in the update formula)
+target_stddev:  The target standard deviation (uncentered), of elements of x.
+           This is our estimate of what standard deviation these elements would
+           have in a well-trained model; it is set by some kind of heuristic.
+       rho:  The learning rate we are going to use, as in:   x -= (grad + l2) * rho.
+   grad_scale: A scale whereby the caller asserts that `grad` is some
+            quantity that is distributed like the real
+            gradient times `grad_scale` (this is useful when the provided `grad`
+            is really a moving average gradient).  Because the l2 term's magnitude
+            is proportional to the gradient squared, we need to divide it by the
+            square of grad_scale, so this function uses 1/grad_scale^2 as a scaling
+            factor.
+period_scale: An integer scale that we use to compensate for the fact that this
+            weight decay is only applied periodically, once every
+            `period_scale` minibatches.  Accordingly, we make the l2 term
+            that many times larger.
+       eps:  A small constant used to avoid division by zero
+      safe:  If true, use a safe version of the formula that checks for
+             'overshoot' of l2 regularization and fixes the issue (might
+             be an issue for models that are getting unstable or have high
+             learning rate)
+
+
+    Returns:
+       Returns l2 pseudo-gradient (term to be added to `grad`).
+    """
+    assert x.shape == grad.shape
+    assert x.ndim >= 2
+
+    l2 = x
+    grad_sumsq = None
+    num_ignored_dims = 0  # for an optimization for when size=1 on some dim.
+    for dim in range(x.ndim):
+        # The code below is an optimization of the following few lines,
+        # which were perhaps easier to understand:
+        # outer_prod = _outer_product(grad, dim)
+        # l2 = _multiply_on_dim(l2, outer_prod, dim)
+        # if dim == 0:  # could choose any dim for this
+        #    grad_sumsq = torch.trace(outer_prod)
+        if x.shape[dim] <= 1:
+            num_ignored_dims += 1
+            continue
+        (l2, maybe_grad_sumsq) = _multiply_product_combined(l2, grad, dim,
+                                                            grad_sumsq is None)
+        if maybe_grad_sumsq is not None:
+            grad_sumsq = maybe_grad_sumsq
+    if grad_sumsq is None:
+        # We shouldn't reach here, except if at some point we start calling this
+        # code for tensors with ndim <= 1, or with numel() == 1.
+        grad_sumsq = (grad ** 2).sum()
+
+    # l2 is the amount of l2, we'll subtract this from x, as in:
+    #   x -= rho * (grad + l2).
+
+    factor = rho * period_scale / (2.0 * (target_stddev * grad_scale)**2)
+    l2 = l2 * (factor / (grad_sumsq ** (x.ndim - 1 - num_ignored_dims) + eps))
+
+    if safe and rho > 0:
+        #x2_sum = (x ** 2).sum()
+        l2_sum = (l2 ** 2).sum() * (rho * rho)
+        cross_sum = (x * l2).sum() * rho
+        alpha = cross_sum / (l2_sum + eps)
+        # We want to minimize the sum-of-squares of (x - alpha * rho * l2), where alpha
+        # is a constant in [0,1] that we are about to estimate, intended to prevent
+        # instability by scaling down our weight decay formula.  Right now (and treating
+        # things as if they were scalars for brevity):
+        #  x2_sum = x * x
+        #  l2_sum = rho * rho * l2 * l2
+        #  cross_sum = x * rho * l2
+        # We want to minimize the sum-sq of  (x - alpha * rho * l2),
+        # i.e. we want to choose alpha to minimize:
+        #   x2_sum - 2 * alpha * cross_sum + alpha^2 * l2_sum
+        # d/dalpha of this, is:
+        #  -2*cross_sum + 2 * alpha * l2_sum
+        # and setting this to zero and solving for alpha, we have:
+        #  alpha = cross_sum / l2_sum.
+        # If it turns out that alpha >= 1, then we just use alpha=1
+        # (the original formula), as there is no problem with
+        # instability/overshoot.
+        l2.mul_(alpha.clamp(max=1.0))
+        if random.random() < 0.001 and  alpha < 1.0:
+            logging.info(f'madam optimizer: alpha={alpha}, shape={tuple(x.shape)}')
+    return l2
+
+
+
+class Moam(object):
+    """
+    Implements Moam optimizer.  This is a modified version of the Noam optimizer
+    which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
+    but changed to use Madam (see above) instead of Adam as the base optimizer.
+    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
+
+    Caution: you probably want to set 'factor' to a smaller value than you would typically
+    use for a corresponding Noam optimizer, because Moam does a kind of l2 regularization which
+    keeps the parameters fairly small, so the relative changes in model parameters
+    will be larger than Noam, for any given learning rate.
+
+    Args:
+        params (iterable): iterable of parameters to optimize or dicts defining parameter groups
+        model_size: attention dimension of the transformer model
+        factor: learning rate factor, that multiplies the output of the
+              formula based on model size
+        warm_step: number of warmup steps before the learning rate starts to decrease
+              (it increases until this point).
+        min_target_rms: this is a parameter of the Madam optimizer; it represents a floor
+             on the "target root-mean-square value" that is used when the initialization
+             of a tensor is zero or below this value.  It may be worth optimizing.
+             Don't worry about tensors with fewer than 2 dimensions when setting this,
+             these are not subject to our l2 formula.
+        limit_grad_factor: you can set this to a finite value, e.g. 2.0, to activate
+             a mechanism that limits the norms of larger-than-usual gradients.
+             This seems to cause a slowdown, likely due to GPU->CPU transfers.
+        l2_period: mechanism to improve the optimization speed, by only applying the l2
+            regularization (which is a complicated formula) every this-many
+            minibatches.  E.g. can set it to 2 or 4.
+    """
+
+    def __init__(self, params, model_size: int = 256,
+                 factor: float = 2.0, warm_step: int = 25000,
+                 min_target_rms: float = 0.05,
+                 limit_grad_factor: float = float('inf'),
+                 l2_period: int = 1) -> None:
+        """Construct an Noam object."""
+        self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
+                               min_target_rms=min_target_rms,
+                               limit_grad_factor=limit_grad_factor,
+                               l2_period=l2_period)
+        self._step = 0
+        self.warmup = warm_step
+        self.factor = factor
+        self.model_size = model_size
+        self._rate = 0
+
+    @property
+    def param_groups(self):
+        """Return param_groups."""
+        return self.optimizer.param_groups
+
+    def step(self):
+        """Update parameters and rate."""
+        self._step += 1
+        rate = self.rate()
+        for p in self.optimizer.param_groups:
+            p["lr"] = rate
+        self._rate = rate
+        self.optimizer.step()
+
+    def rate(self, step=None):
+        """Implement `lrate` above."""
+        if step is None:
+            step = self._step
+        return (
+                self.factor
+                * self.model_size ** (-0.5)
+                * min(step ** (-0.5), step * self.warmup ** (-1.5))
+        )
+
+    def zero_grad(self):
+        """Reset gradient."""
+        self.optimizer.zero_grad()
+
+    def state_dict(self):
+        """Return state_dict."""
+        return {
+            "_step": self._step,
+            "warmup": self.warmup,
+            "factor": self.factor,
+            "model_size": self.model_size,
+            "_rate": self._rate,
+            "optimizer": self.optimizer.state_dict(),
+        }
+
+    def load_state_dict(self, state_dict):
+        """Load state_dict."""
+        for key, value in state_dict.items():
+            if key == "optimizer":
+                self.optimizer.load_state_dict(state_dict["optimizer"])
+            else:
+                setattr(self, key, value)
+
+
+class Foam(object):
+    """
+    Implements Foam optimizer.  This is a modified version of the Noam optimizer
+    which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
+    but changed to use Madam (see above) instead of Adam as the base optimizer, and then
+    to change the learning rate schedule and how it is specified.
+
+
+    This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
+
+
+    Args:
+        params (iterable): iterable of parameters to optimize or dicts defining parameter groups
+
+        warm_step: number of warmup steps before the learning rate starts to decrease
+              (it increases until this point).
+        max_lrate: The learning rate at its maximum, on step `warm_step`
+        knee_factor:  The multiple of `max_lrate` after which the learning rate will
+                 start to decrease more like 1/x.  It increases linearly from 0 to
+                `warm_step`, then decreases approximately as 1/sqrt(x) from
+                `warm_step` to `warm_step * knee_factor`, then decreases
+                 approximately as 1/x from `warm_step * knee_factor` onwards.
+
+        min_target_rms: this is a parameter of the Madam optimizer; it represents a floor
+             on the "target root-mean-square value" that is used when the initialization
+             of a tensor is zero or below this value.  It may be worth optimizing.
+             Don't worry about tensors with fewer than 2 dimensions when setting this,
+             these are not subject to our l2 formula.
+        limit_grad_factor: Another parameter of Madam, you can set this to a finite
+             value, e.g. 2.0, to activate a mechanism that limits the norms of
+             larger-than-usual gradients. This seems to cause a slowdown, likely due
+             to GPU->CPU transfers, and it is disabled by setting it to infinity.
+        l2_period: mechanism to improve the optimization speed, by only applying the l2
+            regularization (which is a complicated formula) every this-many
+            minibatches.  E.g. can set it to 2 or 4.
+    """
+
+    def __init__(self,
+                 params,
+                 max_lrate: float = 5.0e-04,
+                 warm_step: int = 25000,
+                 knee_factor: float = 8.0,
+                 min_target_rms: float = 0.05,
+                 limit_grad_factor: float = float('inf'),
+                 l2_period: int = 1) -> None:
+        """Construct an Foam object."""
+        self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
+                               min_target_rms=min_target_rms,
+                               limit_grad_factor=limit_grad_factor,
+                               l2_period=l2_period)
+        self._step = 0
+
+        self._max_lrate = max_lrate
+        self._warm_step = warm_step
+        self._knee_factor = knee_factor
+        self._rate = 0
+
+
+    @property
+    def param_groups(self):
+        """Return param_groups."""
+        return self.optimizer.param_groups
+
+    def step(self):
+        """Update parameters and rate."""
+        self._step += 1
+        rate = self.rate()
+        for p in self.optimizer.param_groups:
+            p["lr"] = rate
+        self._rate = rate
+        self.optimizer.step()
+
+
+    def rate(self, step=None):
+        """
+        Suppose the step of optimization is 's', i.e. with s = 0, 1, 2...
+        We define 't = s / warm_step', i.e. t is the step s, normalized so that it
+        is 1.0 at warm_step.  Our formula for the learning rate as a function of
+        t is:
+             rate = max_lrate * (t <= 1.0 ? t :
+                                sqrt((2 + alpha) / (1 + t + alpha t^2)))
+        where alpha is chosen so that the 't' and 'alpha t^2' terms are identical
+        at t == knee_factor (this means alpha = 1.0/knee_factor).   So the
+        learning rate increases linearly from t=00 to t=1, and decreases
+        after that.  You can see
+        that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1,
+        which is why the line and the curve meet at that point.
+
+        On the denominator  of that ratio, the "t" term makes it decrease a
+        bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term
+        makes it decrease a bit like 1/t for t > warm_step; and the "1"
+        term makes it decrease a bit slower than 1/sqrt(t) when t is quite
+        close to 1.0 (so we linger a little, near the maximum learning rate).
+
+        This learning rate schedule ultimately decreases more aggressively
+        than Noam, i.e. as 1 / t instead of 1 / sqrt(t).  The reason we
+        feel this will work better in conjunction with Madam, is that Madam
+        keeps the norms of the parameters approximately constant throughout
+        training; whereas with Noam, if there is no weight decay, these
+        norms tend to increase as training progresses (although rather
+        unevenly across different parameter tensors).
+        As the norms of the parameters increase, the relative changes
+        in parameters get smaller (the step sizes don't change because
+        Adam normalizes the gradient magnitudes; they'd get smaller otherwise).
+        So Noam doesn't have to decrease the learning rate too aggressively
+        because even with a fixed learning rate, the effective learning rate
+        would be decreasing (again, this only applies without weight decay).
+        """
+        if step is None:
+            step = self._step
+        t = step / self._warm_step  # floating point division..  t is the normalized step.
+        alpha = 1.0 / self._knee_factor
+        return self._max_lrate * (t if t <= 1.0 else
+                                  ((2 + alpha) / (1 + t + alpha * t * t)) ** 0.5)
+
+    def zero_grad(self):
+        """Reset gradient."""
+        self.optimizer.zero_grad()
+
+    def state_dict(self):
+        """Return state_dict."""
+        return {
+            "_step": self._step,
+            "warmup": self.warmup,
+            "factor": self.factor,
+            "model_size": self.model_size,
+            "_rate": self._rate,
+            "optimizer": self.optimizer.state_dict(),
+        }
+
+    def load_state_dict(self, state_dict):
+        """Load state_dict."""
+        for key, value in state_dict.items():
+            if key == "optimizer":
+                self.optimizer.load_state_dict(state_dict["optimizer"])
+            else:
+                setattr(self, key, value)
+
+
+
+class TestModel(torch.nn.Module):
+    """Class for testing the Madam optimizer"""
+    def __init__(self):
+        super(TestModel, self).__init__()
+        self.first_layers = torch.nn.Sequential(
+            torch.nn.Linear(100, 200),
+            torch.nn.ReLU(),
+            torch.nn.Linear(200, 300),
+            torch.nn.ReLU())
+        self.conv1 = torch.nn.Conv1d(in_channels=300, out_channels=200,
+                                    kernel_size=1)
+        self.relu = torch.nn.ReLU()
+        self.conv2 = torch.nn.Conv1d(in_channels=200, out_channels=250,
+                                    kernel_size=3)
+
+
+    def forward(self, x):
+        # from (B, T, 100) to (B, T, 200)
+        x = self.first_layers(x)
+        # B, T, C -> B, C, T
+        x = x.transpose(1, 2)
+        x = self.conv2(self.relu(self.conv1(x)))
+        # B, C, T -> B, T, C
+        x = x.transpose(1, 2)
+        return x
+
+def test_madam():
+    print("Testing Madam optimizer")
+    global inf_grad_max_count
+    inf_grad_max_count = 200
+    if torch.cuda.is_available():
+        devices_and_l2 = [(torch.device('cuda'), True),
+                          (torch.device('cuda'), False),
+                          (torch.device('cpu'), True),
+                          (torch.device('cpu'), False)]
+    else:
+        devices_and_l2 = [(torch.device('cpu'), True),
+                          (torch.device('cpu'), False)]
+
+
+    for (device, l2) in devices_and_l2:
+        model = TestModel().to(device)
+        # min_target_rms=0.01 is for testing, so the target equals the initial RMS
+        # and we can more easily tell whether our update has the desired effect.
+        # I also tested this with betas=(0.1, 0.98), to check that the effect of
+        # `grad_scale` was correct (it only makes much difference for small beta).
+        optimizer = Madam(model.parameters(), lr=0.0005, betas=(0.9, 0.98),
+                          l2=l2, min_target_rms=0.01, l2_period=1)
+        #optimizer = torch.optim.Adam(model.parameters())
+
+        def get_elems_rms(x: Tensor) -> Tensor:
+            return ((x ** 2).sum() / x.numel()).sqrt().item()
+
+        for i in range(1000):
+            if i % 100 == 0:
+                rms_values = (get_elems_rms(model.first_layers[0].weight),
+                              get_elems_rms(model.first_layers[2].weight),
+                              get_elems_rms(model.conv1.weight),
+                              get_elems_rms(model.conv2.weight))
+                print(f"Iter {i}, l2={l2}, device={device}: stddevs = {rms_values} ")
+            B = 4
+            T = 20
+            x = torch.randn(B, T, 100).to(device)
+            y = model(x)
+            yderiv = torch.randn_like(y)
+            if i % 190 <= 3 and i > 0:
+                yderiv *= 100.0
+            if i % 550 == 0 and i > 0:
+                yderiv *= float('inf')
+
+            y.backward(gradient=yderiv)
+            optimizer.step()
+            model.zero_grad()
+        print("")
+
+def test_moam():
+    print("Testing Moam optimizer")
+    model = TestModel()
+    # min_target_rms=0.01 is for testing, so the target equals the initial RMS
+    # and we can more easily tell whether our update has the desired effect.
+    optimizer = Moam(model.parameters(), factor=1.0, warm_step=300,
+                     min_target_rms=0.01)
+
+
+    def get_elems_rms(x: Tensor) -> Tensor:
+        return ((x ** 2).sum() / x.numel()).sqrt().item()
+
+    for i in range(1000):
+        if i % 100 == 0:
+            rms_values = (get_elems_rms(model.first_layers[0].weight),
+                          get_elems_rms(model.first_layers[2].weight),
+                          get_elems_rms(model.conv1.weight),
+                          get_elems_rms(model.conv2.weight))
+            print(f"Iter {i} (Moam): stddevs = {rms_values} ")
+        B = 4
+        T = 20
+        x = torch.randn(B, T, 100)
+        y = model(x)
+        yderiv = torch.randn_like(y)
+        if i % 190 <= 3 and i > 0:
+            yderiv *= 100.0
+        if i % 550 == 0 and i > 0:
+            yderiv *= float('inf')
+
+        y.backward(gradient=yderiv)
+        optimizer.step()
+        model.zero_grad()
+    print("")
+
+
+def test_foam():
+    print("Testing Foam optimizer")
+    model = TestModel()
+    # min_target_rms=0.01 is for testing, so the target equals the initial RMS
+    # and we can more easily tell whether our update has the desired effect.
+    optimizer = Foam(model.parameters(),
+                     max_lrate=1.0e-03, warm_step=300,
+                     min_target_rms=0.01,
+                     limit_grad_factor=4.0)
+
+
+    def get_elems_rms(x: Tensor) -> Tensor:
+        return ((x ** 2).sum() / x.numel()).sqrt().item()
+
+    for i in range(1000):
+        if i % 100 == 0:
+            rms_values = (get_elems_rms(model.first_layers[0].weight),
+                          get_elems_rms(model.first_layers[2].weight),
+                          get_elems_rms(model.conv1.weight),
+                          get_elems_rms(model.conv2.weight))
+            print(f"Iter {i} (Foam): stddevs = {rms_values} ")
+        B = 4
+        T = 20
+        x = torch.randn(B, T, 100)
+        y = model(x)
+        yderiv = torch.randn_like(y)
+        if i % 190 <= 3 and i > 0:
+            yderiv *= 100.0
+        if i % 550 == 0 and i > 0:
+            yderiv *= float('inf')
+
+        y.backward(gradient=yderiv)
+        optimizer.step()
+        model.zero_grad()
+    print("")
+
+
+
+def test_to_device():
+    if not torch.cuda.is_available():
+        return
+    a_gpu = torch.ones(1,2,3,4, device='cuda')
+    b_gpu = torch.zeros(3,8, device='cuda')
+    (a_cpu, b_cpu) = _to_device('cpu', a_gpu, b_gpu)
+    print("a_cpu,b_cpu = ", a_cpu, b_cpu)
+    (a_gpu2, b_gpu2) = _to_device('cuda', a_cpu, b_cpu)
+    print("a_gpu2,b_gpu2 = ", a_gpu2, b_gpu2)
+
+# Caution: this testing code is not very automated, it reqires looking at the output to
+# make sure it looks right.  The main thing is that with l2=True, the printed stddevs stay close
+# to the "Target rms" values, which are printed out; while with l2=False, the stddevs
+# increase to significantly higher than that.
+#
+# The test of the Moam optimizer is mainly to make sure it runs; the scale of the
+# gradients, and the learning rate, are such that one of the rms's stays quite a bit
+# above the target value, i.e. (0.047, 0.044, 0.047), vs. targets of
+# (0.057, 0.04, 0.019), I think this has to do with the alpha<1 stability mechanism being
+# activated, the l2 does have an effect, as I verified by changing the code to set
+# l2=False.
+def main():
+    # Set number of threads to 1, or Torch can do weird things that make it extremely slow.
+    torch.set_num_threads(1)
+    torch.set_num_interop_threads(1)
+    #test_to_device()
+    random.seed(0)
+    torch.random.manual_seed(0)
+    test_foam()
+    test_moam()
+    test_madam()
+
+
+
+if __name__ == '__main__':
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
index b0dbe72adb..dfb6b36e6f 100755
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
@@ -30,10 +30,10 @@
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
 from lhotse.utils import fix_random_seed
+from madam import Foam
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.nn.utils import clip_grad_norm_
 from torch.utils.tensorboard import SummaryWriter
-from transformer import Noam
 
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import load_checkpoint
@@ -111,13 +111,9 @@ def get_params() -> AttributeDict:
         - lang_dir: It contains language related input files such as
                     "lexicon.txt"
 
-        - lr: It specifies the initial learning rate
-
         - feature_dim: The model input dim. It has to match the one used
                        in computing features.
 
-        - weight_decay:  The weight_decay for the optimizer.
-
         - subsampling_factor:  The subsampling factor for the model.
 
         - best_train_loss: Best training loss so far. It is used to select
@@ -150,10 +146,9 @@ def get_params() -> AttributeDict:
     """
     params = AttributeDict(
         {
-            "exp_dir": Path("conformer_ctc/exp"),
+            "exp_dir": Path("conformer_ctc_embedding_scale/exp"),
             "lang_dir": Path("data/lang_bpe"),
             "feature_dim": 80,
-            "weight_decay": 1e-6,
             "subsampling_factor": 4,
             "best_train_loss": float("inf"),
             "best_valid_loss": float("inf"),
@@ -175,7 +170,8 @@ def get_params() -> AttributeDict:
             "mmi_loss": False,
             "use_feat_batchnorm": True,
             "lr_factor": 5.0,
-            "warm_step": 80000,
+            "max_lrate": 5.0e-04,
+            "warm_step": 25000,
         }
     )
 
@@ -657,12 +653,10 @@ def run(rank, world_size, args):
     if world_size > 1:
         model = DDP(model, device_ids=[rank])
 
-    optimizer = Noam(
+    optimizer = Foam(
         model.parameters(),
-        model_size=params.attention_dim,
-        factor=params.lr_factor,
+        max_lrate=params.max_lrate,
         warm_step=params.warm_step,
-        weight_decay=params.weight_decay,
     )
 
     if checkpoints:

From 66467f2da8213667e24c15b60b0cc4ebe4b53554 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang <csukuangfj@gmail.com>
Date: Thu, 26 Aug 2021 15:21:11 +0800
Subject: [PATCH 4/5] Reduce number of logs.

---
 egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
index dfb6b36e6f..47d1ecadb1 100755
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py
@@ -155,8 +155,8 @@ def get_params() -> AttributeDict:
             "best_train_epoch": -1,
             "best_valid_epoch": -1,
             "batch_idx_train": 0,
-            "log_interval": 10,
-            "reset_interval": 200,
+            "log_interval": 100,
+            "reset_interval": 1000,
             "valid_interval": 3000,
             "beam_size": 10,
             "reduction": "sum",

From b7d4a4f983730d4b6f9a6b2e774c784efae03e34 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang <csukuangfj@gmail.com>
Date: Thu, 26 Aug 2021 22:28:18 +0800
Subject: [PATCH 5/5] Fix errors in madam.py

---
 .../ASR/conformer_ctc_embedding_scale/madam.py | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
index bba17c375d..019dbbe433 100644
--- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
+++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
@@ -855,7 +855,7 @@ def __init__(self,
                  min_target_rms: float = 0.05,
                  limit_grad_factor: float = float('inf'),
                  l2_period: int = 1) -> None:
-        """Construct an Foam object."""
+        """Construct an Noam object."""
         self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
                                min_target_rms=min_target_rms,
                                limit_grad_factor=limit_grad_factor,
@@ -933,20 +933,15 @@ def state_dict(self):
         """Return state_dict."""
         return {
             "_step": self._step,
-            "warmup": self.warmup,
-            "factor": self.factor,
-            "model_size": self.model_size,
-            "_rate": self._rate,
-            "optimizer": self.optimizer.state_dict(),
         }
 
     def load_state_dict(self, state_dict):
-        """Load state_dict."""
+        """Load state_dict.  This is compatible with reading a Moam state_dict"""
         for key, value in state_dict.items():
             if key == "optimizer":
                 self.optimizer.load_state_dict(state_dict["optimizer"])
-            else:
-                setattr(self, key, value)
+            elif key == '_step':
+                self._step = value
 
 
 
@@ -1096,6 +1091,11 @@ def get_elems_rms(x: Tensor) -> Tensor:
         model.zero_grad()
     print("")
 
+    state_dict = optimizer.state_dict()
+    step = optimizer._step
+    optimizer._step = 0
+    optimizer.load_state_dict(state_dict)
+    assert optimizer._step == step
 
 
 def test_to_device():