Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPTNeoX] Flex Attention + Refactor #34896

Merged
merged 20 commits into from
Dec 4, 2024
Merged

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Nov 23, 2024

What does this PR do?

Adds flex attention and the refactor according to #34809

However, I discovered several issues in the current version of gemma2 (#34282):

  • It seems like that flex attention needs a transpose afterwards like sdpa
  • Loading flex attn with from pretrained didn't work and hence, current tests use another attn implementation (eager or sdpa not sure again)
  • Tests could gain from similar tests like sdpa :D for now it's a bit of a hassle to always have some integration test added when it could be a more general test for all subsequent models
  • I'm not familiar with better transformers or limitations of flex attn --> added some todos in case we need to check in
  • Flex attn doesn't support dropout (or maybe I've overlooked something)
  • Setting model.config._attn_implementation = ... should be tracked somewhere and checked for sanity as done the first time - for now it silently overwrites and could cause some ugly errors (tested with changing to flash attention 2 while not having fa2 installed)
  • Documentation should be added somewhere (prolly perf or something else)

So tbh, I'm not sure whether to split this PR into several ones, e.g. a gemma fix, general loading, general tests, docs, and then subsequent models, or not

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

Copy link
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A collection of comments which partially show the issues I listed above

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/utils/import_utils.py Show resolved Hide resolved
Comment on lines 462 to 484
@slow
def test_lm_generate_flex_attn_gptneox(self):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
for checkpointing in [True, False]:
model = GPTNeoXForCausalLM.from_pretrained(
"EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention"
)

if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)

inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
# The hub repo. is updated on 2023-04-04, resulting in poor outputs.
# See: https://github.com/huggingface/transformers/pull/24193
expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure"

output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20)
output_str = tokenizer.batch_decode(output_ids)[0]

self.assertEqual(output_str, expected_output)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love to have common tests in the future (instead)

src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
Comment on lines 378 to 386
if (
self.training
and self.config.attention_dropout > 0
and self.config._attn_implementation == "flex_attention"
):
logger.warning_once(
f"Setting `attention_type` to `eager` because `dropout` is not supported in {attention_type}"
)
attention_type = "eager"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No dropout in flex attn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a great catch! but we can add it to the score-mod no?

Copy link
Contributor Author

@vasqu vasqu Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly not, in the case of the head mask the order of ops doesn't matter since we completely turn the head to all zeros but dropout still depends on the correct distribution calculations and just then turns off some values.

The order of ops is: dropout(softmax(score_mod(Q, K))) --> we would introduce unwanted behaviour.

Edit: llama for ref

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
@vasqu
Copy link
Contributor Author

vasqu commented Nov 23, 2024

CI failings seem unrelated, flaky tests (e.g. XLM, Qwen2VL)

@vasqu
Copy link
Contributor Author

vasqu commented Nov 23, 2024

Possible TODO -> fallback to eager when using head mask in fa2, sdpa + add head mask in flex attention (should be possible via score mod)

Edit: Added now

@dame-cell
Copy link

dame-cell commented Nov 23, 2024

Yes I was testing it for Gemma as well it needed a transpose at the end as well

If you don't mind could you check the pull request i did for gemma seems I keep failing some tests

Also the gemma 2 now supports new stuff in the configuration which confused me a lot
The attn logit soft capping

Also the model.config._attn_implementation is not really implemented correctly for example it does not correctly uses the correct attn upon choosing one

Still working on the gemma flex attention pr might help with the docs as well

@vasqu
Copy link
Contributor Author

vasqu commented Nov 23, 2024

@dame-cell I'll take a look tomorrow! I'm busted for today :)

But as quick thing to let the loading be handled correctly look into my changes into the utils folder and modeling_utils. With those changes, loading should be handled correctly. Tbh, that's one of the main reasons why I think it might be better to split some PRs and get loading etc correctly first before we start adding.

Edit: One last thing to change would be to add _supports_flex_attn = True then like done for sdpa, fa2

@dame-cell
Copy link

@dame-cell I'll take a look tomorrow! I'm busted for today :)

But as quick thing to let the loading be handled correctly look into my changes into the utils folder and modeling_utils. With those changes, loading should be handled correctly. Tbh, that's one of the main reasons why I think it might be better to split some PRs and get loading etc correctly first before we start adding.

Hmmm ohh I get it I see thanks for letting me know 😀

@ArthurZucker
Copy link
Collaborator

Feel free to ping me once you feel like this is ready! 🤗

@vasqu
Copy link
Contributor Author

vasqu commented Nov 25, 2024

I think it should be ready @ArthurZucker just found something on the fly just a min ago, should be good to go 😄

Edit: the ci failure doesn't seem related to this PR

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving the API _check_and_enable_flex_attn was missing from my inital PR!

src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
Comment on lines 378 to 386
if (
self.training
and self.config.attention_dropout > 0
and self.config._attn_implementation == "flex_attention"
):
logger.warning_once(
f"Setting `attention_type` to `eager` because `dropout` is not supported in {attention_type}"
)
attention_type = "eager"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a great catch! but we can add it to the score-mod no?

src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
@vasqu
Copy link
Contributor Author

vasqu commented Nov 26, 2024

I'll take a look tomorrow or the day after :)

Copy link
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the PR gets slowly wrapped up; some things that should be made into separate PR(s) imo:

  • Common tests
  • Docs
  • Tracking when attn implementation is manually changed within the config

Comment on lines +303 to +307
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New peft check within the FA2 interface.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, we can remove it from all of the other modeling code 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd leave that to a separate PR, this PR is big enough already :p

src/transformers/modeling_utils.py Show resolved Hide resolved
@vasqu
Copy link
Contributor Author

vasqu commented Nov 28, 2024

@ArthurZucker updated per review, looking forward to the next round ;)

score_mod=causal_mod,
enable_gqa=True,
scale=norm_factor,
return_lse=output_attentions,
Copy link
Contributor Author

@vasqu vasqu Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could drop the output attentions here, always return both and let the remaining forward (we call from) handle it.

Looked into the torch code, and they also always return both but make an if/else to return just one or both so there shouldn't be any downside imo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay if torch code does that, it makes sense (no additional computation).
Let's add a comment as to why we do it and good for me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!
Left a few small comments but almost ready to go!

Comment on lines +303 to +307
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mice

Comment on lines +303 to +307
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, we can remove it from all of the other modeling code 👀

src/transformers/utils/import_utils.py Show resolved Hide resolved
score_mod=causal_mod,
enable_gqa=True,
scale=norm_factor,
return_lse=output_attentions,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay if torch code does that, it makes sense (no additional computation).
Let's add a comment as to why we do it and good for me!

Comment on lines 331 to 341
target_dtype = None
if self.config._attn_implementation == "flash_attention_2":
input_dtype = value.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.query_key_value.weight.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite heavy! would be cool if we manage to only do it in the flash attention forward function ! (passing just the config for example would be enough to do so)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we are dependent on the models weights in the last case, i.e. self.query_key_value.weight.dtype - I moved it to a separate function but it definitely should be deprecated at some point.

src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a list of TODOs (separate PRs):

  • Deprecate FA peft integration checks
  • Deprecate RoPE reconversion
  • Docs
  • Common Tests
  • Track when attn implementation is manually changed

Otherwise, I think this PR is good now!

Comment on lines +244 to +245
# lse is returned in float32
attn_weights = attn_weights.to(value.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reconvert to correct dtype

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Comment on lines +360 to +361
# Flash Attention 2 specific PEFT check
target_dtype=self._fa_peft_dtype_check(value),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I said in another comment dependent on weights so it's hard to move it to the fa forward without passing additional info (like the weights dtype).

Comment on lines +456 to +473
def _fa_peft_dtype_check(self, value):
"""
PEFT can silently cast the dtype to float32 - this method returns the target dtype to which
FA should convert back to (if necessary). For now, we can not move this to the forward pass
itself due to the dependency on checking on some part of its own weights (last case).
"""
target_dtype = None
if self.config._attn_implementation == "flash_attention_2":
input_dtype = value.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.query_key_value.weight.dtype
return target_dtype
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new function I mentioned.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@vasqu
Copy link
Contributor Author

vasqu commented Dec 2, 2024

@ArthurZucker Hopefully the last round 🤞

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now thanks a lot for the refactor!

Comment on lines +244 to +245
# lse is returned in float32
attn_weights = attn_weights.to(value.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Comment on lines +456 to +473
def _fa_peft_dtype_check(self, value):
"""
PEFT can silently cast the dtype to float32 - this method returns the target dtype to which
FA should convert back to (if necessary). For now, we can not move this to the forward pass
itself due to the dependency on checking on some part of its own weights (last case).
"""
target_dtype = None
if self.config._attn_implementation == "flash_attention_2":
input_dtype = value.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.query_key_value.weight.dtype
return target_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@ArthurZucker ArthurZucker merged commit 46df859 into huggingface:main Dec 4, 2024
22 checks passed
@vasqu vasqu deleted the flex-gptneox branch December 4, 2024 15:25
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* gpt neox flex attention + refactor

* some formatting

* small fix on dropout

* add assertion on flex attn test

* flaky ci :(

* add head mask support

* style

* handle dtype, replace torch where

* fixup flex with output attns

* code review and several other fixes

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* style

* remove unnecessary comment

* remove incorrect comment

* make flex attn check more agnostic tor versions and centralized

* change peft input dtype check to value since q and k could be affected by other stuff like RoPE

* i forgor

* flaky

* code review and small fixes

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Arthur <[email protected]>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* gpt neox flex attention + refactor

* some formatting

* small fix on dropout

* add assertion on flex attn test

* flaky ci :(

* add head mask support

* style

* handle dtype, replace torch where

* fixup flex with output attns

* code review and several other fixes

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* style

* remove unnecessary comment

* remove incorrect comment

* make flex attn check more agnostic tor versions and centralized

* change peft input dtype check to value since q and k could be affected by other stuff like RoPE

* i forgor

* flaky

* code review and small fixes

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants