diff --git a/mmf/modules/encoders.py b/mmf/modules/encoders.py index 8a67bb1cc..98948e4e7 100644 --- a/mmf/modules/encoders.py +++ b/mmf/modules/encoders.py @@ -1,10 +1,12 @@ # Copyright (c) Facebook, Inc. and its affiliates. +import importlib +import logging import os import pickle import re from collections import OrderedDict from copy import deepcopy -from dataclasses import dataclass +from dataclasses import asdict, dataclass from enum import Enum from typing import Any @@ -25,13 +27,15 @@ from transformers.configuration_auto import AutoConfig from transformers.modeling_auto import AutoModel - try: from detectron2.modeling import ShapeSpec, build_resnet_backbone except ImportError: pass +logger = logging.getLogger() + + class Encoder(nn.Module): @dataclass class Config: @@ -688,6 +692,89 @@ def forward(self, x: Tensor) -> Tensor: return out +@registry.register_encoder("pytorchvideo") +class PytorchVideoEncoder(Encoder): + """A thin wrapper around pytorchvideo models. + This class is responsible for integrating pytorchvideo models as encoders. + THis class attempts to construct a pytorchvideo model from torch hub. + If this fails for a random weight model, and pytorchvideo package is available, + build the model with random weights from pytorchvideo.models. + + Config: + name (str): Always 'pytorchvideo' Used for builder_encoder() + random_init (bool): Flag to load pretrained weights + model_name (str): Name of the pytorchvideo model to use + drop_last_n_layers (int): + <=0 value for the number of layers to drop off the end + pooler_name (str): Name of pooler used on model output + + Raises: + ImportError: + The constructor raises an ImportError if pytorchvideo is not installed. + """ + + @dataclass + class Config(Encoder.Config): + name: str = "pytorchvideo" + random_init: bool = False + model_name: str = "slowfast_r50" + drop_last_n_layers: int = -1 + pooler_name: str = "identity" + + PYTORCHVIDEO_REPO = "facebookresearch/pytorchvideo:main" + + def __init__(self, config: Config): + super().__init__() + config = OmegaConf.create({**asdict(self.Config()), **config}) + if config.random_init: + params = dict(**OmegaConf.to_container(config)) + params = { + k: v + for k, v in params.items() + if k not in PytorchVideoEncoder.Config().__dict__ + } + try: + model = torch.hub.load( + PytorchVideoEncoder.PYTORCHVIDEO_REPO, + model=config.model_name, + pretrained=False, + **params, + ) + except BaseException as err: + pytorchvideo_spec = importlib.util.find_spec("pytorchvideo") + if pytorchvideo_spec is None: + raise err + import pytorchvideo.models.hub as hub + + model_create_fn = getattr(hub, config.model_name) + model = model_create_fn(pretrained=False, **params) + else: + # load weights from TorchHub + model = torch.hub.load( + PytorchVideoEncoder.PYTORCHVIDEO_REPO, + model=config.model_name, + pretrained=True, + ) + encoder_list = [] + if config.drop_last_n_layers == 0: + encoder_list += [model] + else: + modules_list = list(model.children()) + if len(modules_list) == 1: + modules_list = list(modules_list[0].children()) + modules = modules_list[: config.drop_last_n_layers] + encoder_list += modules + + pooler = registry.get_pool_class(config.pooler_name)() + encoder_list += [pooler] + self.encoder = nn.Sequential(*encoder_list) + + def forward(self, *args, **kwargs): + # pass along input to model + # assumes caller obeys the dynamic model signature + return self.encoder(*args, **kwargs) + + @registry.register_encoder("r2plus1d_18") class R2Plus1D18VideoEncoder(PooledEncoder): """ diff --git a/tests/models/test_mmf_transformer.py b/tests/models/test_mmf_transformer.py index 63f0e0259..609f7152a 100644 --- a/tests/models/test_mmf_transformer.py +++ b/tests/models/test_mmf_transformer.py @@ -21,7 +21,9 @@ from mmf.utils.configuration import Configuration from mmf.utils.env import setup_imports, teardown_imports from omegaconf import OmegaConf - +from tests.test_utils import ( + skip_if_no_pytorchvideo, +) BERT_VOCAB_SIZE = 30255 ROBERTA_VOCAB_SIZE = 50265 @@ -444,6 +446,63 @@ def test_preprocessing_with_resnet_encoder(self): test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long()) + @skip_if_no_pytorchvideo + def test_preprocessing_with_mvit_encoder(self): + encoder_config = OmegaConf.create( + { + "name": "pytorchvideo", + "model_name": "mvit_base_32x3", + "random_init": True, + "drop_last_n_layers": 0, + "pooler_name": "cls", + "spatial_size": 224, + "temporal_size": 8, + "head": None, + "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], + "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], + "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], + "pool_kv_stride_adaptive": [1, 8, 8], + "pool_kvq_kernel": [3, 3, 3], + } + ) + self._image_modality_config = MMFTransformerModalityConfig( + type="image", + key="image", + embedding_dim=768, + position_dim=1, + segment_id=0, + encoder=encoder_config, + ) + modalities_config = [self._image_modality_config, self._text_modality_config] + config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) + mmft = build_model(config) + + sample_list = SampleList() + sample_list.image = torch.rand((2, 3, 8, 224, 224)) + sample_list.text = torch.randint(0, 512, (2, 128)) + + transformer_input = mmft.preprocess_sample(sample_list) + input_ids = transformer_input["input_ids"] + self.assertEqual(input_ids["image"].dim(), 3) + self.assertEqual(list(input_ids["image"].size()), [2, 1, 768]) + + self.assertEqual(input_ids["text"].dim(), 2) + self.assertEqual(list(input_ids["text"].size()), [2, 128]) + + position_ids = transformer_input["position_ids"] + test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]])) + test_utils.compare_tensors( + position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128)) + ) + + masks = transformer_input["masks"] + test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]])) + test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long()) + + segment_ids = transformer_input["segment_ids"] + test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) + test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long()) + def test_tie_mlm_head_weight_to_encoder(self): self._text_modality_config = MMFTransformerModalityConfig( type="text", diff --git a/tests/modules/test_encoders.py b/tests/modules/test_encoders.py index 25e8063c0..bf405cb46 100644 --- a/tests/modules/test_encoders.py +++ b/tests/modules/test_encoders.py @@ -6,7 +6,11 @@ import torch from mmf.modules import encoders from omegaconf import OmegaConf -from tests.test_utils import setup_proxy, skip_if_old_transformers +from tests.test_utils import ( + setup_proxy, + skip_if_old_transformers, + skip_if_no_pytorchvideo, +) from torch import nn @@ -102,3 +106,59 @@ def test_vit_encoder(self): x = torch.rand(32, 197, 768) output, _ = encoder(x) self.assertEqual(output.size(-1), config.out_dim) + + @skip_if_no_pytorchvideo + def test_pytorchvideo_slowfast_r50_encoder(self): + # instantiate video encoder from pytorchvideo + # default model is slowfast_r50 + config = OmegaConf.structured(encoders.PytorchVideoEncoder.Config()) + encoder = encoders.PytorchVideoEncoder(config) + fast = torch.rand((1, 3, 32, 224, 224)) + slow = torch.rand((1, 3, 8, 224, 224)) + output = encoder([slow, fast]) + # check output tensor is the expected feature dim size + # (bs, feature_dim) + self.assertEqual(output.size(1), 2304) + + @skip_if_no_pytorchvideo + def test_mvit_encoder(self): + config = { + "name": "pytorchvideo", + "model_name": "mvit_base_32x3", + "random_init": True, + "drop_last_n_layers": 0, + "pooler_name": "cls", + "spatial_size": 224, + "temporal_size": 8, + "head": None, + "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], + "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], + "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], + "pool_kv_stride_adaptive": [1, 8, 8], + "pool_kvq_kernel": [3, 3, 3], + } + # test bert cls pooler + encoder = encoders.PytorchVideoEncoder(OmegaConf.create(config)) + x = torch.rand((1, 3, 8, 224, 224)) + output = encoder(x) + # check output tensor is the expected feature dim size + # based on pooled attention configs + # for more details consult https://arxiv.org/pdf/2104.11227 + # and https://github.com/facebookresearch/pytorchvideo/ + # (bs, num_features, feature_dim) + self.assertEqual(output.shape, torch.Size([1, 768])) + + # test avg pooler + encoder = encoders.PytorchVideoEncoder( + OmegaConf.create(dict(config, pooler_name="avg")) + ) + output = encoder(x) + self.assertEqual(output.shape, torch.Size([1, 768])) + + # test no pooling + encoder = encoders.PytorchVideoEncoder( + OmegaConf.create(dict(config, pooler_name="identity")) + ) + output = encoder(x) + # (bs, num_features, feature_dim) + self.assertEqual(output.shape, torch.Size([1, 197, 768])) diff --git a/tests/test_utils.py b/tests/test_utils.py index c35433164..d48558543 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -102,6 +102,13 @@ def wrap(testfn, reason="Requires newer version of transformers"): return wrap +def skip_if_no_pytorchvideo(testfn, reason="Requires pytorchvideo"): + import importlib + + pytorchvideo_spec = importlib.util.find_spec("pytorchvideo") + return unittest.skipIf(pytorchvideo_spec is None, reason)(testfn) + + def compare_state_dicts(a, b): same = True same = same and (list(a.keys()) == list(b.keys()))