Skip to content

Commit

Permalink
[docs] Add pytorchvideo docs
Browse files Browse the repository at this point in the history
Add pytorchvideo tutorial docs for using a
pytorchvideo model as an encoder with
TorchVideoEncoder class.

ghstack-source-id: 9cc594fc33357f86b45efe5db14e04f50f71804c
Pull Request resolved: #1163
  • Loading branch information
Ryan-Qiyu-Jiang committed Dec 1, 2021
1 parent 81e2e43 commit 74ad9d0
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 6 deletions.
27 changes: 26 additions & 1 deletion mmf/modules/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any
from typing import Any, Optional

import torch
import torchvision
Expand Down Expand Up @@ -773,6 +773,31 @@ def filter_dict_to_signature(self, callable, params):
return accepted_params, ignored_params


@registry.register_encoder("mvit")
class MViTEncoder(Encoder):
"""
MVIT from pytorchvideo
"""

@dataclass
class Config(Encoder.Config):
name: str = "mvit"
random_init: bool = False
model_name: str = "multiscale_vision_transformers"
spatial_size: int = 224
temporal_size: int = 8
head: Optional[Any] = None

def __init__(self, config: Config):
super().__init__()
self.encoder = TorchVideoEncoder(config)

def forward(self, *args, **kwargs):
output = self.encoder(*args, **kwargs)
output = output.permute(0, 2, 1)
return output[:, :1, :]


@registry.register_encoder("r2plus1d_18")
class R2Plus1D18VideoEncoder(PooledEncoder):
"""
Expand Down
55 changes: 54 additions & 1 deletion tests/models/test_mmf_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from mmf.utils.configuration import Configuration
from mmf.utils.env import setup_imports, teardown_imports
from omegaconf import OmegaConf

from tests.test_utils import (
skip_if_no_pytorchvideo,
)

BERT_VOCAB_SIZE = 30255
ROBERTA_VOCAB_SIZE = 50265
Expand Down Expand Up @@ -444,6 +446,57 @@ def test_preprocessing_with_resnet_encoder(self):
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())

@skip_if_no_pytorchvideo
def test_preprocessing_with_mvit_encoder(self):
encoder_config = OmegaConf.create(
{
"name": "mvit",
"model_name": "multiscale_vision_transformers",
"random_init": True,
"cls_layer_num": 0,
"spatial_size": 224,
"temporal_size": 8,
"head": None,
}
)
self._image_modality_config = MMFTransformerModalityConfig(
type="image",
key="image",
embedding_dim=12545,
position_dim=1,
segment_id=0,
encoder=encoder_config,
)
modalities_config = [self._image_modality_config, self._text_modality_config]
config = MMFTransformer.Config(modalities=modalities_config, num_labels=2)
mmft = build_model(config)

sample_list = SampleList()
sample_list.image = torch.rand((2, 3, 8, 224, 224))
sample_list.text = torch.randint(0, 512, (2, 128))

transformer_input = mmft.preprocess_sample(sample_list)
input_ids = transformer_input["input_ids"]
self.assertEqual(input_ids["image"].dim(), 3)
self.assertEqual(list(input_ids["image"].size()), [2, 1, 12545])

self.assertEqual(input_ids["text"].dim(), 2)
self.assertEqual(list(input_ids["text"].size()), [2, 128])

position_ids = transformer_input["position_ids"]
test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(
position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))
)

masks = transformer_input["masks"]
test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

segment_ids = transformer_input["segment_ids"]
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())

def test_tie_mlm_head_weight_to_encoder(self):
self._text_modality_config = MMFTransformerModalityConfig(
type="text",
Expand Down
8 changes: 4 additions & 4 deletions tests/modules/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def test_torchvision_slowfast_r50_encoder(self):
self.assertEqual(output.size(1), 2304)

@skip_if_no_pytorchvideo
def test_torchvision_mvit_encoder(self):
def test_mvit_encoder(self):
config = OmegaConf.create(
{
"name": "torchvideo",
"name": "mvit",
"model_name": "multiscale_vision_transformers",
"random_init": True,
"cls_layer_num": 0,
Expand All @@ -129,7 +129,7 @@ def test_torchvision_mvit_encoder(self):
"head": None,
}
)
encoder = encoders.TorchVideoEncoder(config)
encoder = encoders.MViTEncoder(config)
x = torch.rand((1, 3, 8, 224, 224))
output = encoder(x)
self.assertEqual(output.shape, torch.Size([1, 12545, 96]))
self.assertEqual(output.shape, torch.Size([1, 1, 12545]))
101 changes: 101 additions & 0 deletions website/docs/tutorials/pytorchvideo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
---
id: pytorchvideo
title: Using Pytorchvideo
sidebar_label: Using Pytorchvideo
---

MMF is integrating with Pytorchvideo!

This means you'll be able to use Pytorchvideo models, datasets, and transforms in multimodal models from MMF.
Pytorch datasets and transforms are coming soon!

If you find PyTorchVideo useful in your work, please use the following BibTeX entry for citation.
```
@inproceedings{fan2021pytorchvideo,
author = {Haoqi Fan and Tullie Murrell and Heng Wang and Kalyan Vasudev Alwala and Yanghao Li and Yilei Li and Bo Xiong and Nikhila Ravi and Meng Li and Haichuan Yang and Jitendra Malik and Ross Girshick and Matt Feiszli and Aaron Adcock and Wan-Yen Lo and Christoph Feichtenhofer},
title = {{PyTorchVideo}: A Deep Learning Library for Video Understanding},
booktitle = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021},
note = {\url{https://pytorchvideo.org/}},
}
```

## Setup

In order to use pytorchvideo in MMF you need pytorchvideo installed.
You can install pytorchvideo by running
```
pip install pytorchvideo
```
For detailed instructions consult https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md


## Using Pytorchvideo Models in MMF

Currently Pytorchvideo models are supported as MMF encoders.
To use a Pytorchvideo model as the image encoder for your multimodal model,
use the MMF TorchVideoEncoder or write your own encoder that uses pytorchvideo directly.

The TorchVideoEncoder class is a wrapper around pytorchvideo models.
To instantiate a pytorchvideo model as an encoder you can do,

```python
from mmf.modules import encoders
from omegaconf import OmegaConfg

config = OmegaConf.create(
{
"name": "torchvideo",
"model_name": "slowfast_r50",
"random_init": True,
"cls_layer_num": 1,
}
)
encoder = encoders.TorchVideoEncoder(config)

# some video input
fast = torch.rand((1, 3, 32, 224, 224))
slow = torch.rand((1, 3, 8, 224, 224))
output = encoder([slow, fast])
```

In our config object, we specify that we want to build the `torchvideo` (name) encoder,
that we want to use the pytorchvideo model `slowfast_r50` (model_name),
without pretrained weights (`random_init: True`),
and that we want to remove the last module of the network (the transformer head) (`cls_layer_num: 1`) to just get the hidden state.
This part depends on which model you're using and what you need it for.

This encoder is usually configured from yaml through your model_config yaml.


Suppose we want to use MViT as our image encoder and we only want the first hidden state.
As the MViT model in pytorchvideo returns hidden states in format (batch size, feature dim, num features),
we want to permute the tensor and take the first feature.
To do this we can write our own encoder class in encoders.py

```python
@registry.register_encoder("mvit")
class MViTEncoder(Encoder):
"""
MVIT from pytorchvideo
"""
@dataclass
class Config(Encoder.Config):
name: str = "mvit"
random_init: bool = False
model_name: str = "multiscale_vision_transformers"
spatial_size: int = 224
temporal_size: int = 8
head: Optional[Any] = None

def __init__(self, config: Config):
super().__init__()
self.encoder = TorchVideoEncoder(config)

def forward(self, *args, **kwargs):
output = self.encoder(*args, **kwargs)
output = output.permute(0, 2, 1)
return output[:, :1, :]
```

Here we use the TorchVideoEncoder class to make our MViT model and transform the output to match what we need from an encoder.

0 comments on commit 74ad9d0

Please sign in to comment.