Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mistral&Mixtral]Add sliding window for sdpa #29407

Closed
wants to merge 21 commits into from

Conversation

ehuaa
Copy link
Contributor

@ehuaa ehuaa commented Mar 2, 2024

@ArthurZucker Arthur has reviewed before, but my git changes log is weird, so i open a new pr instead. I uploaded a new test for slidingwindow flash vs sdpa for checking.

Superseeds #29220

ehuaa added 21 commits February 22, 2024 21:36
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗

Ok, and the test flash vs sdpa i submitted above cannot pass the tests, have you debugged with it? I'm also curious about the reason why it failed here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important

and the generation test you mentioned above i think test_model_7b_long_prompt_sdpa is enough, it contains generation with sdpa and sliding window.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important

And i see that https://github.com/huggingface/transformers/blob/main/tests/models/gemma/test_modeling_gemma.py#L471 gemma has a similar sdpa logits test as i committed. I think they have passed this test, maybe it can help with the debug.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late but glad we waited!
The _prepare_4d_causal_attention_mask_for_sdpa does not seem to fair well with sliding_window when there is no mask. Let's add one more full generation tets similar to test_model_7b_logits_long_with_sdpa_and_flash2 but generating!

Comment on lines +491 to +501
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2"
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()

input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2"
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa"
)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16
)

I am getting an error because by default it seems to be float32.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this passes for me

with torch.no_grad():
out = model(input_ids).logits.cpu()

input_ids = [1] + [306, 338] * 2048
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_ids = [1] + [306, 338] * 2048

model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa"
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)

input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out1 = model(input_ids).logits.cpu()
torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make sure we test all logits not just the mean

Suggested change
torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2)
torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4)

with this, the test is failing:

>       torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4)
E       AssertionError: Tensor-likes are not close!
E       
E       Mismatched elements: 90967735 / 131104000 (69.4%)
E       Greatest absolute difference: 0.328125 at index (0, 2310, 338) (up to 0.0001 allowed)
E       Greatest relative difference: 114689.0 at index (0, 1267, 4581) (up to 0.0001 allowed)

@@ -1190,6 +1192,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window if is_torch_version_greater_or_equal_than_2_2_0 else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is that _prepare_4d_causal_attention_mask_for_sdpa seems to return None if attention_mask is None (which is the case in the test) while if we actually want to use sliding we need to return the full causal mask. cc @fxmarty

@ArthurZucker
Copy link
Collaborator

@fxmarty if you want to take over in a new PR, this is fairly important IMO

@cyr0930
Copy link

cyr0930 commented Apr 8, 2024

This PL will solve #28980

Copy link

github-actions bot commented May 2, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@fxmarty
Copy link
Contributor

fxmarty commented May 2, 2024

closing as #30127 was merged and takes inspiration from this PR

@fxmarty fxmarty closed this May 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants