Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Swin Transformer architecture #5491

Merged
merged 116 commits into from
Apr 27, 2022
Merged

Adding Swin Transformer architecture #5491

merged 116 commits into from
Apr 27, 2022

Conversation

xiaohu2015
Copy link
Contributor

@xiaohu2015 xiaohu2015 commented Feb 27, 2022

This work related to #2707 and #5410: add swin transformer to torchvision model_zoo.

  • Implement swin_transformer.
  • Check model accuracy from converted weights swin_tiny_weights.
torchrun --nproc_per_node=8 train.py --model swin_tiny --interpolation bicubic --test-only --pretrained
  • Refactor code.
    I made some modifications compared to official code:

  • remove absolute position embedding: as we can see from table 4 in the paper, the swin model with relative position bias get best results, so the default swin model does not use absolute position embedding. Another trouble with absolute position embedding is that we have to set input_size to the model to initialize the pos_embedding.

  • remove the input_resolution parameter: so input with arbitrary shape can be handled by the swint model, which is necessary for some tasks eg. segmentaion and object detection. Compared to offical code, we keep tensor with shape [B, H, W, C] instead of [B, N, C], so we can get width and height without input_resolution. But after do that, one must dynamically compare the window size and input size in the shifted window attention, for example, if the input size is lower than window size (when the image size is 224, the feature size of last stage is 7x7), you need't do shift operation. but the dynamic behavior is not well supported in torch.fx, so I create shifted_window_attention function and warp it. Note: this modification can add run time as we have to generate attention_mask dynamically, but the cost time is insignificant.

  • Validate the training.

torchrun --nproc_per_node=8 train.py\
    --model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\
    --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\
    --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\
    --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra

which can give result: Acc@1 81.222 Acc@5 95.332 train logs

I also modified the reference code, https://github.com/xiaohu2015/vision/blob/main/references/classification/utils.py#L406. as the current code only supports no weight decay for norm layers.

references

@facebook-github-bot
Copy link

facebook-github-bot commented Feb 27, 2022

💊 CI failures summary and remediations

As of commit 92ae7dd (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@xiaohu2015 xiaohu2015 marked this pull request as ready for review March 6, 2022 12:23
@xiaohu2015
Copy link
Contributor Author

xiaohu2015 commented Apr 26, 2022

@xiaohu2015 Can you please share the logs with @jdsgomes and all the information that will allow us to reproduce your experiment (for example the git commit hashcodes you used etc). It's also unclear to me whether you used TorchVision's reference scripts or something else in your experiments. Could you please clarify?

for swin_t, I have shared the training logs (just use TorchVision's reference script ) with @jdsgomes, I got Acc@1 81.222 Acc@5 95.332, @jdsgomes can reproduce the training (81.204), the result can match the offical result. But for swin_s and swin_b, it drops about 0.5 point compared to the offical result, in fact, I only trained swin_t. I plan to check these two models. thanks to the training logs from @jdsgomes, I found that these two models seem to get behind in the last few epochs.

@jdsgomes yes, iteration-based lr scheduler and epoch-based lr scheduler should be equivalent. I just suspect that might be the reason, because some minor difference can make a result difference.

@jdsgomes
Copy link
Contributor

@xiaohu2015 Can you please share the logs with @jdsgomes and all the information that will allow us to reproduce your experiment (for example the git commit hashcodes you used etc). It's also unclear to me whether you used TorchVision's reference scripts or something else in your experiments. Could you please clarify?

for swin_t, I have shared the training logs (just use TorchVision's reference script ) with @jdsgomes, I got Acc@1 81.222 Acc@5 95.332, @jdsgomes can reproduce the training (81.204), the result can match the offical result. But for swin_s and swin_b, it drops about 0.5 point compared to the offical result, in fact, I only trained swin_t. I plan to check these two models. thanks to the training logs from @jdsgomes, I found that these two models seem to get behind in the last few epochs.

@jdsgomes yes, iteration-based lr scheduler and epoch-based lr scheduler should be equivalent. I just suspect that might be the reason, because some minor difference can make a result difference.

After discussing offline with @datumbox I think we should proceed to merge the PR with the swin_t only since it is clear that we can reproduce the result, so great work @xiaohu2015 !

After that we can continue investigations to close the gap and aim to merge the other variants in a different PR.

I will do the final cleanups between today and tomorrow.

@datumbox
Copy link
Contributor

Just wanted to echo what Joao said. Big massive thank you @xiaohu2015 for your awesome contribution. Top notch code and excellent research reproduction skills. Also apologies for taking us long to review and reproduce the PR; it's something we want to improve upon.

Looking forward seeing this merged!

@xiaohu2015
Copy link
Contributor Author

@datumbox @jdsgomes Thanks very much!

@@ -416,7 +416,7 @@ class Swin_T_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_t-81486767.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value was determined in post-training optimisation similarly to what we did in convenext

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jdsgomes. I know you are looking into all these, just added a few comments for changes done on the documentation side last week so that we don't forget to include it.

docs/source/models.rst Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Show resolved Hide resolved
@xiaohu2015
Copy link
Contributor Author

for other models, I can convert the offical weight to torchvision version just like efficientnet.

@jdsgomes
Copy link
Contributor

to

@xiaohu2015 I understand that would be useful to include pre-trained weights from the initial implementation, but we would prefer to include the other variants once we can replicate the results fully. I am running a few experiments now, and hopefully we can get good results, but for now I will remove even the constructors so this PR can be merged.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks again @xiaohu2015 for the awesome contribution.

@jdsgomes thanks as well for your support and guidance.

I think we are good to merge. Just make sure we remove the unnecessary expect files ModelTester.test_swin_*_expect.pk for variants s/b/l that were removed.

Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xiaohu2015 for the great contribution and @datumbox for the feedback

@jdsgomes jdsgomes merged commit e288f6c into pytorch:main Apr 27, 2022
facebook-github-bot pushed a commit that referenced this pull request May 6, 2022
Summary:
* add swin transformer

* Update swin_transformer.py

* Update swin_transformer.py

* fix lint

* fix lint

* refactor code

* add swin_transformer

* Update swin_transformer.py

* fix bug

* refactor code

* fix lint

* update init_weights

* move shift_window into attention

* refactor code

* fix bug

* Update swin_transformer.py

* Update swin_transformer.py

* fix lint

* add patch_merge

* fix bug

* Update swin_transformer.py

* Update swin_transformer.py

* Update swin_transformer.py

* refactor code

* Update swin_transformer.py

* refactor code

* fix lint

* refactor code

* add swin_tiny

* add swin_tiny.pkl

* fix lint

* Delete ModelTester.test_swin_tiny_expect.pkl

* add swin_tiny

* add

* add Optional to bias

* update init weights

* update init_weights and add no weight decay

* add no weight decay

* add set_weight_decay

* add set_weight_decay

* fix lint

* fix lint

* add lr_cos_min

* add other swin models

* Update torchvision/models/swin_transformer.py

* refactor doc

* Update utils.py

* Update train.py

* Update train.py

* Update swin_transformer.py

* update model builder

* fix lint

* add

* Update torchvision/models/swin_transformer.py

* Update torchvision/models/swin_transformer.py

* update other model

* simplify the model name just like ViT

* add lr_cos_min

* fix lint

* fix lint

* Update swin_transformer.py

* Update swin_transformer.py

* Update swin_transformer.py

* Delete ModelTester.test_swin_tiny_expect.pkl

* add swin_t

* refactor code

* Update train.py

* add swin_s

* ignore a error of mypy

* Update swin_transformer.py

* fix lint

* add swin_b

* add swin_l

* refactor code

* Update train.py

* move relative_position_bias to __init__

* fix formatting

* Revert "fix formatting"

This reverts commit 41faba2.

* Revert "move relative_position_bias to __init__"

This reverts commit f061544.

* refactor code

* Remove deprecated meta-data from `_COMMON_META`

* fix linter

* add pretrained weights for swin_t

* fix format

* apply ufmt

* add documentation

* update references README

* adding new style docs

* update pre-trained weights values

* remove other variants

* fix typo

* Remove expect for the variants not yet supported

Reviewed By: jdsgomes, NicolasHug

Differential Revision: D36095689

fbshipit-source-id: d387402233977b1628efe72f98341822602d5b81

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Joao Gomes <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants