Skip to content

Commit

Permalink
move relative_position_bias to __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
jdsgomes committed Apr 7, 2022
1 parent 2500ff3 commit f061544
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,9 @@ def __init__(
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.dropout = dropout

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)

# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size)
coords_w = torch.arange(self.window_size)
Expand All @@ -199,22 +193,25 @@ def __init__(
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

def forward(self, x: Tensor):
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]

# define a parameter table of relative position bias
relative_position_bias_table = torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) # 2*Wh-1 * 2*Ww-1, nH
nn.init.trunc_normal_(relative_position_bias_table, std=0.02)
relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(
self.window_size * self.window_size, self.window_size * self.window_size, -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
self.relative_position_bias = nn.Parameter(relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0))

def forward(self, x: Tensor):


return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
Expand Down

0 comments on commit f061544

Please sign in to comment.