-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPTNeoX
] Flex Attention + Refactor
#34896
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A collection of comments which partially show the issues I listed above
@slow | ||
def test_lm_generate_flex_attn_gptneox(self): | ||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped") | ||
for checkpointing in [True, False]: | ||
model = GPTNeoXForCausalLM.from_pretrained( | ||
"EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention" | ||
) | ||
|
||
if checkpointing: | ||
model.gradient_checkpointing_enable() | ||
else: | ||
model.gradient_checkpointing_disable() | ||
model.to(torch_device) | ||
|
||
inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device) | ||
# The hub repo. is updated on 2023-04-04, resulting in poor outputs. | ||
# See: https://github.com/huggingface/transformers/pull/24193 | ||
expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure" | ||
|
||
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20) | ||
output_str = tokenizer.batch_decode(output_ids)[0] | ||
|
||
self.assertEqual(output_str, expected_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would love to have common tests in the future (instead)
if ( | ||
self.training | ||
and self.config.attention_dropout > 0 | ||
and self.config._attn_implementation == "flex_attention" | ||
): | ||
logger.warning_once( | ||
f"Setting `attention_type` to `eager` because `dropout` is not supported in {attention_type}" | ||
) | ||
attention_type = "eager" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No dropout in flex attn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a great catch! but we can add it to the score-mod no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly not, in the case of the head mask the order of ops doesn't matter since we completely turn the head to all zeros but dropout still depends on the correct distribution calculations and just then turns off some values.
The order of ops is: dropout(softmax(score_mod(Q, K)))
--> we would introduce unwanted behaviour.
Edit: llama for ref
transformers/src/transformers/models/llama/modeling_llama.py
Lines 339 to 340 in 0b5b5e6
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI failings seem unrelated, flaky tests (e.g. XLM, Qwen2VL) |
Possible TODO -> fallback to eager when using head mask in fa2, sdpa + add head mask in flex attention (should be possible via score mod) Edit: Added now |
Yes I was testing it for Gemma as well it needed a transpose at the end as well If you don't mind could you check the pull request i did for gemma seems I keep failing some tests Also the gemma 2 now supports new stuff in the configuration which confused me a lot Also the model.config._attn_implementation is not really implemented correctly for example it does not correctly uses the correct attn upon choosing one Still working on the gemma flex attention pr might help with the docs as well |
@dame-cell I'll take a look tomorrow! I'm busted for today :) But as quick thing to let the loading be handled correctly look into my changes into the utils folder and modeling_utils. With those changes, loading should be handled correctly. Tbh, that's one of the main reasons why I think it might be better to split some PRs and get loading etc correctly first before we start adding. Edit: One last thing to change would be to add |
Hmmm ohh I get it I see thanks for letting me know 😀 |
Feel free to ping me once you feel like this is ready! 🤗 |
I think it should be ready @ArthurZucker just found something on the fly just a min ago, should be good to go 😄 Edit: the ci failure doesn't seem related to this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving the API _check_and_enable_flex_attn
was missing from my inital PR!
if ( | ||
self.training | ||
and self.config.attention_dropout > 0 | ||
and self.config._attn_implementation == "flex_attention" | ||
): | ||
logger.warning_once( | ||
f"Setting `attention_type` to `eager` because `dropout` is not supported in {attention_type}" | ||
) | ||
attention_type = "eager" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a great catch! but we can add it to the score-mod no?
I'll take a look tomorrow or the day after :) |
Co-authored-by: Arthur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the PR gets slowly wrapped up; some things that should be made into separate PR(s) imo:
- Common tests
- Docs
- Tracking when attn implementation is manually changed within the config
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op | ||
query_states, key_states, value_states = fa_peft_integration_check( | ||
query_states, key_states, value_states, target_dtype | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New peft check within the FA2 interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, we can remove it from all of the other modeling code 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd leave that to a separate PR, this PR is big enough already :p
@ArthurZucker updated per review, looking forward to the next round ;) |
score_mod=causal_mod, | ||
enable_gqa=True, | ||
scale=norm_factor, | ||
return_lse=output_attentions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could drop the output attentions here, always return both and let the remaining forward (we call from) handle it.
Looked into the torch code, and they also always return both but make an if/else to return just one or both so there shouldn't be any downside imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay if torch code does that, it makes sense (no additional computation).
Let's add a comment as to why we do it and good for me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
Left a few small comments but almost ready to go!
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op | ||
query_states, key_states, value_states = fa_peft_integration_check( | ||
query_states, key_states, value_states, target_dtype | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mice
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op | ||
query_states, key_states, value_states = fa_peft_integration_check( | ||
query_states, key_states, value_states, target_dtype | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, we can remove it from all of the other modeling code 👀
score_mod=causal_mod, | ||
enable_gqa=True, | ||
scale=norm_factor, | ||
return_lse=output_attentions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay if torch code does that, it makes sense (no additional computation).
Let's add a comment as to why we do it and good for me!
target_dtype = None | ||
if self.config._attn_implementation == "flash_attention_2": | ||
input_dtype = value.dtype | ||
if input_dtype == torch.float32: | ||
if torch.is_autocast_enabled(): | ||
target_dtype = torch.get_autocast_gpu_dtype() | ||
# Handle the case where the model is quantized | ||
elif hasattr(self.config, "_pre_quantization_dtype"): | ||
target_dtype = self.config._pre_quantization_dtype | ||
else: | ||
target_dtype = self.query_key_value.weight.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite heavy! would be cool if we manage to only do it in the flash attention forward function ! (passing just the config for example would be enough to do so)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that we are dependent on the models weights in the last case, i.e. self.query_key_value.weight.dtype
- I moved it to a separate function but it definitely should be deprecated at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a list of TODOs (separate PRs):
- Deprecate FA peft integration checks
- Deprecate RoPE reconversion
- Docs
- Common Tests
- Track when attn implementation is manually changed
Otherwise, I think this PR is good now!
# lse is returned in float32 | ||
attn_weights = attn_weights.to(value.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reconvert to correct dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
# Flash Attention 2 specific PEFT check | ||
target_dtype=self._fa_peft_dtype_check(value), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like I said in another comment dependent on weights so it's hard to move it to the fa forward without passing additional info (like the weights dtype).
def _fa_peft_dtype_check(self, value): | ||
""" | ||
PEFT can silently cast the dtype to float32 - this method returns the target dtype to which | ||
FA should convert back to (if necessary). For now, we can not move this to the forward pass | ||
itself due to the dependency on checking on some part of its own weights (last case). | ||
""" | ||
target_dtype = None | ||
if self.config._attn_implementation == "flash_attention_2": | ||
input_dtype = value.dtype | ||
if input_dtype == torch.float32: | ||
if torch.is_autocast_enabled(): | ||
target_dtype = torch.get_autocast_gpu_dtype() | ||
# Handle the case where the model is quantized | ||
elif hasattr(self.config, "_pre_quantization_dtype"): | ||
target_dtype = self.config._pre_quantization_dtype | ||
else: | ||
target_dtype = self.query_key_value.weight.dtype | ||
return target_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new function I mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
@ArthurZucker Hopefully the last round 🤞 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now thanks a lot for the refactor!
# lse is returned in float32 | ||
attn_weights = attn_weights.to(value.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
def _fa_peft_dtype_check(self, value): | ||
""" | ||
PEFT can silently cast the dtype to float32 - this method returns the target dtype to which | ||
FA should convert back to (if necessary). For now, we can not move this to the forward pass | ||
itself due to the dependency on checking on some part of its own weights (last case). | ||
""" | ||
target_dtype = None | ||
if self.config._attn_implementation == "flash_attention_2": | ||
input_dtype = value.dtype | ||
if input_dtype == torch.float32: | ||
if torch.is_autocast_enabled(): | ||
target_dtype = torch.get_autocast_gpu_dtype() | ||
# Handle the case where the model is quantized | ||
elif hasattr(self.config, "_pre_quantization_dtype"): | ||
target_dtype = self.config._pre_quantization_dtype | ||
else: | ||
target_dtype = self.query_key_value.weight.dtype | ||
return target_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
* gpt neox flex attention + refactor * some formatting * small fix on dropout * add assertion on flex attn test * flaky ci :( * add head mask support * style * handle dtype, replace torch where * fixup flex with output attns * code review and several other fixes * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <[email protected]> * style * remove unnecessary comment * remove incorrect comment * make flex attn check more agnostic tor versions and centralized * change peft input dtype check to value since q and k could be affected by other stuff like RoPE * i forgor * flaky * code review and small fixes * Update src/transformers/models/gpt_neox/modeling_gpt_neox.py Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
* gpt neox flex attention + refactor * some formatting * small fix on dropout * add assertion on flex attn test * flaky ci :( * add head mask support * style * handle dtype, replace torch where * fixup flex with output attns * code review and several other fixes * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <[email protected]> * style * remove unnecessary comment * remove incorrect comment * make flex attn check more agnostic tor versions and centralized * change peft input dtype check to value since q and k could be affected by other stuff like RoPE * i forgor * flaky * code review and small fixes * Update src/transformers/models/gpt_neox/modeling_gpt_neox.py Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
What does this PR do?
Adds flex attention and the refactor according to #34809
However, I discovered several issues in the current version of gemma2 (#34282):
model.config._attn_implementation = ...
should be tracked somewhere and checked for sanity as done the first time - for now it silently overwrites and could cause some ugly errors (tested with changing to flash attention 2 while not having fa2 installed)So tbh, I'm not sure whether to split this PR into several ones, e.g. a gemma fix, general loading, general tests, docs, and then subsequent models, or not
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker