-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potential error in num_patches calculation in src/transformers/models/vit_mae/modeling_vit_mae.py #32410
Comments
Hi @ziyiss thanks for reporting the issue and such a detailed description! That indeed looks like a bug in the implementation, I can confirm that the following example produces different outputs for the default shape (224, 224) depending on the import torch
import requests
from PIL import Image
from transformers import AutoImageProcessor, ViTMAEForPreTraining
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained('facebook/vit-mae-base')
model = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base')
inputs = processor(images=image, return_tensors="pt")
noise = torch.rand(size=(1, 196), device=inputs.pixel_values.device)
with torch.no_grad():
outputs_no_interpolate = model(**inputs, noise=noise, interpolate_pos_encoding=False)
outputs_with_interpolate = model(**inputs, noise=noise, interpolate_pos_encoding=True)
max_diff = torch.max(torch.abs(outputs_no_interpolate.logits - outputs_with_interpolate.logits)).item()
assert max_diff < 1e-4, f"Max diff is {max_diff}" Output:
The fix you suggested works and the outputs match. This will be a breaking change for the model in case anyone uses the model with |
Hi @qubvel , I agree with your assessment that this will be a breaking change for the model, for users who have been using it with Regarding the test results you shared, I have a friendly reminder that might help in further testing(Correct me if I am wrong!:)). To further confirm the inconsistency ( Without setting a random seed before each run, inconsistent results may occur regardless of other factors, which can also lead to significant differences in model output. This effect can be verified from the following test run:
Output: After setting a random seed before both runs, able to get consistent results when passing the same inputs, no AssertionError raised:
Output: Hope this example helps! Thank you and all the developers for continuously developing and maintaining this repo! It has been incredibly helpful for learning and helpful for my ML work! |
Hi @ziyiss, thanks for the update. inputs = processor(images=image, return_tensors="pt")
noise = torch.rand(size=(1, 196), device=inputs.pixel_values.device)
with torch.no_grad():
outputs_no_interpolate = model(**inputs, noise=noise, interpolate_pos_encoding=False)
outputs_with_interpolate = model(**inputs, noise=noise, interpolate_pos_encoding=True) |
There are even more issues with the current
current implementation may lead to an error for certain image sizes: Example to reproduce: import torch
import requests
from PIL import Image
from transformers import AutoImageProcessor, ViTMAEForPreTraining
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
model = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base')
for i in range(14, 30):
for j in range(14, 30):
processor = AutoImageProcessor.from_pretrained(
'facebook/vit-mae-base',
size={"height": i * 16, "width": j * 16},
use_fast=True,
)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
try:
outputs_with_interpolate = model(**inputs, interpolate_pos_encoding=True)
except Exception as e:
print(f"Failed with interpolate_pos_encoding=True for shape {inputs.pixel_values.shape}") Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 224, 272])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 256, 256])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 256, 432])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 272, 224])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 272, 448])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 288, 352])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 288, 368])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 288, 384])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 304, 416])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 320, 352])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 352, 288])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 352, 320])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 368, 288])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 384, 288])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 416, 304])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 432, 256])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 448, 272])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 464, 464]) This is caused by using the |
Thank you @qubvel for the clarification and additional insights! I'm also very grateful for you sharing information about the other related bugs in the |
Yes - we should fix it! If I've understood correctly, this issue derives from the original implementation but we can address by switching the logic to rely on |
#33330 should fix this 👍 |
System Info
I've been examining the
interpolate_pos_encoding
method in the ViTMAE implementation, and I've discovered what appears to be an error in the calculation ofnum_patches
. This error seems to have unintended consequences on the method's behavior especially when the input image has the same size as the pretrained image. Here are my findings:num_patches
calculation:ViTMAEPatchEmbeddings.forward()
,x = self.projection(pixel_values).flatten(2).transpose(1, 2)
embeddings.shape
= [1, 1200, 768].Current:
num_patches
= 1200 - 1 = 1199 which shouldn't be the case, since the num_patches
should be 30*40 = 1200.num_patches = embeddings.shape[1] - 1
the -1 is for removing the cls token?num_patches = embeddings.shape[1] - 1
, we getnum_patches
= 196 - 1 = 195.num_patches != num_positions
, even when using the pre-trained image size.num_patches == num_positions and height == width
is never true.self.position_embeddings
directly, even when it should (i.e., when the input image size matches the pretrained size).interpolate_pos_encoding = True
and the image size matches the pretrained size.Change
num_patches = embeddings.shape[1] - 1
tonum_patches = embeddings.shape[1]
Can you confirm that this is indeed an error in the implementation? Are there any considerations I might be missing? Thank you for your time and for maintaining this amazing project!
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
The text was updated successfully, but these errors were encountered: