-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Attention Mask for Training #1516
Comments
Hi @Sanger2000 that's a good point. Just so you know, we are upstreaming SDPA support in Transformers directly & used by default, and you can already use it for a few models (see https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention), with good performances during training (see the benchmark at huggingface/transformers#28005). I won't be putting too much effort in BetterTransformer (when it is only about using SDPA, not e.g. nested tensors, etc.), but rather extending the support of models with SDPA in Transformers. |
The attention mask is only supported for Memory-Efficient Attention, not FlashAttention. But even without the mask, on CUDA FlashAttention won't be used anyway because it requires fp16 or bf16 dtypes (on CUDA). You can test what kernel is being used enabling only one kernel at a time with the relevant functions in @fxmarty Your point makes sense, but seeing as not all models are currently supported, and since this library is the recommended solution for models that aren't yet supported in the main Transformers library, it would help to allow the masks, at least if running on PyTorch v2.1+. Otherwise, would you be willing to accept a pull request that addresses this? |
Yes!
Right, happy to review PRs. In the future I think PRs as huggingface/transformers#28802 are the way to go (although taking ages to be merged) |
Feature request
It appears that originally, attention masks were ignored for training because they used the slow path in pytorch's scaled dot product attention.
Am not fully confident, but I believe that they now support custom attention masks with memory efficient attention as per - pytorch/pytorch#104310.
It would be good to enable custom attention masks in BetterTransformer training.
Motivation
Want to throw in custom attention mask (for example fitting multiple examples in a given sequence, but only letting tokens pay attention to others in the same example.
Your contribution
It could be as straightforward as just removing the lines:
In all implementations. I would be happy to do this. Perhaps it is also worth it to warn the user that memory-efficient attention will be used instead of flash attention.
The text was updated successfully, but these errors were encountered: