-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sdpa for DistilBert #33724
Add sdpa for DistilBert #33724
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this feature and all the benchmark numbers!
Regarding the tests:
- The slow tests were skipped because the PR didn't have the
run-slow
label on the PR. Could you try retriggering with another[run-slow] distilbert
commit? - The failure on the fetch_tests step is something which seems to be afflicting other PRs -- we've contacted CircleCI to try and get this resolved
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and the `distilbert-base-uncased` model with | ||
a MaskedLM head, we saw the following speedups during training and inference. | ||
|
||
#### Training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
Hey! 🤗 Thanks for your contribution to the Before merging this pull request, slow tests CI should be triggered. To enable this:
(For maintainers) The documentation for slow tests CI on PRs is here. |
@amyeroberts It seems the tests worked. |
* Add sdpa for DistilBert * [run_slow] distilbert * [run_slow] distilbert * [run_slow] distilbert * Try without slow tests * [run_slow] distilbert * [run_slow] distilbert
* Add sdpa for DistilBert * [run_slow] distilbert * [run_slow] distilbert * [run_slow] distilbert * Try without slow tests * [run_slow] distilbert * [run_slow] distilbert
* Add sdpa for DistilBert * [run_slow] distilbert * [run_slow] distilbert * [run_slow] distilbert * Try without slow tests * [run_slow] distilbert * [run_slow] distilbert
What does this PR do?
Towards #28005
Adds sdpa for the DistilBert model
Who can review?
@amyeroberts