-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix_mbart_tied_weights #26422
fix_mbart_tied_weights #26422
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Hmm I'm not sure if all mBART weights share all weight matrices with each other. We should make sure that at least for all the following models: |
Yes I think we should look at the |
Thanks for the link @LysandreJik. I've updated the code and added a test. |
As @patrickvonplaten was saying could you also quickly verify that it works with the most downloaded mbart models on the Hub? When doing the FSMT change I ended up breaking a few FSMT models on the Hub, let's try to prevent this here 😁 Thanks for your help @SunMarc |
Hi @LysandreJik , I confirm that for the most downloaded mbart models on the hub, all three embedding matrices are identical. Here's the snippet that I used: from transformers import AutoModelForSeq2SeqLM
models = ["facebook/mbart-large-50-many-to-many-mmt", "facebook/mbart-large-50-many-to-one-mmt", "facebook/mbart-large-50-one-to-many-mmt","facebook/mbart-large-50","facebook/mbart-large-cc25","facebook/mbart-large-en-ro","facebook/mgenre-wiki"]
for model_id in models:
for safetensors in [True, False]:
for device_map in ["auto", None]:
try:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, use_safetensors=safetensors, device_map=device_map)
except:
print(f"{model_id} failed to load with safetensors={safetensors} and device_map={device_map}")
assert len(
{
model.get_output_embeddings().weight.data_ptr(),
model.get_input_embeddings().weight.data_ptr(),
model.base_model.decoder.embed_tokens.weight.data_ptr(),
model.base_model.encoder.embed_tokens.weight.data_ptr(),
}
) == 1, "Embeddings are not tied in {}".format(model_id) |
Thanks a lot @SunMarc ! |
Yay, that works. Thanks a lot everyone! |
* fix_mbart_tied_weights * add test
* fix_mbart_tied_weights * add test
What does this PR do ?
Fixes #26266. This PR fixes the tied weights for mbart model. Before this PR, only
lm_head
was tied tomodel.shared
. Now, we also make sure to tiemodel.encoder.embed_tokens
andmodel.decoder.embed_tokens
tomodel.shared
by defining the_tie_weights
method which will be called when we domodel.tie_weights()
. I've checked that we get the same weights at the end. This issue only happens when we load withsafetensors
+device_map
because we don't save the shared tensors and the weights are on the meta device.