From 25384140540c5a1f1d8eee77bfba8b84277b766c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 21 Dec 2023 11:20:36 -0800 Subject: [PATCH] add a gateloop block for 2d itransformer, time axis --- iTransformer/iTransformer2D.py | 7 ++++++- setup.py | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/iTransformer/iTransformer2D.py b/iTransformer/iTransformer2D.py index 6d9d0c5..0288c47 100644 --- a/iTransformer/iTransformer2D.py +++ b/iTransformer/iTransformer2D.py @@ -15,6 +15,7 @@ from iTransformer.attend import Attend from iTransformer.revin import RevIN +from gateloop_transformer import SimpleGateLoopLayer from rotary_embedding_torch import RotaryEmbedding # helper functions @@ -194,6 +195,7 @@ def __init__( for _ in range(depth): self.layers.append(ModuleList([ + SimpleGateLoopLayer(dim = dim), TransformerBlock(causal = True, rotary_emb = rotary_emb, **block_kwargs), TransformerBlock(causal = False, **block_kwargs) ])) @@ -269,9 +271,12 @@ def forward( # attention and feedforward layers - for time_attn_block, variate_attn_block in self.layers: + for gateloop_block, time_attn_block, variate_attn_block in self.layers: x, ps = pack_one(x, '* t d') + # gateloop block + x = gateloop_block(x) + x + # causal attention across time for each variate x = time_attn_block(x) diff --git a/setup.py b/setup.py index 7d49c1e..bd63799 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'iTransformer', packages = find_packages(exclude=[]), - version = '0.4.4', + version = '0.5.0', license='MIT', description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting', author = 'Phil Wang', @@ -20,6 +20,7 @@ install_requires=[ 'beartype', 'einops>=0.7.0', + 'gateloop-transformer>=0.5.1', 'rotary-embedding-torch', 'torch>=1.6', ],