-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add I-JEPA #33125
Add I-JEPA #33125
Conversation
cc @amyeroberts and @qubvel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this model @jmtzt!
Overall it's looking great. Did an initial review outlining some small things to update.
Before merging, we'll need to make sure the slow model tests are passing. To trigger these, you'll need to push a commit (empty or otherwise) with the message [run_slow] ijpea
. Me or another person at HF will then need to approve the workflow
Thanks for reviewing the code, @amyeroberts! :) I've just pushed the adjustments, and now the fx support seems to be working fine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating - looks great!
Just a small nit. It seems the slow model tests didn't run -- not really sure what happened there. Could you try with [run-slow] ijepa
instead of [run_slow] ijepa
?
After than, just a merge conflict to resolve and we're good to go!
Thanks again for checking @amyeroberts :) I've just pushed the adjustments, however some tests appear to be failing due to some OS errors inside CircleCI, do you know why's that? |
Thanks for pushing again! Hmmmm - I'm not sure why the slow model tests aren't being picked up or run here cc @ydshieh |
Yih-Dar found the issue - the actions won't be triggered whilst there are merge conflicts in the PR. Could you resolve these, then push another |
the original I-JEPA model doesn't have the pooling layer, so I think to get around this we might need to default the |
Ok, sounds good 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jmtzt! Seems like almost everything is fine! Thanks for correcting snippets in docs and model cards! Snippets work fine on my side, I also did an experiment fine-tuning classification model - converges really good.
A few final nits regarding docstrings/constants and a suggestion regarding classification head and I will pass it to a core maintainers review, thank you for the great job!
hi @qubvel, thanks for the support and reviewing the PR! Just pushed the suggested changes, let me know if it's alright now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing all the comments quickly! IMO it's ready for the next review
P.S. It may take a while since there are a lot of PRs for the final review in a line 🤗
checkpoint=_IMAGE_CLASS_CHECKPOINT, | ||
output_type=ImageClassifierOutput, | ||
config_class=_CONFIG_FOR_DOC, | ||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose we can delete _IMAGE_CLASS_CHECKPOINT
and _IMAGE_CLASS_EXPECTED_OUTPUT
because we don't really have pretrained checkpoint for such a model
return_dict (`bool`, *optional*): | ||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | ||
""" | ||
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, it is not overwritten in modeling, while it exists in a modular file.. maybe a modular converter issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the modular_model_converter.py
implementation, line 472, we have something like:
These top-level variables will always use the value in the modular_xxx.py
file
ASSIGNMENTS_TO_KEEP = {"_CHECKPOINT_FOR_DOC", }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, maybe we have to add "_EXPECTED_OUTPUT_SHAPE" too
@ArthurZucker please review whenever you have bandwidth! The model is similar to ViT, so the Checkpoints can be found here: (we will need to transfer them to facebook org as soon as we ensure the code is in the final stage + rename all occurrences in code and model cards) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR! Modular makes it easy to understand: basically VIT but no cls token embedding right? (checked that no models in transformers already has this!)
num_patches = self.patch_embeddings.num_patches | ||
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) | ||
|
||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason why this was resolved?
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) | ||
self.encoder = IJepaEncoder(config) | ||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
self.pooler = IJepaPooler(config) if add_pooling_layer else None | ||
# Initialize weights and apply final processing | ||
self.post_init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) | |
self.encoder = IJepaEncoder(config) | |
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.pooler = IJepaPooler(config) if add_pooling_layer else None | |
# Initialize weights and apply final processing | |
self.post_init() | |
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) |
this should be the only thing you have to change as all the other classes are the exact same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO you should use either copied form or inheritance for the tests as the model is so similar to VIT models. The most important test to add is an integration tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example of inheritance in tests with Gemma2 !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are already integration tests implemented, quite similar to the ViT ones actually, the only one missing is the test_inference_image_classification_head
, and IMO it doesn't make sense to have this, as we don't have a checkpoint for image classification tasks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also for your comment on the interpolate_pos_encoding
, my implementation is different, to account for the lack of a CLS token...
num_patches = self.patch_embeddings.num_patches | ||
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) | ||
|
||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not even have to right it in the modular as it should be inherited!
Feel free to ping me again for another review! 🫡 |
@ArthurZucker I've pushed the adjustments according to your comments, let me know if anything is missing. Thanks! :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Camel casing is a bit wrong but would look super ugly otherwise!
@ArthurZucker can you transfer checkpoints to Facebook org? Also, code snippets need to be adjusted. |
@jmtzt congratulations on the model merged 🎉 was glad to collaborate with you on this! |
thanks for the support @qubvel and @ArthurZucker :) Sure, go ahead! |
What does this PR do?
This PR adds I-JEPA.
To-Do's:
ijepa_vith14_22k
,ijepa_vith16_1k
,ijepa_vitg16_22k
.