-
Notifications
You must be signed in to change notification settings - Fork 175
/
token_performer.py
60 lines (51 loc) · 2.31 KB
/
token_performer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Take Performer as T2T Transformer
"""
import math
import torch
import torch.nn as nn
class Token_performer(nn.Module):
def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1):
super().__init__()
self.emb = in_dim * head_cnt # we use 1, so it is no need here
self.kqv = nn.Linear(dim, 3 * self.emb)
self.dp = nn.Dropout(dp1)
self.proj = nn.Linear(self.emb, self.emb)
self.head_cnt = head_cnt
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(self.emb)
self.epsilon = 1e-8 # for stable in division
self.mlp = nn.Sequential(
nn.Linear(self.emb, 1 * self.emb),
nn.GELU(),
nn.Linear(1 * self.emb, self.emb),
nn.Dropout(dp2),
)
self.m = int(self.emb * kernel_ratio)
self.w = torch.randn(self.m, self.emb)
self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False)
def prm_exp(self, x):
# part of the function is borrow from https://github.com/lucidrains/performer-pytorch
# and Simo Ryu (https://github.com/cloneofsimo)
# ==== positive random features for gaussian kernels ====
# x = (B, T, hs)
# w = (m, hs)
# return : x : B, T, m
# SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)]
# therefore return exp(w^Tx - |x|/2)/sqrt(m)
xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2
wtx = torch.einsum('bti,mi->btm', x.float(), self.w)
return torch.exp(wtx - xd) / math.sqrt(self.m)
def single_attn(self, x):
k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)
kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m)
D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1)
kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m)
y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag
# skip connection
y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection
return y
def forward(self, x):
x = self.single_attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x