-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: SDPA FA2 path + static cache fix #30437
Conversation
cc @younesbelkada @fxmarty (SDPA + FA2 changes) |
@@ -533,10 +533,13 @@ def forward( | |||
cache_position: Optional[torch.LongTensor] = None, | |||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |||
if output_attentions: | |||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. | |||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
120 char limit OCD :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not in make style
is it?
Line 5 in a98c417
# Never enforce `E501` (line length violations). |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Thanks !
8b15576
to
631c2da
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay very interesting and partially related to #30442.
Before merging, let's test that this does not affect compiled performances, as indexing can be costly.
As Arthur wrote -- the current state of the PR adds a slowdown on the eager path. I'm exploring an alternative path: first standardize our static cache to behave like our other caches (living outside the model as a stand-alone object), then forcing the generation of the full mask if a static cache is used (which fixes this issue). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for noticing, indeed I should have tested static without compile. It is very good to add a test for it.
I am wondering - couldn this change cause issues with cuda graph capture, adding dynamicity in the tensor shapes? Making the capture slower?
@@ -1073,6 +1080,7 @@ def _update_causal_mask( | |||
if self.config._attn_implementation == "sdpa": | |||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, | |||
# in order to dispatch on Flash Attention 2. | |||
breakpoint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 | ||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can skipTest on torch version
@@ -533,10 +533,13 @@ def forward( | |||
cache_position: Optional[torch.LongTensor] = None, | |||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |||
if output_attentions: | |||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. | |||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not in make style
is it?
Line 5 in a98c417
# Never enforce `E501` (line length violations). |
# Dynamic Cache | ||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) | ||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) | ||
|
||
# Static Cache | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
) | ||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) | ||
|
||
# Static Cache + compile | ||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
) | ||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the static cache tests be somewhere like test_modeling_common.py and test every supported models?
NUM_TOKENS_TO_GENERATE = 40 | ||
EXPECTED_TEXT_COMPLETION = { | ||
7: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note you might get different results from A10 to T4, hence this dict. this change can lead to the push-important-models test + Slow tests to fail 😢
You can now SSH into our runners and get the value of the generations for each device type
(closed in favor of #30476, a much better long-term solution) |
What does this PR do?
Problem
The recently enabled SDPA FA2 path doesn't pass the
attn_mask
(causal mask) argument. As such, when the static cache is used,S
(see the SDPA docs for terminology) is the full cache length as opposed to the sequence length. Therefore, the inferred mask in SDPA is incorrect, resulting in bad numerical values.PR that introduced the issue: #30317. The issue was not caught in our llama testing suite because we didn't have a test for the static cache WITHOUT compilation:
causal mask is None
= problem described above triggeredSolution
There were two possible paths here:
causal mask is None
(Fix attn mask for static cache #30414 );I went with
2
, as it saves us tons of masked computations :) I've also ensured the static cache without compilation is numerically tested in the llama test file.Fixes #30417
Slow tests ran locally:
llama
,gemma
,cohere
,test_cache_utils.py
.