diff --git a/gsa_pytorch/gsa_pytorch.py b/gsa_pytorch/gsa_pytorch.py index 55a265c..a0a1b6d 100644 --- a/gsa_pytorch/gsa_pytorch.py +++ b/gsa_pytorch/gsa_pytorch.py @@ -27,7 +27,7 @@ def calc_reindexing_tensor(l, L, device): # classes class GSA(nn.Module): - def __init__(self, dim, *, rel_pos_length = None, dim_out = None, heads = 8, dim_key = 64, norm_queries = False): + def __init__(self, dim, *, rel_pos_length = None, dim_out = None, heads = 8, dim_key = 64, norm_queries = False, batch_norm = True): super().__init__() dim_out = default(dim_out, dim) dim_hidden = dim_key * heads @@ -43,7 +43,7 @@ def __init__(self, dim, *, rel_pos_length = None, dim_out = None, heads = 8, dim self.rel_pos_length = rel_pos_length if exists(rel_pos_length): num_rel_shifts = 2 * rel_pos_length - 1 - self.norm = nn.BatchNorm2d(dim_key) + self.norm = nn.BatchNorm2d(dim_key) if batch_norm else None self.rel_rows = nn.Parameter(torch.randn(num_rel_shifts, dim_key)) self.rel_columns = nn.Parameter(torch.randn(num_rel_shifts, dim_key)) @@ -71,7 +71,8 @@ def forward(self, img): Sx = einsum('ndxy,xid->nixy', q, Px) Yh = einsum('nixy,neiy->nexy', Sx, v) - Yh = self.norm(Yh) + if exists(self.norm): + Yh = self.norm(Yh) Iy = calc_reindexing_tensor(y, L, device) Py = einsum('yir,rd->yid', Iy, self.rel_columns) diff --git a/setup.py b/setup.py index ff5bb18..aa5e84e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'gsa-pytorch', packages = find_packages(), - version = '0.2.1', + version = '0.2.2', license='MIT', description = 'Global Self-attention Network (GSA) - Pytorch', author = 'Phil Wang',