Skip to content

Commit

Permalink
address #44
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 23, 2024
1 parent e085ec3 commit 06edbd3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,12 +1233,12 @@ def __init__(

encoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond = dim_cond)))
Residual(TokenShift(FeedForward(dim)))
)

decoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond = dim_cond)))
Residual(TokenShift(FeedForward(dim)))
)

elif layer_type == 'cond_attend_space':
Expand All @@ -1255,12 +1255,12 @@ def __init__(

encoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim))
Residual(FeedForward(dim, dim_cond = dim_cond))
)

decoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim))
Residual(FeedForward(dim, dim_cond = dim_cond))
)

elif layer_type == 'cond_linear_attend_space':
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.8'
__version__ = '0.4.9'

0 comments on commit 06edbd3

Please sign in to comment.