Skip to content

Commit

Permalink
Fix ViT-MAE decoder interpolate (#33330)
Browse files Browse the repository at this point in the history
* Fix ViT-MAE decoder interpolate

* Add unit test for `interpolate_pos_encoding` w/ custom sizes

* [run_slow] vit_mae
  • Loading branch information
xenova authored Sep 30, 2024
1 parent 1dba608 commit 18c5b21
Showing 2 changed files with 42 additions and 15 deletions.
15 changes: 7 additions & 8 deletions src/transformers/models/vit_mae/modeling_vit_mae.py
Original file line number Diff line number Diff line change
@@ -841,21 +841,20 @@ def __init__(self, config, num_patches):

def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
This method is a modified version of the interpolation function for ViT-mae model at the decoder, that
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
resolution images.
Source:
Adapted from:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

# -1 removes the class dimension since we later append it without interpolation
embeddings_positions = embeddings.shape[1] - 1
num_positions = self.decoder_pos_embed.shape[1] - 1

# Separation of class token and patch tokens
class_pos_embed = self.decoder_pos_embed[:, 0, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
class_pos_embed = self.decoder_pos_embed[:, :1]
patch_pos_embed = self.decoder_pos_embed[:, 1:]

# To retain the final 3d tensor with the required dimensions
dim = self.decoder_pos_embed.shape[-1]
@@ -867,18 +866,18 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
# 1 keeps the other dimension constant
# we keep the second last dimension constant
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(1, embeddings_positions / num_positions),
size=(patch_pos_embed.shape[-2], embeddings_positions),
mode="bicubic",
align_corners=False,
)

# Converting back to the original shape
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
# Adding the class token back
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def initialize_weights(self, num_patches):
# initialize (and freeze) position embeddings by sin-cos embedding
42 changes: 35 additions & 7 deletions tests/models/vit_mae/test_modeling_vit_mae.py
Original file line number Diff line number Diff line change
@@ -298,12 +298,16 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")

@cached_property
def default_model(self):
return ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)

@slow
def test_inference_for_pretraining(self):
# make random mask reproducible across the PT and TF model
np.random.seed(2)

model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
model = self.default_model

image_processor = self.default_image_processor
image = prepare_img()
@@ -313,11 +317,11 @@ def test_inference_for_pretraining(self):
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2)
noise = np.random.uniform(size=(1, num_patches))
noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs, noise=torch.from_numpy(noise).to(device=torch_device))
outputs = model(**inputs, noise=noise)

# verify the logits
expected_shape = torch.Size((1, 196, 768))
@@ -339,7 +343,7 @@ def test_inference_interpolate_pos_encoding(self):
# make random mask reproducible across the PT and TF model
np.random.seed(2)

model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
model = self.default_model

image_processor = self.default_image_processor
image = prepare_img()
@@ -349,14 +353,38 @@ def test_inference_interpolate_pos_encoding(self):
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
noise = np.random.uniform(size=(1, num_patches))
noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)

# verify the logits
expected_shape = torch.Size((1, 1200, 768))
self.assertEqual(outputs.logits.shape, expected_shape)

@slow
def test_inference_interpolate_pos_encoding_custom_sizes(self):
# Ensure custom sizes are correctly handled when interpolating the position embeddings

# make random mask reproducible across the PT and TF model
np.random.seed(2)

model = self.default_model
image_processor = self.default_image_processor

image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt", size={"height": 256, "width": 256}).to(
torch_device
)

# forward pass
with torch.no_grad():
outputs = model(
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
**inputs,
interpolate_pos_encoding=True,
)

# verify the logits
expected_shape = torch.Size((1, 1200, 768))
expected_shape = torch.Size((1, 256, 768))
self.assertEqual(outputs.logits.shape, expected_shape)

0 comments on commit 18c5b21

Please sign in to comment.