Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate FlashAttention into HF OPT #18439

Closed
wants to merge 2 commits into from

Conversation

erichan1
Copy link

@erichan1 erichan1 commented Aug 2, 2022

Integrate FlashAttention.

  • Requires flash_attention integration pytorch/pytorch#81434 to work. torch._scaled_dot_product_attention is only there.
  • Turn on fast path or go back to slow path using fast_attention=True/False flag.
  • Turn on causal mask or turn it off for the fast attention path using fast_attention_causal = True/False.
  • Does not support attention mask or padding mask on the fast path.
  • Currently requires us to do an unnecessary conversion to Nestedtensor and back because the current FlashAttn implementation only takes NestedTensor. Will remove once torch._scaled_dot_product_attention supports regular tensor.

@erichan1 erichan1 changed the title Erichan1/flashatt opt Integrate FlashAttention into HF OPT Aug 2, 2022
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Comment on lines +200 to +202
query_states_fast = torch.nested_tensor(torch.unbind(query_states, dim=0))
key_states_fast = torch.nested_tensor(torch.unbind(key_states_fast, dim=0))
value_states_fast = torch.nested_tensor(torch.unbind(value_states_fast, dim=0))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't this result in padding within the nested tensors?

i.e. if query_states is a padded rectangular tensor, calling unbind on it will produce sequences padded to the same length, so we won't be taking advantage of nested tensors to reduce padding. Or am I missing something?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming only 0 padding tensor inputs for now. This is a hack just to make those tensors into NestedTensors because currently FlashAttention SDP requires NestedTensor. If FlashAttn SDP supported regular tensor I would just remove these entirely.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha, thanks for the clarification :)

@github-actions
Copy link

github-actions bot commented Sep 2, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Sep 10, 2022
@fzyzcjy
Copy link
Contributor

fzyzcjy commented Dec 30, 2022

Hi, is there any updates? Coming from https://github.com/HazyResearch/flash-attention/blob/main/usage.md

@puyuanOT
Copy link

Looking forward to the update!

@erichan1
Copy link
Author

Looking forward to the update!

Hey there @puyuanOT! Not working on this actively anymore. Check out torch SDP to use FlashAttn in native torch!

@puyuanOT
Copy link

Thanks @erichan1 ! I will check it out.

@vincentmin
Copy link

vincentmin commented Apr 24, 2023

@erichan1 Could you explain the reason for stopping to work on this feature? I think it would be a great implementation for the transformers library.
Regarding the torch SDP link, could you give instructions on how to use this torch feature when using a model in Huggingface transformers?

Edit: Is it the case that flash attention is now activated by default with recent versions of torch? If so, I would recommend a HuggingFace blog article to advertise this feature and explain its workings. Currently documentation is rather lacking on flash-attention support.

@amyeroberts
Copy link
Collaborator

Within the Hugging Face ecosystem, it's possible to use BetterTransformer and the optimum library to improve model performance: [1], [2]. @younesbelkada Is flash attention available yet through this?

@erichan1
Copy link
Author

@amyeroberts @vincentmin I'm from the PyTorch team. We decided that the best way to provide FlashAttention was to create a new module that was just the component FlashAttention covers, Scaled Dot Product Attention. This is the part which does softmax(Q@K)@v, and doesn't include the in projection and out projection. Since we built this abstraction, we also decided that we could use it to offer some other implementations of SDP, including a memory efficient one that we've built in house which uses less memory than FlashAttn, but is slower.

You can just directly use SDP by replacing the necessary chunk of code in your transformer definition. But I'm unsure about a way to use it with a flag you flip in HuggingFace. I'll let @younesbelkada speak to that. I believe BetterTransformer and SDP (which is part of BetterTransformer) support is already part of Optimum.

@vincentmin
Copy link

@erichan1 @amyeroberts Thank you for the clarifications. I now understand that BetterTransformer should offer the features I am looking for. I encourage you to write a blog post on Huggingface to advertise this to the world!

@younesbelkada
Copy link
Contributor

Hi @erichan1 @amyeroberts @vincentmin
This is correct, SDPA is now part of the optimum's BetterTransformer API, however this is only available for decoder-based models right now.
We are indeed panning to write a blogpost soon with Pytorch to publicly announce the feature soon. We will keep you posted here!

@pseudotensor pseudotensor mentioned this pull request Apr 28, 2023
7 tasks
@KatarinaYuan
Copy link

Hi, any recent updates on this blogpost for BetterTransformer that you mentioned earlier?

@younesbelkada
Copy link
Contributor

Hi @KatarinaYuan
Yes the blogpost is out and is here: https://pytorch.org/blog/out-of-the-box-acceleration/

@KatarinaYuan
Copy link

KatarinaYuan commented Jun 14, 2023 via email

@ASR-SCI
Copy link

ASR-SCI commented Jun 29, 2023

I use the transformer trainer + FSDP llama training options, model cannot be saved, and unable to use bettertransformer.reverse() convert to original model. I don't know how to deal with this problem.

@EwoutH
Copy link

EwoutH commented Jul 18, 2023

Are there any updates on the integration of FlashAttention into HuggingFace Transformers?

@younesbelkada
Copy link
Contributor

younesbelkada commented Jul 18, 2023

@EwoutH
Flashattention should be used as a backend for torch.SDPA which is itself integrated into BetterTransformer API. Make sure to install the latest transformers and optimum libraries and run:

model = model.to_bettertransformer()

Check the blogpost: https://pytorch.org/blog/out-of-the-box-acceleration/ for reference

cc @fxmarty as well

@tmm1
Copy link
Contributor

tmm1 commented Aug 1, 2023

is BetterTransformer up to date with FlashAttention v2?

@fxmarty
Copy link
Contributor

fxmarty commented Aug 1, 2023

Hi, BetterTransformer integrates with PyTorch SDPA (for now), and PyTorch has not integrated flash v2 yet: pytorch/pytorch#105602. Hopefully it will be there in Pytorch 2.1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.