-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathflash_attention.py
234 lines (169 loc) · 7.67 KB
/
flash_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import math
import torch
from functools import partial
from torch import nn, einsum
from torch.autograd.function import Function
from einops import rearrange
# constants
EPSILON = 1e-10
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# flash attention forwards and backwards
# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf
class FlashAttentionFunction(Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
""" Algorithm 1 in the v2 paper """
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)
scale = (q.shape[-1] ** -0.5)
num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)
if exists(mask) and mask.ndim == 2:
mask = rearrange(mask, 'b n -> b 1 1 n')
if not exists(mask):
col_masks = (None,) * num_col_tiles
mask = (col_masks,) * num_row_tiles
else:
mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim = -2)
mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim = -1) for row_mask in mask)
row_splits = zip(
q.split(q_bucket_size, dim = -2),
o.split(q_bucket_size, dim = -2),
mask,
all_row_sums.split(q_bucket_size, dim = -2),
all_row_maxes.split(q_bucket_size, dim = -2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim = -2),
v.split(k_bucket_size, dim = -2),
row_mask
)
for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if exists(col_mask):
attn_weights.masked_fill_(~col_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_weights = torch.exp(attn_weights - new_row_maxes)
if exists(col_mask):
exp_weights.masked_fill_(~col_mask, 0.)
block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + block_row_sums
oc.mul_(exp_row_max_diff).add_(exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
oc.div_(row_sums)
lse = all_row_sums.log() + all_row_maxes
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, lse)
return o
@staticmethod
@torch.no_grad()
def backward(ctx, do):
""" Algorithm 2 in the v2 paper """
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, lse = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim = -2),
o.split(q_bucket_size, dim = -2),
do.split(q_bucket_size, dim = -2),
mask,
lse.split(q_bucket_size, dim = -2),
dq.split(q_bucket_size, dim = -2)
)
for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim = -2),
v.split(k_bucket_size, dim = -2),
dk.split(k_bucket_size, dim = -2),
dv.split(k_bucket_size, dim = -2),
row_mask
)
for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
p = torch.exp(attn_weights - lsec)
if exists(col_mask):
p.masked_fill_(~col_mask, 0.)
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
D = (doc * oc).sum(dim = -1, keepdims = True)
ds = p * scale * (dp - D)
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
# main class
# just flash attention in plain pytorch
# it will be way slower than implementing it in CUDA
# for tinkering and educational purposes
class FlashAttention(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
dim_head = 64,
causal = False,
q_bucket_size = 512,
k_bucket_size = 1024
):
super().__init__()
self.heads = heads
self.causal = causal
inner_dim = heads * dim_head
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# memory efficient attention related parameters
# can be overriden on forward
self.q_bucket_size = q_bucket_size
self.k_bucket_size = k_bucket_size
def forward(
self,
x,
context = None,
mask = None,
q_bucket_size = None,
k_bucket_size = None,
):
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
h = self.heads
context = default(context, x)
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)