Skip to content

Commit

Permalink
Move out the pad operation from PatchMerging in swin transformer to m…
Browse files Browse the repository at this point in the history
…ake it fx compatible (pytorch#6252)
  • Loading branch information
YosuaMichael authored and atalman committed Jul 26, 2022
1 parent 055b6f4 commit d772041
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
]


def _patch_merging_pad(x):
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
return x


torch.fx.wrap("_patch_merging_pad")


class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
Expand All @@ -46,8 +55,7 @@ def forward(self, x: Tensor):
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x = _patch_merging_pad(x)

x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
Expand Down

0 comments on commit d772041

Please sign in to comment.