Skip to content

Commit

Permalink
make sure text receives global attention for microsoft deepspeed spar…
Browse files Browse the repository at this point in the history
…se attention
  • Loading branch information
lucidrains committed Feb 25, 2021
1 parent 0106211 commit 8a0c45a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from inspect import isfunction
from math import ceil

import torch
from torch import nn, einsum
Expand Down Expand Up @@ -287,6 +288,7 @@ def __init__(
self,
*args,
block_size = 16,
text_seq_len = 256,
num_random_blocks = None,
**kwargs
):
Expand All @@ -301,6 +303,7 @@ def __init__(
num_heads = self.heads,
block = self.block_size,
num_random_blocks = num_random_blocks,
global_block_indices = list(range(ceil(text_seq_len / block_size))),
attention = 'unidirectional' if self.causal else 'bidirectional'
),
max_seq_length = self.seq_len,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.1.26',
version = '0.1.27',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 8a0c45a

Please sign in to comment.