Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SDPA sliding window compatibility #30127

Merged
merged 5 commits into from
Apr 17, 2024

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Apr 8, 2024

As per title, fixes #28980

Supersedes #29220 #29407 as the implementation ends up being different (added you as co-author here @ehuaa).

This bug dates back to #26572 where sliding_window was not properly accounted for in the _prepare_4d_causal_attention_mask_for_sdpa method. Since then, SDPA support was added to models that use sliding window, but this bug was not yet fixed.

@fxmarty fxmarty requested a review from ArthurZucker April 8, 2024 16:14
@ehuaa
Copy link
Contributor

ehuaa commented Apr 8, 2024

Hi, @fxmarty ,thanks for your great work! I think you can add some tests i mentioned in #29407, to check if the result with sliding window in SDPA is the same as flashattention2.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fxmarty
Copy link
Contributor Author

fxmarty commented Apr 8, 2024

@ehuaa Thank you, for sure will do!

Runing mistral, mixtral, starcoder2 tests, those fail but are already failing on main:

FAILED tests/models/mistral/test_modeling_mistral.py::MistralIntegrationTest::test_speculative_generation - AssertionError: 'My f[19 chars]is 100% Sriracha. I love the heat, the tang and the fact costs' != 'My f[19 chars]is 100% Sriracha. I love the heat, the ...
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralIntegrationTest::test_small_model_logits - AssertionError: The values for attribute 'dtype' do not match: torch.float16 != torch.float32.
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralIntegrationTest::test_small_model_logits_batched - AssertionError: The values for attribute 'device' do not match: cuda:0 != cpu.
FAILED tests/models/starcoder2/test_modeling_starcoder2.py::Starcoder2IntegrationTest::test_starcoder2_batched_generation_4bit - AssertionError: Lists differ: ['Hel[110 chars]t is related to the topic of "How to make a ga[179 chars]ute'] != ['Hel[110 chars]t is aimed at creating a...
FAILED tests/models/starcoder2/test_modeling_starcoder2.py::Starcoder2IntegrationTest::test_starcoder2_batched_generation_eager - AssertionError: Lists differ: ['Hel[181 chars]I am currently working on', "def hello_world()[114 chars]app"] != ['Hel[181 chars]I am looking for a', "de...

@fxmarty fxmarty requested review from amyeroberts, LysandreJik and ArthurZucker and removed request for ArthurZucker April 15, 2024 08:09
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know if I mentioned it offline, we'll refactor this to a single function without inheritance similar to update_causal_mask. See Recurrent Gemma, as it supports sliding window!

Thanks for re-enabling sliding window.

Comment on lines +319 to +324
ignore_causal_mask = False

if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
elif sliding_window is None or key_value_length < sliding_window:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are basically 2 cases:

  1. You ignore the causal mask
  2. You don't ignore it.
    Code is really not super clear but we will refactor this soon anyways.

@fxmarty fxmarty merged commit 40eb6d6 into huggingface:main Apr 17, 2024
19 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Apr 18, 2024
* fix sdpa + sliding window

* give credit

Co-authored-by: ehuaa <[email protected]>

* remove unnecessary warning

* fix typog

* add test

---------

Co-authored-by: ehuaa <[email protected]>
ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* fix sdpa + sliding window

* give credit

Co-authored-by: ehuaa <[email protected]>

* remove unnecessary warning

* fix typog

* add test

---------

Co-authored-by: ehuaa <[email protected]>
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* fix sdpa + sliding window

* give credit

Co-authored-by: ehuaa <[email protected]>

* remove unnecessary warning

* fix typog

* add test

---------

Co-authored-by: ehuaa <[email protected]>
@gugarosa gugarosa mentioned this pull request Apr 24, 2024
5 tasks
itazap pushed a commit that referenced this pull request May 14, 2024
* fix sdpa + sliding window

* give credit

Co-authored-by: ehuaa <[email protected]>

* remove unnecessary warning

* fix typog

* add test

---------

Co-authored-by: ehuaa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add sliding window attention to sdpa in mistral
4 participants