-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix SDPA sliding window compatibility #30127
Conversation
Co-authored-by: ehuaa <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ehuaa Thank you, for sure will do! Runing mistral, mixtral, starcoder2 tests, those fail but are already failing on main:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know if I mentioned it offline, we'll refactor this to a single function without inheritance similar to update_causal_mask
. See Recurrent Gemma, as it supports sliding window!
Thanks for re-enabling sliding window.
ignore_causal_mask = False | ||
|
||
if attention_mask is None: | ||
if sliding_window is None or key_value_length < sliding_window: | ||
ignore_causal_mask = not is_tracing | ||
elif sliding_window is None or key_value_length < sliding_window: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are basically 2 cases:
- You ignore the causal mask
- You don't ignore it.
Code is really not super clear but we will refactor this soon anyways.
* fix sdpa + sliding window * give credit Co-authored-by: ehuaa <[email protected]> * remove unnecessary warning * fix typog * add test --------- Co-authored-by: ehuaa <[email protected]>
* fix sdpa + sliding window * give credit Co-authored-by: ehuaa <[email protected]> * remove unnecessary warning * fix typog * add test --------- Co-authored-by: ehuaa <[email protected]>
* fix sdpa + sliding window * give credit Co-authored-by: ehuaa <[email protected]> * remove unnecessary warning * fix typog * add test --------- Co-authored-by: ehuaa <[email protected]>
* fix sdpa + sliding window * give credit Co-authored-by: ehuaa <[email protected]> * remove unnecessary warning * fix typog * add test --------- Co-authored-by: ehuaa <[email protected]>
As per title, fixes #28980
Supersedes #29220 #29407 as the implementation ends up being different (added you as co-author here @ehuaa).
This bug dates back to #26572 where
sliding_window
was not properly accounted for in the_prepare_4d_causal_attention_mask_for_sdpa
method. Since then, SDPA support was added to models that use sliding window, but this bug was not yet fixed.