Skip to content

Commit

Permalink
add value residual learning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 23, 2024
1 parent fa2773a commit 3dc0f0d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 10 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,14 @@ preds = model(time_series)
url = {https://api.semanticscholar.org/CorpusID:265018962}
}
```

```bibtex
@article{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
journal = {ArXiv},
year = {2024},
volume = {abs/2410.17897},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
13 changes: 12 additions & 1 deletion iTransformer/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from einops import rearrange, repeat

from torch.nn.attention import SDPBackend

# constants

EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
Expand Down Expand Up @@ -92,9 +94,18 @@ def flash_attn(

config = self.cuda_config if is_cuda else self.cpu_config

str_to_backend = dict(
enable_flash = SDPBackend.FLASH_ATTENTION,
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
enable_math = SDPBackend.MATH,
enable_cudnn = SDPBackend.CUDNN_ATTENTION
)

sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in config._asdict().items() if enable]

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
with torch.nn.attention.sdpa_kernel(sdpa_backends):
out = F.scaled_dot_product_attention(
q, k, v,
is_causal = self.causal,
Expand Down
43 changes: 37 additions & 6 deletions iTransformer/iTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(
dim_head = 32,
heads = 4,
dropout = 0.,
flash = True
flash = True,
learned_value_residual_mix = False
):
super().__init__()
self.scale = dim_head ** -0.5
Expand All @@ -46,6 +47,12 @@ def __init__(
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
)

self.to_value_residual_mix = nn.Sequential(
nn.Linear(dim, heads, bias = False),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if learned_value_residual_mix else None

self.to_v_gates = nn.Sequential(
nn.Linear(dim, heads, bias = False),
nn.Sigmoid(),
Expand All @@ -60,13 +67,25 @@ def __init__(
nn.Dropout(dropout)
)

def forward(self, x):
def forward(
self,
x,
value_residual = None
):
q, k, v = self.to_qkv(x)

orig_v = v

if exists(self.to_value_residual_mix):
assert exists(value_residual)
mix = self.to_value_residual_mix(x)
v = v.lerp(value_residual, mix)

out = self.attend(q, k, v)

out = out * self.to_v_gates(x)
return self.to_out(out)

return self.to_out(out), orig_v

# feedforward

Expand Down Expand Up @@ -120,9 +139,11 @@ def __init__(
self.num_tokens_per_variate = num_tokens_per_variate

self.layers = ModuleList([])
for _ in range(depth):
for i in range(depth):
is_first = i == 0

self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn),
Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn, learned_value_residual_mix = not is_first),
nn.LayerNorm(dim),
FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
nn.LayerNorm(dim)
Expand Down Expand Up @@ -180,10 +201,20 @@ def forward(
m = repeat(self.mem_tokens, 'm d -> b m d', b = x.shape[0])
x, mem_ps = pack([m, x], 'b * d')

# value residual learning
# https://arxiv.org/abs/2410.17897

first_values = None

# attention and feedforward layers

for attn, attn_post_norm, ff, ff_post_norm in self.layers:
x = attn(x) + x

attn_out, values = attn(x, value_residual = first_values)
first_values = default(first_values, values)

x = x + attn_out

x = attn_post_norm(x)
x = ff(x) + x
x = ff_post_norm(x)
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'iTransformer',
packages = find_packages(exclude=[]),
version = '0.6.0',
version = '0.7.0',
license='MIT',
description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting',
author = 'Phil Wang',
Expand All @@ -19,10 +19,10 @@
],
install_requires=[
'beartype',
'einops>=0.7.0',
'einops>=0.8.0',
'gateloop-transformer>=0.2.3',
'rotary-embedding-torch',
'torch>=2.1',
'torch>=2.3',
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 3dc0f0d

Please sign in to comment.