-
Notifications
You must be signed in to change notification settings - Fork 30
/
Flow_Attention.py
142 lines (129 loc) · 6.81 KB
/
Flow_Attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
## Core code for Flow-Attention, Please refer to each folder for corresponding experiments
class Flow_Attention(nn.Module):
# flow attention in normal version
def __init__(self, d_input, d_model, d_output, n_heads, drop_out=0.05, eps=1e-6):
super(Flow_Attention, self).__init__()
self.n_heads = n_heads
self.query_projection = nn.Linear(d_input, d_model)
self.key_projection = nn.Linear(d_input, d_model)
self.value_projection = nn.Linear(d_input, d_model)
self.out_projection = nn.Linear(d_model, d_output)
self.dropout = nn.Dropout(drop_out)
self.eps = eps
def kernel_method(self, x):
return torch.sigmoid(x)
def dot_product(self, q, k, v):
kv = torch.einsum("nhld,nhlm->nhdm", k, v)
qkv = torch.einsum("nhld,nhdm->nhlm", q, kv)
return qkv
def forward(self, queries, keys, values):
## input: B (L or S) D; output: B L D
## Note: queries, keys, values are not projected yet
## 1. Linear projection
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = self.query_projection(queries).view(B, L, self.n_heads, -1)
keys = self.key_projection(keys).view(B, S, self.n_heads, -1)
values = self.value_projection(values).view(B, S, self.n_heads, -1)
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# 2. Non-negative projection
queries = self.kernel_method(queries)
keys = self.kernel_method(keys)
## 3. Flow-Attention
# (1) Calculate incoming and outgoing flow
sink_incoming = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + self.eps, keys.sum(dim=2) + self.eps))
source_outgoing = 1.0 / (torch.einsum("nhld,nhd->nhl", keys + self.eps, queries.sum(dim=2) + self.eps))
# (2) conservation refine for source and sink
conserved_sink = torch.einsum("nhld,nhd->nhl", queries + self.eps,
(keys * source_outgoing[:, :, :, None]).sum(dim=2) + self.eps)
conserved_source = torch.einsum("nhld,nhd->nhl", keys + self.eps,
(queries * sink_incoming[:, :, :, None]).sum(dim=2) + self.eps)
conserved_source = torch.clamp(conserved_source, min=-1.0, max=1.0) # for stability
# (3) Competition & Allocation
sink_allocation = torch.sigmoid(conserved_sink * (float(queries.shape[2]) / float(keys.shape[2])))
source_competition = torch.softmax(conserved_source, dim=-1) * float(keys.shape[2])
# (4) dot product
x = (self.dot_product(queries * sink_incoming[:, :, :, None], # for value normalization
keys,
values * source_competition[:, :, :, None]) # competition
* sink_allocation[:, :, :, None]).transpose(1, 2) # allocation
## (5) Final projection
x = x.reshape(B, L, -1)
x = self.out_projection(x)
x = self.dropout(x)
return x
class Flow_Attention_Causal(nn.Module):
# flow attention in causal version
def __init__(self, d_input, d_model, d_output, n_heads, drop_out=0.05, eps=1e-6):
super(Flow_Attention_Causal, self).__init__()
self.n_heads = n_heads
self.query_projection = nn.Linear(d_input, d_model)
self.key_projection = nn.Linear(d_input, d_model)
self.value_projection = nn.Linear(d_input, d_model)
self.out_projection = nn.Linear(d_model, d_output)
self.dropout = nn.Dropout(drop_out)
self.eps = eps
def kernel_method(self, x):
return torch.sigmoid(x)
def causal_dot_product(self, q, k, v):
kv = torch.einsum("nhld,nhlm->nhldm", k, v)
kv = torch.cumsum(kv, dim=2)
qkv = torch.einsum("nhld,nhldm->nhlm", q, kv)
return qkv
def forward(self, queries, keys, values):
## input: B (L or S) D; output: B L D
## Note: queries, keys, values are not projected yet
## 1. Linear projection
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = self.query_projection(queries).view(B, L, self.n_heads, -1)
keys = self.key_projection(keys).view(B, S, self.n_heads, -1)
values = self.value_projection(values).view(B, S, self.n_heads, -1)
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# 2. Non-negative projection
queries = self.kernel_method(queries)
keys = self.kernel_method(keys)
## 3. Causal Flow-Attention
# (1) Calculate incoming and outgoing flow
sink_incoming = 1.0 / (torch.einsum("nhld,nhld->nhl", queries + self.eps, keys.cumsum(dim=2) + self.eps))
source_outgoing = 1.0 / (torch.einsum("nhld,nhld->nhl", keys + self.eps, queries.cumsum(dim=2) + self.eps))
# approximate normal conservation col and row by multiplying corresponding element number
normal = (((torch.arange(queries.shape[2])).float() + 1.0)).to(queries.device)[None, None, :]
sink_incoming = sink_incoming * normal
source_outgoing = source_outgoing * normal
# (2) conservation refine for source and sink
conserved_sink = torch.einsum("nhld,nhld->nhl", queries + self.eps,
(keys * source_outgoing[:, :, :, None]).cumsum(dim=2) + self.eps) / normal
conserved_source = torch.einsum("nhld,nhld->nhl", keys + self.eps,
(queries * sink_incoming[:, :, :, None]).cumsum(
dim=2) + self.eps) / normal
conserved_source = torch.clamp(conserved_source, min=-1.0, max=1.0) # for stability
# (3) Competition & Allocation
sink_allocation = torch.sigmoid(conserved_sink)
conserved_source = torch.exp(conserved_source)
source_competition = (conserved_source / conserved_source.cumsum(dim=-1)) * normal
# (4) Causal dot product
x = (self.causal_dot_product(queries * (sink_incoming[:, :, :, None] / normal[:, :, :, None]), # for value normalization
keys,
values * source_competition[:, :, :, None]) # competition
* sink_allocation[:, :, :, None]).transpose(1, 2) # allocation
## (5) Final projection
x = x.reshape(B, L, -1)
x = self.out_projection(x)
x = self.dropout(x)
return x
if __name__ == '__main__':
# just for simple test
attn_normal = Flow_Attention(10, 16, 16, 4)
attn_causal = Flow_Attention_Causal(10, 16, 16, 4)
x = torch.rand([1, 100, 10])
x_attn_normal = attn_normal(x, x, x)
x_attn_causal = attn_causal(x, x, x)
assert x_attn_normal.shape == (1, 100, 16)
assert x_attn_causal.shape == (1, 100, 16)