Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MViT architecture in TorchVision #6198

Merged
merged 11 commits into from
Jun 24, 2022
Merged

Add MViT architecture in TorchVision #6198

merged 11 commits into from
Jun 24, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Jun 23, 2022

This PR adds on TorchVision the Video variant of "Multiscale Vision Transformers". The model can also handle images by considering them single frame videos. This PR contains changes from #6086, #6105 and #6179.

The implementations is based on MViTv1 but includes some extensions from MViTv2 (the Attention supports a residual_pool option). The mvit_v1_b variant introduced is canonical and the weights are ported from the paper. This is based on the work of @haooooooqi, @feichtenhofer and @lyttonhao on PyTorchVideo.

Verification process

Comparing outputs

To confirm that the implementation is compatible with the original from PyTorch Video we create a weight converter, load the same weights for both implementations and compare them against the same input:

import collections

import torch
from pytorchvideo.models.hub.vision_transformers import mvit_base_16x4
from pytorchvideo.models.vision_transformers import create_multiscale_vision_transformers
from torchvision.models.video import mvit as TorchVision


class PyTorchVideo:
    @staticmethod
    def mvit_v1_b(**kwargs):
        return create_multiscale_vision_transformers(
            spatial_size=(224, 224),
            temporal_size=16,
            depth=16,
            embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
            droppath_rate_block=0.2,
            # additional params for PyTorch Video
            residual_pool=False,
            separate_qkv=False,
            pool_kv_stride_adaptive=(1, 8, 8),
            pool_kvq_kernel=(3, 3, 3),
            head_dropout_rate=0.0,
            **kwargs,
        )

    @staticmethod
    def mvit_base_16x4(**kwargs):
        # return mvit_base_16x4(pretrained=True)
        m = create_multiscale_vision_transformers(
            spatial_size=(224, 224),
            temporal_size=16,
            depth=16,
            embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
            droppath_rate_block=0.2,
            # additional params for PyTorch Video
            residual_pool=False,
            separate_qkv=True,
            pool_kv_stride_adaptive=(1, 8, 8),
            pool_kvq_kernel=(3, 3, 3),
            head_dropout_rate=0.0,
            **kwargs,
        )
        # https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/MVIT_B_16x4.pyth
        d = torch.load("./MVIT_B_16x4.pyth")["model_state"]
        m.load_state_dict(d, strict=False)
        return m


def ptv_to_tv_weights(state_dict, separate_qkv):
    d = dict(state_dict)

    # merge qkv if necessary
    if separate_qkv:
        components = collections.defaultdict(dict)
        for k in list(d.keys()):
            for pattern in ["q", "k", "v"]:
                if f".attn.{pattern}." in k:
                    group = k.rsplit(".", 2)[0]
                    components[group][k] = d.pop(k)
                    break
        for group in components.keys():
            for typ in ["weight", "bias"]:
                l = []
                for pattern in ["q", "k", "v"]:
                    l.append(components[group].pop(f"{group}.{pattern}.{typ}"))
                d[f"{group}.qkv.{typ}"] = torch.cat(l, dim=0)

    # remapping keys
    mapping = collections.OrderedDict(
        [
            ("patch_embed.patch_model.weight", "conv_proj.weight"),
            ("patch_embed.patch_model.bias", "conv_proj.bias"),
            ("cls_positional_encoding.cls_token", "pos_encoding.class_token"),
            ("cls_positional_encoding.pos_embed_spatial", "pos_encoding.spatial_pos"),
            ("cls_positional_encoding.pos_embed_temporal", "pos_encoding.temporal_pos"),
            ("cls_positional_encoding.pos_embed_class", "pos_encoding.class_pos"),
            ("attn.proj.weight", "attn.project.0.weight"),
            ("attn.proj.bias", "attn.project.0.bias"),
            ("attn.pool_q.weight", "attn.pool_q.pool.weight"),
            ("attn.norm_q.weight", "attn.pool_q.norm_act.0.weight"),
            ("attn.norm_q.bias", "attn.pool_q.norm_act.0.bias"),
            ("attn.pool_k.weight", "attn.pool_k.pool.weight"),
            ("attn.norm_k.weight", "attn.pool_k.norm_act.0.weight"),
            ("attn.norm_k.bias", "attn.pool_k.norm_act.0.bias"),
            ("attn.pool_v.weight", "attn.pool_v.pool.weight"),
            ("attn.norm_v.weight", "attn.pool_v.norm_act.0.weight"),
            ("attn.norm_v.bias", "attn.pool_v.norm_act.0.bias"),
            ("mlp.fc1.weight", "mlp.0.weight"),
            ("mlp.fc1.bias", "mlp.0.bias"),
            ("mlp.fc2.weight", "mlp.3.weight"),
            ("mlp.fc2.bias", "mlp.3.bias"),
            ("norm_embed.weight", "norm.weight"),
            ("norm_embed.bias", "norm.bias"),
            ("head.proj.weight", "head.1.weight"),
            ("head.proj.bias", "head.1.bias"),
            ("proj.weight", "project.weight"),
            ("proj.bias", "project.bias"),
        ]
    )
    for k in list(d.keys()):
        for pattern, replacement in mapping.items():
            if pattern in k:
                new_key = k.replace(pattern, replacement)
                d[new_key] = d.pop(k)
                break

    # matching dimensions
    d["pos_encoding.class_token"] = d["pos_encoding.class_token"][0, 0, :]
    d["pos_encoding.spatial_pos"] = d["pos_encoding.spatial_pos"][0, :]
    d["pos_encoding.temporal_pos"] = d["pos_encoding.temporal_pos"][0, :]
    d["pos_encoding.class_pos"] = d["pos_encoding.class_pos"][0, 0, :]

    # removing unnecessary keys
    for k in list(d.keys()):
        if "attn._attention_pool_" in k:
            del d[k]
    return d


def compare_models(ptv_model_fn, tv_model_fn, input_shape):
    print(tv_model_fn.__name__)
    x = torch.randn(input_shape)

    ptv_m = ptv_model_fn().eval()
    exp_result = ptv_m(x).sum()

    separate_qkv = isinstance(ptv_m.blocks[0].attn.qkv, torch.nn.Identity)

    d = ptv_m.state_dict()
    d = ptv_to_tv_weights(d, separate_qkv)

    tv_m = tv_model_fn().eval()
    tv_m.load_state_dict(d)
    result = tv_m(x).sum()

    torch.testing.assert_close(result, exp_result, rtol=0, atol=1e-6)
    print("OK")


compare_models(PyTorchVideo.mvit_v1_b, TorchVision.mvit_v1_b, (1, 3, 16, 224, 224))
compare_models(PyTorchVideo.mvit_base_16x4, TorchVision.mvit_v1_b, (1, 3, 16, 224, 224))

Benchmarks

To ensure that we don't introduce any speed regression we test the speed as follows:

import time


def benchmark(model_fn, input_shape, device, n=5, warmup=0.1):
    torch.manual_seed(42)
    m = model_fn().to(device).eval()
    x = torch.randn(input_shape).to(device)

    s = []
    for i in range(n):
        start = time.time()
        m(x)
        t = time.time() - start
        if i > n * warmup:
            s.append(t)

    print(model_fn.__name__, torch.tensor(s).median())


device = "cuda"
batch_size = 4
n = 100

print(f"device={device}, batch_size={batch_size}, n={n}")
for name, backend in [("TorchVision", TorchVision), ("PyTorchVideo", PyTorchVideo)]:
    print(name)
    benchmark(backend.mvit_v1_b, (batch_size, 3, 16, 224, 224), device, n=n)

This was tested on an A100 and as we see below the implementation is 4% faster than the original:

device=cuda, batch_size=4, n=100
TorchVision
mvit_v1_b tensor(0.0298)
PyTorchVideo
mvit_v1_b tensor(0.0310)

Accuracy

Normally to test the accuracy of the model, we would run the following:

torchrun --nproc_per_node=8 train.py --data-path /datasets01/kinetics/070618/400/ --batch-size 1 --test-only --clip-len 16 --frame-rate 4 --clips-per-video 5 --model mvit_v1_b --weights MViT_V1_B_Weights.KINETICS400_V1 --cache-dataset
* Clip Acc@1 73.153 Clip Acc@5 90.542

The above check shows reduced accuracy comparing to the expected one. There seems to be a regression on our reference script, Kinetics dataset or Video Decoder. The accuracy of the model using TorchVision's implementation was verified using the Slowfast reference scripts:

INFO:test_net:testing done: _ak78.47 Top1 Acc: 78.47 Top5 Acc: 93.65 MEM: 1.97 dataset: k400

The accuracy is close to the one reported at the Slowfast repo. Minor differences exist because we had to remove quite a few corrupted videos from our infra that are no longer available to redownload.

Follow up work

This PR doesn't address all potential extensions and these will be addressed on followup PRs:

  1. Review the main class API to ensure it supports production use-cases from Meta
  2. Add a native image model for MViT
  3. Add full support for MViTv2 by including the rel_pos + dim_mul_in_att extensions.

The API needs to be finalized prior the release of TorchVision v0.14, else we will need the implementation to the prototype area.

datumbox and others added 4 commits June 7, 2022 13:28
* Adding mvitv2 architecture

* Fixing memory issues on tests and minor refactorings.

* Adding input validation

* Adding docs and minor refactoring

* Add `min_temporal_size` in the supported meta-data.

* Switch Tuple[int, int, int] with List[int] to support easier the 2D case

* Adding more docs and references

* Change naming conventions of classes to follow the same pattern as MobileNetV3

* Fix test breakage.

* Update todos

* Performance optimizations.
* Switch implementation to v1 variant.

* Fix docs

* Adding back a v2 pseudovariant

* Changing the way the network are configured.

* Temporarily removing v2

* Adding weights.

* Expand _squeeze/_unsqueeze to support arbitrary dims.

* Update references script.

* Fix tests.

* Fixing frames and preprocessing.

* Fix std/mean values in transforms.

* Add permanent Dropout and update the weights.

* Update accuracies.
@datumbox datumbox changed the title Mvit Add MViT architecture in TorchVision Jun 23, 2022
test/test_models.py Outdated Show resolved Hide resolved
@yassineAlouini
Copy link
Contributor

yassineAlouini commented Jun 24, 2022

That work looks impressive @datumbox (et al), well done.

Is there anything I can help with (even some documentation and/or code review)?

@datumbox
Copy link
Contributor Author

@yassineAlouini Absolutely, feel free to code-review or follow up with a PR to improve documentation on the model. There are details that are not fully documented at the moment such as that the Class supports not only V1 but also some extensions done on V2 (though not all). The API of the class is expected also to change on the near future, so here I just follow the established idioms but @YosuaMichael is looking into proposing a slightly different structure going forwards for new models.

@yassineAlouini
Copy link
Contributor

@yassineAlouini Absolutely, feel free to code-review or follow up with a PR to improve documentation on the model. There are details that are not fully documented at the moment such as that the Class supports not only V1 but also some extensions done on V2 (though not all). The API of the class is expected also to change on the near future, so here I just follow the established idioms but @YosuaMichael is looking into proposing a slightly different structure going forwards for new models.

Awesome, I will give this a look and see what I can contribute, thanks @datumbox for the quick reply.

Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Added to Nits but feel free to merge as it is if you this its more readable as it stands now. Specially the second NIT, I am not convinced about the readability benefits

torchvision/models/video/mvit.py Show resolved Hide resolved
torchvision/models/video/mvit.py Show resolved Hide resolved
Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for this great work!

@datumbox
Copy link
Contributor Author

datumbox commented Jun 24, 2022

The failing tests are unrelated to this PR. The issue is recorded at #6202. Merging.

@datumbox datumbox merged commit fb7f9a1 into main Jun 24, 2022
@datumbox datumbox deleted the mvit branch June 24, 2022 13:48
facebook-github-bot pushed a commit that referenced this pull request Jun 27, 2022
Summary:
* Adding MViT v2 architecture (#6105)

* Adding mvitv2 architecture

* Fixing memory issues on tests and minor refactorings.

* Adding input validation

* Adding docs and minor refactoring

* Add `min_temporal_size` in the supported meta-data.

* Switch Tuple[int, int, int] with List[int] to support easier the 2D case

* Adding more docs and references

* Change naming conventions of classes to follow the same pattern as MobileNetV3

* Fix test breakage.

* Update todos

* Performance optimizations.

* Add support to MViT v1 (#6179)

* Switch implementation to v1 variant.

* Fix docs

* Adding back a v2 pseudovariant

* Changing the way the network are configured.

* Temporarily removing v2

* Adding weights.

* Expand _squeeze/_unsqueeze to support arbitrary dims.

* Update references script.

* Fix tests.

* Fixing frames and preprocessing.

* Fix std/mean values in transforms.

* Add permanent Dropout and update the weights.

* Update accuracies.

* Fix documentation

* Remove unnecessary expected file.

* Skip big model test

* Rewrite the configuration logic to reduce LOC.

* Fix mypy

Reviewed By: NicolasHug

Differential Revision: D37450352

fbshipit-source-id: 5c0bf1065351d8dd612012902117fd866db02899
@yassineAlouini
Copy link
Contributor

@datumbox Do you think it is a good idea I work on this point:

Add full support for MViTv2 by including the rel_pos + dim_mul_in_att extensions.

or is someone else going to work on it or does it require a lot of knowledge that I might be lacking?

Thanks. 👍

@datumbox
Copy link
Contributor Author

datumbox commented Jun 30, 2022

@yassineAlouini Thanks for offering. Happy to help you implement it if you are interested. I'm actually working with the research team who developed MViT and there are a few things we would need to clarify with them prior supporting v2. Mainly which configurations they consider canonical. The release of the codebase for the 3d case is extremely new (9 days at the time of writing). This is why I opted to wait a bit before jumping into it. Shall I give you a ping when we clarify the details and we are ready to start? I got some draft PRs that you could build upon (or start from scratch).

@yassineAlouini
Copy link
Contributor

Yes, that would be awesome, thanks @datumbox. In the meantime, I am reading a bit more the v2 paper https://arxiv.org/pdf/2112.01526.pdf and checking the existing implementation.

Please share additional resources if necessary and thanks again.

@yassineAlouini
Copy link
Contributor

Regarding the configurations and what is considered canonical, I guess you are referring to these @datumbox? =>
mvit_v2_variants

@datumbox
Copy link
Contributor Author

That is correct but there are nuances. Several of the booleans that refer to the encodings are mutually exclusive, some of them are used only for ablation experiments and some are more appropriate for a production environment. So it's complicated. We want to finalize the API of MViTv1 first to ensure it covers both external and internal needs and then move forwards with the V2. That's in the nutshell the blocker.

@yassineAlouini
Copy link
Contributor

Thanks for this clear explanation. 👌

@yassineAlouini
Copy link
Contributor

By the way @datumbox, is the following v2 implementation useful to check? https://github.com/facebookresearch/mvit (particularly the models file https://github.com/facebookresearch/mvit/blob/main/mvit/models/mvit_model.py)

@datumbox
Copy link
Contributor Author

datumbox commented Jul 6, 2022

This implementation is lacking the 3d capabilities. I recommend instead looking at the Slowfast repo. In particular the commit facebookresearch/SlowFast@1aebd71 which introduced it.

@yassineAlouini
Copy link
Contributor

Hello @datumbox. Hope you are doing well. Any update on the implementation? How can I help or is it still a WIP and should work on something else in the meantime? Thanks. 🙏

@datumbox
Copy link
Contributor Author

@yassineAlouini Hey there! I just got back yesterday from PTO and trying to catch up. Let me get up to speed on this and I'll get back to you. Thanks for following up!

@yassineAlouini
Copy link
Contributor

Thanks @datumbox. Hope you had a good vacation. 🌴

@datumbox
Copy link
Contributor Author

datumbox commented Aug 3, 2022

@yassineAlouini Can I interest you on another architecture? 😄 Perhaps MobileViT v1/v2 which is part of #6323?

This one might be a bit tricky. Changes on this model affect internal teams and it will require confirming we didn't break anything on their end. Given you don't have access on FBcode, that's going to complicate things. Let me know if the proposed alternative architecture works for you, or we can find another one that is closer to your interests. Thanks and apologies for this.

@yassineAlouini
Copy link
Contributor

@datumbox I will give it a look and let you know. If the changes require some coordination, I will keep you in touch to act as the link with internal teams. Thanks again for the suggestion. 👌

@yassineAlouini
Copy link
Contributor

@datumbox By the way, does that mean that MViT v2 will be handled by someone else or is it still a WIP?

@datumbox
Copy link
Contributor Author

datumbox commented Aug 8, 2022

@yassineAlouini I'm currently doing some investigation at #6373. I'm also syncing with the research and production teams to ensure we won't break anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants