-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MViT architecture in TorchVision #6198
Conversation
* Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations.
* Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies.
That work looks impressive @datumbox (et al), well done. Is there anything I can help with (even some documentation and/or code review)? |
@yassineAlouini Absolutely, feel free to code-review or follow up with a PR to improve documentation on the model. There are details that are not fully documented at the moment such as that the Class supports not only V1 but also some extensions done on V2 (though not all). The API of the class is expected also to change on the near future, so here I just follow the established idioms but @YosuaMichael is looking into proposing a slightly different structure going forwards for new models. |
Awesome, I will give this a look and see what I can contribute, thanks @datumbox for the quick reply. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Added to Nits but feel free to merge as it is if you this its more readable as it stands now. Specially the second NIT, I am not convinced about the readability benefits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you for this great work!
The failing tests are unrelated to this PR. The issue is recorded at #6202. Merging. |
Summary: * Adding MViT v2 architecture (#6105) * Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations. * Add support to MViT v1 (#6179) * Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies. * Fix documentation * Remove unnecessary expected file. * Skip big model test * Rewrite the configuration logic to reduce LOC. * Fix mypy Reviewed By: NicolasHug Differential Revision: D37450352 fbshipit-source-id: 5c0bf1065351d8dd612012902117fd866db02899
@yassineAlouini Thanks for offering. Happy to help you implement it if you are interested. I'm actually working with the research team who developed MViT and there are a few things we would need to clarify with them prior supporting v2. Mainly which configurations they consider canonical. The release of the codebase for the 3d case is extremely new (9 days at the time of writing). This is why I opted to wait a bit before jumping into it. Shall I give you a ping when we clarify the details and we are ready to start? I got some draft PRs that you could build upon (or start from scratch). |
Yes, that would be awesome, thanks @datumbox. In the meantime, I am reading a bit more the v2 paper https://arxiv.org/pdf/2112.01526.pdf and checking the existing implementation. Please share additional resources if necessary and thanks again. |
Regarding the configurations and what is considered canonical, I guess you are referring to these @datumbox? => |
That is correct but there are nuances. Several of the booleans that refer to the encodings are mutually exclusive, some of them are used only for ablation experiments and some are more appropriate for a production environment. So it's complicated. We want to finalize the API of MViTv1 first to ensure it covers both external and internal needs and then move forwards with the V2. That's in the nutshell the blocker. |
Thanks for this clear explanation. 👌 |
By the way @datumbox, is the following v2 implementation useful to check? https://github.com/facebookresearch/mvit (particularly the models file https://github.com/facebookresearch/mvit/blob/main/mvit/models/mvit_model.py) |
This implementation is lacking the 3d capabilities. I recommend instead looking at the Slowfast repo. In particular the commit facebookresearch/SlowFast@1aebd71 which introduced it. |
Hello @datumbox. Hope you are doing well. Any update on the implementation? How can I help or is it still a WIP and should work on something else in the meantime? Thanks. 🙏 |
@yassineAlouini Hey there! I just got back yesterday from PTO and trying to catch up. Let me get up to speed on this and I'll get back to you. Thanks for following up! |
Thanks @datumbox. Hope you had a good vacation. 🌴 |
@yassineAlouini Can I interest you on another architecture? 😄 Perhaps MobileViT v1/v2 which is part of #6323? This one might be a bit tricky. Changes on this model affect internal teams and it will require confirming we didn't break anything on their end. Given you don't have access on FBcode, that's going to complicate things. Let me know if the proposed alternative architecture works for you, or we can find another one that is closer to your interests. Thanks and apologies for this. |
@datumbox I will give it a look and let you know. If the changes require some coordination, I will keep you in touch to act as the link with internal teams. Thanks again for the suggestion. 👌 |
@datumbox By the way, does that mean that MViT v2 will be handled by someone else or is it still a WIP? |
@yassineAlouini I'm currently doing some investigation at #6373. I'm also syncing with the research and production teams to ensure we won't break anything. |
This PR adds on TorchVision the Video variant of "Multiscale Vision Transformers". The model can also handle images by considering them single frame videos. This PR contains changes from #6086, #6105 and #6179.
The implementations is based on MViTv1 but includes some extensions from MViTv2 (the Attention supports a
residual_pool
option). Themvit_v1_b
variant introduced is canonical and the weights are ported from the paper. This is based on the work of @haooooooqi, @feichtenhofer and @lyttonhao on PyTorchVideo.Verification process
Comparing outputs
To confirm that the implementation is compatible with the original from PyTorch Video we create a weight converter, load the same weights for both implementations and compare them against the same input:
Benchmarks
To ensure that we don't introduce any speed regression we test the speed as follows:
This was tested on an A100 and as we see below the implementation is 4% faster than the original:
Accuracy
Normally to test the accuracy of the model, we would run the following:
The above check shows reduced accuracy comparing to the expected one. There seems to be a regression on our reference script, Kinetics dataset or Video Decoder. The accuracy of the model using TorchVision's implementation was verified using the Slowfast reference scripts:
The accuracy is close to the one reported at the Slowfast repo. Minor differences exist because we had to remove quite a few corrupted videos from our infra that are no longer available to redownload.
Follow up work
This PR doesn't address all potential extensions and these will be addressed on followup PRs:
rel_pos
+dim_mul_in_att
extensions.The API needs to be finalized prior the release of TorchVision v0.14, else we will need the implementation to the prototype area.