Skip to content

Commit

Permalink
[Fix] ViViT interpolate_pos_encoding (huggingface#33815)
Browse files Browse the repository at this point in the history
* fix:test_inference_interpolate_pos_encoding

* style:make style;make fixup

* test: add suggestion to test_modeling_vivit

* chore:add suggestions

* style:make style

* [run_slow] vivit

* ci:slow test fix

* [run_slow] vivit
  • Loading branch information
RUFFY-369 authored and BernardZach committed Dec 5, 2024
1 parent 42520fe commit 1b7f9ad
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/vivit/modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def __init__(self, config):
torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.tubelet_size[1:]
self.config = config

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
Expand All @@ -129,8 +130,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:

dim = embeddings.shape[-1]

new_height = height // self.patch_size
new_width = width // self.patch_size
new_height = height // self.patch_size[0]
new_width = width // self.patch_size[1]

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/vivit/test_modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ def test_inference_interpolate_pos_encoding(self):
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device)
model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400").to(torch_device)

image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2")
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
video = prepare_video()
inputs = image_processor(
video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt"
video, size={"shortest_edge": 480}, crop_size={"height": 232, "width": 232}, return_tensors="pt"
)
pixel_values = inputs.pixel_values.to(torch_device)

Expand Down

0 comments on commit 1b7f9ad

Please sign in to comment.