diff --git a/mmf/modules/encoders.py b/mmf/modules/encoders.py index 191178a7f..cebbb0a9e 100644 --- a/mmf/modules/encoders.py +++ b/mmf/modules/encoders.py @@ -9,7 +9,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from enum import Enum -from typing import Any +from typing import Any, Optional import torch import torchvision @@ -773,6 +773,31 @@ def filter_dict_to_signature(self, callable, params): return accepted_params, ignored_params +@registry.register_encoder("mvit") +class MViTEncoder(Encoder): + """ + MVIT from pytorchvideo + """ + + @dataclass + class Config(Encoder.Config): + name: str = "mvit" + random_init: bool = False + model_name: str = "multiscale_vision_transformers" + spatial_size: int = 224 + temporal_size: int = 8 + head: Optional[Any] = None + + def __init__(self, config: Config): + super().__init__() + self.encoder = TorchVideoEncoder(config) + + def forward(self, *args, **kwargs): + output = self.encoder(*args, **kwargs) + output = output.permute(0, 2, 1) + return output[:, :1, :] + + @registry.register_encoder("r2plus1d_18") class R2Plus1D18VideoEncoder(PooledEncoder): """ diff --git a/tests/models/test_mmf_transformer.py b/tests/models/test_mmf_transformer.py index 63f0e0259..45074af19 100644 --- a/tests/models/test_mmf_transformer.py +++ b/tests/models/test_mmf_transformer.py @@ -21,7 +21,9 @@ from mmf.utils.configuration import Configuration from mmf.utils.env import setup_imports, teardown_imports from omegaconf import OmegaConf - +from tests.test_utils import ( + skip_if_no_pytorchvideo, +) BERT_VOCAB_SIZE = 30255 ROBERTA_VOCAB_SIZE = 50265 @@ -444,6 +446,57 @@ def test_preprocessing_with_resnet_encoder(self): test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long()) + @skip_if_no_pytorchvideo + def test_preprocessing_with_mvit_encoder(self): + encoder_config = OmegaConf.create( + { + "name": "mvit", + "model_name": "multiscale_vision_transformers", + "random_init": True, + "cls_layer_num": 0, + "spatial_size": 224, + "temporal_size": 8, + "head": None, + } + ) + self._image_modality_config = MMFTransformerModalityConfig( + type="image", + key="image", + embedding_dim=12545, + position_dim=1, + segment_id=0, + encoder=encoder_config, + ) + modalities_config = [self._image_modality_config, self._text_modality_config] + config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) + mmft = build_model(config) + + sample_list = SampleList() + sample_list.image = torch.rand((2, 3, 8, 224, 224)) + sample_list.text = torch.randint(0, 512, (2, 128)) + + transformer_input = mmft.preprocess_sample(sample_list) + input_ids = transformer_input["input_ids"] + self.assertEqual(input_ids["image"].dim(), 3) + self.assertEqual(list(input_ids["image"].size()), [2, 1, 12545]) + + self.assertEqual(input_ids["text"].dim(), 2) + self.assertEqual(list(input_ids["text"].size()), [2, 128]) + + position_ids = transformer_input["position_ids"] + test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]])) + test_utils.compare_tensors( + position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128)) + ) + + masks = transformer_input["masks"] + test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]])) + test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long()) + + segment_ids = transformer_input["segment_ids"] + test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) + test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long()) + def test_tie_mlm_head_weight_to_encoder(self): self._text_modality_config = MMFTransformerModalityConfig( type="text", diff --git a/tests/modules/test_encoders.py b/tests/modules/test_encoders.py index 7a5f3bfdc..7c185e962 100644 --- a/tests/modules/test_encoders.py +++ b/tests/modules/test_encoders.py @@ -117,10 +117,10 @@ def test_torchvision_slowfast_r50_encoder(self): self.assertEqual(output.size(1), 2304) @skip_if_no_pytorchvideo - def test_torchvision_mvit_encoder(self): + def test_mvit_encoder(self): config = OmegaConf.create( { - "name": "torchvideo", + "name": "mvit", "model_name": "multiscale_vision_transformers", "random_init": True, "cls_layer_num": 0, @@ -129,7 +129,7 @@ def test_torchvision_mvit_encoder(self): "head": None, } ) - encoder = encoders.TorchVideoEncoder(config) + encoder = encoders.MViTEncoder(config) x = torch.rand((1, 3, 8, 224, 224)) output = encoder(x) - self.assertEqual(output.shape, torch.Size([1, 12545, 96])) + self.assertEqual(output.shape, torch.Size([1, 1, 12545])) diff --git a/website/docs/tutorials/pytorchvideo.md b/website/docs/tutorials/pytorchvideo.md new file mode 100644 index 000000000..19a1c9a60 --- /dev/null +++ b/website/docs/tutorials/pytorchvideo.md @@ -0,0 +1,101 @@ +--- +id: pytorchvideo +title: Using Pytorchvideo +sidebar_label: Using Pytorchvideo +--- + +MMF is integrating with Pytorchvideo! + +This means you'll be able to use Pytorchvideo models, datasets, and transforms in multimodal models from MMF. +Pytorch datasets and transforms are coming soon! + +If you find PyTorchVideo useful in your work, please use the following BibTeX entry for citation. +``` +@inproceedings{fan2021pytorchvideo, + author = {Haoqi Fan and Tullie Murrell and Heng Wang and Kalyan Vasudev Alwala and Yanghao Li and Yilei Li and Bo Xiong and Nikhila Ravi and Meng Li and Haichuan Yang and Jitendra Malik and Ross Girshick and Matt Feiszli and Aaron Adcock and Wan-Yen Lo and Christoph Feichtenhofer}, + title = {{PyTorchVideo}: A Deep Learning Library for Video Understanding}, + booktitle = {Proceedings of the 29th ACM International Conference on Multimedia}, + year = {2021}, + note = {\url{https://pytorchvideo.org/}}, +} +``` + +## Setup + +In order to use pytorchvideo in MMF you need pytorchvideo installed. +You can install pytorchvideo by running +``` +pip install pytorchvideo +``` +For detailed instructions consult https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md + + +## Using Pytorchvideo Models in MMF + +Currently Pytorchvideo models are supported as MMF encoders. +To use a Pytorchvideo model as the image encoder for your multimodal model, +use the MMF TorchVideoEncoder or write your own encoder that uses pytorchvideo directly. + +The TorchVideoEncoder class is a wrapper around pytorchvideo models. +To instantiate a pytorchvideo model as an encoder you can do, + +```python +from mmf.modules import encoders +from omegaconf import OmegaConfg + +config = OmegaConf.create( + { + "name": "torchvideo", + "model_name": "slowfast_r50", + "random_init": True, + "cls_layer_num": 1, + } + ) +encoder = encoders.TorchVideoEncoder(config) + +# some video input +fast = torch.rand((1, 3, 32, 224, 224)) +slow = torch.rand((1, 3, 8, 224, 224)) +output = encoder([slow, fast]) +``` + +In our config object, we specify that we want to build the `torchvideo` (name) encoder, +that we want to use the pytorchvideo model `slowfast_r50` (model_name), +without pretrained weights (`random_init: True`), +and that we want to remove the last module of the network (the transformer head) (`cls_layer_num: 1`) to just get the hidden state. +This part depends on which model you're using and what you need it for. + +This encoder is usually configured from yaml through your model_config yaml. + + +Suppose we want to use MViT as our image encoder and we only want the first hidden state. +As the MViT model in pytorchvideo returns hidden states in format (batch size, feature dim, num features), +we want to permute the tensor and take the first feature. +To do this we can write our own encoder class in encoders.py + +```python +@registry.register_encoder("mvit") +class MViTEncoder(Encoder): + """ + MVIT from pytorchvideo + """ + @dataclass + class Config(Encoder.Config): + name: str = "mvit" + random_init: bool = False + model_name: str = "multiscale_vision_transformers" + spatial_size: int = 224 + temporal_size: int = 8 + head: Optional[Any] = None + + def __init__(self, config: Config): + super().__init__() + self.encoder = TorchVideoEncoder(config) + + def forward(self, *args, **kwargs): + output = self.encoder(*args, **kwargs) + output = output.permute(0, 2, 1) + return output[:, :1, :] +``` + +Here we use the TorchVideoEncoder class to make our MViT model and transform the output to match what we need from an encoder.