Skip to content

Commit

Permalink
add ability to turn off batchnorm, readying use for GAN
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2020
1 parent cee78ea commit c31594a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions gsa_pytorch/gsa_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def calc_reindexing_tensor(l, L, device):
# classes

class GSA(nn.Module):
def __init__(self, dim, *, rel_pos_length = None, dim_out = None, heads = 8, dim_key = 64, norm_queries = False):
def __init__(self, dim, *, rel_pos_length = None, dim_out = None, heads = 8, dim_key = 64, norm_queries = False, batch_norm = True):
super().__init__()
dim_out = default(dim_out, dim)
dim_hidden = dim_key * heads
Expand All @@ -43,7 +43,7 @@ def __init__(self, dim, *, rel_pos_length = None, dim_out = None, heads = 8, dim
self.rel_pos_length = rel_pos_length
if exists(rel_pos_length):
num_rel_shifts = 2 * rel_pos_length - 1
self.norm = nn.BatchNorm2d(dim_key)
self.norm = nn.BatchNorm2d(dim_key) if batch_norm else None
self.rel_rows = nn.Parameter(torch.randn(num_rel_shifts, dim_key))
self.rel_columns = nn.Parameter(torch.randn(num_rel_shifts, dim_key))

Expand Down Expand Up @@ -71,7 +71,8 @@ def forward(self, img):
Sx = einsum('ndxy,xid->nixy', q, Px)
Yh = einsum('nixy,neiy->nexy', Sx, v)

Yh = self.norm(Yh)
if exists(self.norm):
Yh = self.norm(Yh)

Iy = calc_reindexing_tensor(y, L, device)
Py = einsum('yir,rd->yid', Iy, self.rel_columns)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'gsa-pytorch',
packages = find_packages(),
version = '0.2.1',
version = '0.2.2',
license='MIT',
description = 'Global Self-attention Network (GSA) - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c31594a

Please sign in to comment.