diff --git a/dalle_pytorch/attention.py b/dalle_pytorch/attention.py index 85a5c152..0a954748 100644 --- a/dalle_pytorch/attention.py +++ b/dalle_pytorch/attention.py @@ -1,4 +1,5 @@ from inspect import isfunction +from math import ceil import torch from torch import nn, einsum @@ -287,6 +288,7 @@ def __init__( self, *args, block_size = 16, + text_seq_len = 256, num_random_blocks = None, **kwargs ): @@ -301,6 +303,7 @@ def __init__( num_heads = self.heads, block = self.block_size, num_random_blocks = num_random_blocks, + global_block_indices = list(range(ceil(text_seq_len / block_size))), attention = 'unidirectional' if self.causal else 'bidirectional' ), max_seq_length = self.seq_len, diff --git a/setup.py b/setup.py index 7b40bf43..5338b856 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '0.1.26', + version = '0.1.27', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',