Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SwinV2 #6246

Merged
merged 61 commits into from
Aug 10, 2022
Merged
Changes from 14 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
6eeecf9
init submit
ain-soph Jul 7, 2022
a104e79
Merge branch 'pytorch:main' into swin_transfomer_v2
ain-soph Jul 7, 2022
7ad94a7
fix typo
ain-soph Jul 7, 2022
4f1a59b
Merge branch 'swin_transfomer_v2' of https://github.com/ain-soph/visi…
ain-soph Jul 7, 2022
e28ec45
support ufmt and mypy
ain-soph Jul 7, 2022
ff44832
fix 2 unittest errors
ain-soph Jul 7, 2022
7d84b31
fix ufmt issue
ain-soph Jul 8, 2022
4a21e98
Apply suggestions from code review
ain-soph Jul 8, 2022
6e488d5
Merge branch 'swin_transfomer_v2' of https://github.com/ain-soph/visi…
ain-soph Jul 8, 2022
8e0f8f6
unify codes
ain-soph Jul 8, 2022
284ca50
fix meshgrid indexing
ain-soph Jul 8, 2022
e801222
fix a bug
ain-soph Jul 8, 2022
32bee48
Merge branch 'pytorch:main' into swin_transfomer_v2
ain-soph Jul 8, 2022
56cb30b
Merge branch 'swin_transfomer_v2' of https://github.com/ain-soph/visi…
ain-soph Jul 8, 2022
eb06414
fix type check
ain-soph Jul 8, 2022
5deccd5
add type_annotation
ain-soph Jul 8, 2022
75bcbc7
add slow model
ain-soph Jul 9, 2022
084833e
fix device issue
ain-soph Jul 9, 2022
a0498a9
fix ufmt issue
ain-soph Jul 9, 2022
c9b77c8
add expect pickle file
ain-soph Jul 9, 2022
81e2a2e
fix jit script issue
ain-soph Jul 9, 2022
3eb0de8
fix type check
ain-soph Jul 9, 2022
d7a4ca2
keep consistent argument order
ain-soph Jul 9, 2022
005bb13
add support for pretrained_window_size
ain-soph Jul 9, 2022
69bad17
avoid code duplication
ain-soph Jul 9, 2022
6717145
a better code reuse
ain-soph Jul 9, 2022
0dc1b22
update window_size argument
ain-soph Jul 9, 2022
8b85859
make permute and flatten operations modular
ain-soph Jul 9, 2022
b9bba94
add PatchMergingV2
ain-soph Jul 10, 2022
4d8de8a
modify expect.pkl
ain-soph Jul 10, 2022
db94095
use None as default argument value
ain-soph Jul 11, 2022
54cf584
fix type check
ain-soph Jul 11, 2022
568731c
fix indent
ain-soph Jul 11, 2022
b04b9c7
fix window_size (temporarily)
ain-soph Jul 11, 2022
a0e7a40
remove "v2_" related prefix and add v2 builder
ain-soph Jul 12, 2022
f643ff0
remove v2 builder
ain-soph Jul 12, 2022
e2b338b
keep default value consistent with official repo
ain-soph Jul 12, 2022
8a13f93
deprecate dropout
ain-soph Jul 12, 2022
f4ea5df
deprecate pretrained_window_size
ain-soph Jul 12, 2022
1c7579c
fix dynamic padding edge case
ain-soph Jul 13, 2022
daf8e19
remove unused imports
ain-soph Jul 13, 2022
0587039
remove doc modification
ain-soph Jul 13, 2022
1526658
Revert "deprecate dropout"
ain-soph Jul 13, 2022
a7b18d8
Revert "fix dynamic padding edge case"
ain-soph Jul 13, 2022
5353afb
remove unused kwargs
ain-soph Jul 13, 2022
83f9d3d
add downsample docs
ain-soph Jul 13, 2022
e07de70
revert block default value
ain-soph Jul 13, 2022
ce0103a
revert argument order change
ain-soph Jul 14, 2022
07fb86b
explicitly specify start_dim
ain-soph Jul 14, 2022
2670ae1
Merge branch 'main' into swin_transfomer_v2
ain-soph Jul 14, 2022
fba200b
add small and base variants
ain-soph Jul 20, 2022
e3f8935
Merge branch 'main' into swin_transfomer_v2
ain-soph Jul 20, 2022
f0a53b9
add expect files and slow_models
ain-soph Jul 20, 2022
da56c95
Merge branch 'main' into swin_transfomer_v2
jdsgomes Jul 22, 2022
af2f491
Merge branch 'main' into swin_transfomer_v2
ain-soph Aug 4, 2022
1266a4c
Merge branch 'main' into swin_transfomer_v2
jdsgomes Aug 8, 2022
3fd6968
Merge branch 'main' into swin_transfomer_v2
jdsgomes Aug 9, 2022
72cd9d8
Add model weights and documentation for swin v2
jdsgomes Aug 9, 2022
68ffa1c
fix lint
jdsgomes Aug 10, 2022
73373a5
fix end of files line
jdsgomes Aug 10, 2022
c13c5ef
Merge branch 'main' into swin_transfomer_v2
jdsgomes Aug 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 228 additions & 12 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Optional, Callable, List, Any
from typing import Any, Callable, List, Optional

import torch
import torch.nn.functional as F
Expand All @@ -9,7 +9,7 @@
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._api import Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param

Expand All @@ -19,9 +19,15 @@
"Swin_T_Weights",
"Swin_S_Weights",
"Swin_B_Weights",
"Swin_V2_T_Weights",
# "Swin_V2_S_Weights",
# "Swin_V2_B_Weights",
"swin_t",
"swin_s",
"swin_b",
"swin_v2_t",
# "swin_v2_s",
# "swin_v2_b",
]


Expand Down Expand Up @@ -80,6 +86,8 @@ def shifted_window_attention(
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
v2: bool = False,
logit_scale: torch.Tensor = None,
):
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
Expand Down Expand Up @@ -122,11 +130,21 @@ def shifted_window_attention(
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C

# multi-head attention
if v2 and qkv_bias is not None:
qkv_bias = qkv_bias.clone()
length = qkv_bias.numel() // 3
qkv_bias[length : 2 * length].zero_()
qkv = F.linear(x, qkv_weight, qkv_bias)
qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (C // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
if v2:
# cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
logit_scale = torch.clamp(logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))).exp()
attn = attn * logit_scale
else:
q = q * (C // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
# add relative position bias
attn = attn + relative_position_bias

Expand Down Expand Up @@ -199,9 +217,12 @@ def __init__(
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)

self.define_relative_coords()

def define_relative_coords(self):
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads)
) # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
Expand All @@ -214,24 +235,127 @@ def __init__(
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

def get_relative_position_bias(self, relative_position_bias_table: Optional[torch.Tensor]) -> torch.Tensor:
if relative_position_bias_table is None:
relative_position_bias_table = self.relative_position_bias_table
N = self.window_size[0] * self.window_size[1]
relative_position_bias = relative_position_bias_table[self.relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
return relative_position_bias

def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
)

N = self.window_size[0] * self.window_size[1]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)

class ShiftedWindowAttentionV2(ShiftedWindowAttention):
"""
See :func:`shifted_window_attention_v2`.
"""

def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
pretrained_window_size: List[int] = [0, 0],
):
self.pretrained_window_size = pretrained_window_size # TODO: unsafe, need copy?
super().__init__(
ain-soph marked this conversation as resolved.
Show resolved Hide resolved
dim,
window_size,
shift_size,
num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attention_dropout=attention_dropout,
dropout=dropout,
)

self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
)
if qkv_bias:
length = self.qkv.bias.numel() // 3
self.qkv.bias[length : 2 * length].data.zero_()

def define_relative_coords(self):
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if self.pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.pretrained_window_size[1] - 1
else:
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
)

self.register_buffer("relative_coords_table", relative_coords_table)

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

def get_relative_position_bias(self) -> torch.Tensor:
relative_position_bias_table: torch.Tensor = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
Copy link
Contributor Author

@ain-soph ain-soph Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use flatten(end_dim=-2) here instead of view(-1, self.num_heads)?

relative_position_bias = super().get_relative_position_bias(relative_position_bias_table)
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
return relative_position_bias

def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
Expand All @@ -244,6 +368,8 @@ def forward(self, x: Tensor):
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
v2=True,
logit_scale=self.logit_scale,
)


Expand Down Expand Up @@ -274,7 +400,7 @@ def __init__(
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
attn_layer: Optional[Callable[..., nn.Module]] = ShiftedWindowAttention,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -304,6 +430,54 @@ def forward(self, x: Tensor):
return x


class SwinTransformerBlockV2(SwinTransformerBlock):
"""
Swin Transformer V2 Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2
"""

def __init__(
self,
dim: int,
num_heads: int,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Optional[Callable[..., nn.Module]] = ShiftedWindowAttentionV2,
):
super().__init__(
dim,
num_heads,
window_size,
shift_size,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=stochastic_depth_prob,
norm_layer=norm_layer,
attn_layer=attn_layer,
)

def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.norm1(self.attn(x)))
x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
return x


class SwinTransformer(nn.Module):
"""
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Expand Down Expand Up @@ -337,13 +511,14 @@ def __init__(
num_classes: int = 1000,
norm_layer: Optional[Callable[..., nn.Module]] = None,
ain-soph marked this conversation as resolved.
Show resolved Hide resolved
block: Optional[Callable[..., nn.Module]] = None,
ain-soph marked this conversation as resolved.
Show resolved Hide resolved
v2: bool = False,
):
super().__init__()
_log_api_usage_once(self)
self.num_classes = num_classes

if block is None:
block = SwinTransformerBlock
block = SwinTransformerBlockV2 if v2 else SwinTransformerBlock

if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5)
Expand Down Expand Up @@ -514,6 +689,10 @@ class Swin_B_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1


class Swin_V2_T_Weights(WeightsEnum):
pass


def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_tiny architecture from
Expand Down Expand Up @@ -620,3 +799,40 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
progress=progress,
**kwargs,
)


def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_v2_tiny architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_.

Args:
weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_V2_T_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.

.. autoclass:: torchvision.models.Swin_V2_T_Weights
:members:
"""
weights = Swin_V2_T_Weights.verify(weights)

return _swin_transformer(
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[7, 7],
stochastic_depth_prob=0.2,
weights=weights,
progress=progress,
v2=True,
**kwargs,
)