-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Swin Transformer architecture #5491
Conversation
💊 CI failures summary and remediationsAs of commit 92ae7dd (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
for swin_t, I have shared the training logs (just use TorchVision's reference script ) with @jdsgomes, I got @jdsgomes yes, iteration-based lr scheduler and epoch-based lr scheduler should be equivalent. I just suspect that might be the reason, because some minor difference can make a result difference. |
After discussing offline with @datumbox I think we should proceed to merge the PR with the swin_t only since it is clear that we can reproduce the result, so great work @xiaohu2015 ! After that we can continue investigations to close the gap and aim to merge the other variants in a different PR. I will do the final cleanups between today and tomorrow. |
Just wanted to echo what Joao said. Big massive thank you @xiaohu2015 for your awesome contribution. Top notch code and excellent research reproduction skills. Also apologies for taking us long to review and reproduce the PR; it's something we want to improve upon. Looking forward seeing this merged! |
@@ -416,7 +416,7 @@ class Swin_T_Weights(WeightsEnum): | |||
IMAGENET1K_V1 = Weights( | |||
url="https://download.pytorch.org/models/swin_t-81486767.pth", | |||
transforms=partial( | |||
ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC | |||
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This value was determined in post-training optimisation similarly to what we did in convenext
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jdsgomes. I know you are looking into all these, just added a few comments for changes done on the documentation side last week so that we don't forget to include it.
for other models, I can convert the offical weight to torchvision version just like efficientnet. |
@xiaohu2015 I understand that would be useful to include pre-trained weights from the initial implementation, but we would prefer to include the other variants once we can replicate the results fully. I am running a few experiments now, and hopefully we can get good results, but for now I will remove even the constructors so this PR can be merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks again @xiaohu2015 for the awesome contribution.
@jdsgomes thanks as well for your support and guidance.
I think we are good to merge. Just make sure we remove the unnecessary expect files ModelTester.test_swin_*_expect.pk
for variants s/b/l that were removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @xiaohu2015 for the great contribution and @datumbox for the feedback
Summary: * add swin transformer * Update swin_transformer.py * Update swin_transformer.py * fix lint * fix lint * refactor code * add swin_transformer * Update swin_transformer.py * fix bug * refactor code * fix lint * update init_weights * move shift_window into attention * refactor code * fix bug * Update swin_transformer.py * Update swin_transformer.py * fix lint * add patch_merge * fix bug * Update swin_transformer.py * Update swin_transformer.py * Update swin_transformer.py * refactor code * Update swin_transformer.py * refactor code * fix lint * refactor code * add swin_tiny * add swin_tiny.pkl * fix lint * Delete ModelTester.test_swin_tiny_expect.pkl * add swin_tiny * add * add Optional to bias * update init weights * update init_weights and add no weight decay * add no weight decay * add set_weight_decay * add set_weight_decay * fix lint * fix lint * add lr_cos_min * add other swin models * Update torchvision/models/swin_transformer.py * refactor doc * Update utils.py * Update train.py * Update train.py * Update swin_transformer.py * update model builder * fix lint * add * Update torchvision/models/swin_transformer.py * Update torchvision/models/swin_transformer.py * update other model * simplify the model name just like ViT * add lr_cos_min * fix lint * fix lint * Update swin_transformer.py * Update swin_transformer.py * Update swin_transformer.py * Delete ModelTester.test_swin_tiny_expect.pkl * add swin_t * refactor code * Update train.py * add swin_s * ignore a error of mypy * Update swin_transformer.py * fix lint * add swin_b * add swin_l * refactor code * Update train.py * move relative_position_bias to __init__ * fix formatting * Revert "fix formatting" This reverts commit 41faba2. * Revert "move relative_position_bias to __init__" This reverts commit f061544. * refactor code * Remove deprecated meta-data from `_COMMON_META` * fix linter * add pretrained weights for swin_t * fix format * apply ufmt * add documentation * update references README * adding new style docs * update pre-trained weights values * remove other variants * fix typo * Remove expect for the variants not yet supported Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095689 fbshipit-source-id: d387402233977b1628efe72f98341822602d5b81 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Joao Gomes <[email protected]>
This work related to #2707 and #5410: add swin transformer to torchvision model_zoo.
Refactor code.
I made some modifications compared to official code:
remove
absolute position embedding
: as we can see from table 4 in the paper, the swin model withrelative position bias
get best results, so the default swin model does not useabsolute position embedding
. Another trouble withabsolute position embedding
is that we have to set input_size to the model to initialize the pos_embedding.remove the
input_resolution
parameter: so input with arbitrary shape can be handled by the swint model, which is necessary for some tasks eg. segmentaion and object detection. Compared to offical code, we keep tensor with shape [B, H, W, C] instead of [B, N, C], so we can get width and height withoutinput_resolution
. But after do that, one must dynamically compare the window size and input size in the shifted window attention, for example, if the input size is lower than window size (when the image size is 224, the feature size of last stage is 7x7), you need't do shift operation. but the dynamic behavior is not well supported in torch.fx, so I createshifted_window_attention
function and warp it. Note: this modification can add run time as we have to generateattention_mask
dynamically, but the cost time is insignificant.Validate the training.
which can give result:
Acc@1 81.222 Acc@5 95.332
train logsI also modified the reference code, https://github.com/xiaohu2015/vision/blob/main/references/classification/utils.py#L406. as the current code only supports no weight decay for norm layers.
references