-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Gemma] Fix eager attention #29187
[Gemma] Fix eager attention #29187
Conversation
@@ -276,7 +276,7 @@ def forward( | |||
|
|||
attn_output = attn_output.transpose(1, 2).contiguous() | |||
|
|||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |||
attn_output = attn_output.view(bsz, q_len, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only modelling code change required - the remainder of the changes in this PR are logit + integration tests
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks for the catch. A real pity our 280 prior test did not catch this! 🤗 |
* fix modelling code * add tests * fix tests * add some logit tests * style * fix fix
I'll look at integrating 2 tests in |
It's also that it tests dummy models, but here the head_dim != hidden / head! So a case no studied |
@sanchit-gandhi Not true, see transformers/tests/test_modeling_common.py Line 3431 in 2cc8cf6
|
eager vs sdpa passed 😉 |
What does this PR do?
Fixes the Gemma "eager" attention implementation, which is the default for torch versions <= 2.1. This issue was reported on the Hub discussions and by @osanseviero from the model cards/blog post.
The PR also includes a set of slow tests to ensure:
=> these tests confirm all attention implementations work and have equivalence between back-ends