From 68add70259b9b8c94d899ddb54a46dc789bc4a7e Mon Sep 17 00:00:00 2001 From: ryan-qiyu-jiang Date: Thu, 16 Dec 2021 14:22:48 -0800 Subject: [PATCH] [feat] Add pytorchvideo encoder wrapper (#1156) Summary: Pull Request resolved: https://github.com/facebookresearch/mmf/pull/1156 Add an encoder class that constructs any pytorchvideo model from config, and uses this model for its forward pass. Can load pretrained or random init models, based on config. Test Plan: Tested through unit tests on slowfast50 and mvit. Will be tested end-to-end when datasets and transformers are available in mmf ``` (torchvideo) ryanjiang@learnfair5083:~/copy/mmf$ pytest tests/models/test_mmf_transformer.py ================================================== test session starts ================================================== platform linux -- Python 3.7.11, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 rootdir: /private/home/ryanjiang/copy/mmf plugins: forked-1.3.0, timeout-1.4.2, hydra-core-1.1.1, xdist-2.4.0, dash-2.0.0 collected 15 items tests/models/test_mmf_transformer.py ............... [100%] (torchvideo) ryanjiang@learnfair5083:~/copy/mmf$ pytest tests/modules/test_encoders.py ================================================== test session starts ================================================== platform linux -- Python 3.7.11, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 rootdir: /private/home/ryanjiang/copy/mmf plugins: forked-1.3.0, timeout-1.4.2, hydra-core-1.1.1, xdist-2.4.0, dash-2.0.0 collected 12 items tests/modules/test_encoders.py ............ [100%] ``` Reviewed By: apsdehal Differential Revision: D32631207 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: 6b549162f7ae9ccea162563e48ed910618a6da54 --- mmf/modules/encoders.py | 91 +++++++++++++++++++++++++++- tests/models/test_mmf_transformer.py | 61 ++++++++++++++++++- tests/modules/test_encoders.py | 62 ++++++++++++++++++- tests/test_utils.py | 7 +++ 4 files changed, 217 insertions(+), 4 deletions(-) diff --git a/mmf/modules/encoders.py b/mmf/modules/encoders.py index 8a67bb1cc..98948e4e7 100644 --- a/mmf/modules/encoders.py +++ b/mmf/modules/encoders.py @@ -1,10 +1,12 @@ # Copyright (c) Facebook, Inc. and its affiliates. +import importlib +import logging import os import pickle import re from collections import OrderedDict from copy import deepcopy -from dataclasses import dataclass +from dataclasses import asdict, dataclass from enum import Enum from typing import Any @@ -25,13 +27,15 @@ from transformers.configuration_auto import AutoConfig from transformers.modeling_auto import AutoModel - try: from detectron2.modeling import ShapeSpec, build_resnet_backbone except ImportError: pass +logger = logging.getLogger() + + class Encoder(nn.Module): @dataclass class Config: @@ -688,6 +692,89 @@ def forward(self, x: Tensor) -> Tensor: return out +@registry.register_encoder("pytorchvideo") +class PytorchVideoEncoder(Encoder): + """A thin wrapper around pytorchvideo models. + This class is responsible for integrating pytorchvideo models as encoders. + THis class attempts to construct a pytorchvideo model from torch hub. + If this fails for a random weight model, and pytorchvideo package is available, + build the model with random weights from pytorchvideo.models. + + Config: + name (str): Always 'pytorchvideo' Used for builder_encoder() + random_init (bool): Flag to load pretrained weights + model_name (str): Name of the pytorchvideo model to use + drop_last_n_layers (int): + <=0 value for the number of layers to drop off the end + pooler_name (str): Name of pooler used on model output + + Raises: + ImportError: + The constructor raises an ImportError if pytorchvideo is not installed. + """ + + @dataclass + class Config(Encoder.Config): + name: str = "pytorchvideo" + random_init: bool = False + model_name: str = "slowfast_r50" + drop_last_n_layers: int = -1 + pooler_name: str = "identity" + + PYTORCHVIDEO_REPO = "facebookresearch/pytorchvideo:main" + + def __init__(self, config: Config): + super().__init__() + config = OmegaConf.create({**asdict(self.Config()), **config}) + if config.random_init: + params = dict(**OmegaConf.to_container(config)) + params = { + k: v + for k, v in params.items() + if k not in PytorchVideoEncoder.Config().__dict__ + } + try: + model = torch.hub.load( + PytorchVideoEncoder.PYTORCHVIDEO_REPO, + model=config.model_name, + pretrained=False, + **params, + ) + except BaseException as err: + pytorchvideo_spec = importlib.util.find_spec("pytorchvideo") + if pytorchvideo_spec is None: + raise err + import pytorchvideo.models.hub as hub + + model_create_fn = getattr(hub, config.model_name) + model = model_create_fn(pretrained=False, **params) + else: + # load weights from TorchHub + model = torch.hub.load( + PytorchVideoEncoder.PYTORCHVIDEO_REPO, + model=config.model_name, + pretrained=True, + ) + encoder_list = [] + if config.drop_last_n_layers == 0: + encoder_list += [model] + else: + modules_list = list(model.children()) + if len(modules_list) == 1: + modules_list = list(modules_list[0].children()) + modules = modules_list[: config.drop_last_n_layers] + encoder_list += modules + + pooler = registry.get_pool_class(config.pooler_name)() + encoder_list += [pooler] + self.encoder = nn.Sequential(*encoder_list) + + def forward(self, *args, **kwargs): + # pass along input to model + # assumes caller obeys the dynamic model signature + return self.encoder(*args, **kwargs) + + @registry.register_encoder("r2plus1d_18") class R2Plus1D18VideoEncoder(PooledEncoder): """ diff --git a/tests/models/test_mmf_transformer.py b/tests/models/test_mmf_transformer.py index 63f0e0259..609f7152a 100644 --- a/tests/models/test_mmf_transformer.py +++ b/tests/models/test_mmf_transformer.py @@ -21,7 +21,9 @@ from mmf.utils.configuration import Configuration from mmf.utils.env import setup_imports, teardown_imports from omegaconf import OmegaConf - +from tests.test_utils import ( + skip_if_no_pytorchvideo, +) BERT_VOCAB_SIZE = 30255 ROBERTA_VOCAB_SIZE = 50265 @@ -444,6 +446,63 @@ def test_preprocessing_with_resnet_encoder(self): test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long()) + @skip_if_no_pytorchvideo + def test_preprocessing_with_mvit_encoder(self): + encoder_config = OmegaConf.create( + { + "name": "pytorchvideo", + "model_name": "mvit_base_32x3", + "random_init": True, + "drop_last_n_layers": 0, + "pooler_name": "cls", + "spatial_size": 224, + "temporal_size": 8, + "head": None, + "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], + "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], + "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], + "pool_kv_stride_adaptive": [1, 8, 8], + "pool_kvq_kernel": [3, 3, 3], + } + ) + self._image_modality_config = MMFTransformerModalityConfig( + type="image", + key="image", + embedding_dim=768, + position_dim=1, + segment_id=0, + encoder=encoder_config, + ) + modalities_config = [self._image_modality_config, self._text_modality_config] + config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) + mmft = build_model(config) + + sample_list = SampleList() + sample_list.image = torch.rand((2, 3, 8, 224, 224)) + sample_list.text = torch.randint(0, 512, (2, 128)) + + transformer_input = mmft.preprocess_sample(sample_list) + input_ids = transformer_input["input_ids"] + self.assertEqual(input_ids["image"].dim(), 3) + self.assertEqual(list(input_ids["image"].size()), [2, 1, 768]) + + self.assertEqual(input_ids["text"].dim(), 2) + self.assertEqual(list(input_ids["text"].size()), [2, 128]) + + position_ids = transformer_input["position_ids"] + test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]])) + test_utils.compare_tensors( + position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128)) + ) + + masks = transformer_input["masks"] + test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]])) + test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long()) + + segment_ids = transformer_input["segment_ids"] + test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) + test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long()) + def test_tie_mlm_head_weight_to_encoder(self): self._text_modality_config = MMFTransformerModalityConfig( type="text", diff --git a/tests/modules/test_encoders.py b/tests/modules/test_encoders.py index 25e8063c0..bf405cb46 100644 --- a/tests/modules/test_encoders.py +++ b/tests/modules/test_encoders.py @@ -6,7 +6,11 @@ import torch from mmf.modules import encoders from omegaconf import OmegaConf -from tests.test_utils import setup_proxy, skip_if_old_transformers +from tests.test_utils import ( + setup_proxy, + skip_if_old_transformers, + skip_if_no_pytorchvideo, +) from torch import nn @@ -102,3 +106,59 @@ def test_vit_encoder(self): x = torch.rand(32, 197, 768) output, _ = encoder(x) self.assertEqual(output.size(-1), config.out_dim) + + @skip_if_no_pytorchvideo + def test_pytorchvideo_slowfast_r50_encoder(self): + # instantiate video encoder from pytorchvideo + # default model is slowfast_r50 + config = OmegaConf.structured(encoders.PytorchVideoEncoder.Config()) + encoder = encoders.PytorchVideoEncoder(config) + fast = torch.rand((1, 3, 32, 224, 224)) + slow = torch.rand((1, 3, 8, 224, 224)) + output = encoder([slow, fast]) + # check output tensor is the expected feature dim size + # (bs, feature_dim) + self.assertEqual(output.size(1), 2304) + + @skip_if_no_pytorchvideo + def test_mvit_encoder(self): + config = { + "name": "pytorchvideo", + "model_name": "mvit_base_32x3", + "random_init": True, + "drop_last_n_layers": 0, + "pooler_name": "cls", + "spatial_size": 224, + "temporal_size": 8, + "head": None, + "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], + "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], + "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], + "pool_kv_stride_adaptive": [1, 8, 8], + "pool_kvq_kernel": [3, 3, 3], + } + # test bert cls pooler + encoder = encoders.PytorchVideoEncoder(OmegaConf.create(config)) + x = torch.rand((1, 3, 8, 224, 224)) + output = encoder(x) + # check output tensor is the expected feature dim size + # based on pooled attention configs + # for more details consult https://arxiv.org/pdf/2104.11227 + # and https://github.com/facebookresearch/pytorchvideo/ + # (bs, num_features, feature_dim) + self.assertEqual(output.shape, torch.Size([1, 768])) + + # test avg pooler + encoder = encoders.PytorchVideoEncoder( + OmegaConf.create(dict(config, pooler_name="avg")) + ) + output = encoder(x) + self.assertEqual(output.shape, torch.Size([1, 768])) + + # test no pooling + encoder = encoders.PytorchVideoEncoder( + OmegaConf.create(dict(config, pooler_name="identity")) + ) + output = encoder(x) + # (bs, num_features, feature_dim) + self.assertEqual(output.shape, torch.Size([1, 197, 768])) diff --git a/tests/test_utils.py b/tests/test_utils.py index c35433164..d48558543 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -102,6 +102,13 @@ def wrap(testfn, reason="Requires newer version of transformers"): return wrap +def skip_if_no_pytorchvideo(testfn, reason="Requires pytorchvideo"): + import importlib + + pytorchvideo_spec = importlib.util.find_spec("pytorchvideo") + return unittest.skipIf(pytorchvideo_spec is None, reason)(testfn) + + def compare_state_dicts(a, b): same = True same = same and (list(a.keys()) == list(b.keys()))