From 533d2c04078d0f6512d01289e8a11a33500eee64 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 27 Feb 2022 22:31:42 +0800 Subject: [PATCH 01/92] add swin transformer --- torchvision/models/swin_transformer.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 torchvision/models/swin_transformer.py diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py new file mode 100644 index 00000000000..aaa328524fc --- /dev/null +++ b/torchvision/models/swin_transformer.py @@ -0,0 +1 @@ +import torch From 311751eba57a7a9dd189e431dddb383b43de299f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 6 Mar 2022 15:58:25 +0800 Subject: [PATCH 02/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 466 +++++++++++++++++++++++++ 1 file changed, 466 insertions(+) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index aaa328524fc..3bca302c8b6 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1 +1,467 @@ +from typing import Tuple, Optional, Callable, List, Any +from functools import partial + import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from torchvision.ops import StochasticDepth +from .._internally_replaced_utils import load_state_dict_from_url + + +__all__ = [ + "SwinTransformer", + "swin_tiny", +] + + +_MODELS_URLS = { + "swin_tiny": "" +} + + +class MLPBlock(nn.Sequential): + """Transformer MLP block.""" + + def __init__(self, in_dim: int, mlp_dim: int, dropout: float, + act_layer: Callable[..., nn.Module] = nn.GELU): + super().__init__() + self.linear_1 = nn.Linear(in_dim, mlp_dim) + self.act = act_layer() + self.dropout_1 = nn.Dropout(dropout) + self.linear_2 = nn.Linear(mlp_dim, in_dim) + self.dropout_2 = nn.Dropout(dropout) + + +class Permute(nn.Module): + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.permute(x, self.dims) + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: Tensor): + B, H, W, C = x.shape + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + x = x.view(B, H // 2, W // 2, 2 * C) + + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (int)): The size of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True. + attention_dropout (float, optional): Dropout ratio of attention weight. Default: 0.0. + dropout (float, optional): Dropout ratio of output. Default: 0.0. + """ + + def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, + attention_dropout: float = 0., dropout: float = 0.): + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size) + coords_w = torch.arange(self.window_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size - 1 + relative_coords[:, :, 0] *= 2 * self.window_size - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attention_dropout = nn.Dropout(attention_dropout) + self.proj = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + self.softmax = nn.Softmax(dim=-1) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self, x: Tensor, mask: Tensor = None): + """ + Args: + x (Tensor): input features with shape of (num_windows*B, N, C) + mask (Tensor): (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size * self.window_size, self.window_size * self.window_size, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attention_dropout(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.dropout(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True. + dropout (float, optional): Dropout rate. Default: 0.0 + attention_dropout (float, optional): Attention dropout rate. Default: 0.0 + stochastic_depth_prob: (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim: int, num_heads: int, window_size: int = 7, shift_size: int = 0, + mlp_ratio: float = 4., qkv_bias: bool = True, dropout: float = 0., + attention_dropout: float = 0., + stochastic_depth_prob: float = 0., + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size, num_heads, + qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout) + + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.norm2 = norm_layer(dim) + + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) + + + def generate_attention_mask(self, shift_size: int, H: int, W: int, device): + if shift_size > 0: + # calculate attention mask for SW-MSA + mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = self.partition_window(mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def forward(self, x: Tensor): + B, H, W, C = x.shape + shortcut = x + x = self.norm1(x) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + if self.window_size == min(Hp, Wp): + shift_size = 0 + else: + shift_size = self.shift_size + + # cyclic shift + if shift_size > 0: + shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = self.partition_window(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_mask = self.generate_attention_mask(shift_size, Hp, Wp, x.device) + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = self.reverse_window(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if shift_size > 0: + x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) + else: + x = shifted_x + + x = x.view(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # FFN + x = shortcut + self.stochastic_depth(x) + x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + return x + + def partition_window(self, x: Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return x + + def reverse_window(self, x: Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(x.shape[0] / (H * W / window_size / window_size)) + x = x.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class SwinTransformer(nn.Module): + """ + Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using + Shifted Windows" `_ paper. + + Args: + patch_size (int): Patch size. Default: 4. + num_classes (int): Number of classes for classification head. Default: 1000. + embed_dim (int): Patch embedding dimension. Default: 96. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. + dropout (float): Dropout rate. Default: 0. + attention_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.1. + block (nn.Module): SwinTransformer Block. + norm_layer (nn.Module): Normalization layer. + """ + + def __init__(self, + patch_size: int = 4, + num_classes: int = 1000, + embed_dim: int = 96, + depths: List[int] = [2, 2, 6, 2], + num_heads: List[int] = [3, 6, 12, 24], + window_size: int = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0., + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__() + + self.num_classes = num_classes + + if block is None: + block = SwinTransformerBlock + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-5) + + layers: List[nn.Module] = [] + # split image into non-overlapping patches + layers.append(nn.Sequential( + nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size), + Permute([0, 2, 3, 1]), + norm_layer(embed_dim), + )) + + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + stage: List[nn.Module] = [] + dim = embed_dim * 2 ** i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + stage.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=0 if i_layer % 2 == 0 else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer + ) + ) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append( + #nn.Sequential( + # norm_layer(dim), + # Permute([0, 3, 1, 2]), + # nn.Conv2d(dim, dim * 2, kernel_size=2, stride=2, bias=False), + # Permute([0, 2, 3, 1]), + #) + PatchMerging(dim, norm_layer) + ) + + self.features = nn.Sequential(*layers) + + num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(num_features) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.head = nn.Linear(num_features, num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.features(x) + x = self.norm(x) + x = x.permute(0, 3, 1, 2) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.head(x) + return x + + +def _swin_transformer( + arch: str, + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: int, + stochastic_depth_prob: float, + pretrained: bool, + progress: bool, + **kwargs: Any, +) -> SwinTransformer: + model = SwinTransformer( + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + **kwargs + ) + if pretrained: + if arch not in _MODELS_URLS: + raise ValueError(f"No checkpoint is available for model type {arch}") + state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_tiny architecture from + `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _swin_transformer( + arch="swin_tiny", + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + stochastic_depth_prob=0.2, + pretrained=pretrained, + progress=progress, + **kwargs + ) From d478852454fbeab21be2747de17a901c4e830d8f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 6 Mar 2022 16:18:23 +0800 Subject: [PATCH 03/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 144 ++++++++++++++----------- 1 file changed, 79 insertions(+), 65 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 3bca302c8b6..5c148b9362c 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,11 +1,12 @@ -from typing import Tuple, Optional, Callable, List, Any from functools import partial +from typing import Tuple, Optional, Callable, List, Any import torch -from torch import nn, Tensor import torch.nn.functional as F +from torch import nn, Tensor from torchvision.ops import StochasticDepth + from .._internally_replaced_utils import load_state_dict_from_url @@ -16,15 +17,15 @@ _MODELS_URLS = { - "swin_tiny": "" + "swin_tiny": "", + "swin_base": "", } class MLPBlock(nn.Sequential): """Transformer MLP block.""" - def __init__(self, in_dim: int, mlp_dim: int, dropout: float, - act_layer: Callable[..., nn.Module] = nn.GELU): + def __init__(self, in_dim: int, mlp_dim: int, dropout: float, act_layer: Callable[..., nn.Module] = nn.GELU): super().__init__() self.linear_1 = nn.Linear(in_dim, mlp_dim) self.act = act_layer() @@ -44,7 +45,7 @@ def forward(self, x): class PatchMerging(nn.Module): """Patch Merging Layer. - + Args: dim (int): Number of input channels. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm @@ -75,9 +76,9 @@ def forward(self, x: Tensor): class WindowAttention(nn.Module): - """ Window based multi-head self attention (W-MSA) module with relative position bias. + """Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. - + Args: dim (int): Number of input channels. window_size (int)): The size of the window. @@ -87,8 +88,15 @@ class WindowAttention(nn.Module): dropout (float, optional): Dropout ratio of output. Default: 0.0. """ - def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, - attention_dropout: float = 0., dropout: float = 0.): + def __init__( + self, + dim: int, + window_size: int, + num_heads: int, + qkv_bias: bool = True, + attention_dropout: float = 0., + dropout: float = 0. + ): super().__init__() self.dim = dim self.window_size = window_size @@ -120,7 +128,7 @@ def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = self.dropout = nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) - nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor, mask: Tensor = None): """ @@ -133,10 +141,11 @@ def forward(self, x: Tensor, mask: Tensor = None): q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size * self.window_size, self.window_size * self.window_size, -1) # Wh*Ww,Wh*Ww,nH + self.window_size * self.window_size, self.window_size * self.window_size, -1 + ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) @@ -157,8 +166,8 @@ def forward(self, x: Tensor, mask: Tensor = None): class SwinTransformerBlock(nn.Module): - """ Swin Transformer Block. - + """Swin Transformer Block. + Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. @@ -173,12 +182,20 @@ class SwinTransformerBlock(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, dim: int, num_heads: int, window_size: int = 7, shift_size: int = 0, - mlp_ratio: float = 4., qkv_bias: bool = True, dropout: float = 0., - attention_dropout: float = 0., - stochastic_depth_prob: float = 0., - act_layer: Callable[..., nn.Module] = nn.GELU, - norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + def __init__( + self, + dim: int, + num_heads: int, + window_size: int = 7, + shift_size: int = 0, + mlp_ratio: float = 4., + qkv_bias: bool = True, + dropout: float = 0., + attention_dropout: float = 0., + stochastic_depth_prob: float = 0., + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm + ): super().__init__() self.dim = dim self.num_heads = num_heads @@ -188,25 +205,20 @@ def __init__(self, dim: int, num_heads: int, window_size: int = 7, shift_size: i self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size, num_heads, - qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout) + dim, window_size, num_heads, qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout + ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) - - + def generate_attention_mask(self, shift_size: int, H: int, W: int, device): if shift_size > 0: # calculate attention mask for SW-MSA mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -shift_size), - slice(-shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -shift_size), - slice(-shift_size, None)) + h_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) + w_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: @@ -225,14 +237,14 @@ def forward(self, x: Tensor): B, H, W, C = x.shape shortcut = x x = self.norm1(x) - + # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape - + if self.window_size == min(Hp, Wp): shift_size = 0 else: @@ -261,7 +273,7 @@ def forward(self, x: Tensor): x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) else: x = shifted_x - + x = x.view(B, Hp, Wp, C) if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() @@ -283,7 +295,7 @@ def partition_window(self, x: Tensor, window_size: int): x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return x - + def reverse_window(self, x: Tensor, window_size: int, H: int, W: int): """ Args: @@ -321,25 +333,26 @@ class SwinTransformer(nn.Module): norm_layer (nn.Module): Normalization layer. """ - def __init__(self, - patch_size: int = 4, - num_classes: int = 1000, - embed_dim: int = 96, - depths: List[int] = [2, 2, 6, 2], - num_heads: List[int] = [3, 6, 12, 24], - window_size: int = 7, - mlp_ratio: float = 4., - qkv_bias: bool = True, - dropout: float = 0.0, - attention_dropout: float = 0.0, - stochastic_depth_prob: float = 0., - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, + def __init__( + self, + patch_size: int = 4, + num_classes: int = 1000, + embed_dim: int = 96, + depths: List[int] = [2, 2, 6, 2], + num_heads: List[int] = [3, 6, 12, 24], + window_size: int = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + dropout: float = 0., + attention_dropout: float = 0., + stochastic_depth_prob: float = 0., + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, ): super().__init__() self.num_classes = num_classes - + if block is None: block = SwinTransformerBlock @@ -348,10 +361,11 @@ def __init__(self, layers: List[nn.Module] = [] # split image into non-overlapping patches - layers.append(nn.Sequential( - nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size), - Permute([0, 2, 3, 1]), - norm_layer(embed_dim), + layers.append( + nn.Sequential( + nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size), + Permute([0, 2, 3, 1]), + norm_layer(embed_dim), )) total_stage_blocks = sum(depths) @@ -374,7 +388,7 @@ def __init__(self, dropout=dropout, attention_dropout=attention_dropout, stochastic_depth_prob=sd_prob, - norm_layer=norm_layer + norm_layer=norm_layer, ) ) stage_block_id += 1 @@ -382,17 +396,17 @@ def __init__(self, # add patch merging layer if i_stage < (len(depths) - 1): layers.append( - #nn.Sequential( - # norm_layer(dim), - # Permute([0, 3, 1, 2]), - # nn.Conv2d(dim, dim * 2, kernel_size=2, stride=2, bias=False), - # Permute([0, 2, 3, 1]), - #) + # nn.Sequential( + # norm_layer(dim), + # Permute([0, 3, 1, 2]), + # nn.Conv2d(dim, dim * 2, kernel_size=2, stride=2, bias=False), + # Permute([0, 2, 3, 1]), + # ) PatchMerging(dim, norm_layer) ) self.features = nn.Sequential(*layers) - + num_features = embed_dim * 2 ** (len(depths) - 1) self.norm = norm_layer(num_features) self.avgpool = nn.AdaptiveAvgPool2d(1) @@ -402,7 +416,7 @@ def __init__(self, def _init_weights(self, m): if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=.02) + nn.init.trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -436,7 +450,7 @@ def _swin_transformer( num_heads=num_heads, window_size=window_size, stochastic_depth_prob=stochastic_depth_prob, - **kwargs + **kwargs, ) if pretrained: if arch not in _MODELS_URLS: @@ -463,5 +477,5 @@ def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> stochastic_depth_prob=0.2, pretrained=pretrained, progress=progress, - **kwargs + **kwargs, ) From 92a1cf53d6dc6c02560209823e169c139d4b70b4 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 6 Mar 2022 16:52:18 +0800 Subject: [PATCH 04/92] fix lint --- torchvision/models/swin_transformer.py | 43 +++++++++++++------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 5c148b9362c..5661e5b5c6e 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -94,8 +94,8 @@ def __init__( window_size: int, num_heads: int, qkv_bias: bool = True, - attention_dropout: float = 0., - dropout: float = 0. + attention_dropout: float = 0.0, + dropout: float = 0.0, ): super().__init__() self.dim = dim @@ -127,7 +127,7 @@ def __init__( self.proj = nn.Linear(dim, dim) self.dropout = nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) - + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor, mask: Tensor = None): @@ -145,18 +145,16 @@ def forward(self, x: Tensor, mask: Tensor = None): relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size * self.window_size, self.window_size * self.window_size, -1 - ) # Wh*Ww,Wh*Ww,nH + ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: - nW = mask.shape[0] - attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + num_windows = mask.shape[0] + attn = attn.view(B // num_windows, num_windows, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) + attn = self.softmax(attn) attn = self.attention_dropout(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) @@ -188,13 +186,13 @@ def __init__( num_heads: int, window_size: int = 7, shift_size: int = 0, - mlp_ratio: float = 4., + mlp_ratio: float = 4.0, qkv_bias: bool = True, - dropout: float = 0., - attention_dropout: float = 0., - stochastic_depth_prob: float = 0., + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, - norm_layer: Callable[..., nn.Module] = nn.LayerNorm + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() self.dim = dim @@ -210,10 +208,10 @@ def __init__( self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) - def generate_attention_mask(self, shift_size: int, H: int, W: int, device): + def generate_attention_mask(self, shift_size: int, H: int, W: int, device: torch.device): if shift_size > 0: # calculate attention mask for SW-MSA mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1 @@ -341,11 +339,11 @@ def __init__( depths: List[int] = [2, 2, 6, 2], num_heads: List[int] = [3, 6, 12, 24], window_size: int = 7, - mlp_ratio: float = 4., + mlp_ratio: float = 4.0, qkv_bias: bool = True, - dropout: float = 0., - attention_dropout: float = 0., - stochastic_depth_prob: float = 0., + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ): @@ -366,7 +364,8 @@ def __init__( nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size), Permute([0, 2, 3, 1]), norm_layer(embed_dim), - )) + ) + ) total_stage_blocks = sum(depths) stage_block_id = 0 @@ -404,7 +403,7 @@ def __init__( # ) PatchMerging(dim, norm_layer) ) - + self.features = nn.Sequential(*layers) num_features = embed_dim * 2 ** (len(depths) - 1) From 8ac8077a33fa4941bf828ed4238e3bde0b204d52 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 6 Mar 2022 16:56:26 +0800 Subject: [PATCH 05/92] fix lint --- torchvision/models/swin_transformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 5661e5b5c6e..3bb19298596 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -5,8 +5,7 @@ import torch.nn.functional as F from torch import nn, Tensor -from torchvision.ops import StochasticDepth - +from ..ops.stochastic_depth import StochasticDepth from .._internally_replaced_utils import load_state_dict_from_url From c4445a789f87b6b7b83a4947aa14193221876e80 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 6 Mar 2022 19:27:00 +0800 Subject: [PATCH 06/92] refactor code --- torchvision/models/swin_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 3bb19298596..5f7aba89182 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..ops.stochastic_depth import StochasticDepth from .._internally_replaced_utils import load_state_dict_from_url +from ..ops.stochastic_depth import StochasticDepth __all__ = [ @@ -143,7 +143,7 @@ def forward(self, x: Tensor, mask: Tensor = None): attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size * self.window_size, self.window_size * self.window_size, -1 + self.window_size ** 2, self.window_size ** 2, -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) @@ -223,7 +223,7 @@ def generate_attention_mask(self, shift_size: int, H: int, W: int, device: torch cnt += 1 mask_windows = self.partition_window(mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + mask_windows = mask_windows.view(-1, self.window_size ** 2) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: @@ -255,7 +255,7 @@ def forward(self, x: Tensor): # partition windows x_windows = self.partition_window(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = x_windows.view(-1, self.window_size ** 2, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_mask = self.generate_attention_mask(shift_size, Hp, Wp, x.device) From 97e22d787a8a9ead50bf18595935954d51e9befb Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 6 Mar 2022 19:27:34 +0800 Subject: [PATCH 07/92] add swin_transformer --- torchvision/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 16495e8552e..93cd782d0a1 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -12,6 +12,7 @@ from .efficientnet import * from .regnet import * from .vision_transformer import * +from .swin_transformer import * from . import detection from . import feature_extraction from . import optical_flow From 8599a4b92590c967a686027b82adaa16e0dcabf6 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 6 Mar 2022 19:49:20 +0800 Subject: [PATCH 08/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 70 +++++++++++--------------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 5f7aba89182..63e7d18fc02 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Tuple, Optional, Callable, List, Any +from typing import Optional, Callable, List, Any import torch import torch.nn.functional as F @@ -47,7 +47,7 @@ class PatchMerging(nn.Module): Args: dim (int): Number of input channels. - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): @@ -58,7 +58,7 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm def forward(self, x: Tensor): B, H, W, C = x.shape - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + assert H % 2 == 0 and W % 2 == 0, f"input size ({H}*{W}) are not even." x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C @@ -82,9 +82,9 @@ class WindowAttention(nn.Module): dim (int): Number of input channels. window_size (int)): The size of the window. num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True. - attention_dropout (float, optional): Dropout ratio of attention weight. Default: 0.0. - dropout (float, optional): Dropout ratio of output. Default: 0.0. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. """ def __init__( @@ -132,8 +132,8 @@ def __init__( def forward(self, x: Tensor, mask: Tensor = None): """ Args: - x (Tensor): input features with shape of (num_windows*B, N, C) - mask (Tensor): (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + x (Tensor): input features with shape of (num_windows*B, N, C). + mask (Tensor): (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None. """ B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) @@ -171,12 +171,12 @@ class SwinTransformerBlock(nn.Module): window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True. - dropout (float, optional): Dropout rate. Default: 0.0 - attention_dropout (float, optional): Attention dropout rate. Default: 0.0 - stochastic_depth_prob: (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + act_layer (nn.Module): Activation layer. Default: nn.GELU. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """ def __init__( @@ -210,19 +210,19 @@ def __init__( self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) - def generate_attention_mask(self, shift_size: int, H: int, W: int, device: torch.device): + def generate_attention_mask(self, shift_size: int, height: int, width: int, device: torch.device): if shift_size > 0: # calculate attention mask for SW-MSA - mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1 + mask = torch.zeros((1, height, width, 1), device=device) h_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) - cnt = 0 + count = 0 for h in h_slices: for w in w_slices: - mask[:, h, w, :] = cnt - cnt += 1 + mask[:, h, w, :] = count + count += 1 - mask_windows = self.partition_window(mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = self.partition_window(mask) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size ** 2) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) @@ -254,7 +254,7 @@ def forward(self, x: Tensor): shifted_x = x # partition windows - x_windows = self.partition_window(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = self.partition_window(shifted_x) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size ** 2, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA @@ -280,32 +280,22 @@ def forward(self, x: Tensor): x = x + self.stochastic_depth(self.mlp(self.norm2(x))) return x - def partition_window(self, x: Tensor, window_size: int): + def partition_window(self, x: Tensor): """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) + Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size, window_size, C). """ B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + x = x.view(B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size, self.window_size, C) return x - def reverse_window(self, x: Tensor, window_size: int, H: int, W: int): + def reverse_window(self, x: Tensor, height: int, width: int): """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) + The Inverse operation of `partition_window`. """ - B = int(x.shape[0] / (H * W / window_size / window_size)) - x = x.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + B = int(x.shape[0] / (height * width / (self.window_size ** 2))) + x = x.view(B, height // self.window_size, width // self.window_size, self.window_size, self.window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, height, width, -1) return x From c378934cccb00f03bb08829e611f231cb4b01878 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 6 Mar 2022 20:01:31 +0800 Subject: [PATCH 09/92] fix bug --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 63e7d18fc02..9751264f652 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -263,7 +263,7 @@ def forward(self, x: Tensor): # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = self.reverse_window(attn_windows, self.window_size, Hp, Wp) # B H' W' C + shifted_x = self.reverse_window(attn_windows, Hp, Wp) # B H' W' C # reverse cyclic shift if shift_size > 0: From c8e8fe2ba0e0c20f47dac1edf81f103143c678d5 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 14:34:43 +0800 Subject: [PATCH 10/92] refactor code --- torchvision/models/swin_transformer.py | 29 ++++---------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 9751264f652..4994f2df772 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -7,6 +7,8 @@ from .._internally_replaced_utils import load_state_dict_from_url from ..ops.stochastic_depth import StochasticDepth +from .vision_transformer import MLPBlock +from .convnext import Permute __all__ = [ @@ -21,27 +23,6 @@ } -class MLPBlock(nn.Sequential): - """Transformer MLP block.""" - - def __init__(self, in_dim: int, mlp_dim: int, dropout: float, act_layer: Callable[..., nn.Module] = nn.GELU): - super().__init__() - self.linear_1 = nn.Linear(in_dim, mlp_dim) - self.act = act_layer() - self.dropout_1 = nn.Dropout(dropout) - self.linear_2 = nn.Linear(mlp_dim, in_dim) - self.dropout_2 = nn.Dropout(dropout) - - -class Permute(nn.Module): - def __init__(self, dims: List[int]): - super().__init__() - self.dims = dims - - def forward(self, x): - return torch.permute(x, self.dims) - - class PatchMerging(nn.Module): """Patch Merging Layer. @@ -111,7 +92,7 @@ def __init__( # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size) coords_w = torch.arange(self.window_size) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij')) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 @@ -175,7 +156,6 @@ class SwinTransformerBlock(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. - act_layer (nn.Module): Activation layer. Default: nn.GELU. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """ @@ -190,7 +170,6 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, - act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() @@ -208,7 +187,7 @@ def __init__( self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) def generate_attention_mask(self, shift_size: int, height: int, width: int, device: torch.device): if shift_size > 0: From 45bbbfcb2740bc570ed5850c382dd9ca2bb16bf1 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 15:54:08 +0800 Subject: [PATCH 11/92] fix lint --- torchvision/models/swin_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 4994f2df772..919cb7634a4 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -7,8 +7,8 @@ from .._internally_replaced_utils import load_state_dict_from_url from ..ops.stochastic_depth import StochasticDepth -from .vision_transformer import MLPBlock from .convnext import Permute +from .vision_transformer import MLPBlock __all__ = [ @@ -92,7 +92,7 @@ def __init__( # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size) coords_w = torch.arange(self.window_size) - coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij')) # 2, Wh, Ww + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 @@ -240,7 +240,7 @@ def forward(self, x: Tensor): attn_mask = self.generate_attention_mask(shift_size, Hp, Wp, x.device) attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C - # merge windows + # reverse windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = self.reverse_window(attn_windows, Hp, Wp) # B H' W' C From ebae8b131fbe72ef285da54b73103e2fa16f72ed Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 16:44:33 +0800 Subject: [PATCH 12/92] update init_weights --- torchvision/models/swin_transformer.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 919cb7634a4..e2c589eb772 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -379,16 +379,14 @@ def __init__( self.avgpool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(num_features, num_classes) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.weight, 1.0) def forward(self, x): x = self.features(x) From 0e76444565a79ea71b74920c1e050735750afbf7 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 20:45:39 +0800 Subject: [PATCH 13/92] move shift_window into attention --- torchvision/models/swin_transformer.py | 208 +++++++++++-------------- 1 file changed, 95 insertions(+), 113 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e2c589eb772..fdd53e36851 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -55,13 +55,13 @@ def forward(self, x: Tensor): return x -class WindowAttention(nn.Module): +class ShiftedWindowAttention(nn.Module): """Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. - Args: dim (int): Number of input channels. window_size (int)): The size of the window. + shift_size (int): Shift size for SW-MSA. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. @@ -72,6 +72,7 @@ def __init__( self, dim: int, window_size: int, + shift_size: int, num_heads: int, qkv_bias: bool = True, attention_dropout: float = 0.0, @@ -80,9 +81,9 @@ def __init__( super().__init__() self.dim = dim self.window_size = window_size + self.shift_size = shift_size self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 + self.scale = (dim // num_heads) ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( @@ -99,7 +100,7 @@ def __init__( relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -110,42 +111,109 @@ def __init__( nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) - def forward(self, x: Tensor, mask: Tensor = None): - """ - Args: - x (Tensor): input features with shape of (num_windows*B, N, C). - mask (Tensor): (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None. - """ - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] + def forward(self, x: Tensor): + B, H, W, C = x.shape + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + if self.window_size == min(Hp, Wp): + shift_size = 0 + else: + shift_size = self.shift_size + + # cyclic shift + if shift_size > 0: + x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + # partition windows + x = self.partition_window(x) # nW*B, window_size, window_size, C + x = x.view(-1, self.window_size ** 2, C) # nW*B, window_size*window_size, C + + # multi-head attention + attn_mask = self.generate_attention_mask(shift_size, Hp, Wp, x.device) + qkv = self.qkv(x).reshape(x.size(0), x.size(1), 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size ** 2, self.window_size ** 2, -1 ) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) - if mask is not None: - num_windows = mask.shape[0] - attn = attn.view(B // num_windows, num_windows, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) + if attn_mask is not None: + num_windows = attn_mask.shape[0] + attn = attn.view(x.size(0) // num_windows, num_windows, self.num_heads, x.size(1), x.size(1)) + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, x.size(1), x.size(1)) attn = self.softmax(attn) attn = self.attention_dropout(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = self.proj(x) x = self.dropout(x) + + # reverse windows + x = x.view(-1, self.window_size, self.window_size, C) + x = self.reverse_window(x, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if shift_size > 0: + x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2)) + + x = x.view(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + return x + + @torch.no_grad() + def generate_attention_mask(self, shift_size: int, height: int, width: int, device: torch.device): + if shift_size > 0: + # calculate attention mask for SW-MSA + mask = torch.zeros((1, height, width, 1), device=device) + h_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) + w_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) + count = 0 + for h in h_slices: + for w in w_slices: + mask[:, h, w, :] = count + count += 1 + + mask_windows = self.partition_window(mask) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size ** 2) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def partition_window(self, x: Tensor): + """ + Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size, window_size, C). + """ + B, H, W, C = x.shape + x = x.view(B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, self.window_size, self.window_size, C) + return x + + def reverse_window(self, x: Tensor, height: int, width: int): + """ + The Inverse operation of `partition_window`. + """ + B = int(x.shape[0] / (height * width / (self.window_size ** 2))) + x = x.view(B, height // self.window_size, width // self.window_size, self.window_size, self.window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, height, width, -1) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block. - Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. @@ -156,6 +224,7 @@ class SwinTransformerBlock(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + act_layer (nn.Module): Activation layer. Default: nn.GELU. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """ @@ -170,113 +239,26 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() - self.dim = dim - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size, num_heads, qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout + self.attn = ShiftedWindowAttention( + dim, window_size, shift_size, num_heads, qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) - - def generate_attention_mask(self, shift_size: int, height: int, width: int, device: torch.device): - if shift_size > 0: - # calculate attention mask for SW-MSA - mask = torch.zeros((1, height, width, 1), device=device) - h_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) - w_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) - count = 0 - for h in h_slices: - for w in w_slices: - mask[:, h, w, :] = count - count += 1 - - mask_windows = self.partition_window(mask) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size ** 2) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - return attn_mask + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) def forward(self, x: Tensor): - B, H, W, C = x.shape - shortcut = x - x = self.norm1(x) - - # pad feature maps to multiples of window size - pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape - - if self.window_size == min(Hp, Wp): - shift_size = 0 - else: - shift_size = self.shift_size - - # cyclic shift - if shift_size > 0: - shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = self.partition_window(shifted_x) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size ** 2, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA - attn_mask = self.generate_attention_mask(shift_size, Hp, Wp, x.device) - attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C - - # reverse windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = self.reverse_window(attn_windows, Hp, Wp) # B H' W' C - - # reverse cyclic shift - if shift_size > 0: - x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) - else: - x = shifted_x - - x = x.view(B, Hp, Wp, C) - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - - # FFN - x = shortcut + self.stochastic_depth(x) + x = x + self.stochastic_depth(self.attn(self.norm1(x))) x = x + self.stochastic_depth(self.mlp(self.norm2(x))) return x - def partition_window(self, x: Tensor): - """ - Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size, window_size, C). - """ - B, H, W, C = x.shape - x = x.view(B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size, self.window_size, C) - return x - - def reverse_window(self, x: Tensor, height: int, width: int): - """ - The Inverse operation of `partition_window`. - """ - B = int(x.shape[0] / (height * width / (self.window_size ** 2))) - x = x.view(B, height // self.window_size, width // self.window_size, self.window_size, self.window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, height, width, -1) - return x - class SwinTransformer(nn.Module): """ From 9a953c3106f3bfee242d90d74a9e6d0eeb480c1f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 21:28:12 +0800 Subject: [PATCH 14/92] refactor code --- torchvision/models/swin_transformer.py | 35 +++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index fdd53e36851..aabc3e2ab6e 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -131,8 +131,8 @@ def forward(self, x: Tensor): # partition windows x = self.partition_window(x) # nW*B, window_size, window_size, C - x = x.view(-1, self.window_size ** 2, C) # nW*B, window_size*window_size, C - + x = x.view(-1, int(self.window_size ** 2), C) # nW*B, window_size*window_size, C + # multi-head attention attn_mask = self.generate_attention_mask(shift_size, Hp, Wp, x.device) qkv = self.qkv(x).reshape(x.size(0), x.size(1), 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) @@ -141,14 +141,16 @@ def forward(self, x: Tensor): attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size ** 2, self.window_size ** 2, -1 + int(self.window_size ** 2), int(self.window_size ** 2), -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) if attn_mask is not None: num_windows = attn_mask.shape[0] - attn = attn.view(x.size(0) // num_windows, num_windows, self.num_heads, x.size(1), x.size(1)) + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view( + x.size(0) // num_windows, num_windows, self.num_heads, x.size(1), x.size(1) + ) + attn_mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, x.size(1), x.size(1)) attn = self.softmax(attn) @@ -157,7 +159,7 @@ def forward(self, x: Tensor): x = (attn @ v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = self.proj(x) x = self.dropout(x) - + # reverse windows x = x.view(-1, self.window_size, self.window_size, C) x = self.reverse_window(x, Hp, Wp) # B H' W' C @@ -169,7 +171,7 @@ def forward(self, x: Tensor): x = x.view(B, Hp, Wp, C) if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() - + return x @torch.no_grad() @@ -177,22 +179,21 @@ def generate_attention_mask(self, shift_size: int, height: int, width: int, devi if shift_size > 0: # calculate attention mask for SW-MSA mask = torch.zeros((1, height, width, 1), device=device) - h_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) - w_slices = (slice(0, -self.window_size), slice(-self.window_size, -shift_size), slice(-shift_size, None)) + slices = ((0, -self.window_size), (-self.window_size, -shift_size), (-shift_size, None)) count = 0 - for h in h_slices: - for w in w_slices: - mask[:, h, w, :] = count + for h in slices: + for w in slices: + mask[:, h[0]:h[1], w[0]:w[1], :] = count count += 1 mask_windows = self.partition_window(mask) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size ** 2) + mask_windows = mask_windows.view(-1, int(self.window_size ** 2)) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None return attn_mask - + def partition_window(self, x: Tensor): """ Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size, window_size, C). @@ -246,7 +247,13 @@ def __init__( self.norm1 = norm_layer(dim) self.attn = ShiftedWindowAttention( - dim, window_size, shift_size, num_heads, qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout + dim, + window_size, + shift_size, + num_heads, + qkv_bias=qkv_bias, + attention_dropout=attention_dropout, + dropout=dropout, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") From b9321c75fb7fc2715ab97ac2035e98a8100bd90b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 21:38:30 +0800 Subject: [PATCH 15/92] fix bug --- torchvision/models/swin_transformer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index aabc3e2ab6e..0a7c26734f0 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -173,7 +173,7 @@ def forward(self, x: Tensor): x = x[:, :H, :W, :].contiguous() return x - + @torch.no_grad() def generate_attention_mask(self, shift_size: int, height: int, width: int, device: torch.device): if shift_size > 0: @@ -183,7 +183,7 @@ def generate_attention_mask(self, shift_size: int, height: int, width: int, devi count = 0 for h in slices: for w in slices: - mask[:, h[0]:h[1], w[0]:w[1], :] = count + mask[:, h[0] : h[1], w[0] : w[1], :] = count count += 1 mask_windows = self.partition_window(mask) # nW, window_size, window_size, 1 @@ -225,7 +225,6 @@ class SwinTransformerBlock(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. - act_layer (nn.Module): Activation layer. Default: nn.GELU. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """ @@ -240,7 +239,6 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, - act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() @@ -259,7 +257,7 @@ def __init__( self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) def forward(self, x: Tensor): x = x + self.stochastic_depth(self.attn(self.norm1(x))) From f33d1cd10c5dcde6bfcfac8f40d08a4d6e940582 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 23:13:31 +0800 Subject: [PATCH 16/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 125 +++++++++---------------- 1 file changed, 43 insertions(+), 82 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 0a7c26734f0..6df4645887f 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -23,36 +23,19 @@ } -class PatchMerging(nn.Module): - """Patch Merging Layer. +def generate_attention_mask(height: int, width: int, window_size: int, shift_size: int, device: torch.device): + """Generate shifted window attention mask""" + mask = torch.zeros((1, height, width, 1), device=device) + slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) + count = 0 + for h in slices: + for w in slices: + mask[:, h[0] : h[1], w[0] : w[1], :] = count + count += 1 + return mask + - Args: - dim (int): Number of input channels. - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - """ - - def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x: Tensor): - B, H, W, C = x.shape - assert H % 2 == 0 and W % 2 == 0, f"input size ({H}*{W}) are not even." - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - x = x.view(B, H // 2, W // 2, 2 * C) - - return x +torch.fx.wrap("generate_attention_mask") class ShiftedWindowAttention(nn.Module): @@ -85,6 +68,12 @@ def __init__( self.num_heads = num_heads self.scale = (dim // num_heads) ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attention_dropout = nn.Dropout(attention_dropout) + self.proj = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + self.softmax = nn.Softmax(dim=-1) + # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) @@ -103,12 +92,6 @@ def __init__( relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attention_dropout = nn.Dropout(attention_dropout) - self.proj = nn.Linear(dim, dim) - self.dropout = nn.Dropout(dropout) - self.softmax = nn.Softmax(dim=-1) - nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor): @@ -120,21 +103,15 @@ def forward(self, x: Tensor): x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape - if self.window_size == min(Hp, Wp): - shift_size = 0 - else: - shift_size = self.shift_size - # cyclic shift - if shift_size > 0: - x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) + if self.shift_size > 0: + x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # partition windows x = self.partition_window(x) # nW*B, window_size, window_size, C x = x.view(-1, int(self.window_size ** 2), C) # nW*B, window_size*window_size, C - + # multi-head attention - attn_mask = self.generate_attention_mask(shift_size, Hp, Wp, x.device) qkv = self.qkv(x).reshape(x.size(0), x.size(1), 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale @@ -146,7 +123,13 @@ def forward(self, x: Tensor): relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) - if attn_mask is not None: + if self.shift_size > 0: + # generate attention mask + attn_mask = generate_attention_mask(Hp, Wp, self.window_size, self.shift_size, x.device) + attn_mask = self.partition_window(attn_mask) + attn_mask = attn_mask.view(-1, int(self.window_size ** 2)) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) num_windows = attn_mask.shape[0] attn = attn.view( x.size(0) // num_windows, num_windows, self.num_heads, x.size(1), x.size(1) @@ -159,41 +142,20 @@ def forward(self, x: Tensor): x = (attn @ v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = self.proj(x) x = self.dropout(x) - + # reverse windows x = x.view(-1, self.window_size, self.window_size, C) x = self.reverse_window(x, Hp, Wp) # B H' W' C # reverse cyclic shift - if shift_size > 0: - x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2)) + if self.shift_size > 0: + x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) x = x.view(B, Hp, Wp, C) - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() + x = x[:, :H, :W, :].contiguous() return x - - @torch.no_grad() - def generate_attention_mask(self, shift_size: int, height: int, width: int, device: torch.device): - if shift_size > 0: - # calculate attention mask for SW-MSA - mask = torch.zeros((1, height, width, 1), device=device) - slices = ((0, -self.window_size), (-self.window_size, -shift_size), (-shift_size, None)) - count = 0 - for h in slices: - for w in slices: - mask[:, h[0] : h[1], w[0] : w[1], :] = count - count += 1 - - mask_windows = self.partition_window(mask) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, int(self.window_size ** 2)) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - return attn_mask - + def partition_window(self, x: Tensor): """ Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size, window_size, C). @@ -207,7 +169,7 @@ def reverse_window(self, x: Tensor, height: int, width: int): """ The Inverse operation of `partition_window`. """ - B = int(x.shape[0] / (height * width / (self.window_size ** 2))) + B = x.shape[0] // (height * width) // int((self.window_size ** 2)) x = x.view(B, height // self.window_size, width // self.window_size, self.window_size, self.window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, height, width, -1) return x @@ -225,6 +187,7 @@ class SwinTransformerBlock(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + act_layer (nn.Module): Activation layer. Default: nn.GELU. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """ @@ -239,25 +202,20 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = ShiftedWindowAttention( - dim, - window_size, - shift_size, - num_heads, - qkv_bias=qkv_bias, - attention_dropout=attention_dropout, - dropout=dropout, + dim, window_size, shift_size, num_heads, qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) def forward(self, x: Tensor): x = x + self.stochastic_depth(self.attn(self.norm1(x))) @@ -269,7 +227,6 @@ class SwinTransformer(nn.Module): """ Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_ paper. - Args: patch_size (int): Patch size. Default: 4. num_classes (int): Number of classes for classification head. Default: 1000. @@ -277,6 +234,7 @@ class SwinTransformer(nn.Module): depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7. + shift_size (List(int)): Shift size of each stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. dropout (float): Dropout rate. Default: 0. @@ -294,6 +252,7 @@ def __init__( depths: List[int] = [2, 2, 6, 2], num_heads: List[int] = [3, 6, 12, 24], window_size: int = 7, + shift_sizes: List[int] = [3, 3, 3, 0], mlp_ratio: float = 4.0, qkv_bias: bool = True, dropout: float = 0.0, @@ -336,7 +295,7 @@ def __init__( dim, num_heads[i_stage], window_size=window_size, - shift_size=0 if i_layer % 2 == 0 else window_size // 2, + shift_size=shift_sizes[i_stage] if i_layer % 2 == 0 else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, dropout=dropout, @@ -391,6 +350,7 @@ def _swin_transformer( depths: List[int], num_heads: List[int], window_size: int, + shift_sizes: List[int], stochastic_depth_prob: float, pretrained: bool, progress: bool, @@ -426,6 +386,7 @@ def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, + shift_sizes=[3, 3, 3, 0], stochastic_depth_prob=0.2, pretrained=pretrained, progress=progress, From 41e54b8367108084b5a1528e727482d8dd0dc6a8 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 23:14:11 +0800 Subject: [PATCH 17/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 6df4645887f..51acb1e8ec5 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -202,7 +202,6 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, - act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() @@ -215,7 +214,7 @@ def __init__( self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(dim, int(dim * mlp_ratio), act_layer=act_layer, dropout=dropout) + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout=dropout) def forward(self, x: Tensor): x = x + self.stochastic_depth(self.attn(self.norm1(x))) From 71ef0112b7f14dc3cb2076bd84c3025d5e97890f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 23:23:41 +0800 Subject: [PATCH 18/92] fix lint --- torchvision/models/swin_transformer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 51acb1e8ec5..a19914e80ee 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -33,7 +33,7 @@ def generate_attention_mask(height: int, width: int, window_size: int, shift_siz mask[:, h[0] : h[1], w[0] : w[1], :] = count count += 1 return mask - + torch.fx.wrap("generate_attention_mask") @@ -110,7 +110,7 @@ def forward(self, x: Tensor): # partition windows x = self.partition_window(x) # nW*B, window_size, window_size, C x = x.view(-1, int(self.window_size ** 2), C) # nW*B, window_size*window_size, C - + # multi-head attention qkv = self.qkv(x).reshape(x.size(0), x.size(1), 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] @@ -142,7 +142,7 @@ def forward(self, x: Tensor): x = (attn @ v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = self.proj(x) x = self.dropout(x) - + # reverse windows x = x.view(-1, self.window_size, self.window_size, C) x = self.reverse_window(x, Hp, Wp) # B H' W' C @@ -155,7 +155,7 @@ def forward(self, x: Tensor): x = x[:, :H, :W, :].contiguous() return x - + def partition_window(self, x: Tensor): """ Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size, window_size, C). @@ -187,7 +187,6 @@ class SwinTransformerBlock(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. - act_layer (nn.Module): Activation layer. Default: nn.GELU. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """ @@ -208,7 +207,13 @@ def __init__( self.norm1 = norm_layer(dim) self.attn = ShiftedWindowAttention( - dim, window_size, shift_size, num_heads, qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout + dim, + window_size, + shift_size, + num_heads, + qkv_bias=qkv_bias, + attention_dropout=attention_dropout, + dropout=dropout, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") From 6af4964f6ea4cc4a2af6b56472c483887515bb53 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 7 Mar 2022 23:49:08 +0800 Subject: [PATCH 19/92] add patch_merge --- torchvision/models/swin_transformer.py | 29 ++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index a19914e80ee..b14c8e1d19a 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -23,6 +23,35 @@ } +class PatchMerging(nn.Module): + """Patch Merging Layer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: Tensor): + B, H, W, C = x.shape + # assert H % 2 == 0 and W % 2 == 0, f"input size ({H}*{W}) are not even." + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + x = x.view(B, H // 2, W // 2, 2 * C) + + def generate_attention_mask(height: int, width: int, window_size: int, shift_size: int, device: torch.device): """Generate shifted window attention mask""" mask = torch.zeros((1, height, width, 1), device=device) From 1689dd97d201cf303f92e69bb7b6cdb0d6c30bf2 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 8 Mar 2022 12:31:10 +0800 Subject: [PATCH 20/92] fix bug --- torchvision/models/swin_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index b14c8e1d19a..66cdb4ca76d 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -50,6 +50,7 @@ def forward(self, x: Tensor): x = self.norm(x) x = self.reduction(x) x = x.view(B, H // 2, W // 2, 2 * C) + return x def generate_attention_mask(height: int, width: int, window_size: int, shift_size: int, device: torch.device): @@ -198,7 +199,7 @@ def reverse_window(self, x: Tensor, height: int, width: int): """ The Inverse operation of `partition_window`. """ - B = x.shape[0] // (height * width) // int((self.window_size ** 2)) + B = x.shape[0] // ((height * width) // int((self.window_size ** 2))) x = x.view(B, height // self.window_size, width // self.window_size, self.window_size, self.window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, height, width, -1) return x From 3891aad14ed678b0b435edcdb6a87fdcdfc883e4 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 8 Mar 2022 14:01:16 +0800 Subject: [PATCH 21/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 28 ++++++++++++++------------ 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 66cdb4ca76d..cd87ccd51ac 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -53,9 +53,9 @@ def forward(self, x: Tensor): return x -def generate_attention_mask(height: int, width: int, window_size: int, shift_size: int, device: torch.device): +def generate_attention_mask(x: Tensor, window_size: int, shift_size: int): """Generate shifted window attention mask""" - mask = torch.zeros((1, height, width, 1), device=device) + mask = x_new_zeros((1, x.size(1), x.size(2), 1), device=device) slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) count = 0 for h in slices: @@ -138,11 +138,13 @@ def forward(self, x: Tensor): x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # partition windows - x = self.partition_window(x) # nW*B, window_size, window_size, C - x = x.view(-1, int(self.window_size ** 2), C) # nW*B, window_size*window_size, C + x_windows = self.partition_window(x) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, int(self.window_size ** 2), C) # nW*B, window_size*window_size, C # multi-head attention - qkv = self.qkv(x).reshape(x.size(0), x.size(1), 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = self.qkv(x_windows).reshape( + x_windows.size(0), x_windows.size(1), 3, self.num_heads, C // self.num_heads + ).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -155,27 +157,27 @@ def forward(self, x: Tensor): if self.shift_size > 0: # generate attention mask - attn_mask = generate_attention_mask(Hp, Wp, self.window_size, self.shift_size, x.device) + attn_mask = generate_attention_mask(x, self.window_size, self.shift_size) attn_mask = self.partition_window(attn_mask) attn_mask = attn_mask.view(-1, int(self.window_size ** 2)) attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) num_windows = attn_mask.shape[0] attn = attn.view( - x.size(0) // num_windows, num_windows, self.num_heads, x.size(1), x.size(1) + x_windows.size(0) // num_windows, num_windows, self.num_heads, x_windows.size(1), x_windows.size(1) ) + attn_mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, x.size(1), x.size(1)) + attn = attn.view(-1, self.num_heads, x_windows.size(1), x_windows.size(1)) attn = self.softmax(attn) attn = self.attention_dropout(attn) - x = (attn @ v).transpose(1, 2).reshape(x.size(0), x.size(1), C) - x = self.proj(x) - x = self.dropout(x) + x_windows = (attn @ v).transpose(1, 2).reshape(x_windows.size(0), x_windows.size(1), C) + x_windows = self.proj(x_windows) + x_windows = self.dropout(x_windows) # reverse windows - x = x.view(-1, self.window_size, self.window_size, C) - x = self.reverse_window(x, Hp, Wp) # B H' W' C + x_windows = x_windows.view(-1, self.window_size, self.window_size, C) + x = self.reverse_window(x_windows, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: From 86f6d6b93983b8544da47e0f82e714f1c0f309d7 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 8 Mar 2022 14:54:59 +0800 Subject: [PATCH 22/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index cd87ccd51ac..3de5bb421e1 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -55,7 +55,7 @@ def forward(self, x: Tensor): def generate_attention_mask(x: Tensor, window_size: int, shift_size: int): """Generate shifted window attention mask""" - mask = x_new_zeros((1, x.size(1), x.size(2), 1), device=device) + mask = x.new_zeros((1, x.size(1), x.size(2), 1)) slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) count = 0 for h in slices: From dd9b12102f99eb835063e823b4b63a4da9cfa6d2 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 8 Mar 2022 17:00:10 +0800 Subject: [PATCH 23/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 3de5bb421e1..7a6142afa9f 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -119,7 +119,7 @@ def __init__( relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wh*Ww + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) @@ -142,14 +142,16 @@ def forward(self, x: Tensor): x_windows = x_windows.view(-1, int(self.window_size ** 2), C) # nW*B, window_size*window_size, C # multi-head attention - qkv = self.qkv(x_windows).reshape( - x_windows.size(0), x_windows.size(1), 3, self.num_heads, C // self.num_heads - ).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x_windows) + .reshape(x_windows.size(0), x_windows.size(1), 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = q @ k.transpose(-2, -1) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + relative_position_bias = self.relative_position_bias_table[self.relative_position_index].view( int(self.window_size ** 2), int(self.window_size ** 2), -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() From f869896c01b1d8a347379fd2d5223d88ba6015ab Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 8 Mar 2022 22:29:11 +0800 Subject: [PATCH 24/92] refactor code --- torchvision/models/swin_transformer.py | 37 +++++++++----------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 7a6142afa9f..01e1a31c613 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -119,7 +119,7 @@ def __init__( relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) @@ -127,19 +127,17 @@ def __init__( def forward(self, x: Tensor): B, H, W, C = x.shape # pad feature maps to multiples of window size - pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape # cyclic shift if self.shift_size > 0: x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # partition windows - x_windows = self.partition_window(x) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, int(self.window_size ** 2), C) # nW*B, window_size*window_size, C + x_windows = self.partition_window(x) # nW*B, window_size*window_size, C # multi-head attention qkv = ( @@ -161,10 +159,10 @@ def forward(self, x: Tensor): # generate attention mask attn_mask = generate_attention_mask(x, self.window_size, self.shift_size) attn_mask = self.partition_window(attn_mask) - attn_mask = attn_mask.view(-1, int(self.window_size ** 2)) + num_windows = attn_mask.size(0) + attn_mask = attn_mask.view(num_windows, -1) attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - num_windows = attn_mask.shape[0] attn = attn.view( x_windows.size(0) // num_windows, num_windows, self.num_heads, x_windows.size(1), x_windows.size(1) ) + attn_mask.unsqueeze(1).unsqueeze(0) @@ -178,34 +176,24 @@ def forward(self, x: Tensor): x_windows = self.dropout(x_windows) # reverse windows - x_windows = x_windows.view(-1, self.window_size, self.window_size, C) - x = self.reverse_window(x_windows, Hp, Wp) # B H' W' C + x = x_windows.view(B, pad_H // self.window_size, pad_W // self.window_size, self.window_size, self.window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) # reverse cyclic shift if self.shift_size > 0: x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - x = x.view(B, Hp, Wp, C) + # unpad features x = x[:, :H, :W, :].contiguous() - return x def partition_window(self, x: Tensor): """ - Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size, window_size, C). + Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size*window_size, C). """ B, H, W, C = x.shape x = x.view(B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, self.window_size, self.window_size, C) - return x - - def reverse_window(self, x: Tensor, height: int, width: int): - """ - The Inverse operation of `partition_window`. - """ - B = x.shape[0] // ((height * width) // int((self.window_size ** 2))) - x = x.view(B, height // self.window_size, width // self.window_size, self.window_size, self.window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, height, width, -1) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, int(self.window_size ** 2), C) return x @@ -333,7 +321,7 @@ def __init__( dim, num_heads[i_stage], window_size=window_size, - shift_size=shift_sizes[i_stage] if i_layer % 2 == 0 else window_size // 2, + shift_size=0 if i_layer % 2 == 0 else shift_sizes[i_stage], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, dropout=dropout, @@ -399,6 +387,7 @@ def _swin_transformer( depths=depths, num_heads=num_heads, window_size=window_size, + shift_sizes=shift_sizes, stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) From f3ae314aea67cb3797f0e63cc91574fb99abcfe7 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 8 Mar 2022 22:45:42 +0800 Subject: [PATCH 25/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 01e1a31c613..49d927502dc 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -119,7 +119,7 @@ def __init__( relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + relative_position_index = relative_coords.sum(-1).view(-1) self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) @@ -176,7 +176,9 @@ def forward(self, x: Tensor): x_windows = self.dropout(x_windows) # reverse windows - x = x_windows.view(B, pad_H // self.window_size, pad_W // self.window_size, self.window_size, self.window_size, C) + x = x_windows.view( + B, pad_H // self.window_size, pad_W // self.window_size, self.window_size, self.window_size, C + ) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) # reverse cyclic shift From 4ec87106481ef10806cbcd27542d68429c80b445 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 9 Mar 2022 22:17:27 +0800 Subject: [PATCH 26/92] refactor code --- torchvision/models/swin_transformer.py | 201 +++++++++++++------------ 1 file changed, 102 insertions(+), 99 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 49d927502dc..65845b776df 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -53,19 +53,88 @@ def forward(self, x: Tensor): return x -def generate_attention_mask(x: Tensor, window_size: int, shift_size: int): - """Generate shifted window attention mask""" - mask = x.new_zeros((1, x.size(1), x.size(2), 1)) - slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) - count = 0 - for h in slices: - for w in slices: - mask[:, h[0] : h[1], w[0] : w[1], :] = count - count += 1 - return mask +def shifted_window_attention( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: int, + num_heads: int, + shift_size: int = 0, + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Tensor = None, + proj_bias: Tensor = None, +): + B, H, W, C = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size - W % window_size) % window_size + pad_b = (window_size - H % window_size) % window_size + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape + + # If window size is larger than feature size, there is no need to shift window. + if window_size == min(pad_H, pad_W): + shift_size = 0 + + # cyclic shift + if shift_size > 0: + x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + # partition windows + num_windows = (pad_H // window_size) * (pad_W // window_size) + x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size ** 2, C) # B*nW, Ws*Ws, C + + # multi-head attention + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * (C // num_heads) ** -0.5 + attn = q @ k.transpose(-2, -1) + + # add relative position bias + attn = attn + relative_position_bias + + if shift_size > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_H, pad_W)) + slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) + count = 0 + for h in slices: + for w in slices: + attn_mask[h[0] : h[1], w[0] : w[1]] = count + count += 1 + attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size ** 2) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn = attn.view( + x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1) + ) + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout) + + x = (attn @ v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout) + # reverse windows + x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) -torch.fx.wrap("generate_attention_mask") + # reverse cyclic shift + if shift_size > 0: + x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2)) + + # unpad features + x = x[:, :H, :W, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention") class ShiftedWindowAttention(nn.Module): @@ -77,6 +146,7 @@ class ShiftedWindowAttention(nn.Module): shift_size (int): Shift size for SW-MSA. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. + proj_bias (bool): If True, add a learnable bias to projection. Default: True. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0. """ @@ -88,21 +158,19 @@ def __init__( shift_size: int, num_heads: int, qkv_bias: bool = True, + proj_bias: bool = True, attention_dropout: float = 0.0, dropout: float = 0.0, ): super().__init__() - self.dim = dim self.window_size = window_size self.shift_size = shift_size self.num_heads = num_heads - self.scale = (dim // num_heads) ** -0.5 + self.attention_dropout = attention_dropout + self.dropout = dropout self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attention_dropout = nn.Dropout(attention_dropout) - self.proj = nn.Linear(dim, dim) - self.dropout = nn.Dropout(dropout) - self.softmax = nn.Softmax(dim=-1) + self.proj = nn.Linear(dim, dim, bias=proj_bias) # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( @@ -119,84 +187,30 @@ def __init__( relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 - relative_position_index = relative_coords.sum(-1).view(-1) + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor): - B, H, W, C = x.shape - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) - _, pad_H, pad_W, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - - # partition windows - x_windows = self.partition_window(x) # nW*B, window_size*window_size, C - - # multi-head attention - qkv = ( - self.qkv(x_windows) - .reshape(x_windows.size(0), x_windows.size(1), 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) - q, k, v = qkv[0], qkv[1], qkv[2] - q = q * self.scale - attn = q @ k.transpose(-2, -1) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index].view( int(self.window_size ** 2), int(self.window_size ** 2), -1 ) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if self.shift_size > 0: - # generate attention mask - attn_mask = generate_attention_mask(x, self.window_size, self.shift_size) - attn_mask = self.partition_window(attn_mask) - num_windows = attn_mask.size(0) - attn_mask = attn_mask.view(num_windows, -1) - attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - attn = attn.view( - x_windows.size(0) // num_windows, num_windows, self.num_heads, x_windows.size(1), x_windows.size(1) - ) + attn_mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, x_windows.size(1), x_windows.size(1)) - - attn = self.softmax(attn) - attn = self.attention_dropout(attn) - - x_windows = (attn @ v).transpose(1, 2).reshape(x_windows.size(0), x_windows.size(1), C) - x_windows = self.proj(x_windows) - x_windows = self.dropout(x_windows) - - # reverse windows - x = x_windows.view( - B, pad_H // self.window_size, pad_W // self.window_size, self.window_size, self.window_size, C + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias ) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - - # unpad features - x = x[:, :H, :W, :].contiguous() - return x - - def partition_window(self, x: Tensor): - """ - Partition the input tensor into windows: (B, H, W, C) -> (B*nW, window_size*window_size, C). - """ - B, H, W, C = x.shape - x = x.view(B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, int(self.window_size ** 2), C) - return x class SwinTransformerBlock(nn.Module): @@ -207,7 +221,6 @@ class SwinTransformerBlock(nn.Module): window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. @@ -221,7 +234,6 @@ def __init__( window_size: int = 7, shift_size: int = 0, mlp_ratio: float = 4.0, - qkv_bias: bool = True, dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, @@ -235,7 +247,6 @@ def __init__( window_size, shift_size, num_heads, - qkv_bias=qkv_bias, attention_dropout=attention_dropout, dropout=dropout, ) @@ -243,7 +254,7 @@ def __init__( self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout=dropout) + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) def forward(self, x: Tensor): x = x + self.stochastic_depth(self.attn(self.norm1(x))) @@ -262,9 +273,7 @@ class SwinTransformer(nn.Module): depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7. - shift_size (List(int)): Shift size of each stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. dropout (float): Dropout rate. Default: 0. attention_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.1. @@ -280,9 +289,7 @@ def __init__( depths: List[int] = [2, 2, 6, 2], num_heads: List[int] = [3, 6, 12, 24], window_size: int = 7, - shift_sizes: List[int] = [3, 3, 3, 0], mlp_ratio: float = 4.0, - qkv_bias: bool = True, dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, @@ -323,9 +330,8 @@ def __init__( dim, num_heads[i_stage], window_size=window_size, - shift_size=0 if i_layer % 2 == 0 else shift_sizes[i_stage], + shift_size=0 if i_layer % 2 == 0 else window_size // 2, mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, dropout=dropout, attention_dropout=attention_dropout, stochastic_depth_prob=sd_prob, @@ -378,7 +384,6 @@ def _swin_transformer( depths: List[int], num_heads: List[int], window_size: int, - shift_sizes: List[int], stochastic_depth_prob: float, pretrained: bool, progress: bool, @@ -389,7 +394,6 @@ def _swin_transformer( depths=depths, num_heads=num_heads, window_size=window_size, - shift_sizes=shift_sizes, stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) @@ -415,7 +419,6 @@ def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, - shift_sizes=[3, 3, 3, 0], stochastic_depth_prob=0.2, pretrained=pretrained, progress=progress, From 20b4eeed0883a060f1930ce4087002df2065d2f0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 9 Mar 2022 22:26:53 +0800 Subject: [PATCH 27/92] fix lint --- torchvision/models/swin_transformer.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 65845b776df..e3675c107b9 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -72,7 +72,7 @@ def shifted_window_attention( pad_b = (window_size - H % window_size) % window_size x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) _, pad_H, pad_W, _ = x.shape - + # If window size is larger than feature size, there is no need to shift window. if window_size == min(pad_H, pad_W): shift_size = 0 @@ -84,7 +84,7 @@ def shifted_window_attention( # partition windows num_windows = (pad_H // window_size) * (pad_W // window_size) x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size ** 2, C) # B*nW, Ws*Ws, C + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size ** 2, C) # B*nW, Ws*Ws, C # multi-head attention qkv = F.linear(x, qkv_weight, qkv_bias) @@ -92,10 +92,9 @@ def shifted_window_attention( q, k, v = qkv[0], qkv[1], qkv[2] q = q * (C // num_heads) ** -0.5 attn = q @ k.transpose(-2, -1) - # add relative position bias attn = attn + relative_position_bias - + if shift_size > 0: # generate attention mask attn_mask = x.new_zeros((pad_H, pad_W)) @@ -109,11 +108,11 @@ def shifted_window_attention( attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size ** 2) attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - attn = attn.view( - x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1) - ) + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn_mask.unsqueeze( + 1 + ).unsqueeze(0) attn = attn.view(-1, num_heads, x.size(1), x.size(1)) - + attn = F.softmax(attn, dim=-1) attn = F.dropout(attn, p=attention_dropout) @@ -132,7 +131,7 @@ def shifted_window_attention( # unpad features x = x[:, :H, :W, :].contiguous() return x - + torch.fx.wrap("shifted_window_attention") @@ -187,7 +186,7 @@ def __init__( relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) @@ -209,7 +208,7 @@ def forward(self, x: Tensor): attention_dropout=self.attention_dropout, dropout=self.dropout, qkv_bias=self.qkv.bias, - proj_bias=self.proj.bias + proj_bias=self.proj.bias, ) From cb802ecbda3d874204a1dc1d727f840c4498d1f3 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 10 Mar 2022 18:51:16 +0800 Subject: [PATCH 28/92] refactor code --- torchvision/models/swin_transformer.py | 67 ++++++++++++-------------- 1 file changed, 30 insertions(+), 37 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e3675c107b9..3ee8466d059 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -66,6 +66,26 @@ def shifted_window_attention( qkv_bias: Tensor = None, proj_bias: Tensor = None, ): + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (int)): The size of the window. + shift_size (int): Shift size for SW-MSA. + num_heads (int): Number of attention heads. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_weight (Tensor[out_dim]): The bias tensor of query, key, value. Default: None. + proj_weight (Tensor[out_dim]): The bias tensor of projection. Default: None. + + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + """ B, H, W, C = input.shape # pad feature maps to multiples of window size pad_r = (window_size - W % window_size) % window_size @@ -84,8 +104,8 @@ def shifted_window_attention( # partition windows num_windows = (pad_H // window_size) * (pad_W // window_size) x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size ** 2, C) # B*nW, Ws*Ws, C - + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C + # multi-head attention qkv = F.linear(x, qkv_weight, qkv_bias) qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) @@ -105,7 +125,7 @@ def shifted_window_attention( attn_mask[h[0] : h[1], w[0] : w[1]] = count count += 1 attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size) - attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size ** 2) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size) attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn_mask.unsqueeze( @@ -137,17 +157,8 @@ def shifted_window_attention( class ShiftedWindowAttention(nn.Module): - """Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (int)): The size of the window. - shift_size (int): Shift size for SW-MSA. - num_heads (int): Number of attention heads. - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. - proj_bias (bool): If True, add a learnable bias to projection. Default: True. - attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. - dropout (float): Dropout ratio of output. Default: 0.0. + """ + See :func:`shifted_window_attention`. """ def __init__( @@ -193,7 +204,7 @@ def __init__( def forward(self, x: Tensor): relative_position_bias = self.relative_position_bias_table[self.relative_position_index].view( - int(self.window_size ** 2), int(self.window_size ** 2), -1 + self.window_size * self.window_size, self.window_size * self.window_size, -1 ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) @@ -213,7 +224,9 @@ def forward(self, x: Tensor): class SwinTransformerBlock(nn.Module): - """Swin Transformer Block. + """ + Swin Transformer Block. + Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. @@ -265,6 +278,7 @@ class SwinTransformer(nn.Module): """ Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_ paper. + Args: patch_size (int): Patch size. Default: 4. num_classes (int): Number of classes for classification head. Default: 1000. @@ -402,24 +416,3 @@ def _swin_transformer( state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) model.load_state_dict(state_dict) return model - - -def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: - """ - Constructs a swin_tiny architecture from - `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _swin_transformer( - arch="swin_tiny", - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_size=7, - stochastic_depth_prob=0.2, - pretrained=pretrained, - progress=progress, - **kwargs, - ) From 113b07449bf41189c4870822d2517a4d0d08dae3 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 10 Mar 2022 18:53:30 +0800 Subject: [PATCH 29/92] add swin_tiny --- torchvision/models/swin_transformer.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 3ee8466d059..f3ba05cd7c2 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -416,3 +416,24 @@ def _swin_transformer( state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) model.load_state_dict(state_dict) return model + + +def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_tiny architecture from + `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _swin_transformer( + arch="swin_tiny", + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + stochastic_depth_prob=0.2, + pretrained=pretrained, + progress=progress, + **kwargs, + ) From d92a49067c52e46a332914caab7d8f623dca86bf Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 10 Mar 2022 20:29:19 +0800 Subject: [PATCH 30/92] add swin_tiny.pkl --- test/expect/ModelTester.test_swin_tiny_expect.pkl | Bin 0 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_swin_tiny_expect.pkl diff --git a/test/expect/ModelTester.test_swin_tiny_expect.pkl b/test/expect/ModelTester.test_swin_tiny_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..06a0bfaf9cf6b52f8ec359070945a1287da66249 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=Vjao}Jjo4R&#-`0Y@86t6`%^+o_Vtz?w5yI)*za&Fea|`;iTw>t_pFm9EwIh`qi(+~ntA_* z_W!%zxJ|K}Bb#8iWRHV=z~-jCg7pIX6n5nAo#3?2R&0mu{s~im+qOt?+09!Nus>Zs zaDPMGSKG3PgnduGaN1wdsIzZ4dTyVFou2*XTMz8A%J18K+AeP2-XUhUVfr-tv>I9a z#my`Bx(5H+2MR68m-)>nfFT9KxI>Gd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P Date: Thu, 10 Mar 2022 20:30:40 +0800 Subject: [PATCH 31/92] fix lint --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index f3ba05cd7c2..03fb79781cd 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -69,7 +69,7 @@ def shifted_window_attention( """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. - + Args: input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. From 05dd1e2fc74acd460580f8c996e390d1735fd6b9 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 10 Mar 2022 21:00:58 +0800 Subject: [PATCH 32/92] Delete ModelTester.test_swin_tiny_expect.pkl --- test/expect/ModelTester.test_swin_tiny_expect.pkl | Bin 939 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/expect/ModelTester.test_swin_tiny_expect.pkl diff --git a/test/expect/ModelTester.test_swin_tiny_expect.pkl b/test/expect/ModelTester.test_swin_tiny_expect.pkl deleted file mode 100644 index 06a0bfaf9cf6b52f8ec359070945a1287da66249..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=Vjao}Jjo4R&#-`0Y@86t6`%^+o_Vtz?w5yI)*za&Fea|`;iTw>t_pFm9EwIh`qi(+~ntA_* z_W!%zxJ|K}Bb#8iWRHV=z~-jCg7pIX6n5nAo#3?2R&0mu{s~im+qOt?+09!Nus>Zs zaDPMGSKG3PgnduGaN1wdsIzZ4dTyVFou2*XTMz8A%J18K+AeP2-XUhUVfr-tv>I9a z#my`Bx(5H+2MR68m-)>nfFT9KxI>Gd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P Date: Thu, 10 Mar 2022 21:01:26 +0800 Subject: [PATCH 33/92] add swin_tiny --- test/expect/ModelTester.test_swin_tiny_expect.pkl | Bin 0 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_swin_tiny_expect.pkl diff --git a/test/expect/ModelTester.test_swin_tiny_expect.pkl b/test/expect/ModelTester.test_swin_tiny_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9dff738da6cc3085f1e9401911cf2995f4f7961e GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5~r=~qMc8Mu{~GsihXgBm3FF+8SP)V&bHfgP-K6@@yB-F8|C-=95=G73CZ36 z`f8*7BBdGovTil(*k9SizW>FBTYFod*X%DAliT-euH?S* zcz!$6zzDm0Gu!s2G40!D)ID+Ef|Rv;bQrhV9XoJhU!IuF{zuXJ_FBqUY&Yo7+PC7( zahvPD6ZgGX8@b==)SrDms`G4D<#5}VMEtQUlN7gqS9{jhYV+iM3!;Pf-C4$I?=Z`J z-@14q``TI2`$3^K;d#v46TpxHVcem`&tMG?t+LdjVqh$|Ihhj~Tu31YVH$IRY`!=z zJ(LM(D+mX8GlD31nnaF60gwa=Ku@9Qx{>|FhobWpkcX^W-vC`Nva9$}^hy9-2-6D< zg8*+fHXW!UIc8nBa!_Ih0T{g Date: Thu, 10 Mar 2022 21:38:58 +0800 Subject: [PATCH 34/92] add --- .../ModelTester.test_swin_tiny_expect.pkl | Bin 939 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/test/expect/ModelTester.test_swin_tiny_expect.pkl b/test/expect/ModelTester.test_swin_tiny_expect.pkl index 9dff738da6cc3085f1e9401911cf2995f4f7961e..11fa82056db454ef3553e8d1fedd869dab1df94b 100644 GIT binary patch delta 230 zcmVq`R+Q&W!r2{_BuKzxPz#TsbPnNz`c?iFJc0oVfr!l`I+nhd4Y6w3I&?Y|o zHUqyWm&d-&!4H%|sD+QcspnC@%aJg? zUf69v_-E@rLRUe*0T6G#elUqX&L6tI1SSJNP;sKY46n1kiPG>sg!-7jKZ_^7=utnv gOfSSfEK3(YP)i30bjw8~lMn*X1a!+qB$MO<$8ghY*#H0l delta 230 zcmV Date: Thu, 10 Mar 2022 22:17:19 +0800 Subject: [PATCH 35/92] add Optional to bias --- torchvision/models/swin_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 03fb79781cd..a00798a4f67 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -63,8 +63,8 @@ def shifted_window_attention( shift_size: int = 0, attention_dropout: float = 0.0, dropout: float = 0.0, - qkv_bias: Tensor = None, - proj_bias: Tensor = None, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, ): """ Window based multi-head self attention (W-MSA) module with relative position bias. From 4ed22c0712215f30fdbcd17687fc65d4570b1844 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 16 Mar 2022 18:10:59 +0800 Subject: [PATCH 36/92] update init weights --- torchvision/models/swin_transformer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index a00798a4f67..710a064aaff 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -372,14 +372,16 @@ def __init__( self.avgpool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(num_features, num_classes) - for m in self.modules(): - if isinstance(m, (nn.Conv2d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0.0) + nn.init.constant_(m.weight, 1.0) def forward(self, x): x = self.features(x) From bccc2b4125d12a7e119d86befc498ccbb277c47c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 17 Mar 2022 16:32:50 +0800 Subject: [PATCH 37/92] update init_weights and add no weight decay --- torchvision/models/swin_transformer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 710a064aaff..4fedd286101 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -268,6 +268,13 @@ def __init__( self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) + for block in [self.attn, self.mlp]: + for m in block.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + def forward(self, x: Tensor): x = x + self.stochastic_depth(self.attn(self.norm1(x))) x = x + self.stochastic_depth(self.mlp(self.norm2(x))) @@ -372,16 +379,12 @@ def __init__( self.avgpool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(num_features, num_classes) - self.apply(self._init_weights) + nn.init.trunc_normal_(self.head.weight, std=0.02) + nn.init.zeros_(self.head.bias) - def _init_weights(self, m): - if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0.0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0.0) - nn.init.constant_(m.weight, 1.0) + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} def forward(self, x): x = self.features(x) From 2098b2469c1707e47c19ef746ce634863759570a Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 17 Mar 2022 17:19:54 +0800 Subject: [PATCH 38/92] add no weight decay --- torchvision/models/swin_transformer.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 4fedd286101..595bf07e862 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -268,13 +268,6 @@ def __init__( self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) - for block in [self.attn, self.mlp]: - for m in block.modules(): - if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - def forward(self, x: Tensor): x = x + self.stochastic_depth(self.attn(self.norm1(x))) x = x + self.stochastic_depth(self.mlp(self.norm2(x))) @@ -379,12 +372,15 @@ def __init__( self.avgpool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(num_features, num_classes) - nn.init.trunc_normal_(self.head.weight, std=0.02) - nn.init.zeros_(self.head.bias) + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) @torch.jit.ignore def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} + return {"relative_position_bias_table"} def forward(self, x): x = self.features(x) From 6b0b6c2d66fc79c3a8768a6649584a54ebf8ebcf Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 17 Mar 2022 21:36:47 +0800 Subject: [PATCH 39/92] add set_weight_decay --- references/classification/train.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 569cf3009e7..ecc31568e75 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -247,12 +247,17 @@ def main(args): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - if args.norm_weight_decay is None: - parameters = model.parameters() + if hasattr(model, 'no_weight_decay_keywords'): + custom_keys_weight_decay = {k: 0.0 for k in model.no_weight_decay_keywords()} else: - param_groups = torchvision.ops._utils.split_normalization_params(model) - wd_groups = [args.norm_weight_decay, args.weight_decay] - parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] + custom_keys_weight_decay = None + parameters = utils.set_weight_decay( + model, + args.weight_decay, + norm_weight_decay=args.norm_weight_decay, + bias_weight_decay=args.bias_weight_decay, + custom_keys_weight_decay=custom_keys_weight_decay, + ) opt_name = args.opt.lower() if opt_name.startswith("sgd"): @@ -411,6 +416,12 @@ def get_args_parser(add_help=True): type=float, help="weight decay for Normalization layers (default: None, same value as --wd)", ) + parser.add_argument( + "--bias-weight-decay", + default=None, + type=float, + help="weight decay for all bias parameters (default: None, same value as --wd)", + ) parser.add_argument( "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" ) From 991e4c18ba3d2614d6482017f5da390fbba55378 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 17 Mar 2022 21:41:07 +0800 Subject: [PATCH 40/92] add set_weight_decay --- references/classification/utils.py | 55 ++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/references/classification/utils.py b/references/classification/utils.py index 7f573415c4c..69153033f6e 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -5,6 +5,7 @@ import os import time from collections import defaultdict, deque, OrderedDict +from typing import Optional, Dict import torch import torch.distributed as dist @@ -400,3 +401,57 @@ def reduce_across_processes(val): dist.barrier() dist.all_reduce(t) return t + + +def set_weight_decay( + model: torch.nn.Module, + weight_decay: float, + norm_weight_decay: Optional[float] = None, + bias_weight_decay: Optional[float] = None, + custom_keys_weight_decay: Optional[Dict[str, float]] = None +): + norm_classes = (torch.nn.modules.batchnorm._BatchNorm, torch.nn.LayerNorm, torch.nn.GroupNorm) + + norm_params = [] + bias_params = [] + other_params = [] + custom_params = {} + if custom_keys_weight_decay is not None: + for key in custom_keys_weight_decay: + custom_params[key] = [] + + for module in model.modules(): + if next(module.children(), None): + for name, p in module.named_parameters(recurse=False): + if not p.requires_grad: + continue + is_custom_key = False + for key in custom_params: + if key in name: + custom_params[key].append(p) + is_custom_key = True + if not is_custom_key: + other_params.append(p) + elif isinstance(module, norm_classes): + if norm_weight_decay is not None: + norm_params.extend(p for p in module.parameters() if p.requires_grad) + else: + other_params.extend(p for p in module.parameters() if p.requires_grad) + else: + for name, p in module.named_parameters(): + if not p.requires_grad: + continue + if name == "bias" and (bias_weight_decay is not None): + bias_params.append(p) + else: + other_params.append(p) + + param_groups = [] + if norm_weight_decay is not None: + param_groups.append({"params": norm_params, "weight_decay": norm_weight_decay}) + if bias_weight_decay is not None: + param_groups.append({"params": bias_params, "weight_decay": bias_weight_decay}) + for key in custom_params: + param_groups.append({"params": custom_params[key], "weight_decay": custom_keys_weight_decay[key]}) + param_groups.append({"params": other_params, "weight_decay": weight_decay}) + return param_groups From f1ec5c87f98df2c5d6f880055e8955b6583d7202 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 17 Mar 2022 21:46:14 +0800 Subject: [PATCH 41/92] fix lint --- references/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index 69153033f6e..a50322f5f6f 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -408,7 +408,7 @@ def set_weight_decay( weight_decay: float, norm_weight_decay: Optional[float] = None, bias_weight_decay: Optional[float] = None, - custom_keys_weight_decay: Optional[Dict[str, float]] = None + custom_keys_weight_decay: Optional[Dict[str, float]] = None, ): norm_classes = (torch.nn.modules.batchnorm._BatchNorm, torch.nn.LayerNorm, torch.nn.GroupNorm) From 3c2a44d2aa9f57a34e171a6182318d49ec3bfeee Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 17 Mar 2022 21:46:52 +0800 Subject: [PATCH 42/92] fix lint --- references/classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index ecc31568e75..8dc34a81eac 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -247,7 +247,7 @@ def main(args): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - if hasattr(model, 'no_weight_decay_keywords'): + if hasattr(model, "no_weight_decay_keywords"): custom_keys_weight_decay = {k: 0.0 for k in model.no_weight_decay_keywords()} else: custom_keys_weight_decay = None From e8b528ffa5fd4998c3e8b3f117c45976e0ceb863 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 18 Mar 2022 19:53:05 +0800 Subject: [PATCH 43/92] add lr_cos_min --- references/classification/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index 8dc34a81eac..4480c1e037d 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -284,7 +284,7 @@ def main(args): main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) elif args.lr_scheduler == "cosineannealinglr": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs - args.lr_warmup_epochs + optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_cos_min ) elif args.lr_scheduler == "exponentiallr": main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) @@ -435,6 +435,9 @@ def get_args_parser(add_help=True): parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument( + "--lr-cos-min", default=0.0, type=float, help="minimum lr of cosineannealing lr scheduler (default: 0.0)" + ) parser.add_argument("--print-freq", default=10, type=int, help="print frequency") parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") parser.add_argument("--resume", default="", type=str, help="path of checkpoint") From 023ceb0fc8f1212596715a5f24eab14f3e8a48f2 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 18 Mar 2022 22:36:36 +0800 Subject: [PATCH 44/92] add other swin models --- torchvision/models/swin_transformer.py | 63 ++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 595bf07e862..43eed800057 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -438,3 +438,66 @@ def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> progress=progress, **kwargs, ) + + +def swin_samll(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_small architecture from + `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _swin_transformer( + arch="swin_tiny", + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + stochastic_depth_prob=0.3, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def swin_base(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_base architecture from + `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _swin_transformer( + arch="swin_tiny", + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + stochastic_depth_prob=0.5, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def swin_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_large architecture from + `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _swin_transformer( + arch="swin_tiny", + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=7, + stochastic_depth_prob=0.2, + pretrained=pretrained, + progress=progress, + **kwargs, + ) From 113fd094b23951d5de947e3f88aa37e4cceb2113 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 24 Mar 2022 13:00:42 +0800 Subject: [PATCH 45/92] Update torchvision/models/swin_transformer.py Co-authored-by: Vasilis Vryniotis --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 43eed800057..d6cc213792a 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -440,7 +440,7 @@ def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ) -def swin_samll(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_small architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. From e91d607d258342ae5bee955b350d2fec660b9e0d Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 24 Mar 2022 13:19:43 +0800 Subject: [PATCH 46/92] refactor doc --- torchvision/models/swin_transformer.py | 30 +++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index d6cc213792a..5c6764407d6 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -75,13 +75,13 @@ def shifted_window_attention( qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. relative_position_bias (Tensor): The learned relative position bias added to attention. - window_size (int)): The size of the window. - shift_size (int): Shift size for SW-MSA. + window_size (int): Window size. num_heads (int): Number of attention heads. + shift_size (int): Shift size for shifted window attention. Default: 0. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0. - qkv_weight (Tensor[out_dim]): The bias tensor of query, key, value. Default: None. - proj_weight (Tensor[out_dim]): The bias tensor of projection. Default: None. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. Returns: Tensor[N, H, W, C]: The output tensor after shifted window attention. @@ -230,9 +230,9 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + window_size (int): Window size. Default: 7. + shift_size (int): Shift size for shifted window attention. Default: 0. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. @@ -283,15 +283,15 @@ class SwinTransformer(nn.Module): patch_size (int): Patch size. Default: 4. num_classes (int): Number of classes for classification head. Default: 1000. embed_dim (int): Patch embedding dimension. Default: 96. - depths (List(int)): Depth of each Swin Transformer layer. - num_heads (List(int)): Number of attention heads in different layers. + depths (List(int)): Depth of each Swin Transformer layer. Default: [2, 2, 6, 2]. + num_heads (List(int)): Number of attention heads in different layers. Default: [3, 6, 12, 24]. window_size (int): Window size. Default: 7. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - dropout (float): Dropout rate. Default: 0. - attention_drop_rate (float): Attention dropout rate. Default: 0. - drop_path_rate (float): Stochastic depth rate. Default: 0.1. - block (nn.Module): SwinTransformer Block. - norm_layer (nn.Module): Normalization layer. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_drop_rate (float): Attention dropout rate. Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default: 0.0. + block (nn.Module, optional): SwinTransformer Block. Default: None. + norm_layer (nn.Module, optional): Normalization layer. Default: None. """ def __init__( From 78fb3ce3e4e0cff9e83794045f9d13e3e3ec7b12 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 25 Mar 2022 19:40:12 +0800 Subject: [PATCH 47/92] Update utils.py --- references/classification/utils.py | 55 ------------------------------ 1 file changed, 55 deletions(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index f8c5d0c4117..32658a7c137 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -5,7 +5,6 @@ import os import time from collections import defaultdict, deque, OrderedDict -from typing import Optional, Dict import torch import torch.distributed as dist @@ -401,57 +400,3 @@ def reduce_across_processes(val): dist.barrier() dist.all_reduce(t) return t - - -def set_weight_decay( - model: torch.nn.Module, - weight_decay: float, - norm_weight_decay: Optional[float] = None, - bias_weight_decay: Optional[float] = None, - custom_keys_weight_decay: Optional[Dict[str, float]] = None, -): - norm_classes = (torch.nn.modules.batchnorm._BatchNorm, torch.nn.LayerNorm, torch.nn.GroupNorm) - - norm_params = [] - bias_params = [] - other_params = [] - custom_params = {} - if custom_keys_weight_decay is not None: - for key in custom_keys_weight_decay: - custom_params[key] = [] - - for module in model.modules(): - if next(module.children(), None): - for name, p in module.named_parameters(recurse=False): - if not p.requires_grad: - continue - is_custom_key = False - for key in custom_params: - if key in name: - custom_params[key].append(p) - is_custom_key = True - if not is_custom_key: - other_params.append(p) - elif isinstance(module, norm_classes): - if norm_weight_decay is not None: - norm_params.extend(p for p in module.parameters() if p.requires_grad) - else: - other_params.extend(p for p in module.parameters() if p.requires_grad) - else: - for name, p in module.named_parameters(): - if not p.requires_grad: - continue - if name == "bias" and (bias_weight_decay is not None): - bias_params.append(p) - else: - other_params.append(p) - - param_groups = [] - if norm_weight_decay is not None: - param_groups.append({"params": norm_params, "weight_decay": norm_weight_decay}) - if bias_weight_decay is not None: - param_groups.append({"params": bias_params, "weight_decay": bias_weight_decay}) - for key in custom_params: - param_groups.append({"params": custom_params[key], "weight_decay": custom_keys_weight_decay[key]}) - param_groups.append({"params": other_params, "weight_decay": weight_decay}) - return param_groups From b3b9a20f641700eb4ad196213034411b44517348 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 25 Mar 2022 19:45:08 +0800 Subject: [PATCH 48/92] Update train.py --- references/classification/train.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index f517b9ae2f0..ed24239f0a2 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -229,17 +229,12 @@ def main(args): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - if hasattr(model, "no_weight_decay_keywords"): - custom_keys_weight_decay = {k: 0.0 for k in model.no_weight_decay_keywords()} + if args.norm_weight_decay is None: + parameters = model.parameters() else: - custom_keys_weight_decay = None - parameters = utils.set_weight_decay( - model, - args.weight_decay, - norm_weight_decay=args.norm_weight_decay, - bias_weight_decay=args.bias_weight_decay, - custom_keys_weight_decay=custom_keys_weight_decay, - ) + param_groups = torchvision.ops._utils.split_normalization_params(model) + wd_groups = [args.norm_weight_decay, args.weight_decay] + parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] opt_name = args.opt.lower() if opt_name.startswith("sgd"): @@ -266,7 +261,7 @@ def main(args): main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) elif args.lr_scheduler == "cosineannealinglr": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_cos_min + optimizer, T_max=args.epochs - args.lr_warmup_epochs ) elif args.lr_scheduler == "exponentiallr": main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) @@ -398,12 +393,6 @@ def get_args_parser(add_help=True): type=float, help="weight decay for Normalization layers (default: None, same value as --wd)", ) - parser.add_argument( - "--bias-weight-decay", - default=None, - type=float, - help="weight decay for all bias parameters (default: None, same value as --wd)", - ) parser.add_argument( "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" ) From bb255c1ccb6ccb2cf8ca6a52c87072239d65e5e6 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 25 Mar 2022 19:45:59 +0800 Subject: [PATCH 49/92] Update train.py --- references/classification/train.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index ed24239f0a2..eb8b56c1ad0 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -406,9 +406,6 @@ def get_args_parser(add_help=True): parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") - parser.add_argument( - "--lr-cos-min", default=0.0, type=float, help="minimum lr of cosineannealing lr scheduler (default: 0.0)" - ) parser.add_argument("--print-freq", default=10, type=int, help="print frequency") parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") parser.add_argument("--resume", default="", type=str, help="path of checkpoint") From 02f500659f83692a2b55f631d5b8551731be175a Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 11:35:20 +0800 Subject: [PATCH 50/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 84 +++++++++++++++----------- 1 file changed, 48 insertions(+), 36 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 5c6764407d6..5e955f246e7 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -9,6 +9,11 @@ from ..ops.stochastic_depth import StochasticDepth from .convnext import Permute from .vision_transformer import MLPBlock +from ..transforms._presets import ImageClassification, InterpolationMode +from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -17,12 +22,6 @@ ] -_MODELS_URLS = { - "swin_tiny": "", - "swin_base": "", -} - - class PatchMerging(nn.Module): """Patch Merging Layer. Args: @@ -69,7 +68,6 @@ def shifted_window_attention( """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. - Args: input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. @@ -82,7 +80,6 @@ def shifted_window_attention( dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. - Returns: Tensor[N, H, W, C]: The output tensor after shifted window attention. """ @@ -128,9 +125,8 @@ def shifted_window_attention( attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size) attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn_mask.unsqueeze( - 1 - ).unsqueeze(0) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = F.softmax(attn, dim=-1) @@ -226,7 +222,6 @@ def forward(self, x: Tensor): class SwinTransformerBlock(nn.Module): """ Swin Transformer Block. - Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. @@ -278,7 +273,6 @@ class SwinTransformer(nn.Module): """ Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_ paper. - Args: patch_size (int): Patch size. Default: 4. num_classes (int): Number of classes for classification head. Default: 1000. @@ -310,7 +304,7 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, ): super().__init__() - + _log_api_usage_once(self) self.num_classes = num_classes if block is None: @@ -355,15 +349,7 @@ def __init__( layers.append(nn.Sequential(*stage)) # add patch merging layer if i_stage < (len(depths) - 1): - layers.append( - # nn.Sequential( - # norm_layer(dim), - # Permute([0, 3, 1, 2]), - # nn.Conv2d(dim, dim * 2, kernel_size=2, stride=2, bias=False), - # Permute([0, 2, 3, 1]), - # ) - PatchMerging(dim, norm_layer) - ) + layers.append(PatchMerging(dim, norm_layer)) self.features = nn.Sequential(*layers) @@ -378,10 +364,6 @@ def __init__( if m.bias is not None: nn.init.zeros_(m.bias) - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {"relative_position_bias_table"} - def forward(self, x): x = self.features(x) x = self.norm(x) @@ -393,16 +375,18 @@ def forward(self, x): def _swin_transformer( - arch: str, embed_dim: int, depths: List[int], num_heads: List[int], window_size: int, stochastic_depth_prob: float, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> SwinTransformer: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = SwinTransformer( embed_dim=embed_dim, depths=depths, @@ -411,15 +395,41 @@ def _swin_transformer( stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) - if pretrained: - if arch not in _MODELS_URLS: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: +_COMMON_META = { + "task": "image_classification", + "architecture": "SwinTransformer", + "publication_year": 2021, + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BICUBIC, +} + + +class Swin_Tiny_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 86567656, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swin_tiny", + "acc@1": 81.072, + "acc@5": 95.318, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", Swin_Tiny_Weights.IMAGENET1K_V1)) +def swin_tiny(*, weights: Optional[Swin_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. @@ -427,6 +437,8 @@ def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ + weights = Swin_Tiny_Weights.verify(weights) + return _swin_transformer( arch="swin_tiny", embed_dim=96, @@ -434,7 +446,7 @@ def swin_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> num_heads=[3, 6, 12, 24], window_size=7, stochastic_depth_prob=0.2, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) From df626aa2782a9f99c1a1a74c78f7d3183005b286 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 11:39:48 +0800 Subject: [PATCH 51/92] update model builder --- torchvision/models/swin_transformer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 5e955f246e7..bb61785d661 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -7,13 +7,13 @@ from .._internally_replaced_utils import load_state_dict_from_url from ..ops.stochastic_depth import StochasticDepth -from .convnext import Permute -from .vision_transformer import MLPBlock from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param +from .convnext import Permute +from .vision_transformer import MLPBlock __all__ = [ @@ -438,9 +438,8 @@ def swin_tiny(*, weights: Optional[Swin_Tiny_Weights] = None, progress: bool = T progress (bool): If True, displays a progress bar of the download to stderr """ weights = Swin_Tiny_Weights.verify(weights) - + return _swin_transformer( - arch="swin_tiny", embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], From 438a0dd584d46a2f6dc1383582cc773fa9dd17e1 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 11:57:12 +0800 Subject: [PATCH 52/92] fix lint --- torchvision/models/swin_transformer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index bb61785d661..7831f07beaa 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -5,7 +5,6 @@ import torch.nn.functional as F from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once @@ -413,16 +412,16 @@ def _swin_transformer( class Swin_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", + url="", transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, - "num_params": 86567656, + "num_params": 28288354, "size": (224, 224), "min_size": (224, 224), "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swin_tiny", - "acc@1": 81.072, - "acc@5": 95.318, + "acc@1": 81.222, + "acc@5": 95.332, }, ) DEFAULT = IMAGENET1K_V1 @@ -434,7 +433,7 @@ def swin_tiny(*, weights: Optional[Swin_Tiny_Weights] = None, progress: bool = T Constructs a swin_tiny architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Swin_Tiny_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = Swin_Tiny_Weights.verify(weights) From 070aebdf6b9047b82970ed98ca0468867756ace8 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 13:23:06 +0800 Subject: [PATCH 53/92] add --- .../ModelTester.test_swin_tiny_expect.pkl | Bin 939 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/test/expect/ModelTester.test_swin_tiny_expect.pkl b/test/expect/ModelTester.test_swin_tiny_expect.pkl index 11fa82056db454ef3553e8d1fedd869dab1df94b..25ee179b47d24aea13343807275f3811d4bdd61c 100644 GIT binary patch delta 230 zcmVc)U)Mr1pKeN8^j|{%M-YUOd zcQU_%^DRCQd%HeF4-r40Q?@;OwU561fY?10>yEwRR|LO+a!@}Q?ae-*$@IQ<*~7j- zv*12}WEelT&(yw&3e3KyDmTB}u^>M+sxUqWYUVz6A=W-5bAVC5OSTF>Ad#KE1U#6% z5a(gP@qSxBipr6`z|y`utoh(Q+11Iu;HcleeeT!3%sqrY$S?9f8XHSLME1eHL=N;m g2!t9wRq`R+Q&W!r2{_BuKzxPz#TsbPnNz`c?iFJc0oVfr!l`I+nhd4Y6w3I&?Y|o zHUqyWm&d-&!4H%|sD+QcspnC@%aJg? zUf69v_-E@rLRUe*0T6G#elUqX&L6tI1SSJNP;sKY46n1kiPG>sg!-7jKZ_^7=utnv gOfSSfEK3(YP)i30bjw8~lMn*X1a!+qB$MO<$8ghY*#H0l From 0cd82e17e69443f6a6a639e4b2e2d84753fa4d61 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 20:38:35 +0800 Subject: [PATCH 54/92] Update torchvision/models/swin_transformer.py Co-authored-by: Vasilis Vryniotis --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 7831f07beaa..4e6528fddea 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -395,7 +395,7 @@ def _swin_transformer( **kwargs, ) - if weights: + if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) return model From 8fde8ad83e3d4a35d84933861cba7a147054c9a9 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 20:38:51 +0800 Subject: [PATCH 55/92] Update torchvision/models/swin_transformer.py Co-authored-by: Vasilis Vryniotis --- torchvision/models/swin_transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 4e6528fddea..02217fc4395 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -427,7 +427,6 @@ class Swin_Tiny_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 -@handle_legacy_interface(weights=("pretrained", Swin_Tiny_Weights.IMAGENET1K_V1)) def swin_tiny(*, weights: Optional[Swin_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from From 412ad1528a4c0c61949cb421045b8e71bd902147 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 21:08:06 +0800 Subject: [PATCH 56/92] update other model --- torchvision/models/swin_transformer.py | 42 ++++++++++++++++++-------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 02217fc4395..75d8c710847 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -18,6 +18,9 @@ __all__ = [ "SwinTransformer", "swin_tiny", + "swin_small", + "swin_base", + "swin_large", ] @@ -425,6 +428,18 @@ class Swin_Tiny_Weights(WeightsEnum): }, ) DEFAULT = IMAGENET1K_V1 + + +class Swin_Small_Weights(WeightsEnum): + pass + + +class Swin_Base_Weights(WeightsEnum): + pass + + +class Swin_Large_Weights(WeightsEnum): + pass def swin_tiny(*, weights: Optional[Swin_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: @@ -449,64 +464,67 @@ def swin_tiny(*, weights: Optional[Swin_Tiny_Weights] = None, progress: bool = T ) -def swin_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_small(*, weights: Optional[Swin_Small_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_small architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Swin_Small_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = Swin_Small_Weights.verify(weights) + return _swin_transformer( - arch="swin_tiny", embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7, stochastic_depth_prob=0.3, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def swin_base(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_base(*, weights: Optional[Swin_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_base architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Swin_Base_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = Swin_Base_Weights.verify(weights) + return _swin_transformer( - arch="swin_tiny", embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7, stochastic_depth_prob=0.5, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def swin_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_large(*, weights: Optional[Swin_Large_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_large architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Swin_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = Swin_Large_Weights.verify(weights) + return _swin_transformer( - arch="swin_tiny", embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7, stochastic_depth_prob=0.2, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) From 9539c1dd3b2167e145a02f319f4269393171eefa Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 10:27:53 +0800 Subject: [PATCH 57/92] simplify the model name just like ViT --- torchvision/models/swin_transformer.py | 44 ++++++++++++++------------ 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 75d8c710847..421a570926d 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -17,10 +17,14 @@ __all__ = [ "SwinTransformer", - "swin_tiny", - "swin_small", - "swin_base", - "swin_large", + "Swin_T_Weights", + "Swin_S_Weights", + "Swin_B_Weights", + "Swin_L_Weights", + "swin_t", + "swin_s", + "swin_b", + "swin_l", ] @@ -413,7 +417,7 @@ def _swin_transformer( } -class Swin_Tiny_Weights(WeightsEnum): +class Swin_T_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="", transforms=partial(ImageClassification, crop_size=224), @@ -430,27 +434,27 @@ class Swin_Tiny_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 -class Swin_Small_Weights(WeightsEnum): +class Swin_S_Weights(WeightsEnum): pass -class Swin_Base_Weights(WeightsEnum): +class Swin_B_Weights(WeightsEnum): pass -class Swin_Large_Weights(WeightsEnum): +class Swin_L_Weights(WeightsEnum): pass -def swin_tiny(*, weights: Optional[Swin_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: - weights (Swin_Tiny_Weights, optional): The pretrained weights for the model + weights (Swin_T_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - weights = Swin_Tiny_Weights.verify(weights) + weights = Swin_T_Weights.verify(weights) return _swin_transformer( embed_dim=96, @@ -464,15 +468,15 @@ def swin_tiny(*, weights: Optional[Swin_Tiny_Weights] = None, progress: bool = T ) -def swin_small(*, weights: Optional[Swin_Small_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_small architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: - weights (Swin_Small_Weights, optional): The pretrained weights for the model + weights (Swin_S_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - weights = Swin_Small_Weights.verify(weights) + weights = Swin_S_Weights.verify(weights) return _swin_transformer( embed_dim=96, @@ -486,15 +490,15 @@ def swin_small(*, weights: Optional[Swin_Small_Weights] = None, progress: bool = ) -def swin_base(*, weights: Optional[Swin_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_base architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: - weights (Swin_Base_Weights, optional): The pretrained weights for the model + weights (Swin_B_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - weights = Swin_Base_Weights.verify(weights) + weights = Swin_B_Weights.verify(weights) return _swin_transformer( embed_dim=128, @@ -508,15 +512,15 @@ def swin_base(*, weights: Optional[Swin_Base_Weights] = None, progress: bool = T ) -def swin_large(*, weights: Optional[Swin_Large_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_l(*, weights: Optional[Swin_L_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_large architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: - weights (Swin_Large_Weights, optional): The pretrained weights for the model + weights (Swin_L_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - weights = Swin_Large_Weights.verify(weights) + weights = Swin_L_Weights.verify(weights) return _swin_transformer( embed_dim=192, From 04bf82c47929ff99d244db01a3aeeb74f13e6824 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 10:36:10 +0800 Subject: [PATCH 58/92] add lr_cos_min --- references/classification/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index 6a3c289bc04..b6009ef87c8 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -267,7 +267,7 @@ def main(args): main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) elif args.lr_scheduler == "cosineannealinglr": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs - args.lr_warmup_epochs + optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_cos_min ) elif args.lr_scheduler == "exponentiallr": main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) @@ -424,6 +424,7 @@ def get_args_parser(add_help=True): parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--lr-cos-min", default=0.0, type=float, help="minimum lr of cosine annealing schedule (default: 0.0)") parser.add_argument("--print-freq", default=10, type=int, help="print frequency") parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") parser.add_argument("--resume", default="", type=str, help="path of checkpoint") From b24b8d974f852ba3ad33130a79cf8f306e057db0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 11:02:41 +0800 Subject: [PATCH 59/92] fix lint --- references/classification/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index b6009ef87c8..e3b9cbc6d2a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -424,7 +424,9 @@ def get_args_parser(add_help=True): parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") - parser.add_argument("--lr-cos-min", default=0.0, type=float, help="minimum lr of cosine annealing schedule (default: 0.0)") + parser.add_argument( + "--lr-cos-min", default=0.0, type=float, help="minimum lr of cosine annealing schedule (default: 0.0)" + ) parser.add_argument("--print-freq", default=10, type=int, help="print frequency") parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") parser.add_argument("--resume", default="", type=str, help="path of checkpoint") From 54d01f722a41064143f1bd4e07153c4768e31ad4 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 11:04:49 +0800 Subject: [PATCH 60/92] fix lint --- torchvision/models/swin_transformer.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 421a570926d..6e8d114b7c4 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -418,22 +418,9 @@ def _swin_transformer( class Swin_T_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="", - transforms=partial(ImageClassification, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 28288354, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swin_tiny", - "acc@1": 81.222, - "acc@5": 95.332, - }, - ) - DEFAULT = IMAGENET1K_V1 - - + pass + + class Swin_S_Weights(WeightsEnum): pass @@ -477,7 +464,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * progress (bool): If True, displays a progress bar of the download to stderr """ weights = Swin_S_Weights.verify(weights) - + return _swin_transformer( embed_dim=96, depths=[2, 2, 18, 2], @@ -499,7 +486,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * progress (bool): If True, displays a progress bar of the download to stderr """ weights = Swin_B_Weights.verify(weights) - + return _swin_transformer( embed_dim=128, depths=[2, 2, 18, 2], @@ -521,7 +508,7 @@ def swin_l(*, weights: Optional[Swin_L_Weights] = None, progress: bool = True, * progress (bool): If True, displays a progress bar of the download to stderr """ weights = Swin_L_Weights.verify(weights) - + return _swin_transformer( embed_dim=192, depths=[2, 2, 18, 2], From 961d1b55053c6ade8643d193a2d0308e80b660c0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 11:24:05 +0800 Subject: [PATCH 61/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 6e8d114b7c4..e04de94c5d3 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -420,7 +420,7 @@ def _swin_transformer( class Swin_T_Weights(WeightsEnum): pass - + class Swin_S_Weights(WeightsEnum): pass From 38279eda500fe6b48855ad63e70a8c4bbcb2a7d9 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 16:55:53 +0800 Subject: [PATCH 62/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e04de94c5d3..c3220456e2d 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -6,11 +6,10 @@ from torch import nn, Tensor from ..ops.stochastic_depth import StochasticDepth -from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param from .convnext import Permute from .vision_transformer import MLPBlock From b1dcf5ee7bccf68b88be9ab4c147ce5cfa92a6cd Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 17:37:08 +0800 Subject: [PATCH 63/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index c3220456e2d..f5a8375a309 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -6,6 +6,7 @@ from torch import nn, Tensor from ..ops.stochastic_depth import StochasticDepth +from ..transforms._presets import InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum from ._meta import _IMAGENET_CATEGORIES From 0d4014233899012edd203c3e752e7460730d80aa Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 20:23:26 +0800 Subject: [PATCH 64/92] Delete ModelTester.test_swin_tiny_expect.pkl --- test/expect/ModelTester.test_swin_tiny_expect.pkl | Bin 939 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/expect/ModelTester.test_swin_tiny_expect.pkl diff --git a/test/expect/ModelTester.test_swin_tiny_expect.pkl b/test/expect/ModelTester.test_swin_tiny_expect.pkl deleted file mode 100644 index 25ee179b47d24aea13343807275f3811d4bdd61c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66f(F@qN*^&FwVJj@z*mEwpVuc6^`Om2~^<_M7*8?C06H`<~YRxN@WYO`mn` z1gdx2x$q0xF9_adTfMb^-=BsXwj!_l_C5+{+22qUU@!6Ztlff>U-p&VJiO0g^8>qv z6bbun=dbMR=k&P zxc_5aw0+m9iTe&*+GoAy=L6fDS5NMHu;Tu{+IQFYow04PJE8x{PD^NGa?7}u2wF8CLQ~9!GCx9UZ!ni|=pTQa)T4kw4#lTo_b229~xR62)!ZhXr*?e(c zdMFdnRuB&GW&~02G>IIC0w4(#fSy9pbtC(U4@KuIAP-r$z5%*kWLNQ{=#>Dv5T+Lz z1_9n|Y&uXya?HAL<)Fk20x)_zgv&4q>`9P!*+6-N!4s+glnDa7S=m5h%s>cI4^ayM Dd6oVN From 358c6be48fb710fad28117e6965c49cb2484dd19 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 2 Apr 2022 20:50:32 +0800 Subject: [PATCH 65/92] add swin_t --- test/expect/ModelTester.test_swin_t_expect.pkl | Bin 0 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_swin_t_expect.pkl diff --git a/test/expect/ModelTester.test_swin_t_expect.pkl b/test/expect/ModelTester.test_swin_t_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7326683b7a5adbc48f0f84087f5a4de56f585d64 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5~n3--@ealx9*qGklB}*JbPbo4d*_sb7FQ?LWlPnOx&_NJ9U>`s({2kD}627 zAf|4+U+Yfp69_)DFTH8I-TbEFeV-G~@4MQ!)9&wze!GCrQu}T4TlcZsF0k`m)@~P{ z%4^?#Yl0mI-yys85^44hZaeoCDQMYe9b>jHdfc{eMTx-vZ{phf7$ytvyRjk4UQ)Eg zzUQRl{-FPC`^-cX_ZqWv?wh=wXaBj*;{E@hJ-1OilViWzS!QpnV%Yu#>WA$nWTfrq zwr;gE4R~q`3atV;A=?wckOE=cp~cT&4G*od)S_ZwEVwzD6B%4cAqQa^bAfEWI4?bv z31}+_2Y54rD0rGgjza;E1PVY;q3F7i{ltf&^A(VXtXtmzT`#h$_)+vq09^>v3k`z+ zZ#Fg^s3JLLUAS^kVg>;iy&b}3m<0AD$h&NyyusiJRRGEa0p6@^ATeej1gVFp1prXS B@fZLA literal 0 HcmV?d00001 From 07410bd6a301978af8238bedbc5156dcda7ef52c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 4 Apr 2022 20:54:43 +0800 Subject: [PATCH 66/92] refactor code --- torchvision/models/swin_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index f5a8375a309..820acb027a2 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -114,7 +114,7 @@ def shifted_window_attention( qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * (C // num_heads) ** -0.5 - attn = q @ k.transpose(-2, -1) + attn = q.matmul(k.transpose(-2, -1)) # add relative position bias attn = attn + relative_position_bias @@ -138,7 +138,7 @@ def shifted_window_attention( attn = F.softmax(attn, dim=-1) attn = F.dropout(attn, p=attention_dropout) - x = (attn @ v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = F.linear(x, proj_weight, proj_bias) x = F.dropout(x, p=dropout) From 8c6d9103898a8e447c32c9f53a75014a4a92795e Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 4 Apr 2022 21:02:40 +0800 Subject: [PATCH 67/92] Update train.py --- references/classification/train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index e3b9cbc6d2a..374eb7bf16c 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -267,7 +267,7 @@ def main(args): main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) elif args.lr_scheduler == "cosineannealinglr": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_cos_min + optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min ) elif args.lr_scheduler == "exponentiallr": main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) @@ -424,9 +424,7 @@ def get_args_parser(add_help=True): parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") - parser.add_argument( - "--lr-cos-min", default=0.0, type=float, help="minimum lr of cosine annealing schedule (default: 0.0)" - ) + parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") parser.add_argument("--print-freq", default=10, type=int, help="print frequency") parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") parser.add_argument("--resume", default="", type=str, help="path of checkpoint") From 7c9ffd36c9d11765d74d57a4ff96c8893bad99d2 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 4 Apr 2022 21:22:52 +0800 Subject: [PATCH 68/92] add swin_s --- test/expect/ModelTester.test_swin_s_expect.pkl | Bin 0 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_swin_s_expect.pkl diff --git a/test/expect/ModelTester.test_swin_s_expect.pkl b/test/expect/ModelTester.test_swin_s_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..77a3810cc94d334fe7773f606346cdba2bf3fc03 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=VD($bPO~r~L~$!}h;o2;KjD{aQO;{cgKGuY31@6@R@~r7&~1x|5b2SF+;% zSqt~=d+xf_j_oaD8ki<+6D1)G{BtgK2B*j{RD(*HhJb|Ek&cb~57MZIqWO+5h4F zxUV3-ZC}uPak~=nn|7ZBTlZHyRJWh${&(MWrD*$EbDH*tozt^yNVx0uLRJAFul+) z2=HcO(}60IW7dT$2PI|@fYI9_T!u+tPlCM52Fe=@o=^p#Oc3DB$_5f+211Z}h*|&> Cz4!S5 literal 0 HcmV?d00001 From e94fdfd0c66340075df4917474f0deb5ee3210e0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 4 Apr 2022 21:39:59 +0800 Subject: [PATCH 69/92] ignore a error of mypy --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 820acb027a2..60ce296c50a 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -207,7 +207,7 @@ def __init__( def forward(self, x: Tensor): relative_position_bias = self.relative_position_bias_table[self.relative_position_index].view( self.window_size * self.window_size, self.window_size * self.window_size, -1 - ) # Wh*Ww,Wh*Ww,nH + ) # type: ignore[index] relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) return shifted_window_attention( From 1021fd2ca6e713a3102323d7cc40ebcdda2f4023 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 4 Apr 2022 21:51:50 +0800 Subject: [PATCH 70/92] Update swin_transformer.py --- torchvision/models/swin_transformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 60ce296c50a..6d4242072d3 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -205,9 +205,10 @@ def __init__( nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor): - relative_position_bias = self.relative_position_bias_table[self.relative_position_index].view( + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view( self.window_size * self.window_size, self.window_size * self.window_size, -1 - ) # type: ignore[index] + ) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) return shifted_window_attention( From 88a3e034f0a5149525bf1b9b78a1a25439029c0b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 4 Apr 2022 22:10:09 +0800 Subject: [PATCH 71/92] fix lint --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 6d4242072d3..c38d8caa353 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -205,7 +205,7 @@ def __init__( nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor): - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view( self.window_size * self.window_size, self.window_size * self.window_size, -1 ) From 535cc6a0843be20460d1b6d71a9d2e06325a8b1b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 4 Apr 2022 22:36:14 +0800 Subject: [PATCH 72/92] add swin_b --- test/expect/ModelTester.test_swin_b_expect.pkl | Bin 0 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_swin_b_expect.pkl diff --git a/test/expect/ModelTester.test_swin_b_expect.pkl b/test/expect/ModelTester.test_swin_b_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..006af1155597a8b27483a373f0a9dd9e2a3c83f2 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5~oDJ(k|}Y9XqyrEp{7q+U+9c18l=}1?@JyJ#8!U$6>$eyGr|p^&We7wEOMh zeXx07g7=NxOQHhyi%8n>ri)?hn&gW?NF|U~BbV(C(tT z#{TZ0<93W4op#2;OYEYmEcVUYlWahw2#6bR!EEq(@TcxaWS78L_y!Oh8>$lyW>ISA933uN=fdFi1{ zKwCjLz?%_7!P6vi914IWPyl)gMc0k&Cq5LNuYf#c-TDUTdXZhlkD^xs=t7uYXcz=| zv$5$w70EH{!j*#(GYG)w?GP@*B(Ntz-em*j4F*rB0#GIh@MdKLi7^8qNIgU?0A*tE An*aa+ literal 0 HcmV?d00001 From 92ae7dd71cedc48e674260a7868e695389410b32 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 4 Apr 2022 22:37:05 +0800 Subject: [PATCH 73/92] add swin_l --- test/expect/ModelTester.test_swin_l_expect.pkl | Bin 0 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_swin_l_expect.pkl diff --git a/test/expect/ModelTester.test_swin_l_expect.pkl b/test/expect/ModelTester.test_swin_l_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..17d51e057c24c2189890914378b79dcf51a729e3 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5@+&Hz5Uu!CH9*VX78V;Ja<3i?4A3Sedp~x+~ij&+9FKW2^C|J`Qg zz8w#v_h*z$w*N0b#oldWt!?l52X^b#XW8pGx!db3G2MT9daM02xn29zUT?Oy*pzSo z&1!<(lTKcHmE8sVSG>2k|2Ls+-`l`*b_@Tv?CXx%V|RM~UwgA)r+ptb$?kL5p0WSj z%q_Ot9kus$WPY?ellj>GNRZz?yB&w^%ANP`|9*Yd-bz3H{j)=6?X%sNV87{&$^IA1 zllD(k-f!>sqi#Paw0_;bwdVvdq(B&VXz?>x!$YeqwWt^v3vN#4L()0w*Nf~beiXeDKo`RF zLc<`yn~hBesz{Dm7p@$Xm_YzWZ-;OhCV@Q(@-7=FZ!ma56@W59fHx}}NQ@Z>LFyrD E0XzN$mH+?% literal 0 HcmV?d00001 From bb337375bc727e7a966de93b7d701da26b02f1a1 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 5 Apr 2022 10:36:41 +0800 Subject: [PATCH 74/92] refactor code --- torchvision/models/swin_transformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index c38d8caa353..59274a5feda 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -43,7 +43,6 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm def forward(self, x: Tensor): B, H, W, C = x.shape - # assert H % 2 == 0 and W % 2 == 0, f"input size ({H}*{W}) are not even." x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C @@ -264,10 +263,8 @@ def __init__( attention_dropout=attention_dropout, dropout=dropout, ) - self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) def forward(self, x: Tensor): @@ -357,7 +354,6 @@ def __init__( # add patch merging layer if i_stage < (len(depths) - 1): layers.append(PatchMerging(dim, norm_layer)) - self.features = nn.Sequential(*layers) num_features = embed_dim * 2 ** (len(depths) - 1) From 2500ff3763146d57d959f358dad1a667c4db0db8 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 5 Apr 2022 10:37:46 +0800 Subject: [PATCH 75/92] Update train.py --- references/classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index 374eb7bf16c..96703bfdf85 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -233,7 +233,7 @@ def main(args): if args.bias_weight_decay is not None: custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) if args.transformer_embedding_decay is not None: - for key in ["class_token", "position_embedding", "relative_position_bias"]: + for key in ["class_token", "position_embedding", "relative_position_bias_table"]: custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) parameters = utils.set_weight_decay( model, From f0615440bf18617dc0e5dc4839bd5ed27e5ed010 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Thu, 7 Apr 2022 12:47:12 +0000 Subject: [PATCH 76/92] move relative_position_bias to __init__ --- torchvision/models/swin_transformer.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 59274a5feda..a5e36a9218d 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -179,15 +179,9 @@ def __init__( self.num_heads = num_heads self.attention_dropout = attention_dropout self.dropout = dropout - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) - ) # 2*Wh-1 * 2*Ww-1, nH - # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size) coords_w = torch.arange(self.window_size) @@ -199,22 +193,25 @@ def __init__( relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) - - def forward(self, x: Tensor): - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] + + # define a parameter table of relative position bias + relative_position_bias_table = torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) # 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(relative_position_bias_table, std=0.02) + + relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view( self.window_size * self.window_size, self.window_size * self.window_size, -1 ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + self.relative_position_bias = nn.Parameter(relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)) + + def forward(self, x: Tensor): + return shifted_window_attention( x, self.qkv.weight, self.proj.weight, - relative_position_bias, + self.relative_position_bias, self.window_size, self.num_heads, shift_size=self.shift_size, From 41faba232668f7ac4273a0cf632c0d0130c7ce9c Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Thu, 7 Apr 2022 15:33:25 +0100 Subject: [PATCH 77/92] fix formatting --- torchvision/models/swin_transformer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index a5e36a9218d..69c5f654e60 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -193,11 +193,13 @@ def __init__( relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww - + # define a parameter table of relative position bias - relative_position_bias_table = torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) # 2*Wh-1 * 2*Ww-1, nH + relative_position_bias_table = torch.zeros( + (2 * window_size - 1) * (2 * window_size - 1), num_heads + ) # 2*Wh-1 * 2*Ww-1, nH nn.init.trunc_normal_(relative_position_bias_table, std=0.02) - + relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view( self.window_size * self.window_size, self.window_size * self.window_size, -1 @@ -205,7 +207,6 @@ def __init__( self.relative_position_bias = nn.Parameter(relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)) def forward(self, x: Tensor): - return shifted_window_attention( x, From e338dbea822cc451eef78c9014a970ef12868f07 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Apr 2022 09:26:07 +0000 Subject: [PATCH 78/92] Revert "fix formatting" This reverts commit 41faba232668f7ac4273a0cf632c0d0130c7ce9c. --- torchvision/models/swin_transformer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 69c5f654e60..a5e36a9218d 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -193,13 +193,11 @@ def __init__( relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww - + # define a parameter table of relative position bias - relative_position_bias_table = torch.zeros( - (2 * window_size - 1) * (2 * window_size - 1), num_heads - ) # 2*Wh-1 * 2*Ww-1, nH + relative_position_bias_table = torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) # 2*Wh-1 * 2*Ww-1, nH nn.init.trunc_normal_(relative_position_bias_table, std=0.02) - + relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view( self.window_size * self.window_size, self.window_size * self.window_size, -1 @@ -207,6 +205,7 @@ def __init__( self.relative_position_bias = nn.Parameter(relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)) def forward(self, x: Tensor): + return shifted_window_attention( x, From 1b8ffb112951e70bd3ed8b78421e2439f08dbfca Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Apr 2022 09:26:29 +0000 Subject: [PATCH 79/92] Revert "move relative_position_bias to __init__" This reverts commit f0615440bf18617dc0e5dc4839bd5ed27e5ed010. --- torchvision/models/swin_transformer.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index a5e36a9218d..59274a5feda 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -179,9 +179,15 @@ def __init__( self.num_heads = num_heads self.attention_dropout = attention_dropout self.dropout = dropout + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size) coords_w = torch.arange(self.window_size) @@ -193,25 +199,22 @@ def __init__( relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww - - # define a parameter table of relative position bias - relative_position_bias_table = torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) # 2*Wh-1 * 2*Ww-1, nH - nn.init.trunc_normal_(relative_position_bias_table, std=0.02) - - relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] + self.register_buffer("relative_position_index", relative_position_index) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x: Tensor): + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view( self.window_size * self.window_size, self.window_size * self.window_size, -1 ) - self.relative_position_bias = nn.Parameter(relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)) - - def forward(self, x: Tensor): - + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) return shifted_window_attention( x, self.qkv.weight, self.proj.weight, - self.relative_position_bias, + relative_position_bias, self.window_size, self.num_heads, shift_size=self.shift_size, From affd0dfe0a1a5983eaaf03610574349a46bd3b8d Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 21 Apr 2022 19:10:50 +0800 Subject: [PATCH 80/92] refactor code --- torchvision/models/swin_transformer.py | 32 +++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 59274a5feda..1a2007e8c45 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -278,34 +278,34 @@ class SwinTransformer(nn.Module): Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_ paper. Args: - patch_size (int): Patch size. Default: 4. - num_classes (int): Number of classes for classification head. Default: 1000. - embed_dim (int): Patch embedding dimension. Default: 96. - depths (List(int)): Depth of each Swin Transformer layer. Default: [2, 2, 6, 2]. - num_heads (List(int)): Number of attention heads in different layers. Default: [3, 6, 12, 24]. + patch_size (int): Patch size. + embed_dim (int): Patch embedding dimension. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. - attention_drop_rate (float): Attention dropout rate. Default: 0.0. - drop_path_rate (float): Stochastic depth rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. + num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. """ def __init__( self, - patch_size: int = 4, - num_classes: int = 1000, - embed_dim: int = 96, - depths: List[int] = [2, 2, 6, 2], - num_heads: List[int] = [3, 6, 12, 24], + patch_size: int, + embed_dim: int, + depths: List[int], + num_heads: List[int], window_size: int = 7, mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, - block: Optional[Callable[..., nn.Module]] = None, + num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, ): super().__init__() _log_api_usage_once(self) @@ -378,6 +378,7 @@ def forward(self, x): def _swin_transformer( + patch_size: int, embed_dim: int, depths: List[int], num_heads: List[int], @@ -391,6 +392,7 @@ def _swin_transformer( _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = SwinTransformer( + patch_size=patch_size, embed_dim=embed_dim, depths=depths, num_heads=num_heads, @@ -441,6 +443,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * weights = Swin_T_Weights.verify(weights) return _swin_transformer( + patch_size=4, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], @@ -463,6 +466,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * weights = Swin_S_Weights.verify(weights) return _swin_transformer( + patch_size=4, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], @@ -485,6 +489,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * weights = Swin_B_Weights.verify(weights) return _swin_transformer( + patch_size=4, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], @@ -507,6 +512,7 @@ def swin_l(*, weights: Optional[Swin_L_Weights] = None, progress: bool = True, * weights = Swin_L_Weights.verify(weights) return _swin_transformer( + patch_size=4, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], From 565203b8bb73c0a00edd13e122b4119ce7730181 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 22 Apr 2022 11:06:10 +0100 Subject: [PATCH 81/92] Remove deprecated meta-data from `_COMMON_META` --- torchvision/models/swin_transformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 1a2007e8c45..e5a7a4f3699 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -408,11 +408,7 @@ def _swin_transformer( _COMMON_META = { - "task": "image_classification", - "architecture": "SwinTransformer", - "publication_year": 2021, "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BICUBIC, } From 09d63f5d00529b367a8286746a929a2ff8cf9eea Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 26 Apr 2022 09:16:00 +0100 Subject: [PATCH 82/92] fix linter --- torchvision/models/swin_transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e5a7a4f3699..cec4ca98465 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -6,7 +6,6 @@ from torch import nn, Tensor from ..ops.stochastic_depth import StochasticDepth -from ..transforms._presets import InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum from ._meta import _IMAGENET_CATEGORIES From b6fec69947683e9249e275c454216b11df35d73e Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 26 Apr 2022 09:09:13 +0000 Subject: [PATCH 83/92] add pretrained weights for swin_t --- torchvision/models/swin_transformer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 59274a5feda..3d6c345821e 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -6,9 +6,9 @@ from torch import nn, Tensor from ..ops.stochastic_depth import StochasticDepth -from ..transforms._presets import InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param from .convnext import Permute @@ -415,17 +415,21 @@ def _swin_transformer( class Swin_T_Weights(WeightsEnum): - pass - + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_t-81486767.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), + meta={ + **_COMMON_META, + }, + ) + DEFAULT = IMAGENET1K_V1 class Swin_S_Weights(WeightsEnum): pass - class Swin_B_Weights(WeightsEnum): pass - class Swin_L_Weights(WeightsEnum): pass From 64af984147764aad1508febe8419445a24bae5de Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 26 Apr 2022 09:14:04 +0000 Subject: [PATCH 84/92] fix format --- torchvision/models/swin_transformer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index ecb6b8453d7..c739a9de003 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -422,12 +422,15 @@ class Swin_T_Weights(WeightsEnum): ) DEFAULT = IMAGENET1K_V1 + class Swin_S_Weights(WeightsEnum): pass + class Swin_B_Weights(WeightsEnum): pass + class Swin_L_Weights(WeightsEnum): pass From 1528ca8cf539c420ecba5e6d7a6158bc814e6d8f Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 26 Apr 2022 10:49:12 +0100 Subject: [PATCH 85/92] apply ufmt --- torchvision/models/swin_transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index c739a9de003..40f9cf0088f 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -415,7 +415,9 @@ def _swin_transformer( class Swin_T_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/swin_t-81486767.pth", - transforms=partial(ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META, }, From e6e9ffebb8a89d641dc60cd210b552aa99c96454 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Apr 2022 09:32:50 +0100 Subject: [PATCH 86/92] add documentation --- docs/source/models.rst | 17 +++++++++++++++++ torchvision/models/swin_transformer.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index f84d9c7fd1a..9fe8be374e5 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -42,6 +42,7 @@ architectures for image classification: - `RegNet`_ - `VisionTransformer`_ - `ConvNeXt`_ +- `SwinTransformer`_ You can construct a model with random weights by calling its constructor: @@ -97,6 +98,9 @@ You can construct a model with random weights by calling its constructor: convnext_small = models.convnext_small() convnext_base = models.convnext_base() convnext_large = models.convnext_large() + swin_t = models.swin_t() + swin_s = models.swin_s() + swin_b = models.swin_b() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. @@ -219,6 +223,7 @@ convnext_tiny 82.520 96.146 convnext_small 83.616 96.650 convnext_base 84.062 96.870 convnext_large 84.414 96.976 +swin_t 81.358 95.526 ================================ ============= ============= @@ -238,6 +243,7 @@ convnext_large 84.414 96.976 .. _RegNet: https://arxiv.org/abs/2003.13678 .. _VisionTransformer: https://arxiv.org/abs/2010.11929 .. _ConvNeXt: https://arxiv.org/abs/2201.03545 +.. _SwinTransformer: https://arxiv.org/abs/2103.14030 .. currentmodule:: torchvision.models @@ -450,6 +456,17 @@ ConvNeXt convnext_base convnext_large +SwinTransformer +-------- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + swin_t + swin_s + swin_b + Quantized Models ---------------- diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 40f9cf0088f..44d614f95c8 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -416,7 +416,7 @@ class Swin_T_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/swin_t-81486767.pth", transforms=partial( - ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META, From 137d63422165ba7056de849d88cd3747a963d655 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Apr 2022 09:35:30 +0100 Subject: [PATCH 87/92] update references README --- references/classification/README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/references/classification/README.md b/references/classification/README.md index c274c997791..758c57ce27d 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -224,6 +224,18 @@ Note that the above command corresponds to training on a single node with 8 GPUs For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), and `--batch_size 64`. + +### SwinTransformer +``` +torchrun --nproc_per_node=8 train.py\ +--model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\ +--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\ +--lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\ +--clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra +``` +Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. + + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp). From 3457abbb4a422407ebd84ad6f242f1b7ccd809f5 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Apr 2022 10:16:50 +0100 Subject: [PATCH 88/92] adding new style docs --- docs/source/models.rst | 2 + docs/source/models_new.rst | 1 + torchvision/models/swin_transformer.py | 70 ++++++++++++++++++++++---- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 9fe8be374e5..13e5dff4637 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -101,6 +101,7 @@ You can construct a model with random weights by calling its constructor: swin_t = models.swin_t() swin_s = models.swin_s() swin_b = models.swin_b() + swin_l = models.swin_l() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. @@ -466,6 +467,7 @@ SwinTransformer swin_t swin_s swin_b + swin_l Quantized Models ---------------- diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index d3132639be5..e45fddddc4d 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -44,6 +44,7 @@ weights: models/resnet models/resnext models/squeezenet + models/swin_transformer models/vgg models/vision_transformer diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 44d614f95c8..ad6f05832b6 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -440,10 +440,23 @@ class Swin_L_Weights(WeightsEnum): def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from - `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + Args: - weights (Swin_T_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_trasformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_T_Weights + :members: """ weights = Swin_T_Weights.verify(weights) @@ -464,9 +477,22 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * """ Constructs a swin_small architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + Args: - weights (Swin_S_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_trasformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_S_Weights + :members: """ weights = Swin_S_Weights.verify(weights) @@ -487,9 +513,22 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * """ Constructs a swin_base architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + Args: - weights (Swin_B_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_trasformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_B_Weights + :members: """ weights = Swin_B_Weights.verify(weights) @@ -510,9 +549,22 @@ def swin_l(*, weights: Optional[Swin_L_Weights] = None, progress: bool = True, * """ Constructs a swin_large architecture from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. + Args: - weights (Swin_L_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.Swin_L_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_L_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_trasformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_L_Weights + :members: """ weights = Swin_L_Weights.verify(weights) From d3599efe4ae188c6667b71f28fb89d4e5113c78f Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Apr 2022 10:33:13 +0100 Subject: [PATCH 89/92] update pre-trained weights values --- torchvision/models/swin_transformer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index ad6f05832b6..4f3519199de 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -420,6 +420,13 @@ class Swin_T_Weights(WeightsEnum): ), meta={ **_COMMON_META, + "num_params": 28288354, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swin_t", + "metrics": { + "acc@1": 81.358, + "acc@5": 95.526, + }, }, ) DEFAULT = IMAGENET1K_V1 From 6addd1b46f8e5c096ee487f8bf433360c57151df Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Apr 2022 11:09:09 +0100 Subject: [PATCH 90/92] remove other variants --- docs/source/models.rst | 6 -- docs/source/models/swin_transformer.rst | 25 +++++ torchvision/models/swin_transformer.py | 126 ------------------------ 3 files changed, 25 insertions(+), 132 deletions(-) create mode 100644 docs/source/models/swin_transformer.rst diff --git a/docs/source/models.rst b/docs/source/models.rst index 13e5dff4637..51881e505e4 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -99,9 +99,6 @@ You can construct a model with random weights by calling its constructor: convnext_base = models.convnext_base() convnext_large = models.convnext_large() swin_t = models.swin_t() - swin_s = models.swin_s() - swin_b = models.swin_b() - swin_l = models.swin_l() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. @@ -465,9 +462,6 @@ SwinTransformer :template: function.rst swin_t - swin_s - swin_b - swin_l Quantized Models ---------------- diff --git a/docs/source/models/swin_transformer.rst b/docs/source/models/swin_transformer.rst new file mode 100644 index 00000000000..b8726d71d2a --- /dev/null +++ b/docs/source/models/swin_transformer.rst @@ -0,0 +1,25 @@ +SwinTransformer +=============== + +.. currentmodule:: torchvision.models + +The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision +Transformer using Shifted Windows `__ +paper. + + +Model builders +-------------- + +The following model builders can be used to instanciate an SwinTransformer model. +`swin_t` can be instantiated with pre-trained weights and all others without. +All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer`` +base class. Please refer to the `source code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + swin_t diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 4f3519199de..73900cedba7 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -18,13 +18,7 @@ __all__ = [ "SwinTransformer", "Swin_T_Weights", - "Swin_S_Weights", - "Swin_B_Weights", - "Swin_L_Weights", "swin_t", - "swin_s", - "swin_b", - "swin_l", ] @@ -432,18 +426,6 @@ class Swin_T_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 -class Swin_S_Weights(WeightsEnum): - pass - - -class Swin_B_Weights(WeightsEnum): - pass - - -class Swin_L_Weights(WeightsEnum): - pass - - def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from @@ -478,111 +460,3 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * progress=progress, **kwargs, ) - - -def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: - """ - Constructs a swin_small architecture from - `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. - - Args: - weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.Swin_S_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.swin_trasformer.SwinTransformer`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.Swin_S_Weights - :members: - """ - weights = Swin_S_Weights.verify(weights) - - return _swin_transformer( - patch_size=4, - embed_dim=96, - depths=[2, 2, 18, 2], - num_heads=[3, 6, 12, 24], - window_size=7, - stochastic_depth_prob=0.3, - weights=weights, - progress=progress, - **kwargs, - ) - - -def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: - """ - Constructs a swin_base architecture from - `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. - - Args: - weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.Swin_B_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.swin_trasformer.SwinTransformer`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.Swin_B_Weights - :members: - """ - weights = Swin_B_Weights.verify(weights) - - return _swin_transformer( - patch_size=4, - embed_dim=128, - depths=[2, 2, 18, 2], - num_heads=[4, 8, 16, 32], - window_size=7, - stochastic_depth_prob=0.5, - weights=weights, - progress=progress, - **kwargs, - ) - - -def swin_l(*, weights: Optional[Swin_L_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: - """ - Constructs a swin_large architecture from - `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. - - Args: - weights (:class:`~torchvision.models.Swin_L_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.Swin_L_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.swin_trasformer.SwinTransformer`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.Swin_L_Weights - :members: - """ - weights = Swin_L_Weights.verify(weights) - - return _swin_transformer( - patch_size=4, - embed_dim=192, - depths=[2, 2, 18, 2], - num_heads=[6, 12, 24, 48], - window_size=7, - stochastic_depth_prob=0.2, - weights=weights, - progress=progress, - **kwargs, - ) From ca59aafd9cdeb3b36f325cd99af52b8c3e795564 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Apr 2022 11:19:28 +0100 Subject: [PATCH 91/92] fix typo --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 73900cedba7..455397c8403 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -439,7 +439,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.swin_trasformer.SwinTransformer`` + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code `_ for more details about this class. From e4c96463442824d8f39b798a4b973a82fc9be53b Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Apr 2022 12:01:26 +0100 Subject: [PATCH 92/92] Remove expect for the variants not yet supported --- test/expect/ModelTester.test_swin_b_expect.pkl | Bin 939 -> 0 bytes test/expect/ModelTester.test_swin_l_expect.pkl | Bin 939 -> 0 bytes test/expect/ModelTester.test_swin_s_expect.pkl | Bin 939 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/expect/ModelTester.test_swin_b_expect.pkl delete mode 100644 test/expect/ModelTester.test_swin_l_expect.pkl delete mode 100644 test/expect/ModelTester.test_swin_s_expect.pkl diff --git a/test/expect/ModelTester.test_swin_b_expect.pkl b/test/expect/ModelTester.test_swin_b_expect.pkl deleted file mode 100644 index 006af1155597a8b27483a373f0a9dd9e2a3c83f2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5~oDJ(k|}Y9XqyrEp{7q+U+9c18l=}1?@JyJ#8!U$6>$eyGr|p^&We7wEOMh zeXx07g7=NxOQHhyi%8n>ri)?hn&gW?NF|U~BbV(C(tT z#{TZ0<93W4op#2;OYEYmEcVUYlWahw2#6bR!EEq(@TcxaWS78L_y!Oh8>$lyW>ISA933uN=fdFi1{ zKwCjLz?%_7!P6vi914IWPyl)gMc0k&Cq5LNuYf#c-TDUTdXZhlkD^xs=t7uYXcz=| zv$5$w70EH{!j*#(GYG)w?GP@*B(Ntz-em*j4F*rB0#GIh@MdKLi7^8qNIgU?0A*tE An*aa+ diff --git a/test/expect/ModelTester.test_swin_l_expect.pkl b/test/expect/ModelTester.test_swin_l_expect.pkl deleted file mode 100644 index 17d51e057c24c2189890914378b79dcf51a729e3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5@+&Hz5Uu!CH9*VX78V;Ja<3i?4A3Sedp~x+~ij&+9FKW2^C|J`Qg zz8w#v_h*z$w*N0b#oldWt!?l52X^b#XW8pGx!db3G2MT9daM02xn29zUT?Oy*pzSo z&1!<(lTKcHmE8sVSG>2k|2Ls+-`l`*b_@Tv?CXx%V|RM~UwgA)r+ptb$?kL5p0WSj z%q_Ot9kus$WPY?ellj>GNRZz?yB&w^%ANP`|9*Yd-bz3H{j)=6?X%sNV87{&$^IA1 zllD(k-f!>sqi#Paw0_;bwdVvdq(B&VXz?>x!$YeqwWt^v3vN#4L()0w*Nf~beiXeDKo`RF zLc<`yn~hBesz{Dm7p@$Xm_YzWZ-;OhCV@Q(@-7=FZ!ma56@W59fHx}}NQ@Z>LFyrD E0XzN$mH+?% diff --git a/test/expect/ModelTester.test_swin_s_expect.pkl b/test/expect/ModelTester.test_swin_s_expect.pkl deleted file mode 100644 index 77a3810cc94d334fe7773f606346cdba2bf3fc03..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=VD($bPO~r~L~$!}h;o2;KjD{aQO;{cgKGuY31@6@R@~r7&~1x|5b2SF+;% zSqt~=d+xf_j_oaD8ki<+6D1)G{BtgK2B*j{RD(*HhJb|Ek&cb~57MZIqWO+5h4F zxUV3-ZC}uPak~=nn|7ZBTlZHyRJWh${&(MWrD*$EbDH*tozt^yNVx0uLRJAFul+) z2=HcO(}60IW7dT$2PI|@fYI9_T!u+tPlCM52Fe=@o=^p#Oc3DB$_5f+211Z}h*|&> Cz4!S5