Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PerSAM #23652

Closed
wants to merge 23 commits into from
Closed

Add PerSAM #23652

wants to merge 23 commits into from

Conversation

NielsRogge
Copy link
Contributor

What does this PR do?

This PR adds the PerSAM model.

Question: when you do:

from transformers import PerSamModel

model = PerSamModel.from_pretrained("facebook/sam-vit-huge")

you get this warning:

You are using a model of type sam to instantiate a model of type persam. This is not supported for all configurations of models and can yield errors.

was wondering whether we could suppress this warning. PerSAM uses the exact same weights as the original SAM model, just modifies the forward pass with 2 additional arguments. Currently the model_type is set to "persam" in PerSamConfig.

@NielsRogge NielsRogge requested a review from sgugger May 22, 2023 12:14
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 22, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this. If the modification works with all checkpoints for SAM then it passes the test for which we don't require a new model. So in this instance it's fine to add a new config argument to SAM to activate persam mode and adapt the forward, which will also get rid of your warning.

Comment on lines +579 to +584
self.layer_norm1 = SamLayerNorm(
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
)
self.layer_norm2 = SamLayerNorm(
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada can you confirm if this a mistake in the SAM implementation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am unsure if this is needed, @NielsRogge could you elaborate more on why this change is needed? 🙏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I had to include this fix to make input_masks work. I noticed that input_masks is not tested in tests/models/test_modeling_sam.py (only input_points, input_labels and input_boxes are).

Might be good to add a test for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! Indeed we didn't tested input masks!

@NielsRogge
Copy link
Contributor Author

Ok, will close this PR in favor of modifying modeling_sam.py.

@NielsRogge NielsRogge closed this May 22, 2023
@NielsRogge NielsRogge mentioned this pull request May 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants