-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathtransformer_improved.py
67 lines (54 loc) · 2.44 KB
/
transformer_improved.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import copy
from torch.nn import functional as F
from torch.nn.modules.module import Module
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.container import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.rnn import LSTM
from torch.nn.modules.normalization import LayerNorm
class TransformerEncoderLayer(Module):
def __init__(self, d_model, nhead, hidden_size, dim_feedforward, dropout, activation="relu"):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of improved part
self.lstm = LSTM(d_model, hidden_size, 1, bidirectional=True)
self.dropout = Dropout(dropout)
self.linear = Linear(hidden_size*2, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
r"""Pass the input through the encoder layer.
Args:
src: the sequnce to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear(self.dropout(self.activation(self.lstm(src)[0])))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def _get_clones(module, N):
return ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))