Skip to content

Commit

Permalink
Small Patch SwinTransformer for FX compatibility (#6252) (#6301)
Browse files Browse the repository at this point in the history
* Move out the pad operation from PatchMerging in swin transformer to make it fx compatible (#6252)

* empty commmit

* empty commmit

Co-authored-by: YosuaMichael <[email protected]>
  • Loading branch information
NicolasHug and YosuaMichael authored Jul 26, 2022
1 parent 055b6f4 commit 9b6233d
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
]


def _patch_merging_pad(x):
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
return x


torch.fx.wrap("_patch_merging_pad")


class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
Expand All @@ -46,8 +55,7 @@ def forward(self, x: Tensor):
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x = _patch_merging_pad(x)

x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
Expand Down

0 comments on commit 9b6233d

Please sign in to comment.