-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF: GPT2 with native embedding layers #23436
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice - thanks for adding!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean to me!
I noticed this change one year later, while doing a demo of LoRA in the embedding layer :-D Nice, but I am always afraid of doing tensorflow operations outside of keras layers. Usually masks are lost, as well as any subtle things that depend of chaining layers. It is not a problem in transformers because the -100 trick is the standard mask and the loss is aware of it. |
What does this PR do?
This PR continues the (paused) goal of deprecating our custom TF embedding layers and related code. Previously, we have converted encoder-decoder models (e.g. here), removing
TFSharedEmbeddings
there and making the necessary adaptations.In this PR, I make the necessary adaptations for GPT2. The goal is for you, the reviewers, to raise objections in this PR :D All slow tests for TF GPT2 are passing.
Then, the following sequence of PRs will be opened:
TFSharedEmbeddings
from the other decoder-only modelsTFSharedEmbeddings
in the codebase (e.g. in tests)resize_token_embeddings
and all related functions (it is only used to resize our models' embeddings instantiated withTFSharedEmbeddings
)test_save_load_after_resize_token_embeddings
, which will be fixed as a consequence of these changes 🙌