diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index f6444999ac12..0be169a51b27 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -841,21 +841,20 @@ def __init__(self, config, num_patches): def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: """ - This method is a modified version of the interpolation function for ViT-mae model at the deocder, that + This method is a modified version of the interpolation function for ViT-mae model at the decoder, that allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher resolution images. - Source: + Adapted from: https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ # -1 removes the class dimension since we later append it without interpolation embeddings_positions = embeddings.shape[1] - 1 - num_positions = self.decoder_pos_embed.shape[1] - 1 # Separation of class token and patch tokens - class_pos_embed = self.decoder_pos_embed[:, 0, :] - patch_pos_embed = self.decoder_pos_embed[:, 1:, :] + class_pos_embed = self.decoder_pos_embed[:, :1] + patch_pos_embed = self.decoder_pos_embed[:, 1:] # To retain the final 3d tensor with the required dimensions dim = self.decoder_pos_embed.shape[-1] @@ -867,10 +866,10 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) # Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x). - # 1 keeps the other dimension constant + # we keep the second last dimension constant patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(1, embeddings_positions / num_positions), + size=(patch_pos_embed.shape[-2], embeddings_positions), mode="bicubic", align_corners=False, ) @@ -878,7 +877,7 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: # Converting back to the original shape patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) # Adding the class token back - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def initialize_weights(self, num_patches): # initialize (and freeze) position embeddings by sin-cos embedding diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index 6020edca81a7..5cff9616e004 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -298,12 +298,16 @@ class ViTMAEModelIntegrationTest(unittest.TestCase): def default_image_processor(self): return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") + @cached_property + def default_model(self): + return ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device) + @slow def test_inference_for_pretraining(self): # make random mask reproducible across the PT and TF model np.random.seed(2) - model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device) + model = self.default_model image_processor = self.default_image_processor image = prepare_img() @@ -313,11 +317,11 @@ def test_inference_for_pretraining(self): # (this way we can ensure that the PT and TF models operate on the same inputs) vit_mae_config = ViTMAEConfig() num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2) - noise = np.random.uniform(size=(1, num_patches)) + noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device) # forward pass with torch.no_grad(): - outputs = model(**inputs, noise=torch.from_numpy(noise).to(device=torch_device)) + outputs = model(**inputs, noise=noise) # verify the logits expected_shape = torch.Size((1, 196, 768)) @@ -339,7 +343,7 @@ def test_inference_interpolate_pos_encoding(self): # make random mask reproducible across the PT and TF model np.random.seed(2) - model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device) + model = self.default_model image_processor = self.default_image_processor image = prepare_img() @@ -349,14 +353,38 @@ def test_inference_interpolate_pos_encoding(self): # (this way we can ensure that the PT and TF models operate on the same inputs) vit_mae_config = ViTMAEConfig() num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size) - noise = np.random.uniform(size=(1, num_patches)) + noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 1200, 768)) + self.assertEqual(outputs.logits.shape, expected_shape) + + @slow + def test_inference_interpolate_pos_encoding_custom_sizes(self): + # Ensure custom sizes are correctly handled when interpolating the position embeddings + + # make random mask reproducible across the PT and TF model + np.random.seed(2) + + model = self.default_model + image_processor = self.default_image_processor + + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt", size={"height": 256, "width": 256}).to( + torch_device + ) # forward pass with torch.no_grad(): outputs = model( - **inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True + **inputs, + interpolate_pos_encoding=True, ) # verify the logits - expected_shape = torch.Size((1, 1200, 768)) + expected_shape = torch.Size((1, 256, 768)) self.assertEqual(outputs.logits.shape, expected_shape)