-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mistral&Mixtral]Add sliding window for sdpa #29407
Conversation
…com/https://github.com/ehuaa/transformers into add_sliding_window_for_sdpa
…com/https://github.com/ehuaa/transformers into add_sliding_window_for_sdpa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗
Ok, and the test flash vs sdpa i submitted above cannot pass the tests, have you debugged with it? I'm also curious about the reason why it failed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important
and the generation test you mentioned above i think test_model_7b_long_prompt_sdpa is enough, it contains generation with sdpa and sliding window.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important
And i see that https://github.com/huggingface/transformers/blob/main/tests/models/gemma/test_modeling_gemma.py#L471 gemma has a similar sdpa logits test as i committed. I think they have passed this test, maybe it can help with the debug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Late but glad we waited!
The _prepare_4d_causal_attention_mask_for_sdpa
does not seem to fair well with sliding_window
when there is no mask. Let's add one more full generation tets similar to test_model_7b_logits_long_with_sdpa_and_flash2
but generating!
model = MistralForCausalLM.from_pretrained( | ||
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2" | ||
) | ||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | ||
with torch.no_grad(): | ||
out = model(input_ids).logits.cpu() | ||
|
||
input_ids = [1] + [306, 338] * 2048 | ||
model = MistralForCausalLM.from_pretrained( | ||
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model = MistralForCausalLM.from_pretrained( | |
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2" | |
) | |
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | |
with torch.no_grad(): | |
out = model(input_ids).logits.cpu() | |
input_ids = [1] + [306, 338] * 2048 | |
model = MistralForCausalLM.from_pretrained( | |
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa" | |
) | |
model = MistralForCausalLM.from_pretrained( | |
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16 | |
) | |
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | |
with torch.no_grad(): | |
out = model(input_ids).logits.cpu() | |
input_ids = [1] + [306, 338] * 2048 | |
model = MistralForCausalLM.from_pretrained( | |
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 | |
) |
I am getting an error because by default it seems to be float32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this passes for me
with torch.no_grad(): | ||
out = model(input_ids).logits.cpu() | ||
|
||
input_ids = [1] + [306, 338] * 2048 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_ids = [1] + [306, 338] * 2048 |
model = MistralForCausalLM.from_pretrained( | ||
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa" | ||
) | ||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) |
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | ||
with torch.no_grad(): | ||
out1 = model(input_ids).logits.cpu() | ||
torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make sure we test all logits not just the mean
torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2) | |
torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4) |
with this, the test is failing:
> torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 90967735 / 131104000 (69.4%)
E Greatest absolute difference: 0.328125 at index (0, 2310, 338) (up to 0.0001 allowed)
E Greatest relative difference: 114689.0 at index (0, 1267, 4581) (up to 0.0001 allowed)
@@ -1190,6 +1192,7 @@ def forward( | |||
(batch_size, seq_length), | |||
inputs_embeds, | |||
past_key_values_length, | |||
sliding_window=self.config.sliding_window if is_torch_version_greater_or_equal_than_2_2_0 else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue here is that _prepare_4d_causal_attention_mask_for_sdpa
seems to return None if attention_mask
is None
(which is the case in the test) while if we actually want to use sliding we need to return the full causal mask. cc @fxmarty
@fxmarty if you want to take over in a new PR, this is fairly important IMO |
This PL will solve #28980 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
closing as #30127 was merged and takes inspiration from this PR |
@ArthurZucker Arthur has reviewed before, but my git changes log is weird, so i open a new pr instead. I uploaded a new test for slidingwindow flash vs sdpa for checking.
Superseeds #29220