Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix_mbart_tied_weights #26422

Merged
merged 2 commits into from
Sep 28, 2023
Merged

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Sep 26, 2023

What does this PR do ?

Fixes #26266. This PR fixes the tied weights for mbart model. Before this PR, only lm_head was tied to model.shared. Now, we also make sure to tie model.encoder.embed_tokens and model.decoder.embed_tokens to model.shared by defining the _tie_weights method which will be called when we do model.tie_weights(). I've checked that we get the same weights at the end. This issue only happens when we load with safetensors + device_map because we don't save the shared tensors and the weights are on the meta device.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 26, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Hmm I'm not sure if all mBART weights share all weight matrices with each other.

We should make sure that at least for all the following models:
https://huggingface.co/models?other=mbart&sort=trending&search=facebook
all three embedding matrices are identical (I'm not sure this is always the case e.g. for multi-lingual ones)

@LysandreJik
Copy link
Member

Yes I think we should look at the config.tie_word_embeddings value and adapt accordingly. See recent PR on FSMT: #26292

@SunMarc
Copy link
Member Author

SunMarc commented Sep 27, 2023

Thanks for the link @LysandreJik. I've updated the code and added a test.

@LysandreJik
Copy link
Member

As @patrickvonplaten was saying could you also quickly verify that it works with the most downloaded mbart models on the Hub? When doing the FSMT change I ended up breaking a few FSMT models on the Hub, let's try to prevent this here 😁

Thanks for your help @SunMarc

@SunMarc
Copy link
Member Author

SunMarc commented Sep 28, 2023

Hi @LysandreJik , I confirm that for the most downloaded mbart models on the hub, all three embedding matrices are identical. Here's the snippet that I used:

from transformers import AutoModelForSeq2SeqLM
models = ["facebook/mbart-large-50-many-to-many-mmt", "facebook/mbart-large-50-many-to-one-mmt", "facebook/mbart-large-50-one-to-many-mmt","facebook/mbart-large-50","facebook/mbart-large-cc25","facebook/mbart-large-en-ro","facebook/mgenre-wiki"]
for model_id in models:
    for safetensors in [True, False]:
        for device_map in ["auto", None]:
            try:
                model = AutoModelForSeq2SeqLM.from_pretrained(model_id, use_safetensors=safetensors, device_map=device_map)
            except:
                print(f"{model_id} failed to load with safetensors={safetensors} and device_map={device_map}")
            assert len(
                {
                    model.get_output_embeddings().weight.data_ptr(),
                    model.get_input_embeddings().weight.data_ptr(),
                    model.base_model.decoder.embed_tokens.weight.data_ptr(),
                    model.base_model.encoder.embed_tokens.weight.data_ptr(),
                }
            ) == 1, "Embeddings are not tied in {}".format(model_id)

@LysandreJik
Copy link
Member

Thanks a lot @SunMarc !

@LysandreJik LysandreJik merged commit 5e11d72 into huggingface:main Sep 28, 2023
@BramVanroy
Copy link
Collaborator

Yay, that works. Thanks a lot everyone!

blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* fix_mbart_tied_weights

* add test
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* fix_mbart_tied_weights

* add test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Uninitialized token embeddings MBART when using device_map
5 participants