-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] ViViT interpolate_pos_encoding #33815
[Fix] ViViT interpolate_pos_encoding #33815
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
Main comment is about the test
@@ -363,9 +363,7 @@ def test_inference_interpolate_pos_encoding(self): | |||
|
|||
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2") | |||
video = prepare_video() | |||
inputs = image_processor( | |||
video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
crop_size option should still be included in the test, as this will force the interpolation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, let me push the commit in a second
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done 👍
@@ -104,9 +104,10 @@ def __init__(self, config): | |||
torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size) | |||
) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.patch_size = config.tubelet_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to set this correctly to a patch_size i.e. a tuple of len 2, rather than assign it the tublet size
self.patch_size = config.tubelet_size | |
self.patch_size = config.tubelet_size[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pushed the changes 👍
@amyeroberts All green and suggestions are pushed. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@amyeroberts Weirdly the slow test is passing locally but in the
I pushed a commit which could fix it but still a little doubtful |
Hey! 🤗 Thanks for your contribution to the Before merging this pull request, slow tests CI should be triggered. To enable this:
(For maintainers) The documentation for slow tests CI on PRs is here. |
|
||
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2") | ||
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change the checkpoint and the crop_size in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Changing the checkpoint was a test to see if that fixes the test failure in
PR Slow CI
that you triggered. (The one I mentioned above) - different
crop_size
(s) leads to an error during the calling of interpolation method for example when thecrop_size
wascrop_size={"height": 480, "width": 480}
the following error occurs:
# add positional encoding to each token
if interpolate_pos_encoding:
> embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
E RuntimeError: The size of tensor a (14401) must match the size of tensor b (901) at non-singleton dimension 1
same happens with some other crop sizes as well. But the error doesn't occur for crop_size
like 232 or 228 or even the default crop_size
224
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, thanks for explaining. The error shouldn't be triggered for the default crop_size
value (no interpolation should happen) but if it works for these none default values then it's all good :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default crop_size
value is 224, right?!, in all the image processing files. The error doesn't occur for that value tho. This value is only given in the test file so that the error doesn't occur & for the sake of testing a value apart from the default one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
* fix:test_inference_interpolate_pos_encoding * style:make style;make fixup * test: add suggestion to test_modeling_vivit * chore:add suggestions * style:make style * [run_slow] vivit * ci:slow test fix * [run_slow] vivit
* fix:test_inference_interpolate_pos_encoding * style:make style;make fixup * test: add suggestion to test_modeling_vivit * chore:add suggestions * style:make style * [run_slow] vivit * ci:slow test fix * [run_slow] vivit
* fix:test_inference_interpolate_pos_encoding * style:make style;make fixup * test: add suggestion to test_modeling_vivit * chore:add suggestions * style:make style * [run_slow] vivit * ci:slow test fix * [run_slow] vivit
What does this PR do?
Fixes #33814
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@amyeroberts