Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Swin Transformer architecture #5491

Merged
merged 116 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 114 commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
533d2c0
add swin transformer
xiaohu2015 Feb 27, 2022
311751e
Update swin_transformer.py
xiaohu2015 Mar 6, 2022
8db4fcd
Merge branch 'main' into main
xiaohu2015 Mar 6, 2022
d478852
Update swin_transformer.py
xiaohu2015 Mar 6, 2022
92a1cf5
fix lint
xiaohu2015 Mar 6, 2022
8ac8077
fix lint
xiaohu2015 Mar 6, 2022
c4445a7
refactor code
xiaohu2015 Mar 6, 2022
97e22d7
add swin_transformer
xiaohu2015 Mar 6, 2022
8599a4b
Update swin_transformer.py
xiaohu2015 Mar 6, 2022
c378934
fix bug
xiaohu2015 Mar 6, 2022
c8e8fe2
refactor code
xiaohu2015 Mar 7, 2022
45bbbfc
fix lint
xiaohu2015 Mar 7, 2022
ebae8b1
update init_weights
xiaohu2015 Mar 7, 2022
0e76444
move shift_window into attention
xiaohu2015 Mar 7, 2022
9a953c3
refactor code
xiaohu2015 Mar 7, 2022
b9321c7
fix bug
xiaohu2015 Mar 7, 2022
f33d1cd
Update swin_transformer.py
xiaohu2015 Mar 7, 2022
41e54b8
Update swin_transformer.py
xiaohu2015 Mar 7, 2022
71ef011
fix lint
xiaohu2015 Mar 7, 2022
6af4964
add patch_merge
xiaohu2015 Mar 7, 2022
1689dd9
fix bug
xiaohu2015 Mar 8, 2022
3891aad
Update swin_transformer.py
xiaohu2015 Mar 8, 2022
c0e88af
Merge branch 'main' into main
xiaohu2015 Mar 8, 2022
86f6d6b
Update swin_transformer.py
xiaohu2015 Mar 8, 2022
dd9b121
Update swin_transformer.py
xiaohu2015 Mar 8, 2022
f869896
refactor code
xiaohu2015 Mar 8, 2022
f3ae314
Update swin_transformer.py
xiaohu2015 Mar 8, 2022
86a745d
Merge branch 'pytorch:main' into main
xiaohu2015 Mar 9, 2022
4ec8710
refactor code
xiaohu2015 Mar 9, 2022
333660f
Merge branch 'pytorch:main' into main
xiaohu2015 Mar 9, 2022
20b4eee
fix lint
xiaohu2015 Mar 9, 2022
f580cda
Merge branch 'main' into main
xiaohu2015 Mar 10, 2022
cb802ec
refactor code
xiaohu2015 Mar 10, 2022
113b074
add swin_tiny
xiaohu2015 Mar 10, 2022
d92a490
add swin_tiny.pkl
xiaohu2015 Mar 10, 2022
a1032a0
fix lint
xiaohu2015 Mar 10, 2022
05dd1e2
Delete ModelTester.test_swin_tiny_expect.pkl
xiaohu2015 Mar 10, 2022
210b629
add swin_tiny
xiaohu2015 Mar 10, 2022
267fbda
add
xiaohu2015 Mar 10, 2022
f9e6f8a
add Optional to bias
xiaohu2015 Mar 10, 2022
8c4f875
Merge branch 'main' into main
xiaohu2015 Mar 11, 2022
4ed22c0
update init weights
xiaohu2015 Mar 16, 2022
02a0a90
Merge branch 'pytorch:main' into main
xiaohu2015 Mar 16, 2022
b3a61ac
Merge branch 'pytorch:main' into main
xiaohu2015 Mar 17, 2022
bccc2b4
update init_weights and add no weight decay
xiaohu2015 Mar 17, 2022
2098b24
add no weight decay
xiaohu2015 Mar 17, 2022
71ea6bf
Merge branch 'pytorch:main' into main
xiaohu2015 Mar 17, 2022
6b0b6c2
add set_weight_decay
xiaohu2015 Mar 17, 2022
991e4c1
add set_weight_decay
xiaohu2015 Mar 17, 2022
f1ec5c8
fix lint
xiaohu2015 Mar 17, 2022
3c2a44d
fix lint
xiaohu2015 Mar 17, 2022
8e5f08b
Merge branch 'main' into main
xiaohu2015 Mar 18, 2022
e8b528f
add lr_cos_min
xiaohu2015 Mar 18, 2022
023ceb0
add other swin models
xiaohu2015 Mar 18, 2022
643ad6e
Merge branch 'main' into main
xiaohu2015 Mar 21, 2022
997587a
Merge branch 'main' into main
xiaohu2015 Mar 22, 2022
caad59e
Merge branch 'main' into main
xiaohu2015 Mar 23, 2022
113fd09
Update torchvision/models/swin_transformer.py
xiaohu2015 Mar 24, 2022
e91d607
refactor doc
xiaohu2015 Mar 24, 2022
ad1b5f6
Merge branch 'main' into main
xiaohu2015 Mar 24, 2022
78fb3ce
Update utils.py
xiaohu2015 Mar 25, 2022
b3b9a20
Update train.py
xiaohu2015 Mar 25, 2022
bb255c1
Update train.py
xiaohu2015 Mar 25, 2022
6c9b0c2
Merge branch 'main' into main
xiaohu2015 Mar 25, 2022
270360e
Merge branch 'main' into main
xiaohu2015 Mar 27, 2022
7db62b8
Merge branch 'main' into main
xiaohu2015 Mar 31, 2022
02f5006
Update swin_transformer.py
xiaohu2015 Apr 1, 2022
df626aa
update model builder
xiaohu2015 Apr 1, 2022
438a0dd
fix lint
xiaohu2015 Apr 1, 2022
070aebd
add
xiaohu2015 Apr 1, 2022
0cd82e1
Update torchvision/models/swin_transformer.py
xiaohu2015 Apr 1, 2022
8fde8ad
Update torchvision/models/swin_transformer.py
xiaohu2015 Apr 1, 2022
412ad15
update other model
xiaohu2015 Apr 1, 2022
9539c1d
simplify the model name just like ViT
xiaohu2015 Apr 2, 2022
16a8feb
Merge branch 'main' into main
xiaohu2015 Apr 2, 2022
04bf82c
add lr_cos_min
xiaohu2015 Apr 2, 2022
b24b8d9
fix lint
xiaohu2015 Apr 2, 2022
54d01f7
fix lint
xiaohu2015 Apr 2, 2022
961d1b5
Update swin_transformer.py
xiaohu2015 Apr 2, 2022
38279ed
Update swin_transformer.py
xiaohu2015 Apr 2, 2022
b1dcf5e
Update swin_transformer.py
xiaohu2015 Apr 2, 2022
0d40142
Delete ModelTester.test_swin_tiny_expect.pkl
xiaohu2015 Apr 2, 2022
358c6be
add swin_t
xiaohu2015 Apr 2, 2022
07410bd
refactor code
xiaohu2015 Apr 4, 2022
8c6d910
Update train.py
xiaohu2015 Apr 4, 2022
7c9ffd3
add swin_s
xiaohu2015 Apr 4, 2022
e94fdfd
ignore a error of mypy
xiaohu2015 Apr 4, 2022
1021fd2
Update swin_transformer.py
xiaohu2015 Apr 4, 2022
88a3e03
fix lint
xiaohu2015 Apr 4, 2022
535cc6a
add swin_b
xiaohu2015 Apr 4, 2022
92ae7dd
add swin_l
xiaohu2015 Apr 4, 2022
bb33737
refactor code
xiaohu2015 Apr 5, 2022
2500ff3
Update train.py
xiaohu2015 Apr 5, 2022
f061544
move relative_position_bias to __init__
jdsgomes Apr 7, 2022
41faba2
fix formatting
jdsgomes Apr 7, 2022
e338dbe
Revert "fix formatting"
jdsgomes Apr 20, 2022
1b8ffb1
Revert "move relative_position_bias to __init__"
jdsgomes Apr 20, 2022
89fc8f1
Merge branch 'main' into main
jdsgomes Apr 20, 2022
affd0df
refactor code
xiaohu2015 Apr 21, 2022
565203b
Remove deprecated meta-data from `_COMMON_META`
datumbox Apr 22, 2022
09d63f5
fix linter
datumbox Apr 26, 2022
b6fec69
add pretrained weights for swin_t
jdsgomes Apr 26, 2022
64b52d4
merge upstream changes
jdsgomes Apr 26, 2022
64af984
fix format
jdsgomes Apr 26, 2022
9e0bfcb
Merge branch 'main' into main
jdsgomes Apr 26, 2022
1528ca8
apply ufmt
jdsgomes Apr 26, 2022
e6e9ffe
add documentation
jdsgomes Apr 27, 2022
137d634
update references README
jdsgomes Apr 27, 2022
3457abb
adding new style docs
jdsgomes Apr 27, 2022
d3599ef
update pre-trained weights values
jdsgomes Apr 27, 2022
4e05993
Merge branch 'main' into main
jdsgomes Apr 27, 2022
6addd1b
remove other variants
jdsgomes Apr 27, 2022
6c328f9
Merge branch 'main' of github.com:xiaohu2015/vision into xiaohu2015/main
jdsgomes Apr 27, 2022
ca59aaf
fix typo
jdsgomes Apr 27, 2022
e4c9646
Remove expect for the variants not yet supported
jdsgomes Apr 27, 2022
9999e64
Merge branch 'main' into main
jdsgomes Apr 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ architectures for image classification:
- `RegNet`_
- `VisionTransformer`_
- `ConvNeXt`_
- `SwinTransformer`_

You can construct a model with random weights by calling its constructor:

Expand Down Expand Up @@ -97,6 +98,7 @@ You can construct a model with random weights by calling its constructor:
convnext_small = models.convnext_small()
convnext_base = models.convnext_base()
convnext_large = models.convnext_large()
swin_t = models.swin_t()

We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.

Expand Down Expand Up @@ -219,6 +221,7 @@ convnext_tiny 82.520 96.146
convnext_small 83.616 96.650
convnext_base 84.062 96.870
convnext_large 84.414 96.976
swin_t 81.358 95.526
================================ ============= =============


Expand All @@ -238,6 +241,7 @@ convnext_large 84.414 96.976
.. _RegNet: https://arxiv.org/abs/2003.13678
.. _VisionTransformer: https://arxiv.org/abs/2010.11929
.. _ConvNeXt: https://arxiv.org/abs/2201.03545
.. _SwinTransformer: https://arxiv.org/abs/2103.14030

.. currentmodule:: torchvision.models

Expand Down Expand Up @@ -450,6 +454,15 @@ ConvNeXt
convnext_base
convnext_large

SwinTransformer
--------

.. autosummary::
:toctree: generated/
:template: function.rst

swin_t

Quantized Models
----------------

Expand Down
25 changes: 25 additions & 0 deletions docs/source/models/swin_transformer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
SwinTransformer
===============

.. currentmodule:: torchvision.models

The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision
Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`__
paper.


Model builders
--------------

The following model builders can be used to instanciate an SwinTransformer model.
`swin_t` can be instantiated with pre-trained weights and all others without.
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for
more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

swin_t
1 change: 1 addition & 0 deletions docs/source/models_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ weights:
models/resnet
models/resnext
models/squeezenet
models/swin_transformer
models/vgg
models/vision_transformer

Expand Down
12 changes: 12 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,18 @@ Note that the above command corresponds to training on a single node with 8 GPUs
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),
and `--batch_size 64`.


### SwinTransformer
```
torchrun --nproc_per_node=8 train.py\
--model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\
--lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\
--clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra
```
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.


## Mixed precision training
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).

Expand Down
5 changes: 3 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def main(args):
if args.bias_weight_decay is not None:
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
if args.transformer_embedding_decay is not None:
for key in ["class_token", "position_embedding", "relative_position_bias"]:
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
parameters = utils.set_weight_decay(
model,
Expand Down Expand Up @@ -267,7 +267,7 @@ def main(args):
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs - args.lr_warmup_epochs
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
)
elif args.lr_scheduler == "exponentiallr":
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
Expand Down Expand Up @@ -424,6 +424,7 @@ def get_args_parser(add_help=True):
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
Expand Down
Binary file added test/expect/ModelTester.test_swin_b_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_swin_l_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_swin_s_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_swin_t_expect.pkl
Binary file not shown.
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from .swin_transformer import *
from . import detection
from . import optical_flow
from . import quantization
Expand Down
Loading