From 0f7afe3aa9571c693a9940c877c34a7197106f31 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 10 May 2024 21:08:13 +0200 Subject: [PATCH] Add conformer model --- kraken/lib/layers/__init__.py | 20 ++ kraken/lib/layers/conformer/__init__.py | 13 + kraken/lib/layers/conformer/attention.py | 150 +++++++++++ kraken/lib/layers/conformer/convolution.py | 251 ++++++++++++++++++ kraken/lib/layers/conformer/embedding.py | 67 +++++ kraken/lib/layers/conformer/encoder.py | 205 ++++++++++++++ kraken/lib/layers/conformer/feed_forward.py | 54 ++++ kraken/lib/layers/conformer/model.py | 89 +++++++ kraken/lib/layers/conformer/modules.py | 41 +++ .../lib/{layers.py => layers/vgsl_layers.py} | 3 +- 10 files changed, 892 insertions(+), 1 deletion(-) create mode 100644 kraken/lib/layers/__init__.py create mode 100644 kraken/lib/layers/conformer/__init__.py create mode 100644 kraken/lib/layers/conformer/attention.py create mode 100644 kraken/lib/layers/conformer/convolution.py create mode 100644 kraken/lib/layers/conformer/embedding.py create mode 100644 kraken/lib/layers/conformer/encoder.py create mode 100644 kraken/lib/layers/conformer/feed_forward.py create mode 100644 kraken/lib/layers/conformer/model.py create mode 100644 kraken/lib/layers/conformer/modules.py rename kraken/lib/{layers.py => layers/vgsl_layers.py} (99%) diff --git a/kraken/lib/layers/__init__.py b/kraken/lib/layers/__init__.py new file mode 100644 index 000000000..aea57a284 --- /dev/null +++ b/kraken/lib/layers/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Top-level module containing layers and NN implementations. +""" +from .conformer.model import Conformer +from .vgsl_layers import (Addition, MaxPool, Reshape, Dropout, + TransposedSummarizingRNN, LinSoftmax, ActConv2D, + GroupNorm) diff --git a/kraken/lib/layers/conformer/__init__.py b/kraken/lib/layers/conformer/__init__.py new file mode 100644 index 000000000..243f40f1f --- /dev/null +++ b/kraken/lib/layers/conformer/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021, Soohwan Kim. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/kraken/lib/layers/conformer/attention.py b/kraken/lib/layers/conformer/attention.py new file mode 100644 index 000000000..5f3fccca9 --- /dev/null +++ b/kraken/lib/layers/conformer/attention.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021, Soohwan Kim. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from typing import Optional + +from .embedding import RelPositionalEncoding + + +class RelativeMultiHeadAttention(nn.Module): + """ + Multi-head attention with relative positional encoding. + This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + d_model (int): The dimension of model + num_heads (int): The number of attention heads. + dropout_p (float): probability of dropout + + Inputs: query, key, value, pos_embedding, mask + - **query** (batch, time, dim): Tensor containing query vector + - **key** (batch, time, dim): Tensor containing key vector + - **value** (batch, time, dim): Tensor containing value vector + - **pos_embedding** (batch, time, dim): Positional embedding tensor + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + + Returns: + - **outputs**: Tensor produces by relative multi head attention module. + """ + def __init__( + self, + d_model: int = 512, + num_heads: int = 16, + dropout_p: float = 0.1, + ): + super(RelativeMultiHeadAttention, self).__init__() + assert d_model % num_heads == 0, "d_model % num_heads should be zero." + self.d_model = d_model + self.d_head = int(d_model / num_heads) + self.num_heads = num_heads + self.sqrt_dim = math.sqrt(self.d_head) + + self.query_proj = nn.Linear(d_model, d_model) + self.key_proj = nn.Linear(d_model, d_model) + self.value_proj = nn.Linear(d_model, d_model) + self.pos_proj = nn.Linear(d_model, d_model, bias=False) + + self.dropout = nn.Dropout(p=dropout_p) + self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + torch.nn.init.xavier_uniform_(self.u_bias) + torch.nn.init.xavier_uniform_(self.v_bias) + + self.out_proj = nn.Linear(d_model, d_model) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_embedding: Tensor, + mask: Optional[Tensor] = None, + ) -> Tensor: + batch_size = value.size(0) + + query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) + key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) + value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) + pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) + + content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) + pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) + pos_score = self._relative_shift(pos_score) + + score = (content_score + pos_score) / self.sqrt_dim + + if mask is not None: + mask = mask.unsqueeze(1) + score.masked_fill_(mask, -1e9) + + attn = F.softmax(score, -1) + attn = self.dropout(attn) + + context = torch.matmul(attn, value).transpose(1, 2) + context = context.contiguous().view(batch_size, -1, self.d_model) + + return self.out_proj(context) + + def _relative_shift(self, pos_score: Tensor) -> Tensor: + batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() + zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) + padded_pos_score = torch.cat([zeros, pos_score], dim=-1) + + padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) + pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[:, :, :, : seq_length2 // 2 + 1] + + return pos_score + + +class MultiHeadedSelfAttentionModule(nn.Module): + """ + Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, + the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention + module to generalize better on different input length and the resulting encoder is more robust to the variance of + the utterance length. Conformer use prenorm residual units with dropout which helps training + and regularizing deeper models. + + Args: + d_model (int): The dimension of model + num_heads (int): The number of attention heads. + dropout_p (float): probability of dropout + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + + Returns: + - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. + """ + def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): + super(MultiHeadedSelfAttentionModule, self).__init__() + self.positional_encoding = RelPositionalEncoding(d_model) + self.layer_norm = nn.LayerNorm(d_model) + self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) + self.dropout = nn.Dropout(p=dropout_p) + + def forward(self, inputs: Tensor, mask: Optional[Tensor] = None): + batch_size = inputs.size(0) + pos_embedding = self.positional_encoding(inputs) + pos_embedding = pos_embedding.repeat(batch_size, 1, 1) + + inputs = self.layer_norm(inputs) + outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask) + + return self.dropout(outputs) diff --git a/kraken/lib/layers/conformer/convolution.py b/kraken/lib/layers/conformer/convolution.py new file mode 100644 index 000000000..c3d1ea335 --- /dev/null +++ b/kraken/lib/layers/conformer/convolution.py @@ -0,0 +1,251 @@ +# Copyright (c) 2024, Benjamin Kiessling +# Copyright (c) 2021, Soohwan Kim. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import torch +import torch.nn as nn +from torch import Tensor +from typing import Tuple + +from .modules import Transpose + + +class DepthwiseConv1d(nn.Module): + """ + When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, + this operation is termed in literature as depthwise convolution. + + Args: + in_channels (int): Number of channels in the input + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + bias (bool, optional): If True, adds a learnable bias to the output. Default: True + + Inputs: inputs + - **inputs** (batch, in_channels, time): Tensor containing input vector + + Returns: outputs + - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = False, + ) -> None: + super(DepthwiseConv1d, self).__init__() + assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, inputs: Tensor) -> Tensor: + return self.conv(inputs) + + +class PointwiseConv1d(nn.Module): + """ + When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. + This operation often used to match dimensions. + + Args: + in_channels (int): Number of channels in the input + out_channels (int): Number of channels produced by the convolution + stride (int, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + bias (bool, optional): If True, adds a learnable bias to the output. Default: True + + Inputs: inputs + - **inputs** (batch, in_channels, time): Tensor containing input vector + + Returns: outputs + - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ) -> None: + super(PointwiseConv1d, self).__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, inputs: Tensor) -> Tensor: + return self.conv(inputs) + + +class ConformerConvModule(nn.Module): + """ + Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). + This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution + to aid training deep models. + + Args: + in_channels (int): Number of channels in the input + kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 + dropout_p (float, optional): probability of dropout + + Inputs: inputs + inputs (batch, time, dim): Tensor contains input sequences + + Outputs: outputs + outputs (batch, time, dim): Tensor produces by conformer convolution module. + """ + def __init__( + self, + in_channels: int, + kernel_size: int = 31, + expansion_factor: int = 2, + dropout_p: float = 0.1, + ) -> None: + super(ConformerConvModule, self).__init__() + assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" + assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" + + self.sequential = nn.Sequential( + nn.LayerNorm(in_channels), + Transpose(shape=(1, 2)), + PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True), + nn.GLU(dim=1), + DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), + nn.BatchNorm1d(in_channels), + nn.SiLU(), + PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), + nn.Dropout(p=dropout_p), + ) + + def forward(self, inputs: Tensor) -> Tensor: + return self.sequential(inputs).transpose(1, 2) + + +class Conv2dSubsampling(nn.Module): + """ + Depthwise convolutional subsampling with variable subsampling factor. + + Args: + in_channels: Number of channels in the input image + out_channels: Number of channels produced by the convolution + in_feats: Number of features in the height dimension of the input image + conv_channels: Channels of convolutio filter(s). + input_dropout_p: Dropout probability after final linear layer. + subsampling_factor: The subsampling factor which should be a power of 2 + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor containing sequence of inputs + + Returns: outputs, output_lengths + - **outputs** (batch, time, dim): Tensor produced by the convolution + - **output_lengths** (batch): list of sequence output lengths + """ + def __init__(self, + in_channels: int, + out_channels: int, + in_feats: int = 80, + conv_channels: int = 256, + input_dropout_p: float = 0.1, + subsampling_factor: int = 8) -> None: + super(Conv2dSubsampling, self).__init__() + + self._conv_channels = conv_channels + self._in_channels = in_channels + self._in_feats = in_feats + self._input_dropout_p = input_dropout_p + + if not math.log(subsampling_factor, 2).is_integer(): + raise ValueError('Sampling factor should be a power of 2.') + self._sampling_num = int(math.log(subsampling_factor, 2)) + self.subsampling_factor = subsampling_factor + + self._stride = 2 + self._kernel_size = 3 + + self._padding = (self._kernel_size - 1) // 2 + + layers = [] + layers.append(nn.Conv2d(in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._padding)) + in_channels = conv_channels + layers.append(nn.ReLU(True)) + for i in range(self._sampling_num - 1): + layers.append(nn.Conv2d(in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._padding, + groups=in_channels)) + + layers.append(nn.Conv2d(in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1)) + layers.append(nn.ReLU(True)) + in_channels = conv_channels + + self.conv = nn.Sequential(*layers) + in_length = torch.tensor(in_feats, dtype=torch.float) + out_length = calc_length(lengths=in_length, + all_paddings=2*self._padding, + kernel_size=self._kernel_size, + stride=self._stride, + repeat_num=self._sampling_num) + + self.out = nn.Sequential(nn.Linear(conv_channels * int(out_length), out_channels), + nn.Dropout(p=input_dropout_p)) + + def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: + output_lengths = calc_length(input_lengths, + all_paddings=2*self._padding, + kernel_size=self._kernel_size, + stride=self._stride, + repeat_num=self._sampling_num) + x = inputs.unsqueeze(1) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + return x, output_lengths + + +def calc_length(lengths, all_paddings, kernel_size, stride, repeat_num=1): + """ Calculates the output length of a Tensor passed through a convolution or max pooling layer""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one + lengths = torch.floor(lengths) + return lengths.to(dtype=torch.int) diff --git a/kraken/lib/layers/conformer/embedding.py b/kraken/lib/layers/conformer/embedding.py new file mode 100644 index 000000000..03e62b3b3 --- /dev/null +++ b/kraken/lib/layers/conformer/embedding.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021, Soohwan Kim. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +from torch import Tensor + + +class RelPositionalEncoding(nn.Module): + """ + Relative positional encoding module. + Args: + d_model: Embedding dimension. + max_len: Maximum input length. + """ + + def __init__(self, d_model: int = 512, max_len: int = 5000) -> None: + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + if self.pe is not None: + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + pe_positive = torch.zeros(x.size(1), self.d_model) + pe_negative = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) + ) + pe_positive[:, 0::2] = torch.sin(position * div_term) + pe_positive[:, 1::2] = torch.cos(position * div_term) + pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) + pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) + + pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) + pe_negative = pe_negative[1:].unsqueeze(0) + pe = torch.cat([pe_positive, pe_negative], dim=1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x : Input tensor B X T X C + Returns: + torch.Tensor: Encoded tensor B X T X C + """ + self.extend_pe(x) + pos_emb = self.pe[:, self.pe.size(1) // 2 - x.size(1) + 1:self.pe.size(1) // 2 + x.size(1)] + return pos_emb diff --git a/kraken/lib/layers/conformer/encoder.py b/kraken/lib/layers/conformer/encoder.py new file mode 100644 index 000000000..f06fc1ee5 --- /dev/null +++ b/kraken/lib/layers/conformer/encoder.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024, Benjamin Kiessling. All rights reserved. +# Copyright (c) 2021, Soohwan Kim. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +from torch import Tensor +from typing import Tuple + +from .feed_forward import FeedForwardModule +from .attention import MultiHeadedSelfAttentionModule +from .convolution import ConformerConvModule, Conv2dSubsampling +from .modules import ResidualConnectionModule + + +class ConformerBlock(nn.Module): + """ + Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module + and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing + the original feed-forward layer in the Transformer block into two half-step feed-forward layers, + one before the attention layer and one after. + + Args: + encoder_dim (int, optional): Dimension of conformer encoder + num_attention_heads (int, optional): Number of attention heads + feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module + conv_expansion_factor (int, optional): Expansion factor of conformer convolution module + feed_forward_dropout_p (float, optional): Probability of feed forward module dropout + attention_dropout_p (float, optional): Probability of attention module dropout + conv_dropout_p (float, optional): Probability of conformer convolution module dropout + conv_kernel_size (int or tuple, optional): Size of the convolving kernel + half_step_residual (bool): Flag indication whether to use half step residual or not + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor containing input vector + + Returns: outputs + - **outputs** (batch, time, dim): Tensor produces by conformer block. + """ + def __init__( + self, + encoder_dim: int = 512, + num_attention_heads: int = 8, + feed_forward_expansion_factor: int = 4, + conv_expansion_factor: int = 2, + feed_forward_dropout_p: float = 0.1, + attention_dropout_p: float = 0.1, + conv_dropout_p: float = 0.1, + conv_kernel_size: int = 31, + half_step_residual: bool = True, + ): + super(ConformerBlock, self).__init__() + if half_step_residual: + self.feed_forward_residual_factor = 0.5 + else: + self.feed_forward_residual_factor = 1 + + self.sequential = nn.Sequential( + ResidualConnectionModule( + module=FeedForwardModule( + encoder_dim=encoder_dim, + expansion_factor=feed_forward_expansion_factor, + dropout_p=feed_forward_dropout_p, + ), + module_factor=self.feed_forward_residual_factor, + ), + ResidualConnectionModule( + module=MultiHeadedSelfAttentionModule( + d_model=encoder_dim, + num_heads=num_attention_heads, + dropout_p=attention_dropout_p, + ), + ), + ResidualConnectionModule( + module=ConformerConvModule( + in_channels=encoder_dim, + kernel_size=conv_kernel_size, + expansion_factor=conv_expansion_factor, + dropout_p=conv_dropout_p, + ), + ), + ResidualConnectionModule( + module=FeedForwardModule( + encoder_dim=encoder_dim, + expansion_factor=feed_forward_expansion_factor, + dropout_p=feed_forward_dropout_p, + ), + module_factor=self.feed_forward_residual_factor, + ), + nn.LayerNorm(encoder_dim), + ) + + def forward(self, inputs: Tensor) -> Tensor: + return self.sequential(inputs) + + +class ConformerEncoder(nn.Module): + """ + Conformer encoder first processes the input with a convolution subsampling layer and then + with a number of conformer blocks. + + Args: + in_channels: Number of input channels + input_dim: Dimension (height) of input vector + encoder_dim: Dimension of conformer encoder + num_layers: Number of conformer blocks + num_attention_heads: Number of attention heads + feed_forward_expansion_factor: Expansion factor of feed forward module + conv_expansion_factor: Expansion factor of conformer convolution module + feed_forward_dropout_p: Probability of feed forward module dropout + attention_dropout_p: Probability of attention module dropout + conv_dropout_p: Probability of conformer convolution module dropout + conv_kernel_size: Size of the convolving kernel + half_step_residual: Flag indication whether to use half step residual or not + subsampling_conv_channels: Channels in subsampling convolutional filter + subsampling_factor: subsampling factor. Must be a power of 2. + + Inputs: inputs, input_lengths + - **inputs** (batch, time, dim): Tensor containing input vector + - **input_lengths** (batch): list of sequence input lengths + + Returns: outputs, output_lengths + - **outputs** (batch, out_channels, time): Tensor produces by conformer encoder. + - **output_lengths** (batch): list of sequence output lengths + """ + def __init__( + self, + in_channels: int = 1, + input_dim: int = 80, + encoder_dim: int = 512, + num_layers: int = 17, + num_attention_heads: int = 8, + feed_forward_expansion_factor: int = 4, + conv_expansion_factor: int = 2, + input_dropout_p: float = 0.1, + feed_forward_dropout_p: float = 0.1, + attention_dropout_p: float = 0.1, + conv_dropout_p: float = 0.1, + conv_kernel_size: int = 31, + half_step_residual: bool = True, + subsampling_conv_channels: int = 256, + subsampling_factor: int = 4, + ): + super(ConformerEncoder, self).__init__() + self.conv_subsample = Conv2dSubsampling(in_channels=in_channels, + out_channels=encoder_dim, + in_feats=input_dim, + conv_channels=subsampling_conv_channels, + input_dropout_p=input_dropout_p, + subsampling_factor=subsampling_factor) + + self.layers = nn.ModuleList([ConformerBlock( + encoder_dim=encoder_dim, + num_attention_heads=num_attention_heads, + feed_forward_expansion_factor=feed_forward_expansion_factor, + conv_expansion_factor=conv_expansion_factor, + feed_forward_dropout_p=feed_forward_dropout_p, + attention_dropout_p=attention_dropout_p, + conv_dropout_p=conv_dropout_p, + conv_kernel_size=conv_kernel_size, + half_step_residual=half_step_residual, + ) for _ in range(num_layers)]) + + def count_parameters(self) -> int: + """ Count parameters of encoder """ + return sum([p.numel() for p in self.parameters()]) + + def update_dropout(self, dropout_p: float) -> None: + """ Update dropout probability of encoder """ + for name, child in self.named_children(): + if isinstance(child, nn.Dropout): + child.p = dropout_p + + def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: + """ + Forward propagate a `inputs` for encoder training. + + Args: + inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded + `FloatTensor` of size ``(batch, seq_length, dimension)``. + input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` + + Returns: + (Tensor, Tensor) + + * outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size + ``(batch, seq_length, dimension)`` + * output_lengths (torch.LongTensor): The length of output tensor. ``(batch)`` + """ + outputs, output_lengths = self.conv_subsample(inputs, input_lengths) + + for layer in self.layers: + outputs = layer(outputs) + + return outputs, output_lengths diff --git a/kraken/lib/layers/conformer/feed_forward.py b/kraken/lib/layers/conformer/feed_forward.py new file mode 100644 index 000000000..81794cae6 --- /dev/null +++ b/kraken/lib/layers/conformer/feed_forward.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, Benjamin Kiessling +# Copyright (c) 2021, Soohwan Kim. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +from torch import Tensor + + +class FeedForwardModule(nn.Module): + """ + Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit + and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps + regularizing the network. + + Args: + encoder_dim (int): Dimension of conformer encoder + expansion_factor (int): Expansion factor of feed forward module. + dropout_p (float): Ratio of dropout + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor contains input sequences + + Outputs: outputs + - **outputs** (batch, time, dim): Tensor produces by feed forward module. + """ + def __init__( + self, + encoder_dim: int = 512, + expansion_factor: int = 4, + dropout_p: float = 0.1, + ) -> None: + super(FeedForwardModule, self).__init__() + self.sequential = nn.Sequential( + nn.LayerNorm(encoder_dim), + nn.Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), + nn.SiLU(), + nn.Dropout(p=dropout_p), + nn.Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), + nn.Dropout(p=dropout_p), + ) + + def forward(self, inputs: Tensor) -> Tensor: + return self.sequential(inputs) diff --git a/kraken/lib/layers/conformer/model.py b/kraken/lib/layers/conformer/model.py new file mode 100644 index 000000000..2fc5198d6 --- /dev/null +++ b/kraken/lib/layers/conformer/model.py @@ -0,0 +1,89 @@ +# +# Copyright 2024 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +kraken.lib.layers.conformer.model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +FasterConformer model +""" +import torch + +from torch import nn +from typing import Tuple + +from .encoder import ConformerEncoder + + +class ConformerRecognitionModule(nn.Module): + def __init__(self, + num_classes: int, + height: int, + encoder_dim: int, + num_encoder_layers: int, + num_attention_heads: int, + feed_forward_expansion_factor: int, + conv_expansion_factor: int, + input_dropout_p: float, + feed_forward_dropout_p: float, + attention_dropout_p: float, + conv_dropout_p: float, + conv_kernel_size: int, + half_step_residual: bool, + subsampling_conv_channels: int, + subsampling_factor: int, + **kwargs): + """ + A nn.Module version of a conformer_ocr.model.RecognitionModel for + inference. + """ + super().__init__() + encoder = ConformerEncoder(in_channels=1, + input_dim=height, + encoder_dim=encoder_dim, + num_layers=num_encoder_layers, + num_attention_heads=num_attention_heads, + feed_forward_expansion_factor=feed_forward_expansion_factor, + conv_expansion_factor=conv_expansion_factor, + input_dropout_p=input_dropout_p, + feed_forward_dropout_p=feed_forward_dropout_p, + attention_dropout_p=attention_dropout_p, + conv_dropout_p=conv_dropout_p, + conv_kernel_size=conv_kernel_size, + half_step_residual=half_step_residual, + subsampling_conv_channels=subsampling_conv_channels, + subsampling_factor=subsampling_factor) + decoder = nn.Linear(encoder_dim, num_classes, bias=True) + self.nn = nn.ModuleDict({'encoder': encoder, + 'decoder': decoder}) + + def forward(self, line: torch.Tensor, lens: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Performs a forward pass on a torch tensor of one or more lines with + shape (N, C, H, W) and returns a numpy array (N, W, C). + + Args: + line: NCHW line(s) tensor + lens: N-shape Tensor containing sequence lengths + + Returns: + Tuple with (N, W, C) shaped numpy array and final output sequence + lengths. + """ + if self.device: + line = line.to(self.device) + line = line.squeeze(1).transpose(1, 2) + encoder_outputs, encoder_lens = self.nn.encoder(line, lens) + logits = self.nn.decoder(encoder_outputs) + return logits, encoder_lens diff --git a/kraken/lib/layers/conformer/modules.py b/kraken/lib/layers/conformer/modules.py new file mode 100644 index 000000000..2496ecd75 --- /dev/null +++ b/kraken/lib/layers/conformer/modules.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021, Soohwan Kim. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +from torch import Tensor + + +class ResidualConnectionModule(nn.Module): + """ + Residual Connection Module. + outputs = (module(inputs) x module_factor + inputs x input_factor) + """ + def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0): + super(ResidualConnectionModule, self).__init__() + self.module = module + self.module_factor = module_factor + self.input_factor = input_factor + + def forward(self, inputs: Tensor) -> Tensor: + return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) + + +class Transpose(nn.Module): + """ Wrapper class of torch.transpose() for Sequential module. """ + def __init__(self, shape: tuple): + super(Transpose, self).__init__() + self.shape = shape + + def forward(self, x: Tensor) -> Tensor: + return x.transpose(*self.shape) diff --git a/kraken/lib/layers.py b/kraken/lib/layers/vgsl_layers.py similarity index 99% rename from kraken/lib/layers.py rename to kraken/lib/layers/vgsl_layers.py index a937fdd32..6519392dd 100644 --- a/kraken/lib/layers.py +++ b/kraken/lib/layers/vgsl_layers.py @@ -19,7 +19,8 @@ # all tensors are ordered NCHW, the "feature" dimension is C, so the output of # an LSTM will be put into C same as the filters of a CNN. -__all__ = ['Addition', 'MaxPool', 'Reshape', 'Dropout', 'TransposedSummarizingRNN', 'LinSoftmax', 'ActConv2D'] +__all__ = ['Addition', 'MaxPool', 'Reshape', 'Dropout', + 'TransposedSummarizingRNN', 'LinSoftmax', 'ActConv2D', 'GroupNorm'] class MultiParamSequential(Sequential):