Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DETA save_pretrained #30326

Merged
merged 3 commits into from
Apr 22, 2024
Merged

Fix DETA save_pretrained #30326

merged 3 commits into from
Apr 22, 2024

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Apr 18, 2024

What does this PR do?

save_pretrained is not working for DetaForObjectDetection. This PR is supposed to fix this error.

A simple script to reproduce:

from transformers import DetaConfig, DetaForObjectDetection

config = DetaConfig()
model = DetaForObjectDetection(config)
model.save_pretrained("output_dir/")

or

from transformers import AutoModelForObjectDetection

model = AutoModelForObjectDetection.from_pretrained("jozhang97/deta-resnet-50")
model.save_pretrained("output_dir/")

The following error occurs:

RuntimeError: The weights trying to be saved contained shared tensors [{'model.decoder.class_embed.0.weight', 'class_embed.0.weight'}, {'model.decoder.class_embed.0.bias', 'class_embed.0.bias'}, {'model.decoder.class_embed.1.weight', 'class_embed.1.weight'}, {'model.decoder.class_embed.1.bias', 'class_embed.1.bias'}, {'model.decoder.class_embed.2.weight', 'class_embed.2.weight'}, {'model.decoder.class_embed.2.bias', 'class_embed.2.bias'}, {'model.decoder.class_embed.3.weight', 'class_embed.3.weight'}, {'class_embed.3.bias', 'model.decoder.class_embed.3.bias'}, {'model.decoder.class_embed.4.weight', 'class_embed.4.weight'}, {'model.decoder.class_embed.4.bias', 'class_embed.4.bias'}, {'model.decoder.class_embed.5.weight', 'class_embed.5.weight'}, {'class_embed.5.bias', 'model.decoder.class_embed.5.bias'}, {'class_embed.6.weight', 'model.decoder.class_embed.6.weight'}, {'class_embed.6.bias', 'model.decoder.class_embed.6.bias'}] that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qubvel qubvel changed the title Fix DETA save_pretrained [WIP] Fix DETA save_pretrained Apr 18, 2024
@qubvel qubvel changed the title [WIP] Fix DETA save_pretrained Fix DETA save_pretrained Apr 18, 2024
@qubvel
Copy link
Member Author

qubvel commented Apr 19, 2024

cc @NielsRogge

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@NielsRogge NielsRogge requested a review from amyeroberts April 22, 2024 07:46
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and adding a test!

Just a small comment on removing the exception

# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't really be raising exceptions in a test. Exceptions are for terminating out of code when there's an incorrect input / code behaviour which we can then choose to handle. In tests, we're really performing sanity checks, which should always be True given the test.

Either we should test this behaviour with an assert or just remove.

Copy link
Member Author

@qubvel qubvel Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Amy, I fixed this test case and the original one too 👍 (replaced with assert)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and iterating!

@amyeroberts amyeroberts merged commit 13b3b90 into huggingface:main Apr 22, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants