Skip to content

Commit

Permalink
cleanup conv-like attention
Browse files Browse the repository at this point in the history
lucidrains committed Mar 24, 2022

Verified

This commit was signed with the committer’s verified signature. The key has expired.
juliusknorr Julius Knorr
1 parent 3064403 commit 2969103
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
@@ -177,18 +177,18 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)

# calculate causal attention for local convolution
# use padding of 0 on tensor of 1s and unfold for padding mask

i, j = dots_image.shape[-2:]
img_seq = torch.arange(img_seq_len, device = device)
k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size)
k_img_indices = F.pad(k_img_indices, causal_padding, value = img_seq_len) # padding set to be max, so it is never attended to
k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation)
k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')
ones = torch.ones((img_seq_len,), device = device)
ones = rearrange(ones, '(h w) -> () () h w', h = img_size)
ones = F.pad(ones, causal_padding, value = 0.)
ones = F.unfold(ones, kernel_size, dilation = dilation)
ones = rearrange(ones, 'b j i -> b i j')

# mask image attention

padding_mask = k_img_indices == img_seq_len
padding_mask = ones == 0.

# concat text mask with image causal mask

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.5.0',
version = '1.5.1',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',

0 comments on commit 2969103

Please sign in to comment.