Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: SDPA FA2 path + static cache fix #30437

Closed
wants to merge 4 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Apr 23, 2024

What does this PR do?

Problem

The recently enabled SDPA FA2 path doesn't pass the attn_mask (causal mask) argument. As such, when the static cache is used, S (see the SDPA docs for terminology) is the full cache length as opposed to the sequence length. Therefore, the inferred mask in SDPA is incorrect, resulting in bad numerical values.

PR that introduced the issue: #30317. The issue was not caught in our llama testing suite because we didn't have a test for the static cache WITHOUT compilation:

  • dynamic cache -> no issue with the length
  • static cache + compile -> we manually build the causal attention mask, passing it to sdpa
  • static cache -> we don't build the causal mask = causal mask is None = problem described above triggered

Solution

There were two possible paths here:

  1. Build and pass a full attention mask, corresponding to the full cache length, to avoid causal mask is None (Fix attn mask for static cache #30414 );
  2. Crop empty KV entries (corresponding to the empty cache) before SDPA.

I went with 2, as it saves us tons of masked computations :) I've also ensured the static cache without compilation is numerically tested in the llama test file.

Fixes #30417


Slow tests ran locally: llama, gemma, cohere, test_cache_utils.py.

@gante gante requested a review from ArthurZucker April 23, 2024 17:43
@gante
Copy link
Member Author

gante commented Apr 23, 2024

cc @younesbelkada @fxmarty (SDPA + FA2 changes)
cc @zucchini-nlp (alternative fix to #30414, see possible solutions in the PR header)

@@ -533,10 +533,13 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

120 char limit OCD :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not in make style is it?

# Never enforce `E501` (line length violations).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks !

@gante gante force-pushed the spda_fa2_static_fix branch from 8b15576 to 631c2da Compare April 24, 2024 10:30
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay very interesting and partially related to #30442.
Before merging, let's test that this does not affect compiled performances, as indexing can be costly.

@gante
Copy link
Member Author

gante commented Apr 24, 2024

As Arthur wrote -- the current state of the PR adds a slowdown on the eager path.

I'm exploring an alternative path: first standardize our static cache to behave like our other caches (living outside the model as a stand-alone object), then forcing the generation of the full mask if a static cache is used (which fixes this issue).

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for noticing, indeed I should have tested static without compile. It is very good to add a test for it.

I am wondering - couldn this change cause issues with cuda graph capture, adding dynamicity in the tensor shapes? Making the capture slower?

@@ -1073,6 +1080,7 @@ def _update_causal_mask(
if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
breakpoint()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Comment on lines +686 to +687
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can skipTest on torch version

@@ -533,10 +533,13 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not in make style is it?

# Never enforce `E501` (line length violations).

Comment on lines +709 to +727
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)

# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)

# Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the static cache tests be somewhere like test_modeling_common.py and test every supported models?

NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = {
7: [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note you might get different results from A10 to T4, hence this dict. this change can lead to the push-important-models test + Slow tests to fail 😢
You can now SSH into our runners and get the value of the generations for each device type

@gante
Copy link
Member Author

gante commented Apr 26, 2024

(closed in favor of #30476, a much better long-term solution)

@gante gante closed this Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

StaticCache Bad generation results with Llama after v4.39.0
5 participants