Skip to content

Commit

Permalink
complete halo attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 24, 2021
1 parent f6817ae commit 942e002
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 5 deletions.
28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
<img src="./halonet.png" width="500px"></img>

## HaloNet - Pytorch (wip)
## HaloNet - Pytorch

Implementation of the Attention layer from the paper, <a href="https://arxiv.org/abs/2103.12731">Scaling Local Self-Attention For Parameter Efficient Visual Backbones</a>
Implementation of the Attention layer from the paper, <a href="https://arxiv.org/abs/2103.12731">Scaling Local Self-Attention For Parameter Efficient Visual Backbones</a>. This repository will only house the attention layer and not much more.


## Install

```bash
$ pip install halonet-pytorch
```

## Usage

```python
import torch
from halonet_pytorch import HaloAttention

attn = HaloAttention(
dim = 512,
fmap_size = 32,
block_size = 8,
halo_size = 4
).cuda()

fmap = torch.randn(1, 512, 32, 32).cuda()
attn(fmap) # (1, 512, 32, 32)
```

## Citations

Expand Down
2 changes: 1 addition & 1 deletion halonet_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from halonet_pytorch.halonet_pytorch import HaloNet
from halonet_pytorch.halonet_pytorch import HaloAttention
157 changes: 155 additions & 2 deletions halonet_pytorch/halonet_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,171 @@
def exists(val):
return val is not None

# relative positional embedding

def to(x):
return {'device': x.device, 'dtype': x.dtype}

def pair(x):
return (x, x) if not isinstance(x, tuple) else x

def expand_dim(t, dim, k):
t = t.unsqueeze(dim = dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)

def rel_to_abs(x):
b, l, m = x.shape
r = (m + 1) // 2

col_pad = torch.zeros((b, l, 1), **to(x))
x = torch.cat((x, col_pad), dim = 2)
flat_x = rearrange(x, 'b l c -> b (l c)')
flat_pad = torch.zeros((b, m - l), **to(x))
flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1)
final_x = flat_x_padded.reshape(b, l + 1, m)
final_x = final_x[:, :l, -r:]
return final_x

def relative_logits_1d(q, rel_k):
b, h, w, _ = q.shape
r = (rel_k.shape[0] + 1) // 2

logits = einsum('b x y d, r d -> b x y r', q, rel_k)
logits = rearrange(logits, 'b x y r -> (b x) y r')
logits = rel_to_abs(logits)

logits = logits.reshape(b, h, w, r)
logits = expand_dim(logits, dim = 2, k = r)
return logits

class RelPosEmb(nn.Module):
def __init__(
self,
block_size,
fmap_size,
dim_head
):
super().__init__()
fmap_size = pair(fmap_size)
height, width = fmap_size
scale = dim_head ** -0.5

self.fmap_size = fmap_size
self.block_size = block_size
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)

def forward(self, q):
block = self.block_size

q = rearrange(q, 'b (h w) c -> b h w c', h = block)
rel_logits_w = relative_logits_1d(q, self.rel_width)
rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')

q = rearrange(q, 'b x y d -> b y x d')
rel_logits_h = relative_logits_1d(q, self.rel_height)
rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')
return rel_logits_w + rel_logits_h

# classes

class HaloNet(nn.Module):
class HaloAttention(nn.Module):
def __init__(
self,
*,
dim,
fmap_size,
block_size,
halo_size,
dim_head = 64,
heads = 8
):
super().__init__()
assert fmap_size % block_size == 0, 'feature map height or width must be divisible by block size'
assert halo_size > 0, 'halo size must be greater than 0'

self.dim = dim
self.heads = heads
self.scale = dim_head ** -0.5

self.block_size = block_size
self.halo_size = halo_size

inner_dim = dim_head * heads

self.rel_pos_emb = RelPosEmb(
block_size = block_size,
fmap_size = block_size + (halo_size * 2),
dim_head = dim_head
)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

# prepare a mask for removing attention to padding, cached for performance

mask = torch.ones(1, 1, fmap_size, fmap_size)
mask = F.unfold(mask, kernel_size = block_size + (halo_size * 2), stride = block_size, padding = halo_size)
mask = repeat(mask, 'b j i -> (b i h) j', h = heads)
self.register_buffer('mask', mask == 0)

def forward(self, x):
return x
shape = x.shape
b, c, h, w, block, halo, heads, device = *shape, self.block_size, self.halo_size, self.heads, x.device
assert h == w, 'dimensions of fmap must be same on both sides, for now'
assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'

# get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key values

q_inp = F.unfold(x, kernel_size = block, stride = block)
kv_inp = F.unfold(x, kernel_size = block + halo * 2, stride = block, padding = halo)

q_inp, kv_inp = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = c), (q_inp, kv_inp))

# derive queries, keys, values

q = self.to_q(q_inp)
k, v = self.to_kv(kv_inp).chunk(2, dim = -1)

# split heads

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = heads), (q, k, v))

# scale

q *= self.scale

# attention

sim = einsum('b i d, b j d -> b i j', q, k)

# add relative positional bias

sim += self.rel_pos_emb(q)

# mask out padding (in the paper, they claim to not need masks, but what about padding?)

mask = repeat(self.mask, 'h j -> (b h) () j', b = b)
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(mask, max_neg_value)

# attention

attn = sim.softmax(dim = -1)

# aggregate

out = einsum('b i j, b j d -> b i d', attn, v)

# merge and combine heads

out = rearrange(out, '(b h) n d -> b n (h d)', h = heads)
out = self.to_out(out)

# merge blocks back to original feature map

out = rearrange(out, '(b i) j c -> b c j i', i = (h // block) * (w // block))
return out.reshape(shape)

0 comments on commit 942e002

Please sign in to comment.