Skip to content

Commit

Permalink
flash attention all the way
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 11, 2023
1 parent 58e068b commit 62c2e69
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 13 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,12 @@ preds = model(time_series)
url = {https://api.semanticscholar.org/CorpusID:263134283}
}
```

```bibtex
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```
146 changes: 146 additions & 0 deletions iTransformer/attend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from functools import partial

import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version

from einops import rearrange, repeat

# constants

EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

print_once = once(print)

# main class

class Attend(nn.Module):
def __init__(
self,
*,
dropout = 0.,
heads = None,
scale = None,
flash = False,
):
super().__init__()
self.scale = scale

self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

# flash attention

self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

# determine efficient attention configs for cuda and cpu

self.cpu_config = EfficientAttentionConfig(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not flash:
return

device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

major, minor = device_properties.major, device_properties.minor

if (major, minor) == (8, 0):
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(True, False, False)
elif (major, minor) == (9, 0):
print_once('H100 GPU detected, using flash attention')
self.cuda_config = EfficientAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(False, True, True)

def flash_attn(
self,
q, k, v,
mask = None
):
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

# expand key padding mask

if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.
)

return out

def forward(
self,
q, k, v,
mask = None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""

n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device

scale = default(self.scale, q.shape[-1] ** -0.5)

if self.flash:
return self.flash_attn(q, k, v, mask = mask)

sim = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale

i, j, dtype = *sim.shape[-2:], sim.dtype

mask_value = -torch.finfo(sim.dtype).max

if exists(mask):
sim = sim.masked_fill(~mask, mask_value)

attn = sim.softmax(dim = -1)
attn = attn.type(dtype)

attn = self.attn_dropout(attn)

out = einsum(f'b h i j, b h j d -> b h i d', attn, v)

return out
21 changes: 9 additions & 12 deletions iTransformer/iTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange

from iTransformer.attend import Attend

# helper functions

def exists(v):
Expand All @@ -28,7 +30,8 @@ def __init__(
dim,
dim_head = 32,
heads = 4,
dropout = 0.
dropout = 0.,
flash = True
):
super().__init__()
self.scale = dim_head ** -0.5
Expand All @@ -39,7 +42,7 @@ def __init__(
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
)

self.dropout = nn.Dropout(dropout)
self.attend = Attend(flash = flash, dropout = dropout)

self.to_out = nn.Sequential(
Rearrange('b h n d -> b n (h d)'),
Expand All @@ -50,14 +53,7 @@ def __init__(
def forward(self, x):
q, k, v = self.to_qkv(x)

q = q * self.scale

sim = einsum('b h i d, b h j d -> b h i j', q, k)

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = self.attend(q, k, v)

return self.to_out(out)

Expand Down Expand Up @@ -90,7 +86,8 @@ def __init__(
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
num_mem_tokens = 4
num_mem_tokens = 4,
flash_attn = True
):
super().__init__()
self.num_variates = num_variates
Expand All @@ -104,7 +101,7 @@ def __init__(
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn),
nn.LayerNorm(dim),
FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
nn.LayerNorm(dim)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'iTransformer',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting',
author = 'Phil Wang',
Expand Down

0 comments on commit 62c2e69

Please sign in to comment.