Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dynamic resolution for vivit #30630

Merged
merged 19 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
b03c87a
feat: enable dynamic resolution for vivit
jla524 May 3, 2024
765c8e6
fix: formatting
jla524 May 3, 2024
4c18598
remove: print statement for testing
jla524 May 4, 2024
ccd1822
Merge branch 'main' of github.com:huggingface/transformers into vivit…
jla524 May 7, 2024
3022649
Update src/transformers/models/vivit/modeling_vivit.py
jla524 May 8, 2024
d46b1a4
Update src/transformers/models/vivit/modeling_vivit.py
jla524 May 8, 2024
8299ae9
Update src/transformers/models/vivit/modeling_vivit.py
jla524 May 8, 2024
47b6e9c
Update tests/models/vivit/test_modeling_vivit.py
jla524 May 8, 2024
976664d
Merge branch 'vivit_dynamic_resolution' of github.com:jla524/transfor…
jla524 May 8, 2024
eec7f61
Update tests/models/vivit/test_modeling_vivit.py
jla524 May 8, 2024
d9a0626
Merge branch 'vivit_dynamic_resolution' of github.com:jla524/transfor…
jla524 May 8, 2024
b342bab
Update src/transformers/models/vivit/modeling_vivit.py
jla524 May 8, 2024
eed482a
Update tests/models/vivit/test_modeling_vivit.py
jla524 May 8, 2024
12e8aa1
Update src/transformers/models/vivit/modeling_vivit.py
jla524 May 8, 2024
c5df2fb
Update src/transformers/models/vivit/modeling_vivit.py
jla524 May 8, 2024
da54336
Update src/transformers/models/vivit/modeling_vivit.py
jla524 May 8, 2024
23d2812
Update src/transformers/models/vivit/modeling_vivit.py
jla524 May 8, 2024
851ff5b
fix: style check
jla524 May 8, 2024
810eba9
Merge branch 'vivit_dynamic_resolution' of github.com:jla524/transfor…
jla524 May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions src/transformers/models/vivit/modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ def __init__(self, config):
config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
)

def forward(self, pixel_values):
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
batch_size, num_frames, num_channels, height, width = pixel_values.shape
if height != self.image_size or width != self.image_size:
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
f"Image image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)

# permute to (batch_size, num_channels, num_frames, height, width)
Expand Down Expand Up @@ -102,16 +103,50 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config

def forward(self, pixel_values):
batch_size = pixel_values.shape[0]
embeddings = self.patch_embeddings(pixel_values)
def interpolate_pos_encoding(self, embeddings, height, width):
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.

cls_tokens = self.cls_token.tile([batch_size, 1, 1])
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
batch_size, num_frames, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

cls_tokens = self.cls_token.tile([batch_size, 1, 1])
embeddings = torch.cat((cls_tokens, embeddings), dim=1)

# add positional encoding to each token
embeddings = embeddings + self.position_embeddings
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings

embeddings = self.dropout(embeddings)

Expand Down Expand Up @@ -437,6 +472,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -482,6 +519,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
r"""
Expand Down Expand Up @@ -571,7 +609,7 @@ def forward(

head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(pixel_values)
embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

encoder_outputs = self.encoder(
embedding_output,
Expand All @@ -596,8 +634,18 @@ def forward(


@add_start_docstrings(
"""ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
[CLS] token) e.g. for Kinetics-400.""",
"""
ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
[CLS] token) e.g. for Kinetics-400.

<Tip>

Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.

</Tip>
""",
VIVIT_START_DOCSTRING,
)
class VivitForVideoClassification(VivitPreTrainedModel):
Expand All @@ -622,6 +670,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]:
r"""
Expand Down Expand Up @@ -715,6 +764,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down
23 changes: 23 additions & 0 deletions tests/models/vivit/test_modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,26 @@ def test_inference_for_video_classification(self):
expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))

@slow
def test_inference_interpolate_pos_encoding(self):
# Vivit models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device)

image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2")
video = prepare_video()
inputs = image_processor(
video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt"
)
pixel_values = inputs.pixel_values.to(torch_device)

# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)

# verify the logits shape
expected_shape = torch.Size((1, 3137, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)