From 3dc0f0db7714e0a50397645bd0f905b74ecb52bd Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 23 Dec 2024 15:37:34 -0800 Subject: [PATCH] add value residual learning --- README.md | 11 +++++++++ iTransformer/attend.py | 13 ++++++++++- iTransformer/iTransformer.py | 43 +++++++++++++++++++++++++++++++----- setup.py | 6 ++--- 4 files changed, 63 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index a4c1eb1..ad2c1f3 100644 --- a/README.md +++ b/README.md @@ -201,3 +201,14 @@ preds = model(time_series) url = {https://api.semanticscholar.org/CorpusID:265018962} } ``` + +```bibtex +@article{Zhou2024ValueRL, + title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, + author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, + journal = {ArXiv}, + year = {2024}, + volume = {abs/2410.17897}, + url = {https://api.semanticscholar.org/CorpusID:273532030} +} +``` diff --git a/iTransformer/attend.py b/iTransformer/attend.py index 39c61e1..9e880c5 100644 --- a/iTransformer/attend.py +++ b/iTransformer/attend.py @@ -10,6 +10,8 @@ from einops import rearrange, repeat +from torch.nn.attention import SDPBackend + # constants EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) @@ -92,9 +94,18 @@ def flash_attn( config = self.cuda_config if is_cuda else self.cpu_config + str_to_backend = dict( + enable_flash = SDPBackend.FLASH_ATTENTION, + enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION, + enable_math = SDPBackend.MATH, + enable_cudnn = SDPBackend.CUDNN_ATTENTION + ) + + sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in config._asdict().items() if enable] + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale - with torch.backends.cuda.sdp_kernel(**config._asdict()): + with torch.nn.attention.sdpa_kernel(sdpa_backends): out = F.scaled_dot_product_attention( q, k, v, is_causal = self.causal, diff --git a/iTransformer/iTransformer.py b/iTransformer/iTransformer.py index 5aabce2..3a2362c 100644 --- a/iTransformer/iTransformer.py +++ b/iTransformer/iTransformer.py @@ -35,7 +35,8 @@ def __init__( dim_head = 32, heads = 4, dropout = 0., - flash = True + flash = True, + learned_value_residual_mix = False ): super().__init__() self.scale = dim_head ** -0.5 @@ -46,6 +47,12 @@ def __init__( Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads) ) + self.to_value_residual_mix = nn.Sequential( + nn.Linear(dim, heads, bias = False), + Rearrange('b n h -> b h n 1'), + nn.Sigmoid() + ) if learned_value_residual_mix else None + self.to_v_gates = nn.Sequential( nn.Linear(dim, heads, bias = False), nn.Sigmoid(), @@ -60,13 +67,25 @@ def __init__( nn.Dropout(dropout) ) - def forward(self, x): + def forward( + self, + x, + value_residual = None + ): q, k, v = self.to_qkv(x) + orig_v = v + + if exists(self.to_value_residual_mix): + assert exists(value_residual) + mix = self.to_value_residual_mix(x) + v = v.lerp(value_residual, mix) + out = self.attend(q, k, v) out = out * self.to_v_gates(x) - return self.to_out(out) + + return self.to_out(out), orig_v # feedforward @@ -120,9 +139,11 @@ def __init__( self.num_tokens_per_variate = num_tokens_per_variate self.layers = ModuleList([]) - for _ in range(depth): + for i in range(depth): + is_first = i == 0 + self.layers.append(ModuleList([ - Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn), + Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn, learned_value_residual_mix = not is_first), nn.LayerNorm(dim), FeedForward(dim, mult = ff_mult, dropout = ff_dropout), nn.LayerNorm(dim) @@ -180,10 +201,20 @@ def forward( m = repeat(self.mem_tokens, 'm d -> b m d', b = x.shape[0]) x, mem_ps = pack([m, x], 'b * d') + # value residual learning + # https://arxiv.org/abs/2410.17897 + + first_values = None + # attention and feedforward layers for attn, attn_post_norm, ff, ff_post_norm in self.layers: - x = attn(x) + x + + attn_out, values = attn(x, value_residual = first_values) + first_values = default(first_values, values) + + x = x + attn_out + x = attn_post_norm(x) x = ff(x) + x x = ff_post_norm(x) diff --git a/setup.py b/setup.py index 2904162..2199234 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'iTransformer', packages = find_packages(exclude=[]), - version = '0.6.0', + version = '0.7.0', license='MIT', description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting', author = 'Phil Wang', @@ -19,10 +19,10 @@ ], install_requires=[ 'beartype', - 'einops>=0.7.0', + 'einops>=0.8.0', 'gateloop-transformer>=0.2.3', 'rotary-embedding-torch', - 'torch>=2.1', + 'torch>=2.3', ], classifiers=[ 'Development Status :: 4 - Beta',