Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jdsgomes committed Apr 7, 2022
1 parent f061544 commit 41faba2
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,20 @@ def __init__(
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww

# define a parameter table of relative position bias
relative_position_bias_table = torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) # 2*Wh-1 * 2*Ww-1, nH
relative_position_bias_table = torch.zeros(
(2 * window_size - 1) * (2 * window_size - 1), num_heads
) # 2*Wh-1 * 2*Ww-1, nH
nn.init.trunc_normal_(relative_position_bias_table, std=0.02)

relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(
self.window_size * self.window_size, self.window_size * self.window_size, -1
)
self.relative_position_bias = nn.Parameter(relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0))

def forward(self, x: Tensor):


return shifted_window_attention(
x,
Expand Down

0 comments on commit 41faba2

Please sign in to comment.