From 6eeecf97374e8c26580a8fecc9ec1131c819f81d Mon Sep 17 00:00:00 2001 From: ain-soph Date: Thu, 7 Jul 2022 10:38:29 -0700 Subject: [PATCH 01/49] init submit --- torchvision/models/swin_transformer.py | 429 ++++++++++++++++++++++++- 1 file changed, 425 insertions(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 2f2cfd44445..97e19e28a34 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional, Callable, List, Any +from typing import Any, Callable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -19,9 +19,15 @@ "Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights", + "Swin_V2_T_Weights", + "Swin_V2_S_Weights", + "Swin_V2_B_Weights", "swin_t", "swin_s", "swin_b", + "swin_v2_t", + "swin_v2_s", + "swin_v2_b", ] @@ -239,6 +245,215 @@ def forward(self, x: Tensor): ) +def shifted_window_attention_v2( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + logit_scale: torch.Tensor, + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, +): + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (List[int]): Window size. + num_heads (int): Number of attention heads. + shift_size (List[int]): Shift size for shifted window attention. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + """ + B, H, W, C = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape + + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + + # partition windows + num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) + x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C + + # multi-head attention + if qkv_bias is not None: # v2 ignores k_bias + qkv_bias= qkv_bias.clone() + length = qkv_bias.numel() // 3 + qkv_bias[length: 2*length].zero_() + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + # v2 cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + attn = attn * logit_scale + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_H, pad_W)) + h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) + w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) + count = 0 + for h in h_slices: + for w in w_slices: + attn_mask[h[0] : h[1], w[0] : w[1]] = count + count += 1 + attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout) + + # reverse windows + x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + + # unpad features + x = x[:, :H, :W, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention_v2") + + +class ShiftedWindowAttentionV2(nn.Module): + """ + See :func:`shifted_window_attention_v2`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + pretrained_window_size: Tuple[int] = (0, 0), + ): + super().__init__() + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") + self.window_size = window_size + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + # V2 + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / 3.0 + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + if qkv_bias: + length = self.qkv.bias.numel() // 3 + self.qkv.bias[length: 2*length].zero_() + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias).unsqueeze(0) + + return shifted_window_attention_v2( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + logit_scale=self.logit_scale, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + ) + class SwinTransformerBlock(nn.Module): """ Swin Transformer Block. @@ -253,6 +468,8 @@ class SwinTransformerBlock(nn.Module): stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention + pretrained_window_size (int): (V2) Window size in pre-training. Default: 0. + version (int): SwinTransformer version. Default: 1. """ def __init__( @@ -266,11 +483,21 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, + attn_layer: Optional[Callable[..., nn.Module]] = None, + pretrained_window_size: int = 0, + version: int = 1, ): super().__init__() _log_api_usage_once(self) + if attn_layer is None: + if self.version==1: # TODO: switch after python 3.10 + attn_layer = ShiftedWindowAttention + elif self.version==2: + attn_layer = ShiftedWindowAttentionV2 + else: + raise NotImplementedError(self.version) + self.version=version self.norm1 = norm_layer(dim) self.attn = attn_layer( dim, @@ -279,6 +506,7 @@ def __init__( num_heads, attention_dropout=attention_dropout, dropout=dropout, + pretrained_window_size=(pretrained_window_size, pretrained_window_size), ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) @@ -291,8 +519,14 @@ def __init__( nn.init.normal_(m.bias, std=1e-6) def forward(self, x: Tensor): - x = x + self.stochastic_depth(self.attn(self.norm1(x))) - x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + if self.version==1: # TODO: switch after python 3.10 + x = x + self.stochastic_depth(self.attn(self.norm1(x))) + x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + elif self.version==2: + x = x + self.stochastic_depth(self.norm1(self.attn(x))) + x = x + self.stochastic_depth(self.norm2(self.mlp(x))) + else: + raise NotImplementedError(self.version) return x @@ -310,6 +544,7 @@ class SwinTransformer(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. + version (int): SwinTransformer version. Default: 1. num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. @@ -326,6 +561,7 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, + version: int = 1, num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, @@ -372,6 +608,7 @@ def __init__( attention_dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, + version=version, ) ) stage_block_id += 1 @@ -409,6 +646,7 @@ def _swin_transformer( num_heads: List[int], window_size: List[int], stochastic_depth_prob: float, + version: int, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, @@ -423,6 +661,7 @@ def _swin_transformer( num_heads=num_heads, window_size=window_size, stochastic_depth_prob=stochastic_depth_prob, + version=version, **kwargs, ) @@ -506,6 +745,75 @@ class Swin_B_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +class Swin_V2_T_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_t-123456.pth", # TODO + transforms=partial( + ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC + ), # TODO + meta={ + **_COMMON_META, + "num_params": 28288354, # TODO + "min_size": (224, 224), # TODO + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 0.100, # TODO + "acc@5": 0.100, # TODO + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_V2_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_s-123456.pth", # TODO + transforms=partial( + ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC + ), # TODO + meta={ + **_COMMON_META, + "num_params": 49606258, # TODO + "min_size": (224, 224), # TODO + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 0.100, # TODO + "acc@5": 0.100, # TODO + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_V2_B_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_b-123456.pth", # TODO + transforms=partial( + ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC + ), # TODO + meta={ + **_COMMON_META, + "num_params": 87768224, # TODO + "min_size": (224, 224), # TODO + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 0.100, # TODO + "acc@5": 0.100, # TODO + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from @@ -536,6 +844,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * num_heads=[3, 6, 12, 24], window_size=[7, 7], stochastic_depth_prob=0.2, + version=1, weights=weights, progress=progress, **kwargs, @@ -572,6 +881,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * num_heads=[3, 6, 12, 24], window_size=[7, 7], stochastic_depth_prob=0.3, + version=1, weights=weights, progress=progress, **kwargs, @@ -608,6 +918,117 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * num_heads=[4, 8, 16, 32], window_size=[7, 7], stochastic_depth_prob=0.5, + version=1, + weights=weights, + progress=progress, + **kwargs, + ) + +def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_tiny architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_T_Weights + :members: + """ + weights = Swin_V2_T_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0.2, + version=2, + weights=weights, + progress=progress, + **kwargs, + ) + + +def swin_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_small architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_S_Weights + :members: + """ + weights = Swin_V2_S_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0.3, + version=2, + weights=weights, + progress=progress, + **kwargs, + ) + + +def swin_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_base architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_B_Weights + :members: + """ + weights = Swin_V2_B_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[7, 7], + stochastic_depth_prob=0.5, + version=2, weights=weights, progress=progress, **kwargs, From 7ad94a77a11662cd0bfb29598a024c6c2329956f Mon Sep 17 00:00:00 2001 From: ain-soph Date: Thu, 7 Jul 2022 10:46:43 -0700 Subject: [PATCH 02/49] fix typo --- torchvision/models/swin_transformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 97e19e28a34..acc67bea54a 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -939,7 +939,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code - `_ + `_ for more details about this class. .. autoclass:: torchvision.models.Swin_V2_T_Weights @@ -961,7 +961,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T ) -def swin_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_v2_small architecture from `Swin Transformer V2: Scaling Up Capacity and Resolution `_. @@ -976,7 +976,7 @@ def swin_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code - `_ + `_ for more details about this class. .. autoclass:: torchvision.models.Swin_V2_S_Weights @@ -998,7 +998,7 @@ def swin_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True ) -def swin_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_v2_base architecture from `Swin Transformer V2: Scaling Up Capacity and Resolution `_. @@ -1013,7 +1013,7 @@ def swin_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code - `_ + `_ for more details about this class. .. autoclass:: torchvision.models.Swin_V2_B_Weights From e28ec45973a85b6908c571af9c7c6518e040da4c Mon Sep 17 00:00:00 2001 From: ain-soph Date: Thu, 7 Jul 2022 11:56:02 -0700 Subject: [PATCH 03/49] support ufmt and mypy --- torchvision/models/swin_transformer.py | 65 ++++++++++++++------------ 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index acc67bea54a..958b68634a7 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -9,7 +9,7 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param @@ -300,16 +300,16 @@ def shifted_window_attention_v2( x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C # multi-head attention - if qkv_bias is not None: # v2 ignores k_bias - qkv_bias= qkv_bias.clone() + if qkv_bias is not None: # v2 ignores k_bias + qkv_bias = qkv_bias.clone() length = qkv_bias.numel() // 3 - qkv_bias[length: 2*length].zero_() + qkv_bias[length : 2 * length].zero_() qkv = F.linear(x, qkv_weight, qkv_bias) qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # v2 cosine attention - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))).exp() attn = attn * logit_scale # add relative position bias attn = attn + relative_position_bias @@ -370,7 +370,7 @@ def __init__( proj_bias: bool = True, attention_dropout: float = 0.0, dropout: float = 0.0, - pretrained_window_size: Tuple[int] = (0, 0), + pretrained_window_size: Tuple[int, int] = (0, 0), ): super().__init__() if len(window_size) != 2 or len(shift_size) != 2: @@ -388,18 +388,22 @@ def __init__( # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + relative_coords_table = ( + torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + .permute(1, 2, 0) + .contiguous() + .unsqueeze(0) + ) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / 3.0 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0 + ) self.register_buffer("relative_coords_table", relative_coords_table) @@ -418,12 +422,12 @@ def __init__( self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) if qkv_bias: length = self.qkv.bias.numel() // 3 - self.qkv.bias[length: 2*length].zero_() + self.qkv.bias[length : 2 * length].zero_() def forward(self, x: Tensor): """ @@ -434,8 +438,9 @@ def forward(self, x: Tensor): """ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( # type: ignore[operator] + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = 16 * torch.sigmoid(relative_position_bias).unsqueeze(0) @@ -454,6 +459,7 @@ def forward(self, x: Tensor): proj_bias=self.proj.bias, ) + class SwinTransformerBlock(nn.Module): """ Swin Transformer Block. @@ -491,13 +497,13 @@ def __init__( _log_api_usage_once(self) if attn_layer is None: - if self.version==1: # TODO: switch after python 3.10 + if version == 1: # TODO: switch after python 3.10 attn_layer = ShiftedWindowAttention - elif self.version==2: + elif version == 2: attn_layer = ShiftedWindowAttentionV2 else: - raise NotImplementedError(self.version) - self.version=version + raise NotImplementedError(version) + self.version = version self.norm1 = norm_layer(dim) self.attn = attn_layer( dim, @@ -519,10 +525,10 @@ def __init__( nn.init.normal_(m.bias, std=1e-6) def forward(self, x: Tensor): - if self.version==1: # TODO: switch after python 3.10 + if self.version == 1: # TODO: switch after python 3.10 x = x + self.stochastic_depth(self.attn(self.norm1(x))) x = x + self.stochastic_depth(self.mlp(self.norm2(x))) - elif self.version==2: + elif self.version == 2: x = x + self.stochastic_depth(self.norm1(self.attn(x))) x = x + self.stochastic_depth(self.norm2(self.mlp(x))) else: @@ -593,7 +599,7 @@ def __init__( # build SwinTransformer blocks for i_stage in range(len(depths)): stage: List[nn.Module] = [] - dim = embed_dim * 2 ** i_stage + dim = embed_dim * 2**i_stage for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) @@ -924,6 +930,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * **kwargs, ) + def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_v2_tiny architecture from From ff448320f9227de3d5ed5547a135794f0eb89924 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Thu, 7 Jul 2022 12:13:31 -0700 Subject: [PATCH 04/49] fix 2 unittest errors --- torchvision/models/swin_transformer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 958b68634a7..a7e81a28141 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -427,7 +427,7 @@ def __init__( ) if qkv_bias: length = self.qkv.bias.numel() // 3 - self.qkv.bias[length : 2 * length].zero_() + self.qkv.bias[length : 2 * length].data.zero_() def forward(self, x: Tensor): """ @@ -474,7 +474,6 @@ class SwinTransformerBlock(nn.Module): stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention - pretrained_window_size (int): (V2) Window size in pre-training. Default: 0. version (int): SwinTransformer version. Default: 1. """ @@ -490,7 +489,6 @@ def __init__( stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Optional[Callable[..., nn.Module]] = None, - pretrained_window_size: int = 0, version: int = 1, ): super().__init__() @@ -512,7 +510,6 @@ def __init__( num_heads, attention_dropout=attention_dropout, dropout=dropout, - pretrained_window_size=(pretrained_window_size, pretrained_window_size), ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) From 7d84b318473dafb153b17261882c3ec62443c56e Mon Sep 17 00:00:00 2001 From: ain-soph Date: Thu, 7 Jul 2022 20:26:49 -0700 Subject: [PATCH 05/49] fix ufmt issue --- torchvision/models/swin_transformer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index a7e81a28141..0858172e15a 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -388,12 +388,8 @@ def __init__( # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = ( - torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) - .permute(1, 2, 0) - .contiguous() - .unsqueeze(0) - ) # 1, 2*Wh-1, 2*Ww-1, 2 + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 @@ -596,7 +592,7 @@ def __init__( # build SwinTransformer blocks for i_stage in range(len(depths)): stage: List[nn.Module] = [] - dim = embed_dim * 2**i_stage + dim = embed_dim * 2 ** i_stage for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) From 4a21e9823a43b3764d6b4844922a22c32337ebd4 Mon Sep 17 00:00:00 2001 From: Local State Date: Fri, 8 Jul 2022 09:33:35 -0700 Subject: [PATCH 06/49] Apply suggestions from code review Co-authored-by: Vasilis Vryniotis --- torchvision/models/swin_transformer.py | 65 +------------------------- 1 file changed, 1 insertion(+), 64 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index a7e81a28141..f075fad72d1 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -749,72 +749,9 @@ class Swin_B_Weights(WeightsEnum): class Swin_V2_T_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_v2_t-123456.pth", # TODO - transforms=partial( - ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC - ), # TODO - meta={ - **_COMMON_META, - "num_params": 28288354, # TODO - "min_size": (224, 224), # TODO - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", - "_metrics": { - "ImageNet-1K": { - "acc@1": 0.100, # TODO - "acc@5": 0.100, # TODO - } - }, - "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class Swin_V2_S_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_v2_s-123456.pth", # TODO - transforms=partial( - ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC - ), # TODO - meta={ - **_COMMON_META, - "num_params": 49606258, # TODO - "min_size": (224, 224), # TODO - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", - "_metrics": { - "ImageNet-1K": { - "acc@1": 0.100, # TODO - "acc@5": 0.100, # TODO - } - }, - "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", - }, - ) - DEFAULT = IMAGENET1K_V1 + pass -class Swin_V2_B_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_v2_b-123456.pth", # TODO - transforms=partial( - ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC - ), # TODO - meta={ - **_COMMON_META, - "num_params": 87768224, # TODO - "min_size": (224, 224), # TODO - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", - "_metrics": { - "ImageNet-1K": { - "acc@1": 0.100, # TODO - "acc@5": 0.100, # TODO - } - }, - "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", - }, - ) - DEFAULT = IMAGENET1K_V1 def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: From 8e0f8f6d73c1fdd92424067d5fb42a1a558d84e9 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 13:06:06 -0700 Subject: [PATCH 07/49] unify codes --- torchvision/models/swin_transformer.py | 387 ++++++++----------------- 1 file changed, 122 insertions(+), 265 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index aea86b23808..e5012d4993c 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional import torch import torch.nn.functional as F @@ -20,14 +20,14 @@ "Swin_S_Weights", "Swin_B_Weights", "Swin_V2_T_Weights", - "Swin_V2_S_Weights", - "Swin_V2_B_Weights", + # "Swin_V2_S_Weights", + # "Swin_V2_B_Weights", "swin_t", "swin_s", "swin_b", "swin_v2_t", - "swin_v2_s", - "swin_v2_b", + # "swin_v2_s", + # "swin_v2_b", ] @@ -78,6 +78,8 @@ def shifted_window_attention( dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, + v2: bool = False, + logit_scale: torch.Tensor = None, ): """ Window based multi-head self attention (W-MSA) module with relative position bias. @@ -120,11 +122,21 @@ def shifted_window_attention( x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C # multi-head attention + if v2 and qkv_bias is not None: + qkv_bias = qkv_bias.clone() + length = qkv_bias.numel() // 3 + qkv_bias[length : 2 * length].zero_() qkv = F.linear(x, qkv_weight, qkv_bias) qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - q = q * (C // num_heads) ** -0.5 - attn = q.matmul(k.transpose(-2, -1)) + if v2: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))).exp() + attn = attn * logit_scale + else: + q = q * (C // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) # add relative position bias attn = attn + relative_position_bias @@ -197,9 +209,12 @@ def __init__( self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.define_relative_coords() + + def define_relative_coords(self): # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window @@ -212,11 +227,19 @@ def __init__( relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + def get_relative_position_bias(self, relative_position_bias_table: Optional[torch.Tensor]) -> torch.Tensor: + relative_position_bias_table = relative_position_bias_table or self.relative_position_bias_table + N = self.window_size[0] * self.window_size[1] + relative_position_bias = relative_position_bias_table[self.relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + return relative_position_bias + def forward(self, x: Tensor): """ Args: @@ -224,12 +247,7 @@ def forward(self, x: Tensor): Returns: Tensor with same layout as input, i.e. [B, H, W, C] """ - - N = self.window_size[0] * self.window_size[1] - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] - relative_position_bias = relative_position_bias.view(N, N, -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) - + relative_position_bias = self.get_relative_position_bias() return shifted_window_attention( x, self.qkv.weight, @@ -245,117 +263,7 @@ def forward(self, x: Tensor): ) -def shifted_window_attention_v2( - input: Tensor, - qkv_weight: Tensor, - proj_weight: Tensor, - relative_position_bias: Tensor, - window_size: List[int], - num_heads: int, - shift_size: List[int], - logit_scale: torch.Tensor, - attention_dropout: float = 0.0, - dropout: float = 0.0, - qkv_bias: Optional[Tensor] = None, - proj_bias: Optional[Tensor] = None, -): - """ - Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. - qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. - proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. - relative_position_bias (Tensor): The learned relative position bias added to attention. - window_size (List[int]): Window size. - num_heads (int): Number of attention heads. - shift_size (List[int]): Shift size for shifted window attention. - attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. - dropout (float): Dropout ratio of output. Default: 0.0. - qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. - proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. - Returns: - Tensor[N, H, W, C]: The output tensor after shifted window attention. - """ - B, H, W, C = input.shape - # pad feature maps to multiples of window size - pad_r = (window_size[1] - W % window_size[1]) % window_size[1] - pad_b = (window_size[0] - H % window_size[0]) % window_size[0] - x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) - _, pad_H, pad_W, _ = x.shape - - # If window size is larger than feature size, there is no need to shift window - if window_size[0] >= pad_H: - shift_size[0] = 0 - if window_size[1] >= pad_W: - shift_size[1] = 0 - - # cyclic shift - if sum(shift_size) > 0: - x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) - - # partition windows - num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) - x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C - - # multi-head attention - if qkv_bias is not None: # v2 ignores k_bias - qkv_bias = qkv_bias.clone() - length = qkv_bias.numel() // 3 - qkv_bias[length : 2 * length].zero_() - qkv = F.linear(x, qkv_weight, qkv_bias) - qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - # v2 cosine attention - attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) - logit_scale = torch.clamp(logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))).exp() - attn = attn * logit_scale - # add relative position bias - attn = attn + relative_position_bias - - if sum(shift_size) > 0: - # generate attention mask - attn_mask = x.new_zeros((pad_H, pad_W)) - h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) - w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) - count = 0 - for h in h_slices: - for w in w_slices: - attn_mask[h[0] : h[1], w[0] : w[1]] = count - count += 1 - attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) - attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) - attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) - attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, num_heads, x.size(1), x.size(1)) - - attn = F.softmax(attn, dim=-1) - attn = F.dropout(attn, p=attention_dropout) - - x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) - x = F.linear(x, proj_weight, proj_bias) - x = F.dropout(x, p=dropout) - - # reverse windows - x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) - - # reverse cyclic shift - if sum(shift_size) > 0: - x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) - - # unpad features - x = x[:, :H, :W, :].contiguous() - return x - - -torch.fx.wrap("shifted_window_attention_v2") - - -class ShiftedWindowAttentionV2(nn.Module): +class ShiftedWindowAttentionV2(ShiftedWindowAttention): """ See :func:`shifted_window_attention_v2`. """ @@ -370,29 +278,38 @@ def __init__( proj_bias: bool = True, attention_dropout: float = 0.0, dropout: float = 0.0, - pretrained_window_size: Tuple[int, int] = (0, 0), + pretrained_window_size: List[int] = [0, 0], ): - super().__init__() - if len(window_size) != 2 or len(shift_size) != 2: - raise ValueError("window_size and shift_size must be of length 2") - self.window_size = window_size - self.shift_size = shift_size - self.num_heads = num_heads - self.attention_dropout = attention_dropout - self.dropout = dropout + self.pretrained_window_size = pretrained_window_size # TODO: unsafe, need copy? + super().__init__( + dim, + window_size, + shift_size, + num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attention_dropout=attention_dropout, + dropout=dropout, + ) - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + if qkv_bias: + length = self.qkv.bias.numel() // 3 + self.qkv.bias[length : 2 * length].data.zero_() - # V2 + def define_relative_coords(self): # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 - relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 + if self.pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.pretrained_window_size[1] - 1 else: relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 @@ -413,17 +330,14 @@ def __init__( relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) - # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential( - nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) - ) - if qkv_bias: - length = self.qkv.bias.numel() // 3 - self.qkv.bias[length : 2 * length].data.zero_() + def get_relative_position_bias(self) -> torch.Tensor: + relative_position_bias_table: torch.Tensor = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = super().get_relative_position_bias(relative_position_bias_table) + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + return relative_position_bias def forward(self, x: Tensor): """ @@ -432,15 +346,8 @@ def forward(self, x: Tensor): Returns: Tensor with same layout as input, i.e. [B, H, W, C] """ - - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( # type: ignore[operator] - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 - ) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - relative_position_bias = 16 * torch.sigmoid(relative_position_bias).unsqueeze(0) - - return shifted_window_attention_v2( + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( x, self.qkv.weight, self.proj.weight, @@ -448,11 +355,12 @@ def forward(self, x: Tensor): self.window_size, self.num_heads, shift_size=self.shift_size, - logit_scale=self.logit_scale, attention_dropout=self.attention_dropout, dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, + v2=True, + logit_scale=self.logit_scale, ) @@ -470,7 +378,6 @@ class SwinTransformerBlock(nn.Module): stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention - version (int): SwinTransformer version. Default: 1. """ def __init__( @@ -484,20 +391,11 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_layer: Optional[Callable[..., nn.Module]] = None, - version: int = 1, + attn_layer: Optional[Callable[..., nn.Module]] = ShiftedWindowAttention, ): super().__init__() _log_api_usage_once(self) - if attn_layer is None: - if version == 1: # TODO: switch after python 3.10 - attn_layer = ShiftedWindowAttention - elif version == 2: - attn_layer = ShiftedWindowAttentionV2 - else: - raise NotImplementedError(version) - self.version = version self.norm1 = norm_layer(dim) self.attn = attn_layer( dim, @@ -518,14 +416,56 @@ def __init__( nn.init.normal_(m.bias, std=1e-6) def forward(self, x: Tensor): - if self.version == 1: # TODO: switch after python 3.10 - x = x + self.stochastic_depth(self.attn(self.norm1(x))) - x = x + self.stochastic_depth(self.mlp(self.norm2(x))) - elif self.version == 2: - x = x + self.stochastic_depth(self.norm1(self.attn(x))) - x = x + self.stochastic_depth(self.norm2(self.mlp(x))) - else: - raise NotImplementedError(self.version) + x = x + self.stochastic_depth(self.attn(self.norm1(x))) + x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + return x + + +class SwinTransformerBlockV2(SwinTransformerBlock): + """ + Swin Transformer V2 Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2 + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Optional[Callable[..., nn.Module]] = ShiftedWindowAttentionV2, + ): + super().__init__( + dim, + num_heads, + window_size, + shift_size, + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=stochastic_depth_prob, + norm_layer=norm_layer, + attn_layer=attn_layer, + ) + + def forward(self, x: Tensor): + x = x + self.stochastic_depth(self.norm1(self.attn(x))) + x = x + self.stochastic_depth(self.norm2(self.mlp(x))) return x @@ -543,7 +483,6 @@ class SwinTransformer(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. - version (int): SwinTransformer version. Default: 1. num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. @@ -560,17 +499,17 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, - version: int = 1, num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, + v2: bool = False, ): super().__init__() _log_api_usage_once(self) self.num_classes = num_classes if block is None: - block = SwinTransformerBlock + block = SwinTransformerBlockV2 if v2 else SwinTransformerBlock if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) @@ -607,7 +546,6 @@ def __init__( attention_dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, - version=version, ) ) stage_block_id += 1 @@ -645,7 +583,6 @@ def _swin_transformer( num_heads: List[int], window_size: List[int], stochastic_depth_prob: float, - version: int, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, @@ -660,7 +597,6 @@ def _swin_transformer( num_heads=num_heads, window_size=window_size, stochastic_depth_prob=stochastic_depth_prob, - version=version, **kwargs, ) @@ -748,8 +684,6 @@ class Swin_V2_T_Weights(WeightsEnum): pass - - def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from @@ -780,7 +714,6 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * num_heads=[3, 6, 12, 24], window_size=[7, 7], stochastic_depth_prob=0.2, - version=1, weights=weights, progress=progress, **kwargs, @@ -817,7 +750,6 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * num_heads=[3, 6, 12, 24], window_size=[7, 7], stochastic_depth_prob=0.3, - version=1, weights=weights, progress=progress, **kwargs, @@ -854,7 +786,6 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * num_heads=[4, 8, 16, 32], window_size=[7, 7], stochastic_depth_prob=0.5, - version=1, weights=weights, progress=progress, **kwargs, @@ -891,82 +822,8 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T num_heads=[3, 6, 12, 24], window_size=[7, 7], stochastic_depth_prob=0.2, - version=2, - weights=weights, - progress=progress, - **kwargs, - ) - - -def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: - """ - Constructs a swin_v2_small architecture from - `Swin Transformer V2: Scaling Up Capacity and Resolution `_. - - Args: - weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.Swin_V2_S_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.Swin_V2_S_Weights - :members: - """ - weights = Swin_V2_S_Weights.verify(weights) - - return _swin_transformer( - patch_size=[4, 4], - embed_dim=96, - depths=[2, 2, 18, 2], - num_heads=[3, 6, 12, 24], - window_size=[7, 7], - stochastic_depth_prob=0.3, - version=2, - weights=weights, - progress=progress, - **kwargs, - ) - - -def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: - """ - Constructs a swin_v2_base architecture from - `Swin Transformer V2: Scaling Up Capacity and Resolution `_. - - Args: - weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.Swin_V2_B_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.Swin_V2_B_Weights - :members: - """ - weights = Swin_V2_B_Weights.verify(weights) - - return _swin_transformer( - patch_size=[4, 4], - embed_dim=128, - depths=[2, 2, 18, 2], - num_heads=[4, 8, 16, 32], - window_size=[7, 7], - stochastic_depth_prob=0.5, - version=2, weights=weights, progress=progress, + v2=True, **kwargs, ) From 284ca506b3468f0e192ffe8cf167fcdec988bb5d Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 13:13:38 -0700 Subject: [PATCH 08/49] fix meshgrid indexing --- torchvision/models/swin_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e5012d4993c..44c46a3b3a3 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -305,7 +305,7 @@ def define_relative_coords(self): # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if self.pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1 @@ -323,7 +323,7 @@ def define_relative_coords(self): # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 From e801222f4af320a38d6b3c78c97b9367c589b484 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 13:15:37 -0700 Subject: [PATCH 09/49] fix a bug --- torchvision/models/swin_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 44c46a3b3a3..87e260c206b 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -233,7 +233,8 @@ def define_relative_coords(self): nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def get_relative_position_bias(self, relative_position_bias_table: Optional[torch.Tensor]) -> torch.Tensor: - relative_position_bias_table = relative_position_bias_table or self.relative_position_bias_table + if relative_position_bias_table is None: + relative_position_bias_table = self.relative_position_bias_table N = self.window_size[0] * self.window_size[1] relative_position_bias = relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view(N, N, -1) From eb0641489300558ddea3624d38e96348e3b84c29 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 13:45:55 -0700 Subject: [PATCH 10/49] fix type check --- torchvision/models/swin_transformer.py | 33 +++++++++++++------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 298a04b8b68..26330633666 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -86,8 +86,7 @@ def shifted_window_attention( dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, - v2: bool = False, - logit_scale: torch.Tensor = None, + v2_logit_scale: Optional[torch.Tensor] = None, ): """ Window based multi-head self attention (W-MSA) module with relative position bias. @@ -130,18 +129,18 @@ def shifted_window_attention( x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C # multi-head attention - if v2 and qkv_bias is not None: + if v2_logit_scale is not None and qkv_bias is not None: qkv_bias = qkv_bias.clone() length = qkv_bias.numel() // 3 qkv_bias[length : 2 * length].zero_() qkv = F.linear(x, qkv_weight, qkv_bias) qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - if v2: + if v2_logit_scale is not None: # cosine attention attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) - logit_scale = torch.clamp(logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))).exp() - attn = attn * logit_scale + v2_logit_scale = torch.clamp(v2_logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))).exp() + attn = attn * v2_logit_scale else: q = q * (C // num_heads) ** -0.5 attn = q.matmul(k.transpose(-2, -1)) @@ -240,7 +239,7 @@ def define_relative_coords(self): nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) - def get_relative_position_bias(self, relative_position_bias_table: Optional[torch.Tensor]) -> torch.Tensor: + def get_relative_position_bias(self, relative_position_bias_table: Optional[torch.Tensor] = None) -> torch.Tensor: if relative_position_bias_table is None: relative_position_bias_table = self.relative_position_bias_table N = self.window_size[0] * self.window_size[1] @@ -301,7 +300,7 @@ def __init__( dropout=dropout, ) - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + self.v2_logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential( nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) @@ -342,8 +341,9 @@ def define_relative_coords(self): relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) - def get_relative_position_bias(self) -> torch.Tensor: - relative_position_bias_table: torch.Tensor = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + def get_relative_position_bias(self, relative_position_bias_table: Optional[torch.Tensor] = None) -> torch.Tensor: + if relative_position_bias_table is None: + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) relative_position_bias = super().get_relative_position_bias(relative_position_bias_table) relative_position_bias = 16 * torch.sigmoid(relative_position_bias) return relative_position_bias @@ -368,8 +368,7 @@ def forward(self, x: Tensor): dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, - v2=True, - logit_scale=self.logit_scale, + v2_logit_scale=self.v2_logit_scale, ) @@ -400,7 +399,7 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_layer: Optional[Callable[..., nn.Module]] = ShiftedWindowAttention, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, ): super().__init__() _log_api_usage_once(self) @@ -457,7 +456,7 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_layer: Optional[Callable[..., nn.Module]] = ShiftedWindowAttentionV2, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, ): super().__init__( dim, @@ -511,14 +510,14 @@ def __init__( num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, - v2: bool = False, + use_v2: bool = False, ): super().__init__() _log_api_usage_once(self) self.num_classes = num_classes if block is None: - block = SwinTransformerBlockV2 if v2 else SwinTransformerBlock + block = SwinTransformerBlockV2 if use_v2 else SwinTransformerBlock if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) @@ -833,6 +832,6 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T stochastic_depth_prob=0.2, weights=weights, progress=progress, - v2=True, + use_v2=True, **kwargs, ) From 5deccd5e317c615a881a097d0d32a357690ecf47 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 16:24:46 -0700 Subject: [PATCH 11/49] add type_annotation --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 26330633666..08bf750c4eb 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -31,7 +31,7 @@ ] -def _patch_merging_pad(x): +def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor: H, W, _ = x.shape[-3:] x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) return x From 75bcbc7b834acbce4bd6060219086426fef170a9 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 18:27:53 -0700 Subject: [PATCH 12/49] add slow model --- test/test_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_models.py b/test/test_models.py index 866fafae5f6..5668c157769 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -332,6 +332,9 @@ def _check_input_backprop(model, inputs): "swin_t", "swin_s", "swin_b", + "swin_v2_t", + # "swin_v2_s", + # "swin_v2_b", ] for m in slow_models: _model_params[m] = {"input_shape": (1, 3, 64, 64)} From 084833eb4f6645c7758fc69047e290aef9163d2a Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 18:28:00 -0700 Subject: [PATCH 13/49] fix device issue --- torchvision/models/swin_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 08bf750c4eb..1baa3c38d8f 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F from torch import nn, Tensor +import math from ..ops.misc import MLP, Permute from ..ops.stochastic_depth import StochasticDepth @@ -139,7 +140,7 @@ def shifted_window_attention( if v2_logit_scale is not None: # cosine attention attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) - v2_logit_scale = torch.clamp(v2_logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))).exp() + v2_logit_scale = torch.clamp(v2_logit_scale, max=math.log(100.0)).exp() attn = attn * v2_logit_scale else: q = q * (C // num_heads) ** -0.5 From a0498a9bb820e94ad73385fb1c3d9907efa25242 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 18:35:35 -0700 Subject: [PATCH 14/49] fix ufmt issue --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 1baa3c38d8f..3575913441e 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,10 +1,10 @@ +import math from functools import partial from typing import Any, Callable, List, Optional import torch import torch.nn.functional as F from torch import nn, Tensor -import math from ..ops.misc import MLP, Permute from ..ops.stochastic_depth import StochasticDepth From c9b77c81fe3a21cadcc85dc92778e2e821cd6855 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 20:18:55 -0700 Subject: [PATCH 15/49] add expect pickle file --- .../expect/ModelTester.test_swin_v2_t_expect.pkl | Bin 0 -> 1081 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_swin_v2_t_expect.pkl diff --git a/test/expect/ModelTester.test_swin_v2_t_expect.pkl b/test/expect/ModelTester.test_swin_v2_t_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e35e680eb0296c49ace65195ad3223615d84d5e4 GIT binary patch literal 1081 zcmWIWW@cev;NW1u09p(d48Hj(sW~C3#U-gldL=+AzPLOyFTTtuz9c@iq98T7L_a05 zBvG#*JIBq*gdvIy(7=>jl3$dZp%-6Tl9^M?6>g z9q7JdpvuJb)Ivsx7B07({KS%Ah#*%XQv@SWjX@!E1T&CME+{Qz@iri!vXHrTlyEwE=leP&Y~anRV$_SnMy!nLBkQA^+LOP;i3-`1(s_VVi) z?dPzl+Rr(8#8ze#@4f(!Cw3Bw)^?X?PPTJ7lxhbGt)8F{4JUvh1;V&POOrv7h|nra zEh+}akDHS$N^oJMFk>!|%@^mThcW?e1>pd1Mi2#0)5vkC1(HAk#HVm{bCG?kfnuf> zFd-6WuD$`fsmMOjL@_l2=qtjeLc=Y Date: Fri, 8 Jul 2022 20:34:01 -0700 Subject: [PATCH 16/49] fix jit script issue --- torchvision/models/swin_transformer.py | 35 +++++++++++++++++--------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 3575913441e..5b2a8fa0c95 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -41,6 +41,19 @@ def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor: torch.fx.wrap("_patch_merging_pad") +def _get_relative_position_bias( + relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int] +) -> torch.Tensor: + N = window_size[0] * window_size[1] + relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + return relative_position_bias + + +torch.fx.wrap("_get_relative_position_bias") + + class PatchMerging(nn.Module): """Patch Merging Layer. Args: @@ -240,14 +253,10 @@ def define_relative_coords(self): nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) - def get_relative_position_bias(self, relative_position_bias_table: Optional[torch.Tensor] = None) -> torch.Tensor: - if relative_position_bias_table is None: - relative_position_bias_table = self.relative_position_bias_table - N = self.window_size[0] * self.window_size[1] - relative_position_bias = relative_position_bias_table[self.relative_position_index] # type: ignore[index] - relative_position_bias = relative_position_bias.view(N, N, -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) - return relative_position_bias + def get_relative_position_bias(self) -> torch.Tensor: + return _get_relative_position_bias( + self.relative_position_bias_table, self.relative_position_index, self.window_size + ) def forward(self, x: Tensor): """ @@ -342,10 +351,12 @@ def define_relative_coords(self): relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) - def get_relative_position_bias(self, relative_position_bias_table: Optional[torch.Tensor] = None) -> torch.Tensor: - if relative_position_bias_table is None: - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = super().get_relative_position_bias(relative_position_bias_table) + def get_relative_position_bias(self) -> torch.Tensor: + relative_position_bias = _get_relative_position_bias( + self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads), + self.relative_position_index, + self.window_size, + ) relative_position_bias = 16 * torch.sigmoid(relative_position_bias) return relative_position_bias From 3eb0de8da041bfd5011080e5d69f41687029a07d Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 21:36:30 -0700 Subject: [PATCH 17/49] fix type check --- torchvision/models/swin_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 5b2a8fa0c95..bd123bb02ec 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -255,7 +255,7 @@ def define_relative_coords(self): def get_relative_position_bias(self) -> torch.Tensor: return _get_relative_position_bias( - self.relative_position_bias_table, self.relative_position_index, self.window_size + self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] ) def forward(self, x: Tensor): @@ -354,7 +354,7 @@ def define_relative_coords(self): def get_relative_position_bias(self) -> torch.Tensor: relative_position_bias = _get_relative_position_bias( self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads), - self.relative_position_index, + self.relative_position_index, # type: ignore[arg-type] self.window_size, ) relative_position_bias = 16 * torch.sigmoid(relative_position_bias) From d7a4ca275f82e9e8783ae1f22955697c29199fc4 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 21:40:33 -0700 Subject: [PATCH 18/49] keep consistent argument order --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index bd123bb02ec..e7e8876924c 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -520,8 +520,8 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, num_classes: int = 1000, - norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, use_v2: bool = False, ): super().__init__() From 005bb13b1240d04e1f67cd18aea975ccfef57b56 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 21:53:03 -0700 Subject: [PATCH 19/49] add support for pretrained_window_size --- torchvision/models/swin_transformer.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e7e8876924c..514ae06f97c 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch import torch.nn.functional as F @@ -412,6 +412,7 @@ def __init__( stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, + **kwargs, ): super().__init__() _log_api_usage_once(self) @@ -424,6 +425,7 @@ def __init__( num_heads, attention_dropout=attention_dropout, dropout=dropout, + **kwargs, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) @@ -454,7 +456,8 @@ class SwinTransformerBlockV2(SwinTransformerBlock): attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2 + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. + pretrained_window_size (int): Local window size in pre-training. Default: 0. """ def __init__( @@ -469,6 +472,7 @@ def __init__( stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, + pretrained_window_size: int = 0, ): super().__init__( dim, @@ -481,6 +485,7 @@ def __init__( stochastic_depth_prob=stochastic_depth_prob, norm_layer=norm_layer, attn_layer=attn_layer, + pretrained_window_size=[pretrained_window_size, pretrained_window_size], ) def forward(self, x: Tensor): @@ -506,6 +511,7 @@ class SwinTransformer(nn.Module): num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. + v2_pretrained_window_sizes (List[int]): Pretrained window sizes of each layer. Default: [0, 0, 0, 0]. """ def __init__( @@ -523,6 +529,7 @@ def __init__( block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, use_v2: bool = False, + v2_pretrained_window_sizes: List[int] = [0, 0, 0, 0], ): super().__init__() _log_api_usage_once(self) @@ -552,6 +559,9 @@ def __init__( for i_stage in range(len(depths)): stage: List[nn.Module] = [] dim = embed_dim * 2 ** i_stage + kwargs: Dict[str, Any] = {} + if use_v2: + kwargs["pretrained_window_size"] = v2_pretrained_window_sizes[i_stage] for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) @@ -566,6 +576,7 @@ def __init__( attention_dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, + **kwargs, ) ) stage_block_id += 1 From 69bad17ac8d5ec7b5230c2330e4fe6a086549e6d Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 22:27:34 -0700 Subject: [PATCH 20/49] avoid code duplication --- torchvision/models/swin_transformer.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 514ae06f97c..1ccdda4212d 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -237,7 +237,11 @@ def define_relative_coords(self): self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads) ) # 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + self.define_relative_position_index() + def define_relative_position_index(self): # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) @@ -251,8 +255,6 @@ def define_relative_coords(self): relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) - nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) - def get_relative_position_bias(self) -> torch.Tensor: return _get_relative_position_bias( self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] @@ -335,21 +337,9 @@ def define_relative_coords(self): relative_coords_table = ( torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0 ) - self.register_buffer("relative_coords_table", relative_coords_table) - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) + self.define_relative_position_index() def get_relative_position_bias(self) -> torch.Tensor: relative_position_bias = _get_relative_position_bias( From 671714566e86697af47773481875ac28ce1bd7c1 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Fri, 8 Jul 2022 22:43:07 -0700 Subject: [PATCH 21/49] a better code reuse --- torchvision/models/swin_transformer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 1ccdda4212d..565d0f1061a 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -230,17 +230,16 @@ def __init__( self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) - self.define_relative_coords() + self.define_relative_position_bias_table() + self.define_relative_position_index() - def define_relative_coords(self): + def define_relative_position_bias_table(self): # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads) ) # 2*Wh-1 * 2*Ww-1, nH nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) - self.define_relative_position_index() - def define_relative_position_index(self): # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) @@ -321,7 +320,7 @@ def __init__( length = self.qkv.bias.numel() // 3 self.qkv.bias[length : 2 * length].data.zero_() - def define_relative_coords(self): + def define_relative_position_bias_table(self): # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) @@ -339,8 +338,6 @@ def define_relative_coords(self): ) self.register_buffer("relative_coords_table", relative_coords_table) - self.define_relative_position_index() - def get_relative_position_bias(self) -> torch.Tensor: relative_position_bias = _get_relative_position_bias( self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads), From 0dc1b223906f7713d6ef0076ae7ea039c04a6f4b Mon Sep 17 00:00:00 2001 From: ain-soph Date: Sat, 9 Jul 2022 10:49:28 -0700 Subject: [PATCH 22/49] update window_size argument --- torchvision/models/swin_transformer.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 565d0f1061a..1f8a637c9cb 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -702,12 +702,13 @@ class Swin_V2_T_Weights(WeightsEnum): pass -def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. Args: + window_size (List[int]): Window size. Default: [7, 7]. weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_T_Weights` below for @@ -730,7 +731,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=[7, 7], + window_size=window_size, stochastic_depth_prob=0.2, weights=weights, progress=progress, @@ -738,12 +739,13 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * ) -def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any) -> SwinTransformer: """ Constructs a swin_small architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. Args: + window_size (List[int]): Window size. Default: [7, 7]. weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_S_Weights` below for @@ -766,7 +768,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], - window_size=[7, 7], + window_size=window_size, stochastic_depth_prob=0.3, weights=weights, progress=progress, @@ -774,12 +776,13 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * ) -def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any) -> SwinTransformer: """ Constructs a swin_base architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. Args: + window_size (List[int]): Window size. Default: [7, 7]. weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_B_Weights` below for @@ -802,7 +805,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], - window_size=[7, 7], + window_size=window_size, stochastic_depth_prob=0.5, weights=weights, progress=progress, @@ -810,12 +813,13 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * ) -def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: +def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, window_size: List[int] = [8, 8], **kwargs: Any) -> SwinTransformer: """ Constructs a swin_v2_tiny architecture from `Swin Transformer V2: Scaling Up Capacity and Resolution `_. Args: + window_size (List[int]): Window size. Default: [8, 8]. weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_V2_T_Weights` below for @@ -838,7 +842,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=[7, 7], + window_size=window_size, stochastic_depth_prob=0.2, weights=weights, progress=progress, From 8b858594aafa52efd56764dcf49237bcb0eb595b Mon Sep 17 00:00:00 2001 From: ain-soph Date: Sat, 9 Jul 2022 10:49:45 -0700 Subject: [PATCH 23/49] make permute and flatten operations modular --- torchvision/models/swin_transformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 1f8a637c9cb..62b06f101e6 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -575,7 +575,9 @@ def __init__( num_features = embed_dim * 2 ** (len(depths) - 1) self.norm = norm_layer(num_features) + self.permute = Permute([0, 3, 1, 2]) self.avgpool = nn.AdaptiveAvgPool2d(1) + self.flatten = nn.Flatten() self.head = nn.Linear(num_features, num_classes) for m in self.modules(): @@ -587,9 +589,9 @@ def __init__( def forward(self, x): x = self.features(x) x = self.norm(x) - x = x.permute(0, 3, 1, 2) + x = self.permute(x) x = self.avgpool(x) - x = torch.flatten(x, 1) + x = self.flatten(x) x = self.head(x) return x From b9bba94e545de2e1ca0d25620c9a88524b5de7a3 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Sat, 9 Jul 2022 20:56:06 -0700 Subject: [PATCH 24/49] add PatchMergingV2 --- torchvision/models/swin_transformer.py | 61 +++++++++++++++++++++----- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 62b06f101e6..ee1bf37e759 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -35,6 +35,11 @@ def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor: H, W, _ = x.shape[-3:] x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C return x @@ -76,15 +81,35 @@ def forward(self, x: Tensor): Tensor with layout of [..., H/2, W/2, 2*C] """ x = _patch_merging_pad(x) + x = self.norm(x) + x = self.reduction(x) # ... H/2 W/2 2*C + return x - x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C - x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C - x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C - x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C - x = self.norm(x) +class PatchMergingV2(nn.Module): + """Patch Merging Layer for Swin Transformer V2. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) # difference + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) x = self.reduction(x) # ... H/2 W/2 2*C + x = self.norm(x) return x @@ -528,6 +553,8 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) + downsample_layer = PatchMergingV2 if use_v2 else PatchMerging + layers: List[nn.Module] = [] # split image into non-overlapping patches layers.append( @@ -570,7 +597,7 @@ def __init__( layers.append(nn.Sequential(*stage)) # add patch merging layer if i_stage < (len(depths) - 1): - layers.append(PatchMerging(dim, norm_layer)) + layers.append(downsample_layer(dim, norm_layer)) self.features = nn.Sequential(*layers) num_features = embed_dim * 2 ** (len(depths) - 1) @@ -704,7 +731,9 @@ class Swin_V2_T_Weights(WeightsEnum): pass -def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any) -> SwinTransformer: +def swin_t( + *, weights: Optional[Swin_T_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any +) -> SwinTransformer: """ Constructs a swin_tiny architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. @@ -741,7 +770,9 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, w ) -def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any) -> SwinTransformer: +def swin_s( + *, weights: Optional[Swin_S_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any +) -> SwinTransformer: """ Constructs a swin_small architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. @@ -778,7 +809,9 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, w ) -def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any) -> SwinTransformer: +def swin_b( + *, weights: Optional[Swin_B_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any +) -> SwinTransformer: """ Constructs a swin_base architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. @@ -815,7 +848,13 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, w ) -def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, window_size: List[int] = [8, 8], **kwargs: Any) -> SwinTransformer: +def swin_v2_t( + *, + weights: Optional[Swin_V2_T_Weights] = None, + progress: bool = True, + window_size: List[int] = [8, 8], + **kwargs: Any, +) -> SwinTransformer: """ Constructs a swin_v2_tiny architecture from `Swin Transformer V2: Scaling Up Capacity and Resolution `_. From 4d8de8af658e7e46eadd60c20e8d2a490c810575 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Sun, 10 Jul 2022 09:33:32 -0700 Subject: [PATCH 25/49] modify expect.pkl --- .../ModelTester.test_swin_v2_t_expect.pkl | Bin 1081 -> 1081 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/test/expect/ModelTester.test_swin_v2_t_expect.pkl b/test/expect/ModelTester.test_swin_v2_t_expect.pkl index e35e680eb0296c49ace65195ad3223615d84d5e4..f3752af6265bf31ca02d48f5219e291625403498 100644 GIT binary patch delta 230 zcmVX%V9ryjj+C&gS|d<<1apZT(mtg2e!WBP$fSM-yl9LRU=WqPovLL^jeH%YE*t)%Lf?vM|WYE2V z&Vawp?P@>t>XpBkoUlGc#c)3d%gnxn#{RvO`By&DTqiyj*Qh?J^Nhc>Y$HFEBu+lM zcY?m8+|xdYK9#=Od;2|X0bf2F83;eH8V|nQrKvvpz4bnP92!x+<;JZ(WT#fY2;2-m zW^RZ+j6f*A2IDwC(AILjTBYv3W0Iu4wUv86AFlyFoCGC5oXNyJ9I_3*P)OuH7$G}8 g)0vY#M8ai0P)i30j8gD`lTZTc1dLMffRjE0`-FFJ-~a#s From db94095629a4ab0c5d40802d64357f62bef3bbb7 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Mon, 11 Jul 2022 09:35:36 -0700 Subject: [PATCH 26/49] use None as default argument value --- torchvision/models/swin_transformer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index ee1bf37e759..96fc67cc695 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -322,9 +322,11 @@ def __init__( proj_bias: bool = True, attention_dropout: float = 0.0, dropout: float = 0.0, - pretrained_window_size: List[int] = [0, 0], + pretrained_window_size: Optional[List[int]] = None, ): - self.pretrained_window_size = pretrained_window_size # TODO: unsafe, need copy? + if pretrained_window_size is None: + pretrained_window_size = [0, 0] + self.pretrained_window_size = pretrained_window_size super().__init__( dim, window_size, @@ -541,7 +543,7 @@ def __init__( block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, use_v2: bool = False, - v2_pretrained_window_sizes: List[int] = [0, 0, 0, 0], + v2_pretrained_window_sizes: Optional[List[int]] = None, ): super().__init__() _log_api_usage_once(self) @@ -553,7 +555,12 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) - downsample_layer = PatchMergingV2 if use_v2 else PatchMerging + if use_v2: + downsample_layer = PatchMergingV2 + if v2_pretrained_window_sizes is None: + v2_pretrained_window_sizes = [0, 0, 0, 0] + else: + downsample_layer = PatchMerging layers: List[nn.Module] = [] # split image into non-overlapping patches From 54cf58465239657680fa1746dd2dc0f63cc28565 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Mon, 11 Jul 2022 09:45:10 -0700 Subject: [PATCH 27/49] fix type check --- torchvision/models/swin_transformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 96fc67cc695..cc3297c6f06 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -555,12 +555,10 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) - if use_v2: - downsample_layer = PatchMergingV2 + downsample_layer = PatchMergingV2 if use_v2 else PatchMerging + if v2_pretrained_window_sizes is None: v2_pretrained_window_sizes = [0, 0, 0, 0] - else: - downsample_layer = PatchMerging layers: List[nn.Module] = [] # split image into non-overlapping patches From 568731c0f3e4c3ffac7fcde4966941c9cea63e83 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Mon, 11 Jul 2022 10:02:43 -0700 Subject: [PATCH 28/49] fix indent --- torchvision/models/swin_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index cc3297c6f06..f3449b30683 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -557,8 +557,8 @@ def __init__( downsample_layer = PatchMergingV2 if use_v2 else PatchMerging - if v2_pretrained_window_sizes is None: - v2_pretrained_window_sizes = [0, 0, 0, 0] + if v2_pretrained_window_sizes is None: + v2_pretrained_window_sizes = [0, 0, 0, 0] layers: List[nn.Module] = [] # split image into non-overlapping patches From b04b9c77f565083ff356323a4b2634fe48fa8e96 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Mon, 11 Jul 2022 11:40:16 -0700 Subject: [PATCH 29/49] fix window_size (temporarily) --- torchvision/models/swin_transformer.py | 32 +++++++------------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index f3449b30683..e5cfab9e25f 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -736,15 +736,12 @@ class Swin_V2_T_Weights(WeightsEnum): pass -def swin_t( - *, weights: Optional[Swin_T_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any -) -> SwinTransformer: +def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. Args: - window_size (List[int]): Window size. Default: [7, 7]. weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_T_Weights` below for @@ -767,7 +764,7 @@ def swin_t( embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=window_size, + window_size=[7, 7], stochastic_depth_prob=0.2, weights=weights, progress=progress, @@ -775,15 +772,12 @@ def swin_t( ) -def swin_s( - *, weights: Optional[Swin_S_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any -) -> SwinTransformer: +def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_small architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. Args: - window_size (List[int]): Window size. Default: [7, 7]. weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_S_Weights` below for @@ -806,7 +800,7 @@ def swin_s( embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], - window_size=window_size, + window_size=[7, 7], stochastic_depth_prob=0.3, weights=weights, progress=progress, @@ -814,15 +808,12 @@ def swin_s( ) -def swin_b( - *, weights: Optional[Swin_B_Weights] = None, progress: bool = True, window_size: List[int] = [7, 7], **kwargs: Any -) -> SwinTransformer: +def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_base architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. Args: - window_size (List[int]): Window size. Default: [7, 7]. weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_B_Weights` below for @@ -845,7 +836,7 @@ def swin_b( embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], - window_size=window_size, + window_size=[7, 7], stochastic_depth_prob=0.5, weights=weights, progress=progress, @@ -853,19 +844,12 @@ def swin_b( ) -def swin_v2_t( - *, - weights: Optional[Swin_V2_T_Weights] = None, - progress: bool = True, - window_size: List[int] = [8, 8], - **kwargs: Any, -) -> SwinTransformer: +def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_v2_tiny architecture from `Swin Transformer V2: Scaling Up Capacity and Resolution `_. Args: - window_size (List[int]): Window size. Default: [8, 8]. weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_V2_T_Weights` below for @@ -888,7 +872,7 @@ def swin_v2_t( embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=window_size, + window_size=[8, 8], stochastic_depth_prob=0.2, weights=weights, progress=progress, From a0e7a40ad9edaa95236022df17f9a0102779109e Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 12 Jul 2022 14:54:39 -0700 Subject: [PATCH 30/49] remove "v2_" related prefix and add v2 builder --- torchvision/models/swin_transformer.py | 66 +++++++++++++++++--------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e5cfab9e25f..e8a24c2d6c0 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -125,7 +125,7 @@ def shifted_window_attention( dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, - v2_logit_scale: Optional[torch.Tensor] = None, + logit_scale: Optional[torch.Tensor] = None, ): """ Window based multi-head self attention (W-MSA) module with relative position bias. @@ -142,6 +142,7 @@ def shifted_window_attention( dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. Returns: Tensor[N, H, W, C]: The output tensor after shifted window attention. """ @@ -168,18 +169,18 @@ def shifted_window_attention( x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C # multi-head attention - if v2_logit_scale is not None and qkv_bias is not None: + if logit_scale is not None and qkv_bias is not None: qkv_bias = qkv_bias.clone() length = qkv_bias.numel() // 3 qkv_bias[length : 2 * length].zero_() qkv = F.linear(x, qkv_weight, qkv_bias) qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - if v2_logit_scale is not None: + if logit_scale is not None: # cosine attention attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) - v2_logit_scale = torch.clamp(v2_logit_scale, max=math.log(100.0)).exp() - attn = attn * v2_logit_scale + logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp() + attn = attn * logit_scale else: q = q * (C // num_heads) ** -0.5 attn = q.matmul(k.transpose(-2, -1)) @@ -338,7 +339,7 @@ def __init__( dropout=dropout, ) - self.v2_logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential( nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) @@ -394,7 +395,7 @@ def forward(self, x: Tensor): dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, - v2_logit_scale=self.v2_logit_scale, + logit_scale=self.logit_scale, ) @@ -525,7 +526,7 @@ class SwinTransformer(nn.Module): num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. - v2_pretrained_window_sizes (List[int]): Pretrained window sizes of each layer. Default: [0, 0, 0, 0]. + pretrained_window_sizes (List[int]): Pretrained window sizes of each layer for Swin Transformer V2. Default: [0, 0, 0, 0]. """ def __init__( @@ -540,26 +541,18 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, + block: Callable[..., nn.Module] = SwinTransformerBlock, norm_layer: Optional[Callable[..., nn.Module]] = None, - use_v2: bool = False, - v2_pretrained_window_sizes: Optional[List[int]] = None, + downsample_layer: Callable[..., nn.Module] = PatchMerging, + pretrained_window_sizes: Optional[List[int]] = None, ): super().__init__() _log_api_usage_once(self) self.num_classes = num_classes - if block is None: - block = SwinTransformerBlockV2 if use_v2 else SwinTransformerBlock - if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) - downsample_layer = PatchMergingV2 if use_v2 else PatchMerging - - if v2_pretrained_window_sizes is None: - v2_pretrained_window_sizes = [0, 0, 0, 0] - layers: List[nn.Module] = [] # split image into non-overlapping patches layers.append( @@ -579,8 +572,8 @@ def __init__( stage: List[nn.Module] = [] dim = embed_dim * 2 ** i_stage kwargs: Dict[str, Any] = {} - if use_v2: - kwargs["pretrained_window_size"] = v2_pretrained_window_sizes[i_stage] + if pretrained_window_sizes is not None: + kwargs["pretrained_window_size"] = pretrained_window_sizes[i_stage] for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) @@ -658,6 +651,34 @@ def _swin_transformer( return model +def _swin_transformer_v2( + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + block: Callable[..., nn.Module] = SwinTransformerBlockV2, + downsample_layer: Callable[..., nn.Module] = PatchMergingV2, + **kwargs: Any, +) -> SwinTransformer: + return _swin_transformer( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + weights=weights, + progress=progress, + block=block, + downsample_layer=downsample_layer, + **kwargs, + ) + + _COMMON_META = { "categories": _IMAGENET_CATEGORIES, } @@ -867,7 +888,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T """ weights = Swin_V2_T_Weights.verify(weights) - return _swin_transformer( + return _swin_transformer_v2( patch_size=[4, 4], embed_dim=96, depths=[2, 2, 6, 2], @@ -876,6 +897,5 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T stochastic_depth_prob=0.2, weights=weights, progress=progress, - use_v2=True, **kwargs, ) From f643ff05132775b3c2a2f0ef22611620695d2c3d Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 12 Jul 2022 15:04:26 -0700 Subject: [PATCH 31/49] remove v2 builder --- torchvision/models/swin_transformer.py | 32 +++----------------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e8a24c2d6c0..2848188b8e4 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -651,34 +651,6 @@ def _swin_transformer( return model -def _swin_transformer_v2( - patch_size: List[int], - embed_dim: int, - depths: List[int], - num_heads: List[int], - window_size: List[int], - stochastic_depth_prob: float, - weights: Optional[WeightsEnum], - progress: bool, - block: Callable[..., nn.Module] = SwinTransformerBlockV2, - downsample_layer: Callable[..., nn.Module] = PatchMergingV2, - **kwargs: Any, -) -> SwinTransformer: - return _swin_transformer( - patch_size=patch_size, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - window_size=window_size, - stochastic_depth_prob=stochastic_depth_prob, - weights=weights, - progress=progress, - block=block, - downsample_layer=downsample_layer, - **kwargs, - ) - - _COMMON_META = { "categories": _IMAGENET_CATEGORIES, } @@ -888,7 +860,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T """ weights = Swin_V2_T_Weights.verify(weights) - return _swin_transformer_v2( + return _swin_transformer( patch_size=[4, 4], embed_dim=96, depths=[2, 2, 6, 2], @@ -897,5 +869,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T stochastic_depth_prob=0.2, weights=weights, progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, **kwargs, ) From e2b338b8bdc808c5ad0f354239a4732ea048ca4d Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 12 Jul 2022 16:29:11 -0700 Subject: [PATCH 32/49] keep default value consistent with official repo --- torchvision/models/swin_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 2848188b8e4..01424d69e4c 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -522,7 +522,7 @@ class SwinTransformer(nn.Module): mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. - stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. @@ -539,7 +539,7 @@ def __init__( mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, - stochastic_depth_prob: float = 0.0, + stochastic_depth_prob: float = 0.1, num_classes: int = 1000, block: Callable[..., nn.Module] = SwinTransformerBlock, norm_layer: Optional[Callable[..., nn.Module]] = None, From 8a13f932815ae25655c07430d52929f86b1ca479 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 12 Jul 2022 16:32:58 -0700 Subject: [PATCH 33/49] deprecate dropout --- torchvision/models/swin_transformer.py | 38 +------------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 01424d69e4c..695368065df 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -121,8 +121,6 @@ def shifted_window_attention( window_size: List[int], num_heads: int, shift_size: List[int], - attention_dropout: float = 0.0, - dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, logit_scale: Optional[torch.Tensor] = None, @@ -138,8 +136,6 @@ def shifted_window_attention( window_size (List[int]): Window size. num_heads (int): Number of attention heads. shift_size (List[int]): Shift size for shifted window attention. - attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. - dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. @@ -206,11 +202,9 @@ def shifted_window_attention( attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = F.softmax(attn, dim=-1) - attn = F.dropout(attn, p=attention_dropout) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = F.linear(x, proj_weight, proj_bias) - x = F.dropout(x, p=dropout) # reverse windows x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) @@ -241,8 +235,6 @@ def __init__( num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, - attention_dropout: float = 0.0, - dropout: float = 0.0, ): super().__init__() if len(window_size) != 2 or len(shift_size) != 2: @@ -250,8 +242,6 @@ def __init__( self.window_size = window_size self.shift_size = shift_size self.num_heads = num_heads - self.attention_dropout = attention_dropout - self.dropout = dropout self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) @@ -301,8 +291,6 @@ def forward(self, x: Tensor): self.window_size, self.num_heads, shift_size=self.shift_size, - attention_dropout=self.attention_dropout, - dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, ) @@ -321,8 +309,6 @@ def __init__( num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, - attention_dropout: float = 0.0, - dropout: float = 0.0, pretrained_window_size: Optional[List[int]] = None, ): if pretrained_window_size is None: @@ -335,8 +321,6 @@ def __init__( num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, - attention_dropout=attention_dropout, - dropout=dropout, ) self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) @@ -391,8 +375,6 @@ def forward(self, x: Tensor): self.window_size, self.num_heads, shift_size=self.shift_size, - attention_dropout=self.attention_dropout, - dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, logit_scale=self.logit_scale, @@ -408,8 +390,6 @@ class SwinTransformerBlock(nn.Module): window_size (List[int]): Window size. shift_size (List[int]): Shift size for shifted window attention. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. - dropout (float): Dropout rate. Default: 0.0. - attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention @@ -422,8 +402,6 @@ def __init__( window_size: List[int], shift_size: List[int], mlp_ratio: float = 4.0, - dropout: float = 0.0, - attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, @@ -438,13 +416,11 @@ def __init__( window_size, shift_size, num_heads, - attention_dropout=attention_dropout, - dropout=dropout, **kwargs, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None) for m in self.mlp.modules(): if isinstance(m, nn.Linear): @@ -467,8 +443,6 @@ class SwinTransformerBlockV2(SwinTransformerBlock): window_size (List[int]): Window size. shift_size (List[int]): Shift size for shifted window attention. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. - dropout (float): Dropout rate. Default: 0.0. - attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. @@ -482,8 +456,6 @@ def __init__( window_size: List[int], shift_size: List[int], mlp_ratio: float = 4.0, - dropout: float = 0.0, - attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, @@ -495,8 +467,6 @@ def __init__( window_size, shift_size, mlp_ratio=mlp_ratio, - dropout=dropout, - attention_dropout=attention_dropout, stochastic_depth_prob=stochastic_depth_prob, norm_layer=norm_layer, attn_layer=attn_layer, @@ -520,8 +490,6 @@ class SwinTransformer(nn.Module): num_heads (List(int)): Number of attention heads in different layers. window_size (List[int]): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. - dropout (float): Dropout rate. Default: 0.0. - attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. @@ -537,8 +505,6 @@ def __init__( num_heads: List[int], window_size: List[int], mlp_ratio: float = 4.0, - dropout: float = 0.0, - attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.1, num_classes: int = 1000, block: Callable[..., nn.Module] = SwinTransformerBlock, @@ -584,8 +550,6 @@ def __init__( window_size=window_size, shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], mlp_ratio=mlp_ratio, - dropout=dropout, - attention_dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, **kwargs, From f4ea5df3d8b54c9fa338277abc471ca2ecc7d783 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 12 Jul 2022 16:34:24 -0700 Subject: [PATCH 34/49] deprecate pretrained_window_size --- torchvision/models/swin_transformer.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 695368065df..70df8b99ea0 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -309,11 +309,7 @@ def __init__( num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, - pretrained_window_size: Optional[List[int]] = None, ): - if pretrained_window_size is None: - pretrained_window_size = [0, 0] - self.pretrained_window_size = pretrained_window_size super().__init__( dim, window_size, @@ -338,12 +334,10 @@ def define_relative_position_bias_table(self): relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if self.pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1 - relative_coords_table[:, :, :, 1] /= self.pretrained_window_size[1] - 1 - else: - relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 - relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table = ( torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0 @@ -446,7 +440,6 @@ class SwinTransformerBlockV2(SwinTransformerBlock): stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. - pretrained_window_size (int): Local window size in pre-training. Default: 0. """ def __init__( @@ -459,7 +452,6 @@ def __init__( stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, - pretrained_window_size: int = 0, ): super().__init__( dim, @@ -470,7 +462,6 @@ def __init__( stochastic_depth_prob=stochastic_depth_prob, norm_layer=norm_layer, attn_layer=attn_layer, - pretrained_window_size=[pretrained_window_size, pretrained_window_size], ) def forward(self, x: Tensor): @@ -494,7 +485,6 @@ class SwinTransformer(nn.Module): num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. - pretrained_window_sizes (List[int]): Pretrained window sizes of each layer for Swin Transformer V2. Default: [0, 0, 0, 0]. """ def __init__( @@ -510,7 +500,6 @@ def __init__( block: Callable[..., nn.Module] = SwinTransformerBlock, norm_layer: Optional[Callable[..., nn.Module]] = None, downsample_layer: Callable[..., nn.Module] = PatchMerging, - pretrained_window_sizes: Optional[List[int]] = None, ): super().__init__() _log_api_usage_once(self) @@ -537,9 +526,6 @@ def __init__( for i_stage in range(len(depths)): stage: List[nn.Module] = [] dim = embed_dim * 2 ** i_stage - kwargs: Dict[str, Any] = {} - if pretrained_window_sizes is not None: - kwargs["pretrained_window_size"] = pretrained_window_sizes[i_stage] for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) @@ -552,7 +538,6 @@ def __init__( mlp_ratio=mlp_ratio, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, - **kwargs, ) ) stage_block_id += 1 From 1c7579cb1bd7bf2f0f94907f39bee6ed707a97a8 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 12 Jul 2022 18:43:30 -0700 Subject: [PATCH 35/49] fix dynamic padding edge case --- torchvision/models/swin_transformer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 70df8b99ea0..eeff9ca6beb 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, List, Optional, Union import torch import torch.nn.functional as F @@ -143,18 +143,21 @@ def shifted_window_attention( Tensor[N, H, W, C]: The output tensor after shifted window attention. """ B, H, W, C = input.shape + + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= H: + shift_size[0] = 0 + window_size[0] = H + if window_size[1] >= W: + shift_size[1] = 0 + window_size[1] = W + # pad feature maps to multiples of window size pad_r = (window_size[1] - W % window_size[1]) % window_size[1] pad_b = (window_size[0] - H % window_size[0]) % window_size[0] x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) _, pad_H, pad_W, _ = x.shape - # If window size is larger than feature size, there is no need to shift window - if window_size[0] >= pad_H: - shift_size[0] = 0 - if window_size[1] >= pad_W: - shift_size[1] = 0 - # cyclic shift if sum(shift_size) > 0: x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) @@ -479,7 +482,7 @@ class SwinTransformer(nn.Module): embed_dim (int): Patch embedding dimension. depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. - window_size (List[int]): Window size. + window_size (int, List[int]): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. num_classes (int): Number of classes for classification head. Default: 1000. From daf8e19af04121e695985e938707858695fcd470 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 12 Jul 2022 18:52:29 -0700 Subject: [PATCH 36/49] remove unused imports --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index eeff9ca6beb..9ef90b499dd 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional import torch import torch.nn.functional as F From 058703914d9086199a67efa54dd5db09b62cf1d6 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 12 Jul 2022 18:58:32 -0700 Subject: [PATCH 37/49] remove doc modification --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 9ef90b499dd..04be04bb2d2 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -482,7 +482,7 @@ class SwinTransformer(nn.Module): embed_dim (int): Patch embedding dimension. depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. - window_size (int, List[int]): Window size. + window_size (List[int]): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. num_classes (int): Number of classes for classification head. Default: 1000. From 1526658714baad111b07052182c2452b5cc046c8 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Wed, 13 Jul 2022 08:47:29 -0700 Subject: [PATCH 38/49] Revert "deprecate dropout" This reverts commit 8a13f932815ae25655c07430d52929f86b1ca479. --- torchvision/models/swin_transformer.py | 38 +++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 04be04bb2d2..1f7ac1444e9 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -121,6 +121,8 @@ def shifted_window_attention( window_size: List[int], num_heads: int, shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, logit_scale: Optional[torch.Tensor] = None, @@ -136,6 +138,8 @@ def shifted_window_attention( window_size (List[int]): Window size. num_heads (int): Number of attention heads. shift_size (List[int]): Shift size for shifted window attention. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. @@ -205,9 +209,11 @@ def shifted_window_attention( attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout) # reverse windows x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) @@ -238,6 +244,8 @@ def __init__( num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, ): super().__init__() if len(window_size) != 2 or len(shift_size) != 2: @@ -245,6 +253,8 @@ def __init__( self.window_size = window_size self.shift_size = shift_size self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) @@ -294,6 +304,8 @@ def forward(self, x: Tensor): self.window_size, self.num_heads, shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, ) @@ -312,6 +324,8 @@ def __init__( num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, ): super().__init__( dim, @@ -320,6 +334,8 @@ def __init__( num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, + attention_dropout=attention_dropout, + dropout=dropout, ) self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) @@ -372,6 +388,8 @@ def forward(self, x: Tensor): self.window_size, self.num_heads, shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, logit_scale=self.logit_scale, @@ -387,6 +405,8 @@ class SwinTransformerBlock(nn.Module): window_size (List[int]): Window size. shift_size (List[int]): Shift size for shifted window attention. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention @@ -399,6 +419,8 @@ def __init__( window_size: List[int], shift_size: List[int], mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, @@ -413,11 +435,13 @@ def __init__( window_size, shift_size, num_heads, + attention_dropout=attention_dropout, + dropout=dropout, **kwargs, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None) + self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) for m in self.mlp.modules(): if isinstance(m, nn.Linear): @@ -440,6 +464,8 @@ class SwinTransformerBlockV2(SwinTransformerBlock): window_size (List[int]): Window size. shift_size (List[int]): Shift size for shifted window attention. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. @@ -452,6 +478,8 @@ def __init__( window_size: List[int], shift_size: List[int], mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, @@ -462,6 +490,8 @@ def __init__( window_size, shift_size, mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, stochastic_depth_prob=stochastic_depth_prob, norm_layer=norm_layer, attn_layer=attn_layer, @@ -484,6 +514,8 @@ class SwinTransformer(nn.Module): num_heads (List(int)): Number of attention heads in different layers. window_size (List[int]): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. @@ -498,6 +530,8 @@ def __init__( num_heads: List[int], window_size: List[int], mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.1, num_classes: int = 1000, block: Callable[..., nn.Module] = SwinTransformerBlock, @@ -539,6 +573,8 @@ def __init__( window_size=window_size, shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, ) From a7b18d8d35f18cfba2286c56ae99a9f7199ecf43 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Wed, 13 Jul 2022 11:36:56 -0700 Subject: [PATCH 39/49] Revert "fix dynamic padding edge case" This reverts commit 1c7579cb1bd7bf2f0f94907f39bee6ed707a97a8. --- torchvision/models/swin_transformer.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 1f7ac1444e9..e3bc0c02bc5 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -147,21 +147,18 @@ def shifted_window_attention( Tensor[N, H, W, C]: The output tensor after shifted window attention. """ B, H, W, C = input.shape - - # If window size is larger than feature size, there is no need to shift window - if window_size[0] >= H: - shift_size[0] = 0 - window_size[0] = H - if window_size[1] >= W: - shift_size[1] = 0 - window_size[1] = W - # pad feature maps to multiples of window size pad_r = (window_size[1] - W % window_size[1]) % window_size[1] pad_b = (window_size[0] - H % window_size[0]) % window_size[0] x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) _, pad_H, pad_W, _ = x.shape + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + # cyclic shift if sum(shift_size) > 0: x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) From 5353afbbddb4a96feb22fec0f42141857bf0a161 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Wed, 13 Jul 2022 12:58:06 -0700 Subject: [PATCH 40/49] remove unused kwargs --- torchvision/models/swin_transformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e3bc0c02bc5..8cdbd0fd0b8 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -421,7 +421,6 @@ def __init__( stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, - **kwargs, ): super().__init__() _log_api_usage_once(self) @@ -434,7 +433,6 @@ def __init__( num_heads, attention_dropout=attention_dropout, dropout=dropout, - **kwargs, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) From 83f9d3d5006c6f08889b138fe9a69008e6e009e3 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Wed, 13 Jul 2022 13:00:56 -0700 Subject: [PATCH 41/49] add downsample docs --- torchvision/models/swin_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 8cdbd0fd0b8..7b8c9a98f2c 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -515,6 +515,7 @@ class SwinTransformer(nn.Module): num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. + downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. """ def __init__( From e07de70d1e6def4c9bf6585eaf58787d95118918 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Wed, 13 Jul 2022 13:01:21 -0700 Subject: [PATCH 42/49] revert block default value --- torchvision/models/swin_transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 7b8c9a98f2c..55132266d45 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -530,7 +530,7 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.1, num_classes: int = 1000, - block: Callable[..., nn.Module] = SwinTransformerBlock, + block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, downsample_layer: Callable[..., nn.Module] = PatchMerging, ): @@ -538,6 +538,8 @@ def __init__( _log_api_usage_once(self) self.num_classes = num_classes + if block is None: + block = SwinTransformerBlock if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) From ce0103affe9ffbd538e496da8f5d34ee00d3d4d9 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Thu, 14 Jul 2022 10:07:23 -0700 Subject: [PATCH 43/49] revert argument order change --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 55132266d45..e52529d3e08 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -530,8 +530,8 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.1, num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, downsample_layer: Callable[..., nn.Module] = PatchMerging, ): super().__init__() From 07fb86bea73f00a120c0bc6bc439e3ae1a124ed8 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Thu, 14 Jul 2022 10:07:53 -0700 Subject: [PATCH 44/49] explicitly specify start_dim --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e52529d3e08..8333cbfa05b 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -588,7 +588,7 @@ def __init__( self.norm = norm_layer(num_features) self.permute = Permute([0, 3, 1, 2]) self.avgpool = nn.AdaptiveAvgPool2d(1) - self.flatten = nn.Flatten() + self.flatten = nn.Flatten(1) self.head = nn.Linear(num_features, num_classes) for m in self.modules(): From fba200beeb2e36946a354190af6ecfd6ff5dae6d Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 19 Jul 2022 21:12:35 -0700 Subject: [PATCH 45/49] add small and base variants --- torchvision/models/swin_transformer.py | 92 ++++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 5e881eb9b7c..16369bf9a99 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -21,14 +21,14 @@ "Swin_S_Weights", "Swin_B_Weights", "Swin_V2_T_Weights", - # "Swin_V2_S_Weights", - # "Swin_V2_B_Weights", + "Swin_V2_S_Weights", + "Swin_V2_B_Weights", "swin_t", "swin_s", "swin_b", "swin_v2_t", - # "swin_v2_s", - # "swin_v2_b", + "swin_v2_s", + "swin_v2_b", ] @@ -716,6 +716,14 @@ class Swin_V2_T_Weights(WeightsEnum): pass +class Swin_V2_S_Weights(WeightsEnum): + pass + + +class Swin_V2_B_Weights(WeightsEnum): + pass + + def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from @@ -860,3 +868,79 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T downsample_layer=PatchMergingV2, **kwargs, ) + + +def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_small architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_S_Weights + :members: + """ + weights = Swin_V2_S_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 8], + stochastic_depth_prob=0.3, + weights=weights, + progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) + + +def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_base architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_B_Weights + :members: + """ + weights = Swin_V2_B_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[8, 8], + stochastic_depth_prob=0.5, + weights=weights, + progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) From f0a53b9243f41c616aa5ad221e2734cb65361805 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Tue, 19 Jul 2022 21:29:01 -0700 Subject: [PATCH 46/49] add expect files and slow_models --- .../expect/ModelTester.test_swin_v2_b_expect.pkl | Bin 0 -> 1081 bytes .../expect/ModelTester.test_swin_v2_s_expect.pkl | Bin 0 -> 1081 bytes test/test_models.py | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 test/expect/ModelTester.test_swin_v2_b_expect.pkl create mode 100644 test/expect/ModelTester.test_swin_v2_s_expect.pkl diff --git a/test/expect/ModelTester.test_swin_v2_b_expect.pkl b/test/expect/ModelTester.test_swin_v2_b_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..5b2be51a2a95d21275fe7962b81eabc72621aca6 GIT binary patch literal 1081 zcmWIWW@cev;NW1u09p(d48Hj(sW~C3#U-gldL=+AzPLOyFTTtuJ}Ex6q98T7L_a05 zBvG#*JIBq*gdvIy(7=>jl3$dZp%-6Tl9^M?6>g z9q7JdpvuJb)Ivsx7B07({KS%Ah#*%XQv@SWjX@!E1T&CME+{Qz@ir<(J*wsLK6iul)8gynSzXHp9Vg&z)ua zmh2ba-|>6$-bWih*oAIiWXI_fxL-@=&Ath3$M+oglCWRp;MIM7PtWYHNS|)M=ylRQ zMU{$u=Y#t8eP&x{FRZArUxi=HKA`O9-gU1S_BZd}YHwnoY1h&D$*%5Z^Zq3vV*9PK zyzE>y1nqOqeYmg5bgsPx&nCMRsZ9G*_tfmwE_v*)c(1WvaLNn2nmtDLzy5pg-?fb2 zK2)mDo?-Wmy$16%>OAYp42L+a45J0bcwUP5?s+gmH(KCW9gop;eYz zR1AzCHz!$?;KE2@##|tqFV0I3Wdhm?!U5ilAPSzQk>gMcB!L2mPvPk1BKuYY#Y`<= zLL|;yeFJn;k$s|xVrm4?SAjl3$dZp%-6Tl9^M?6>g z9q7JdpvuJb)Ivsx7B07({KS%Ah#*%XQv@SWjX@!E1T&CME+{Qz@ir2?e(tot)gzO(tbH)G#k|MGn+q%HRyGT`6;V`t_*gTQP%K5ZY{a~Iz3n=&`v{>`DD zeSzC#_c7eQzSlT|+5U~T+dirF|LoZExa}FvaoU%xdTy)lGRN*x=bU}*t8Uv>_%Ygj z^0{SKl5etqzpae@tQwj9965h%^`%?wUNOq;1%;NLAmghOz>orA+@Yn(ph!e$m8BLH z1LMcdNfsrzFjAN?7s%#|^U_0^fVP5gfHxzEf~RTZIMf13pa9}iIJ&vWzSTf6Qwx|7 zi8EK<0Nqq%pJ<|(8UgebVN;>u7U0dsrUTV3$E*uC2$VHI07kQd@(d7M0tXo=@Yp~( Xjo~^}0Z2N)o0SbD#teiY^$@iHm?J1> literal 0 HcmV?d00001 diff --git a/test/test_models.py b/test/test_models.py index 5668c157769..72a7665b0bc 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -333,8 +333,8 @@ def _check_input_backprop(model, inputs): "swin_s", "swin_b", "swin_v2_t", - # "swin_v2_s", - # "swin_v2_b", + "swin_v2_s", + "swin_v2_b", ] for m in slow_models: _model_params[m] = {"input_shape": (1, 3, 64, 64)} From 72cd9d8eec591ac2d2d21fd4a7c66c72a94ce8fd Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 10 Aug 2022 01:03:15 +0800 Subject: [PATCH 47/49] Add model weights and documentation for swin v2 --- docs/source/models/swin_transformer.rst | 11 +++-- references/classification/README.md | 11 +++++ torchvision/models/swin_transformer.py | 65 +++++++++++++++++++++++-- 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/docs/source/models/swin_transformer.rst b/docs/source/models/swin_transformer.rst index 3eb74069176..e841578011b 100644 --- a/docs/source/models/swin_transformer.rst +++ b/docs/source/models/swin_transformer.rst @@ -3,16 +3,18 @@ SwinTransformer .. currentmodule:: torchvision.models -The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision +The SwinTransformer models are based on the `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `__ paper. +SwinTransformer V2 models are based on the `Swin Transformer V2: Scaling Up Capacity +and Resolution `__ +paper. Model builders -------------- -The following model builders can be used to instantiate an SwinTransformer model. -`swin_t` can be instantiated with pre-trained weights and all others without. +The following model builders can be used to instantiate an SwinTransformer model (original and V2) with and without pre-trained weights. All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code `_ for @@ -25,3 +27,6 @@ more details about this class. swin_t swin_s swin_b + swin_v2_t + swin_v2_s + swin_v2_b \ No newline at end of file diff --git a/references/classification/README.md b/references/classification/README.md index da30159542b..e8d62134ca2 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -236,6 +236,17 @@ Note that `--val-resize-size` was optimized in a post-training step, see their ` +### SwinTransformer V2 +``` +torchrun --nproc_per_node=8 train.py\ +--model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 256 --val-crop-size 256 --train-crop-size 256 +``` +Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`. +Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. + + + + ### ShuffleNet V2 ``` torchrun --nproc_per_node=8 train.py \ diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 1471edfd9ed..9fd8c512628 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -561,7 +561,7 @@ def __init__( # build SwinTransformer blocks for i_stage in range(len(depths)): stage: List[nn.Module] = [] - dim = embed_dim * 2**i_stage + dim = embed_dim * 2 ** i_stage for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) @@ -713,15 +713,72 @@ class Swin_B_Weights(WeightsEnum): class Swin_V2_T_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 28351570, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.072, + "acc@5": 96.132, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 class Swin_V2_S_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 49737442, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.712, + "acc@5": 96.816, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 class Swin_V2_B_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_b-781e5279.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=272, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 87930848, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 84.112, + "acc@5": 96.864, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 @register_model() From 68ffa1ce4e54877a35feb5ae57ace05331abb010 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 10 Aug 2022 20:09:21 +0800 Subject: [PATCH 48/49] fix lint --- docs/source/models/swin_transformer.rst | 3 ++- torchvision/models/swin_transformer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/models/swin_transformer.rst b/docs/source/models/swin_transformer.rst index e841578011b..c8119fae7a1 100644 --- a/docs/source/models/swin_transformer.rst +++ b/docs/source/models/swin_transformer.rst @@ -29,4 +29,5 @@ more details about this class. swin_b swin_v2_t swin_v2_s - swin_v2_b \ No newline at end of file + swin_v2_b + diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 9fd8c512628..9f43b546d59 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -561,7 +561,7 @@ def __init__( # build SwinTransformer blocks for i_stage in range(len(depths)): stage: List[nn.Module] = [] - dim = embed_dim * 2 ** i_stage + dim = embed_dim * 2**i_stage for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) From 73373a5d04167052d31e67ce92ee12dea91e0db6 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 10 Aug 2022 20:34:08 +0800 Subject: [PATCH 49/49] fix end of files line --- docs/source/models/swin_transformer.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/models/swin_transformer.rst b/docs/source/models/swin_transformer.rst index c8119fae7a1..35b52987954 100644 --- a/docs/source/models/swin_transformer.rst +++ b/docs/source/models/swin_transformer.rst @@ -30,4 +30,3 @@ more details about this class. swin_v2_t swin_v2_s swin_v2_b -