Skip to content

Commit

Permalink
add a gateloop block for 2d itransformer, time axis
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 21, 2023
1 parent f74ba9d commit 2538414
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
7 changes: 6 additions & 1 deletion iTransformer/iTransformer2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from iTransformer.attend import Attend
from iTransformer.revin import RevIN

from gateloop_transformer import SimpleGateLoopLayer
from rotary_embedding_torch import RotaryEmbedding

# helper functions
Expand Down Expand Up @@ -194,6 +195,7 @@ def __init__(

for _ in range(depth):
self.layers.append(ModuleList([
SimpleGateLoopLayer(dim = dim),
TransformerBlock(causal = True, rotary_emb = rotary_emb, **block_kwargs),
TransformerBlock(causal = False, **block_kwargs)
]))
Expand Down Expand Up @@ -269,9 +271,12 @@ def forward(

# attention and feedforward layers

for time_attn_block, variate_attn_block in self.layers:
for gateloop_block, time_attn_block, variate_attn_block in self.layers:
x, ps = pack_one(x, '* t d')

# gateloop block
x = gateloop_block(x) + x

# causal attention across time for each variate
x = time_attn_block(x)

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'iTransformer',
packages = find_packages(exclude=[]),
version = '0.4.4',
version = '0.5.0',
license='MIT',
description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting',
author = 'Phil Wang',
Expand All @@ -20,6 +20,7 @@
install_requires=[
'beartype',
'einops>=0.7.0',
'gateloop-transformer>=0.5.1',
'rotary-embedding-torch',
'torch>=1.6',
],
Expand Down

0 comments on commit 2538414

Please sign in to comment.