Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add I-JEPA #33125

Merged
merged 47 commits into from
Dec 5, 2024
Merged

Add I-JEPA #33125

merged 47 commits into from
Dec 5, 2024

Conversation

jmtzt
Copy link
Contributor

@jmtzt jmtzt commented Aug 26, 2024

What does this PR do?

This PR adds I-JEPA.

To-Do's:

  • convert remaining checkpoints ijepa_vith14_22k, ijepa_vith16_1k, ijepa_vitg16_22k.
  • transfer checkpoints to the meta org

@ArthurZucker
Copy link
Collaborator

cc @amyeroberts and @qubvel

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this model @jmtzt!

Overall it's looking great. Did an initial review outlining some small things to update.

Before merging, we'll need to make sure the slow model tests are passing. To trigger these, you'll need to push a commit (empty or otherwise) with the message [run_slow] ijpea. Me or another person at HF will then need to approve the workflow

docs/source/en/model_doc/ijepa.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/ijepa.md Outdated Show resolved Hide resolved
src/transformers/models/ijepa/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/ijepa/configuration_ijepa.py Outdated Show resolved Hide resolved
tests/models/ijepa/test_modeling_ijepa.py Outdated Show resolved Hide resolved
src/transformers/models/ijepa/modeling_ijepa.py Outdated Show resolved Hide resolved
@jmtzt
Copy link
Contributor Author

jmtzt commented Aug 28, 2024

Thanks for reviewing the code, @amyeroberts! :)

I've just pushed the adjustments, and now the fx support seems to be working fine.

@jmtzt jmtzt requested a review from amyeroberts August 30, 2024 17:08
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating - looks great!

Just a small nit. It seems the slow model tests didn't run -- not really sure what happened there. Could you try with [run-slow] ijepa instead of [run_slow] ijepa?

After than, just a merge conflict to resolve and we're good to go!

src/transformers/models/ijepa/modeling_ijepa.py Outdated Show resolved Hide resolved
@jmtzt
Copy link
Contributor Author

jmtzt commented Aug 30, 2024

Thanks again for checking @amyeroberts :) I've just pushed the adjustments, however some tests appear to be failing due to some OS errors inside CircleCI, do you know why's that?

@amyeroberts
Copy link
Collaborator

Thanks for pushing again! Hmmmm - I'm not sure why the slow model tests aren't being picked up or run here cc @ydshieh

@amyeroberts
Copy link
Collaborator

Yih-Dar found the issue - the actions won't be triggered whilst there are merge conflicts in the PR. Could you resolve these, then push another [run-slow] ijepa commit? This should hopefully run the workflow!

@jmtzt
Copy link
Contributor Author

jmtzt commented Nov 18, 2024

the original I-JEPA model doesn't have the pooling layer, so I think to get around this we might need to default the add_pooling_layer to False in its initialization, and modify this snippet accordingly to get the last hidden states rather than the pooler_output

@qubvel
Copy link
Member

qubvel commented Nov 18, 2024

Ok, sounds good 👍

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jmtzt! Seems like almost everything is fine! Thanks for correcting snippets in docs and model cards! Snippets work fine on my side, I also did an experiment fine-tuning classification model - converges really good.

A few final nits regarding docstrings/constants and a suggestion regarding classification head and I will pass it to a core maintainers review, thank you for the great job!

src/transformers/models/ijepa/modeling_ijepa.py Outdated Show resolved Hide resolved
src/transformers/models/ijepa/modeling_ijepa.py Outdated Show resolved Hide resolved
src/transformers/models/ijepa/modeling_ijepa.py Outdated Show resolved Hide resolved
tests/models/ijepa/test_modeling_ijepa.py Show resolved Hide resolved
docs/source/en/model_doc/ijepa.md Show resolved Hide resolved
@jmtzt
Copy link
Contributor Author

jmtzt commented Nov 19, 2024

hi @qubvel, thanks for the support and reviewing the PR! Just pushed the suggested changes, let me know if it's alright now.

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing all the comments quickly! IMO it's ready for the next review

P.S. It may take a while since there are a lot of PRs for the final review in a line 🤗

Comment on lines +677 to +680
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we can delete _IMAGE_CLASS_CHECKPOINT and _IMAGE_CLASS_EXPECTED_OUTPUT because we don't really have pretrained checkpoint for such a model

return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it is not overwritten in modeling, while it exists in a modular file.. maybe a modular converter issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the modular_model_converter.py implementation, line 472, we have something like:

These top-level variables will always use the value in the modular_xxx.py file
ASSIGNMENTS_TO_KEEP = {"_CHECKPOINT_FOR_DOC", }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, maybe we have to add "_EXPECTED_OUTPUT_SHAPE" too

@qubvel qubvel requested a review from ArthurZucker November 19, 2024 14:47
@qubvel
Copy link
Member

qubvel commented Nov 19, 2024

@ArthurZucker please review whenever you have bandwidth! The model is similar to ViT, so the Modular is used here to remove the CLS token.

Checkpoints can be found here:
https://huggingface.co/jmtzt

(we will need to transfer them to facebook org as soon as we ensure the code is in the final stage + rename all occurrences in code and model cards)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! Modular makes it easy to understand: basically VIT but no cls token embedding right? (checked that no models in transformers already has this!)

src/transformers/models/auto/configuration_auto.py Outdated Show resolved Hide resolved
src/transformers/models/ijepa/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/ijepa/modular_ijepa.py Outdated Show resolved Hide resolved
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why this was resolved?

Comment on lines 207 to 212
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = IJepaEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = IJepaPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = IJepaEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = IJepaPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)

this should be the only thing you have to change as all the other classes are the exact same

src/transformers/models/ijepa/modular_ijepa.py Outdated Show resolved Hide resolved
src/transformers/models/ijepa/modular_ijepa.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO you should use either copied form or inheritance for the tests as the model is so similar to VIT models. The most important test to add is an integration tests!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example of inheritance in tests with Gemma2 !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are already integration tests implemented, quite similar to the ViT ones actually, the only one missing is the test_inference_image_classification_head, and IMO it doesn't make sense to have this, as we don't have a checkpoint for image classification tasks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also for your comment on the interpolate_pos_encoding, my implementation is different, to account for the lack of a CLS token...

num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not even have to right it in the modular as it should be inherited!

@ArthurZucker ArthurZucker removed the request for review from amyeroberts November 20, 2024 18:27
@ArthurZucker
Copy link
Collaborator

Feel free to ping me again for another review! 🫡

@jmtzt jmtzt requested a review from ArthurZucker November 26, 2024 15:35
@jmtzt
Copy link
Contributor Author

jmtzt commented Nov 30, 2024

Feel free to ping me again for another review! 🫡

@ArthurZucker I've pushed the adjustments according to your comments, let me know if anything is missing. Thanks! :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Camel casing is a bit wrong but would look super ugly otherwise!

@ArthurZucker ArthurZucker merged commit 50189e3 into huggingface:main Dec 5, 2024
22 checks passed
@qubvel
Copy link
Member

qubvel commented Dec 5, 2024

@ArthurZucker can you transfer checkpoints to Facebook org? Also, code snippets need to be adjusted.
Checkpoints are here: https://huggingface.co/jmtzt

@qubvel
Copy link
Member

qubvel commented Dec 5, 2024

@jmtzt congratulations on the model merged 🎉 was glad to collaborate with you on this!
Can you please confirm if we can transfer checkpoints to Facebook org from your account?

@jmtzt
Copy link
Contributor Author

jmtzt commented Dec 5, 2024

thanks for the support @qubvel and @ArthurZucker :)

Sure, go ahead!

@NielsRogge NielsRogge mentioned this pull request Dec 8, 2024
@qubvel qubvel mentioned this pull request Jan 13, 2025
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants