diff --git a/docs/source/en/llm_tutorial_optimization.md b/docs/source/en/llm_tutorial_optimization.md
index a90fc045aff419..93848d72b0d811 100644
--- a/docs/source/en/llm_tutorial_optimization.md
+++ b/docs/source/en/llm_tutorial_optimization.md
@@ -441,7 +441,7 @@ flush()
```
For comparison, let's run the same function, but enable Flash Attention instead.
-To do so, we convert the model to [BetterTransformers](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is based on Flash Attention.
+To do so, we convert the model to [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is able to use Flash Attention.
```python
model.to_bettertransformer()
diff --git a/docs/source/en/model_doc/bark.md b/docs/source/en/model_doc/bark.md
index 2160159bd783a3..7c02e4be701187 100644
--- a/docs/source/en/model_doc/bark.md
+++ b/docs/source/en/model_doc/bark.md
@@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation
##### Usage
-To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
+To load a model using Flash Attention 2, we can pass the `attn_implementation="flash_attention_2"` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python
-model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
+model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
```
##### Performance comparison
@@ -114,7 +114,7 @@ import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# load in fp16 and use Flash Attention 2
-model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
+model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
# enable CPU offload
model.enable_cpu_offload()
diff --git a/docs/source/en/model_doc/distilbert.md b/docs/source/en/model_doc/distilbert.md
index 233a182a553fa6..bd39260d3ca492 100644
--- a/docs/source/en/model_doc/distilbert.md
+++ b/docs/source/en/model_doc/distilbert.md
@@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> device = "cuda" # the device to load the model onto
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
->>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
+>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> text = "Replace me by any text you'd like."
diff --git a/docs/source/en/model_doc/gpt_bigcode.md b/docs/source/en/model_doc/gpt_bigcode.md
index 0f3bc72d03a55f..b3cb078e2a140c 100644
--- a/docs/source/en/model_doc/gpt_bigcode.md
+++ b/docs/source/en/model_doc/gpt_bigcode.md
@@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
->>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
+>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")
>>> prompt = "def hello_world():"
diff --git a/docs/source/en/model_doc/gpt_neo.md b/docs/source/en/model_doc/gpt_neo.md
index 96b6a8c96fe71a..3c7858c998207e 100644
--- a/docs/source/en/model_doc/gpt_neo.md
+++ b/docs/source/en/model_doc/gpt_neo.md
@@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
->>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True)
+>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
>>> prompt = "def hello_world():"
diff --git a/docs/source/en/model_doc/gpt_neox.md b/docs/source/en/model_doc/gpt_neox.md
index 1885d44450aab9..fd105a3e82e1ee 100644
--- a/docs/source/en/model_doc/gpt_neox.md
+++ b/docs/source/en/model_doc/gpt_neox.md
@@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation
### Usage
-To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
+To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python
>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
-model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
+model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
...
```
diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md
index 8e37bc2caf888d..8e4d75ef2382c3 100644
--- a/docs/source/en/model_doc/mistral.md
+++ b/docs/source/en/model_doc/mistral.md
@@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
->>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
+>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
>>> prompt = "My favourite condiment is"
diff --git a/docs/source/en/model_doc/opt.md b/docs/source/en/model_doc/opt.md
index 3da7b22fab747d..1b02b888994ecf 100644
--- a/docs/source/en/model_doc/opt.md
+++ b/docs/source/en/model_doc/opt.md
@@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import OPTForCausalLM, GPT2Tokenizer
>>> device = "cuda" # the device to load the model onto
->>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True)
+>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
diff --git a/docs/source/en/model_doc/phi.md b/docs/source/en/model_doc/phi.md
index 03eac894162724..3076aa378cbe85 100644
--- a/docs/source/en/model_doc/phi.md
+++ b/docs/source/en/model_doc/phi.md
@@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import PhiForCausalLM, AutoTokenizer
>>> # define the model and tokenizer and push the model and tokens to the GPU.
->>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda")
+>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
>>> # feel free to change the prompt to your liking.
@@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t
- forward
-
\ No newline at end of file
+
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index d91ed2094f0bae..b12670584a4ec5 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -36,13 +36,29 @@ FlashAttention-2 is experimental and may change considerably in future versions.
1. additionally parallelizing the attention computation over sequence length
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them
-FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
+FlashAttention-2 is currently supported for the following architectures:
+* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
+* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
+* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
+* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
+* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
+* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
+* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
+* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
+* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
+* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
+* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
+* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
+* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
+* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
+
+You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
-To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]:
+To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]:
```python
import torch
@@ -54,13 +70,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
)
```
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
+
+Note that `use_flash_attention_2=True` can also be used to enable Flash Attention 2, but is deprecated in favor of `attn_implementation="flash_attention_2"`.
@@ -77,14 +95,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
)
# load in 4bit
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
)
```
@@ -124,41 +142,21 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
-## BetterTransformer
-
-
-
-Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.
-
-
-
-BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:
-
-1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
-2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors
-
-BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.
-
-Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation).
-
-Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:
-
-```python
-model = model.to_bettertransformer()
-```
-
-You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:
+## FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention
-```py
-model = model.reverse_bettertransformer()
-model.save_pretrained("saved_model")
-```
+PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available.
-### FlashAttention
+For now, Transformers supports inference and training through SDPA for the following architectures:
+* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
+* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
+* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
+* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
+* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
+* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
-SDPA can also call FlashAttention kernels under the hood. FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it.
+Note that FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type before using it.
-To enable FlashAttention or to check whether it is available in a given setting (hardware, problem size), use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
+By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
```diff
import torch
@@ -187,6 +185,43 @@ RuntimeError: No available kernel. Aborting execution.
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```
+## BetterTransformer
+
+
+
+Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers.
+
+
+
+
+
+
+Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.
+
+
+
+BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:
+
+1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
+2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors
+
+BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.
+
+Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation).
+
+Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:
+
+```python
+model = model.to_bettertransformer()
+```
+
+You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:
+
+```py
+model = model.reverse_bettertransformer()
+model.save_pretrained("saved_model")
+```
+
## bitsandbytes
bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.
diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md
index 9193b8a61e2b5e..4e73ee2be5da72 100644
--- a/docs/source/en/quantization.md
+++ b/docs/source/en/quantization.md
@@ -82,7 +82,7 @@ AWQ quantization can also be combined with [FlashAttention-2](perf_infer_gpu_one
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
-model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
+model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
```
diff --git a/docs/source/ja/perf_infer_gpu_one.md b/docs/source/ja/perf_infer_gpu_one.md
index d6a18a6f3e2047..6d7466e022220a 100644
--- a/docs/source/ja/perf_infer_gpu_one.md
+++ b/docs/source/ja/perf_infer_gpu_one.md
@@ -44,7 +44,7 @@ Flash Attention 2は、モデルのdtypeが`fp16`または`bf16`の場合にの
### Quick usage
-モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`use_flash_attention_2`を追加します。
+モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`attn_implementation="flash_attention_2"`を追加します。
```python
@@ -57,7 +57,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
)
```
@@ -114,7 +114,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
)
```
@@ -132,7 +132,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
)
```
@@ -151,7 +151,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
)
lora_config = LoraConfig(
diff --git a/docs/source/zh/main_classes/quantization.md b/docs/source/zh/main_classes/quantization.md
index 0a2c1eb4039c36..3c7e4d9212a1d0 100644
--- a/docs/source/zh/main_classes/quantization.md
+++ b/docs/source/zh/main_classes/quantization.md
@@ -66,12 +66,12 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda:0")
### 结合 AWQ 和 Flash Attention
-您可以将AWQ量化与Flash Attention结合起来,得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`use_flash_attention_2=True`参数。
+您可以将AWQ量化与Flash Attention结合起来,得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`attn_implementation="flash_attention_2"`参数。
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
-model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
+model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
```
### 基准测试
diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py
index 22ea0abbd60104..5c419ee0fc7cf3 100755
--- a/src/transformers/configuration_utils.py
+++ b/src/transformers/configuration_utils.py
@@ -236,6 +236,8 @@ class PretrainedConfig(PushToHubMixin):
This attribute is currently not being used during model loading time, but this may change in the future
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
+ attn_implementation (`str`, *optional*):
+ The attention implementation to use in the model. Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
> TensorFlow specific parameters
@@ -374,6 +376,9 @@ def __init__(self, **kwargs):
# Config hash
self._commit_hash = kwargs.pop("_commit_hash", None)
+ # Attention implementation to use, if relevant.
+ self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
+
# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)
@@ -422,6 +427,22 @@ def num_labels(self, num_labels: int):
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
+ @property
+ def _attn_implementation(self):
+ # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
+ if hasattr(self, "_attn_implementation_internal"):
+ if self._attn_implementation_internal is None:
+ # `config.attn_implementation` should never be None, for backward compatibility.
+ return "eager"
+ else:
+ return self._attn_implementation_internal
+ else:
+ return "eager"
+
+ @_attn_implementation.setter
+ def _attn_implementation(self, value):
+ self._attn_implementation_internal = value
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
@@ -747,6 +768,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
+ # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
+ config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
+
config = cls(**config_dict)
if hasattr(config, "pruned_heads"):
@@ -861,8 +885,8 @@ def to_diff_dict(self) -> Dict[str, Any]:
self.dict_torch_dtype_to_str(serializable_config_dict)
- if "_flash_attn_2_enabled" in serializable_config_dict:
- del serializable_config_dict["_flash_attn_2_enabled"]
+ if "_attn_implementation_internal" in serializable_config_dict:
+ del serializable_config_dict["_attn_implementation_internal"]
return serializable_config_dict
@@ -880,8 +904,8 @@ def to_dict(self) -> Dict[str, Any]:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
- if "_flash_attn_2_enabled" in output:
- del output["_flash_attn_2_enabled"]
+ if "_attn_implementation_internal" in output:
+ del output["_attn_implementation_internal"]
# Transformers version when serializing the model
output["transformers_version"] = __version__
diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py
index 2c4a839ae19b39..734f443e1fc9d4 100755
--- a/src/transformers/modeling_attn_mask_utils.py
+++ b/src/transformers/modeling_attn_mask_utils.py
@@ -68,7 +68,7 @@ def to_causal_4d(
key_value_length: int,
dtype: torch.dtype,
device: Union[torch.device, "str"] = "cpu",
- ) -> torch.Tensor:
+ ) -> Optional[torch.Tensor]:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
@@ -184,6 +184,95 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+ @staticmethod
+ def _unmask_unattended(
+ expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
+ ):
+ # fmt: off
+ """
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ Details: https://github.com/pytorch/pytorch/issues/110213
+
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
+ `attention_mask` is [bsz, src_seq_len].
+
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
+
+ For example, if `attention_mask` is
+ ```
+ [[0, 0, 1],
+ [1, 1, 1],
+ [0, 1, 1]]
+ ```
+ and `expanded_mask` is (e.g. here left-padding case)
+ ```
+ [[[[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 1]]],
+ [[[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]],
+ [[[0, 0, 0],
+ [0, 1, 0],
+ [0, 1, 1]]]]
+ ```
+ then the modified `expanded_mask` will be
+ ```
+ [[[[1, 1, 1], <-- modified
+ [1, 1, 1], <-- modified
+ [0, 0, 1]]],
+ [[[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]],
+ [[[1, 1, 1], <-- modified
+ [0, 1, 0],
+ [0, 1, 1]]]]
+ ```
+ """
+ # fmt: on
+
+ # Get the index of the first non-zero value for every sample in the batch.
+ # In the above example, indices = [[2], [0], [1]]]
+ tmp = torch.arange(attention_mask.shape[1], 0, -1)
+ indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
+
+ # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
+ # expanded mask will be completely unattended.
+ left_masked_rows = torch.where(indices > 0)[0]
+
+ if left_masked_rows.shape[0] == 0:
+ return expanded_mask
+ indices = indices[left_masked_rows]
+
+ max_len = torch.max(indices)
+ range_tensor = torch.arange(max_len).unsqueeze(0)
+ range_tensor = range_tensor.repeat(indices.size(0), 1)
+
+ # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
+ range_tensor[range_tensor >= indices] = 0
+
+ # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
+ if expanded_mask.dim() == 4:
+ num_masks = expanded_mask.shape[1]
+ if num_masks == 1:
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
+ mask_slice = (left_masked_rows[:, None], 0, range_tensor)
+ else:
+ # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
+ mask_slice = (
+ left_masked_rows[:, None, None],
+ torch.arange(num_masks)[None, :, None],
+ range_tensor[:, None, :],
+ )
+ else:
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
+ mask_slice = (left_masked_rows[:, None], range_tensor)
+
+ expanded_mask[mask_slice] = unmasked_value
+
+ return expanded_mask
+
def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
@@ -225,6 +314,78 @@ def _prepare_4d_causal_attention_mask(
return attention_mask
+# Adapted from _prepare_4d_causal_attention_mask
+def _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask: Optional[torch.Tensor],
+ input_shape: Union[torch.Size, Tuple, List],
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ sliding_window: Optional[int] = None,
+):
+ """
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
+
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = input_shape[-1] + past_key_values_length
+ batch_size, query_length = input_shape
+
+ # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
+ is_tracing = torch.jit.is_tracing()
+
+ if attention_mask is not None:
+ if torch.all(attention_mask == 1):
+ if is_tracing:
+ pass
+ elif query_length == 1:
+ # For query_length == 1, causal attention and bi-directional attention are the same.
+ attention_mask = None
+ elif key_value_length == query_length:
+ attention_mask = None
+ else:
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
+ pass
+ elif query_length > 1 and key_value_length != query_length:
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
+ attention_mask = True
+ elif is_tracing:
+ raise ValueError(
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
+ )
+
+ if attention_mask is None:
+ expanded_4d_mask = None
+ elif attention_mask is True:
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ else:
+ expanded_4d_mask = attn_mask_converter.to_4d(
+ attention_mask,
+ input_shape[-1],
+ dtype=inputs_embeds.dtype,
+ key_value_length=key_value_length,
+ )
+
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
+ if query_length > 1:
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
+ expanded_4d_mask, attention_mask, unmasked_value=0.0
+ )
+
+ return expanded_4d_mask
+
+
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
@@ -241,13 +402,51 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ mask (`torch.Tensor` or `None`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ tgt_len (`int`):
+ The target length or query length the created mask shall have.
+ """
+ batch_size, key_value_length = mask.shape
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
+
+ # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
+ is_tracing = torch.jit.is_tracing()
+
+ if torch.all(mask == 1):
+ if is_tracing:
+ pass
+ elif tgt_len == 1:
+ # For query_length == 1, causal attention and bi-directional attention are the same.
+ return None
+ elif key_value_length == tgt_len:
+ return None
+ else:
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
+ # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+ else:
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
def _create_4d_causal_attention_mask(
input_shape: Union[torch.Size, Tuple, List],
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
-):
+) -> Optional[torch.Tensor]:
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 980d1c837a537e..2588893d2575be 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -81,6 +81,7 @@
is_peft_available,
is_remote_url,
is_safetensors_available,
+ is_torch_sdpa_available,
is_torch_tpu_available,
logging,
replace_return_docstrings,
@@ -1128,6 +1129,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Flash Attention 2 support
_supports_flash_attn_2 = False
+ # SDPA support
+ _supports_sdpa = False
+
# Has support for a `Cache` instance as `past_key_values`
_supports_cache_class = False
@@ -1154,7 +1158,11 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# Save config and origin of the pretrained weights if given in model
+ config = self._autoset_attn_implementation(
+ config, torch_dtype=torch.get_default_dtype(), check_device_map=False
+ )
self.config = config
+
self.name_or_path = config.name_or_path
self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
@@ -1185,8 +1193,6 @@ def _from_config(cls, config, **kwargs):
Args:
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
- use_flash_attention_2 (`bool`, *optional*):
- Whether to load the model with Flash Attention 2 modules.
"""
torch_dtype = kwargs.pop("torch_dtype", None)
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
@@ -1196,8 +1202,11 @@ def _from_config(cls, config, **kwargs):
if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
- if use_flash_attention_2:
- config = cls._check_and_enable_flash_attn_2(config, torch_dtype)
+ config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
+ config._attn_implementation = kwargs.pop("attn_implementation", None)
+ config = cls._autoset_attn_implementation(
+ config, use_flash_attention_2=use_flash_attention_2, check_device_map=False
+ )
if is_deepspeed_zero3_enabled():
import deepspeed
@@ -1216,6 +1225,67 @@ def _from_config(cls, config, **kwargs):
return model
+ @classmethod
+ def _autoset_attn_implementation(
+ cls,
+ config,
+ use_flash_attention_2: bool = False,
+ torch_dtype: Optional[torch.dtype] = None,
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
+ check_device_map: bool = True,
+ ):
+ """
+ Automatically checks and dispatches to a default attention implementation. In order of priority:
+ 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
+ 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
+ 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
+ 4. The default model's implementation otherwise (`LlamaAttention` for example) .
+ """
+ # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
+ # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
+ # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
+ if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
+ if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
+ raise ValueError(
+ f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
+ ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
+ )
+
+ if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
+ message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
+ if cls._supports_flash_attn_2:
+ message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
+ if cls._supports_sdpa:
+ message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
+ raise ValueError(message + ".")
+
+ # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
+ hard_check_only = True
+ else:
+ hard_check_only = False
+
+ if use_flash_attention_2:
+ logger.warning_once(
+ 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
+ )
+ config._attn_implementation = "flash_attention_2"
+
+ if config._attn_implementation == "flash_attention_2":
+ cls._check_and_enable_flash_attn_2(
+ config,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ hard_check_only=hard_check_only,
+ check_device_map=check_device_map,
+ )
+ elif cls._supports_sdpa or config._attn_implementation == "sdpa":
+ # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
+ config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
+ elif not hard_check_only:
+ config._attn_implementation = "eager"
+
+ return config
+
@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
@@ -1266,38 +1336,33 @@ def can_generate(cls) -> bool:
@classmethod
def _check_and_enable_flash_attn_2(
- cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
+ cls,
+ config,
+ torch_dtype: Optional[torch.dtype] = None,
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
+ check_device_map: bool = True,
+ hard_check_only: bool = False,
) -> PretrainedConfig:
"""
- If you don't know about Flash Attention, check out the official repository of flash attention:
- https://github.com/Dao-AILab/flash-attention
-
- For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
- specific section of the documentation to learn more about it:
- https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
+ Checks the availability of Flash Attention 2 and compatibility with the current model.
- The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
- half precision and not ran on CPU.
-
- If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
- can initialize the correct attention module
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_2:
raise ValueError(
- "The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
+ f"{cls.__name__} does not support Flash Attention 2.0 yet. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)
if not is_flash_attn_2_available():
- flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
-
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
- if torch.version.cuda:
- if importlib.util.find_spec("flash_attn") is None:
- raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
- flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
+ if importlib.util.find_spec("flash_attn") is None:
+ raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
+
+ flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
+ if torch.version.cuda:
if flash_attention_version < version.parse("2.1.0"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
@@ -1305,9 +1370,6 @@ def _check_and_enable_flash_attn_2(
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
elif torch.version.hip:
- if importlib.util.find_spec("flash_attn") is None:
- raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
- flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if flash_attention_version < version.parse("2.0.4"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
@@ -1332,20 +1394,23 @@ def _check_and_enable_flash_attn_2(
" unexpected behaviour."
)
- if device_map is None:
+ # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
+ # or the model may be initialized under the context manager `with torch.device("cuda"):`.
+ if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if torch.cuda.is_available():
logger.warning(
- "You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU"
+ "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
- "You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. "
+ "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
- device_map is not None
+ check_device_map
+ and device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
@@ -1353,7 +1418,37 @@ def _check_and_enable_flash_attn_2(
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
- config._flash_attn_2_enabled = True
+ if not hard_check_only:
+ config._attn_implementation = "flash_attention_2"
+ return config
+
+ @classmethod
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
+ """
+ Checks the availability of SDPA for a given model.
+
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
+ """
+ if hard_check_only:
+ if not cls._supports_sdpa:
+ raise ValueError(
+ f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please open an issue on GitHub to "
+ "request support for this architecture: https://github.com/huggingface/transformers/issues/new"
+ )
+ if not is_torch_sdpa_available():
+ raise ImportError(
+ "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
+ )
+
+ if not is_torch_sdpa_available() or not cls._supports_sdpa:
+ return config
+
+ _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
+ if _is_bettertransformer:
+ return config
+
+ if not hard_check_only:
+ config._attn_implementation = "sdpa"
return config
def enable_input_require_grads(self):
@@ -3312,8 +3407,10 @@ def from_pretrained(
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
- if use_flash_attention_2:
- config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)
+ config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
+ config = cls._autoset_attn_implementation(
+ config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
+ )
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py
index d472b854748d88..703886d500ba12 100644
--- a/src/transformers/models/bark/modeling_bark.py
+++ b/src/transformers/models/bark/modeling_bark.py
@@ -389,7 +389,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
BARK_ATTENTION_CLASSES = {
- "default": BarkSelfAttention,
+ "eager": BarkSelfAttention,
"flash_attention_2": BarkSelfFlashAttention2,
}
@@ -436,8 +436,7 @@ def __init__(self, config, is_causal=False):
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)
+ self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal)
self.mlp = BarkMLP(config)
@@ -670,6 +669,7 @@ def __init__(self, config):
self.drop = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
@@ -805,7 +805,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
@@ -1265,6 +1265,7 @@ def __init__(self, config):
self.drop = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_final = nn.LayerNorm(config.hidden_size)
@@ -1434,7 +1435,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
@@ -1875,7 +1876,11 @@ def generate(
@classmethod
def _check_and_enable_flash_attn_2(
- cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
+ cls,
+ config,
+ torch_dtype: Optional[torch.dtype] = None,
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
+ hard_check_only: bool = False,
):
"""
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
@@ -1892,12 +1897,14 @@ def _check_and_enable_flash_attn_2(
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.
- If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model
can initialize the correct attention module
"""
- config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)
+ config = super()._check_and_enable_flash_attn_2(
+ config, torch_dtype, device_map, hard_check_only=hard_check_only
+ )
- config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
- config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
- config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
+ config.semantic_config._attn_implementation = config._attn_implementation
+ config.coarse_acoustics_config._attn_implementation = config._attn_implementation
+ config.fine_acoustics_config._attn_implementation = config._attn_implementation
return config
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index a71f79b3014251..16527216c7a501 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -25,7 +25,12 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
-from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -505,8 +510,109 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)
+class BartSdpaAttention(BartAttention):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ if output_attentions or layer_head_mask is not None:
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states,
+ key_value_states=key_value_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ query_states = self._shape(query_states, tgt_len, bsz)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
BART_ATTENTION_CLASSES = {
- "default": BartAttention,
+ "eager": BartAttention,
+ "sdpa": BartSdpaAttention,
"flash_attention_2": BartFlashAttention2,
}
@@ -515,9 +621,8 @@ class BartEncoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = BART_ATTENTION_CLASSES[attn_type](
+ self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -587,8 +692,7 @@ def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = BART_ATTENTION_CLASSES[attn_type](
+ self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -601,7 +705,7 @@ def __init__(self, config: BartConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = BART_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -735,6 +839,7 @@ class BartPreTrainedModel(PreTrainedModel):
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_sdpa = True
def _init_weights(self, module):
std = self.config.init_std
@@ -961,6 +1066,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
embed_dim,
)
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
@@ -1048,8 +1155,13 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
+ elif self._use_sdpa and head_mask is None and not output_attentions:
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
@@ -1136,6 +1248,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
config.d_model,
)
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
+
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
@@ -1254,9 +1369,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
@@ -1265,8 +1389,17 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py
index 2fbd4621361edd..4512c3b503a4be 100755
--- a/src/transformers/models/blenderbot/modeling_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_blenderbot.py
@@ -252,7 +252,7 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention}
+BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
@@ -260,9 +260,8 @@ class BlenderbotEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
+ self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -332,9 +331,8 @@ class BlenderbotDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
+ self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -347,7 +345,7 @@ def __init__(self, config: BlenderbotConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
index 1669602832d866..dc4fa30b875ef2 100755
--- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
@@ -254,9 +254,8 @@ class BlenderbotSmallEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
+ self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -321,7 +320,10 @@ def forward(
return outputs
-BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention}
+# TODO: Implement attention with SDPA for TimeSeriesTransformer.
+BLENDERBOT_SMALL_ATTENTION_CLASSES = {
+ "eager": BlenderbotSmallAttention,
+}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
@@ -330,8 +332,7 @@ def __init__(self, config: BlenderbotSmallConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
+ self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -344,7 +345,7 @@ def __init__(self, config: BlenderbotSmallConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py
index 2e58d1728eae28..6e38ee84e98f6c 100755
--- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -471,6 +471,12 @@ def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
return x
+DISTILBERT_ATTENTION_CLASSES = {
+ "eager": MultiHeadSelfAttention,
+ "flash_attention_2": DistilBertFlashAttention2,
+}
+
+
class TransformerBlock(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
@@ -479,11 +485,7 @@ def __init__(self, config: PretrainedConfig):
if config.dim % config.n_heads != 0:
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
- self.attention = (
- MultiHeadSelfAttention(config)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else DistilBertFlashAttention2(config)
- )
+ self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config)
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
self.ffn = FFN(config)
@@ -703,6 +705,7 @@ def __init__(self, config: PretrainedConfig):
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# Initialize weights and apply final processing
self.post_init()
@@ -808,7 +811,7 @@ def forward(
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
if attention_mask is None:
diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py
index 01d9f0c3edbf31..4684f46ded0242 100644
--- a/src/transformers/models/falcon/modeling_falcon.py
+++ b/src/transformers/models/falcon/modeling_falcon.py
@@ -16,7 +16,7 @@
import math
import warnings
-from typing import Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -24,7 +24,11 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
-from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from ...modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@@ -33,6 +37,7 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@@ -44,6 +49,9 @@
from .configuration_falcon import FalconConfig
+if TYPE_CHECKING:
+ from ...configuration_utils import PretrainedConfig
+
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -278,6 +286,7 @@ def __init__(self, config: FalconConfig):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
+ self._use_sdpa = config._attn_implementation == "sdpa"
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
@@ -439,16 +448,15 @@ def forward(
present = None
if alibi is None:
- if hasattr(F, "scaled_dot_product_attention") and not output_attentions:
- # TODO: deprecate this once we add FA2 support in Falcon
- logger.warning_once(
- "The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the"
- " future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call "
- "`model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations."
- )
-
+ if self._use_sdpa and not output_attentions:
attn_output = F.scaled_dot_product_attention(
- query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ 0.0,
+ # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
+ is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attention_scores = None
else:
@@ -456,58 +464,70 @@ def forward(
attention_scores /= math.sqrt(self.head_dim)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
+ # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
attn_output = attention_scores @ value_layer
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
- output_tensor = self.dense(attn_output)
+ attn_output = self.dense(attn_output)
if output_attentions:
- return output_tensor, present, attention_scores
+ return attn_output, present, attention_scores
else:
- return output_tensor, present
+ return attn_output, present
else:
- matmul_result = query_layer @ key_layer.transpose(-1, -2)
+ if self._use_sdpa and not output_attentions and head_mask is None:
+ attn_output = F.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ attn_mask=attention_mask,
+ dropout_p=self.attention_dropout.p if self.training else 0.0,
+ is_causal=self.is_causal and attention_mask is None and query_length > 1,
+ )
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
+
+ attn_output = self.dense(attn_output)
+ else:
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
+
+ # change view to [batch_size, num_heads, q_length, kv_length]
+ attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
- # change view to [batch_size, num_heads, q_length, kv_length]
- attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
+ input_dtype = attention_scores.dtype
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
+ attention_scores = attention_scores.to(torch.float32)
- # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
- input_dtype = attention_scores.dtype
- # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
- if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
- attention_scores = attention_scores.to(torch.float32)
- # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
- # adding (alibi * self.inv_norm_factor) to attention_mask. I think this would be mathematically
- # equivalent and more performant, but there might be a numerical difference. If you're reading this
- # and you'd like to experiment and maybe file a PR, feel free!
- attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
- attention_logits *= self.inv_norm_factor
- attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
- # [batch_size, num_heads, q_length, kv_length]
- attention_probs = self.attention_dropout(attention_probs)
+ attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
+ attention_logits *= self.inv_norm_factor
+ attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
+ # [batch_size, num_heads, q_length, kv_length]
+ attention_probs = self.attention_dropout(attention_probs)
- if head_mask is not None:
- attention_probs = attention_probs * head_mask
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
- # change view [batch_size, num_heads, q_length, kv_length]
- attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
+ # change view [batch_size, num_heads, q_length, kv_length]
+ attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
- # matmul: [batch_size * num_heads, q_length, head_dim]
- context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1)
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
- # change view [batch_size, q_length, num_heads * head_dim]
- context_layer = self._merge_heads(context_layer)
+ # change view [batch_size, q_length, num_heads * head_dim]
+ attn_output = self._merge_heads(attn_output)
- output_tensor = self.dense(context_layer)
+ attn_output = self.dense(attn_output)
if output_attentions:
- return output_tensor, present, attention_probs
+ return attn_output, present, attention_probs
else:
- return output_tensor, present
+ return attn_output, present
class FalconFlashAttention2(FalconAttention):
@@ -734,17 +754,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
+FALCON_ATTENTION_CLASSES = {
+ "eager": FalconAttention,
+ "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
+ "flash_attention_2": FalconFlashAttention2,
+}
+
+
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
- self.self_attention = (
- FalconAttention(config)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else FalconFlashAttention2(config)
- )
+ self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
@@ -912,6 +935,7 @@ class FalconPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["FalconDecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_sdpa = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -932,6 +956,25 @@ def _init_weights(self, module: nn.Module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
+ # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
+ @classmethod
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
+ # NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0).
+ if hard_check_only:
+ if not is_torch_greater_or_equal_than_2_0:
+ raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")
+
+ if not is_torch_greater_or_equal_than_2_0:
+ return config
+
+ _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
+ if _is_bettertransformer:
+ return config
+
+ if not hard_check_only:
+ config._attn_implementation = "sdpa"
+ return config
+
@add_start_docstrings(
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
@@ -950,6 +993,8 @@ def __init__(self, config: FalconConfig):
# Transformer blocks
self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -1003,12 +1048,6 @@ def forward(
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape batch_size x num_heads x N x N
- # head_mask has shape n_layer x batch x num_heads x N x N
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
-
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
@@ -1047,15 +1086,61 @@ def forward(
)
position_ids = position_ids.unsqueeze(0)
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ if alibi is None:
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif head_mask is None:
+ alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
+
+ attention_mask_2d = attention_mask
+ # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ # We take care to integrate alibi bias in the attention_mask here.
+ if attention_mask_2d is None:
+ attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
+ else:
+ attention_mask = torch.masked_fill(
+ alibi / math.sqrt(self.config.hidden_size // self.num_heads),
+ attention_mask < -1,
+ torch.finfo(alibi.dtype).min,
+ )
+
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
+ if seq_length > 1:
+ attention_mask = AttentionMaskConverter._unmask_unattended(
+ attention_mask, attention_mask_2d, unmasked_value=0.0
+ )
+ else:
+ # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index 45c96146a1d6f7..eab28ad131e800 100644
--- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -22,6 +22,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@@ -128,6 +129,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.scale_attention_softmax_in_fp32 = (
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
)
+ self.attn_pdrop = config.attn_pdrop
if self.is_cross_attention:
if self.multi_query:
@@ -359,7 +361,7 @@ def forward(
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
- attn_dropout = self.config.attn_pdrop if self.training else 0.0
+ attn_dropout = self.attn_pdrop if self.training else 0.0
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
upcast = query.dtype != softmax_dtype
@@ -509,6 +511,137 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)
+class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ if head_mask is not None:
+ # The super dispatch is done in the forward.
+ raise ValueError(
+ "PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository."
+ )
+
+ scale = None
+ if not self.scale_attn_weights:
+ scale = 1
+
+ # MQA models: (batch_size, query_length, num_heads * head_dim)
+ # MHA models: (batch_size, num_heads, query_length, head_dim)
+ query_shape = query.shape
+ batch_size = query_shape[0]
+ key.shape[-2]
+
+ if self.multi_query:
+ query_length = query_shape[1]
+
+ # NOTE: Maybe there is better than this?
+ query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
+ key = key.unsqueeze(1)
+ value = value.unsqueeze(1)
+
+ # Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention
+ # and flash attention (No available kernel. Aborting execution.) from the shapes
+ # query = [batch_size, num_heads, query_length, head_dim]
+ # key = [batch_size, 1, past_length, head_dim]
+ # value = [batch_size, 1, past_length, head_dim]
+ # which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy.
+ key = key.expand(-1, self.num_heads, -1, -1)
+ value = value.expand(-1, self.num_heads, -1, -1)
+ else:
+ query_length = query_shape[-1]
+
+ sdpa_result = torch.nn.functional.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=self.attn_pdrop if self.training else 0.0,
+ # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
+ is_causal=self.is_causal and attention_mask is None and query_length > 1,
+ scale=scale,
+ )
+
+ if self.multi_query:
+ # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
+ sdpa_result = sdpa_result.transpose(1, 2)
+
+ # Reshape is kind of expensive here, as it does a memory copy,
+ # but I did not manage to make away without it (logits do not match when using view)
+ # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
+ sdpa_result = sdpa_result.reshape(query_shape)
+
+ return sdpa_result, None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_past: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Union[
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
+ ]:
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn") or not self.is_cross_attention:
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key_value = self.c_attn(encoder_hidden_states)
+ attention_mask = encoder_attention_mask
+ elif self.multi_query:
+ query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
+ else:
+ # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
+ # i.e., the memory layout is not the same as GPT2.
+ # This makes the concatenation with past_key_value more efficient.
+ query, key_value = (
+ self.c_attn(hidden_states)
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
+ .transpose(1, 2)
+ .split((self.head_dim, 2 * self.head_dim), dim=3)
+ )
+
+ if layer_past is not None:
+ key_value = torch.cat((layer_past, key_value), dim=-2)
+ present = key_value if use_cache else None
+
+ key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
+
+ if not output_attentions and head_mask is None:
+ # Difference with the original implementation: there is no need to transpose the key here,
+ # as SDPA expects seq_length to be at index -2 for the key as well
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+ else:
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None."
+ ' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
+
+ if not self.multi_query:
+ attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ if self.multi_query:
+ # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
+ attn_weights = attn_weights.transpose(1, 2)
+ outputs += (attn_weights,)
+
+ return outputs
+
+
class GPTBigCodeMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
@@ -527,6 +660,13 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states
+GPTBIGCODE_ATTENTION_CLASSES = {
+ "eager": GPTBigCodeAttention,
+ "flash_attention_2": GPTBigCodeFlashAttention2,
+ "sdpa": GPTBigCodeSdpaAttention,
+}
+
+
class GPTBigCodeBlock(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
@@ -534,21 +674,19 @@ def __init__(self, config, layer_idx=None):
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = (
- GPTBigCodeAttention(config, layer_idx=layer_idx)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx)
- )
+
+ self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
if config.multi_query:
raise NotImplementedError("Cross-attention not implemented for MQA")
- self.crossattention = (
- GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx)
+
+ self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](
+ config, is_cross_attention=True, layer_idx=layer_idx
)
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
@@ -629,6 +767,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_sdpa = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -770,6 +909,9 @@ def __init__(self, config):
self.gradient_checkpointing = False
+ self._use_sdpa = config._attn_implementation == "sdpa"
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+
# Initialize weights and apply final processing
self.post_init()
@@ -850,7 +992,7 @@ def forward(
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
encoder_attention_mask = (
@@ -867,7 +1009,34 @@ def forward(
# MQA models: (batch_size, query_length, n_heads, key_length)
# MHA models: (batch_size, n_heads, query_length, key_length)
- attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
+ self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
+
+ if self._use_sdpa and head_mask is None and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ if self.multi_query:
+ # gpt_bigcode using MQA has the bad taste to use a causal mask with shape
+ # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
+ self_attention_mask = self_attention_mask.transpose(1, 2)
+
+ if query_length > 1 and attention_mask is not None:
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
+ self_attention_mask = AttentionMaskConverter._unmask_unattended(
+ self_attention_mask, attention_mask, unmasked_value=True
+ )
+
+ # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
+ dtype = self.wte.weight.dtype
+ self_attention_mask = torch.where(
+ self_attention_mask,
+ torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
+ torch.full(
+ [], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device
+ ),
+ )
+
+ attention_mask = self_attention_mask
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
index 1089322cf9fbbe..a6a73dbb8cfd9f 100755
--- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -487,6 +487,12 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)
+GPT_NEO_ATTENTION_CLASSES = {
+ "eager": GPTNeoSelfAttention,
+ "flash_attention_2": GPTNeoFlashAttention2,
+}
+
+
class GPTNeoAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -495,11 +501,7 @@ def __init__(self, config, layer_id=0):
self.attention_type = self.attention_layers[layer_id]
if self.attention_type in ["global", "local"]:
- self.attention = (
- GPTNeoSelfAttention(config, self.attention_type)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else GPTNeoFlashAttention2(config, self.attention_type)
- )
+ self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](config, self.attention_type)
else:
raise NotImplementedError(
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
@@ -718,6 +720,7 @@ def __init__(self, config):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
@@ -795,7 +798,7 @@ def forward(
hidden_states = inputs_embeds + position_embeds
# Attention mask.
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
index 30feda146eabbf..d1c10f58d9d67a 100755
--- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -658,6 +658,12 @@ def forward(self, hidden_states):
return hidden_states
+GPT_NEOX_ATTENTION_CLASSES = {
+ "eager": GPTNeoXAttention,
+ "flash_attention_2": GPTNeoXFlashAttention2,
+}
+
+
class GPTNeoXLayer(nn.Module):
def __init__(self, config):
super().__init__()
@@ -666,11 +672,7 @@ def __init__(self, config):
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
- self.attention = (
- GPTNeoXAttention(config)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else GPTNeoXFlashAttention2(config)
- )
+ self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = GPTNeoXMLP(config)
def forward(
@@ -785,6 +787,7 @@ def __init__(self, config):
self.emb_dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
@@ -861,7 +864,7 @@ def forward(
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# We create a 3D attention mask from a 2D tensor mask.
diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py
index fe40cdb92922d8..d102a2b2a1264a 100644
--- a/src/transformers/models/idefics/modeling_idefics.py
+++ b/src/transformers/models/idefics/modeling_idefics.py
@@ -29,7 +29,7 @@
from ... import PreTrainedModel
from ...activations import ACT2FN
-from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
@@ -578,6 +578,7 @@ def __init__(
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.dropout = dropout
+ self.is_causal = True
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
@@ -693,6 +694,8 @@ def forward(
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout,
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -960,6 +963,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
+ _supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of Idefics isn't meant for training from scratch - only
@@ -975,6 +979,18 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
+ # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
+ @classmethod
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
+ # We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1).
+ _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
+ if _is_bettertransformer:
+ return config
+
+ if not hard_check_only:
+ config._attn_implementation = "sdpa"
+ return config
+
LLAMA_INPUTS_DOCSTRING = r"""
Args:
@@ -1240,7 +1256,7 @@ def forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
- attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index f4234e9a775499..43d5c6faef86ed 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -34,6 +34,7 @@
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
@@ -518,7 +519,7 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- dropout_rate = 0.0 if not self.training else self.attention_dropout
+ dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
@@ -654,15 +655,99 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)
+class LlamaSdpaAttention(LlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+LLAMA_ATTENTION_CLASSES = {
+ "eager": LlamaAttention,
+ "flash_attention_2": LlamaFlashAttention2,
+ "sdpa": LlamaSdpaAttention,
+}
+
+
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = (
- LlamaAttention(config=config, layer_idx=layer_idx)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else LlamaFlashAttention2(config=config, layer_idx=layer_idx)
- )
+
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -757,6 +842,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
@@ -862,6 +948,8 @@ def __init__(self, config: LlamaConfig):
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self._use_sdpa = config._attn_implementation == "sdpa"
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
@@ -922,9 +1010,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py
index 5dc06a5d77548b..200f46e5389078 100644
--- a/src/transformers/models/llava/modeling_llava.py
+++ b/src/transformers/models/llava/modeling_llava.py
@@ -232,9 +232,8 @@ def __init__(self, config: LlavaConfig):
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.vocab_size
- use_flash_attention_2 = getattr(config, "_flash_attn_2_enabled", False)
self.language_model = AutoModelForCausalLM.from_config(
- config.text_config, use_flash_attention_2=use_flash_attention_2
+ config.text_config, attn_implementation=config._attn_implementation
)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index c05948540f7865..656c526536c563 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -325,9 +325,8 @@ class M2M100EncoderLayer(nn.Module):
def __init__(self, config: M2M100Config):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
+ self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -392,7 +391,7 @@ def forward(
return outputs
-M2M100_ATTENTION_CLASSES = {"default": M2M100Attention}
+M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention}
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
@@ -400,9 +399,8 @@ class M2M100DecoderLayer(nn.Module):
def __init__(self, config: M2M100Config):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
+ self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -415,7 +413,7 @@ def __init__(self, config: M2M100Config):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = M2M100_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py
index cabf0c68f8b62b..d52a060d4723c8 100755
--- a/src/transformers/models/marian/modeling_marian.py
+++ b/src/transformers/models/marian/modeling_marian.py
@@ -272,9 +272,8 @@ class MarianEncoderLayer(nn.Module):
def __init__(self, config: MarianConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
+ self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -339,7 +338,7 @@ def forward(
return outputs
-MARIAN_ATTENTION_CLASSES = {"default": MarianAttention}
+MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
@@ -348,8 +347,7 @@ def __init__(self, config: MarianConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
+ self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -362,7 +360,7 @@ def __init__(self, config: MarianConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = MARIAN_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index dab8d4dae1b246..3d25d75b3ef28f 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -501,7 +501,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
MBART_ATTENTION_CLASSES = {
- "default": MBartAttention,
+ "eager": MBartAttention,
"flash_attention_2": MBartFlashAttention2,
}
@@ -510,9 +510,8 @@ class MBartEncoderLayer(nn.Module):
def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
+ self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -581,9 +580,8 @@ class MBartDecoderLayer(nn.Module):
def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
+ self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -596,7 +594,7 @@ def __init__(self, config: MBartConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = MBART_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -935,6 +933,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
embed_dim,
)
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -1023,7 +1022,7 @@ def forward(
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1112,6 +1111,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
config.d_model,
)
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -1231,7 +1231,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1242,7 +1242,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index ef65d8a0894b12..29af7c0e88e979 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -601,15 +601,19 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)
+MISTRAL_ATTENTION_CLASSES = {
+ "eager": MistralAttention,
+ "flash_attention_2": MistralFlashAttention2,
+}
+
+
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = (
- MistralAttention(config=config, layer_idx=layer_idx)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else MistralFlashAttention2(config, layer_idx=layer_idx)
- )
+
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -807,6 +811,7 @@ def __init__(self, config: MistralConfig):
self.layers = nn.ModuleList(
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
@@ -870,12 +875,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if (
- attention_mask is not None
- and hasattr(self.config, "_flash_attn_2_enabled")
- and self.config._flash_attn_2_enabled
- and use_cache
- ):
+ if attention_mask is not None and self._use_flash_attention_2 and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
@@ -884,7 +884,7 @@ def forward(
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
index 7d6b71dc090a51..9b22a51abc77aa 100644
--- a/src/transformers/models/opt/modeling_opt.py
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -491,15 +491,18 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)
+OPT_ATTENTION_CLASSES = {
+ "eager": OPTAttention,
+ "flash_attention_2": OptFlashAttention2,
+}
+
+
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
- if not getattr(config, "_flash_attn_2_enabled", False):
- self.self_attn = OPTAttention(config=config, is_decoder=True)
- else:
- self.self_attn = OptFlashAttention2(config=config, is_decoder=True)
+ self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
@@ -732,6 +735,7 @@ def __init__(self, config: OPTConfig):
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -830,7 +834,7 @@ def forward(
mask_seq_length = past_key_values_length + seq_length
# embed positions
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py
index 18af4d518a899b..1a75c43e58e0ee 100755
--- a/src/transformers/models/pegasus/modeling_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_pegasus.py
@@ -267,7 +267,7 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-PEGASUS_ATTENTION_CLASSES = {"default": PegasusAttention}
+PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
@@ -275,9 +275,8 @@ class PegasusEncoderLayer(nn.Module):
def __init__(self, config: PegasusConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
+ self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -347,9 +346,8 @@ class PegasusDecoderLayer(nn.Module):
def __init__(self, config: PegasusConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
+ self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -362,7 +360,7 @@ def __init__(self, config: PegasusConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py
index ca32193d535893..c73d5b942e6d4f 100644
--- a/src/transformers/models/phi/modeling_phi.py
+++ b/src/transformers/models/phi/modeling_phi.py
@@ -612,14 +612,16 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)
+PHI_ATTENTION_CLASSES = {
+ "eager": PhiAttention,
+ "flash_attention_2": PhiFlashAttention2,
+}
+
+
class PhiDecoderLayer(nn.Module):
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__()
- self.self_attn = (
- PhiAttention(config=config, layer_idx=layer_idx)
- if not getattr(config, "_flash_attn_2_enabled", False)
- else PhiFlashAttention2(config=config, layer_idx=layer_idx)
- )
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.mlp = PhiMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
@@ -813,6 +815,7 @@ def __init__(self, config: PhiConfig):
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -876,7 +879,7 @@ def forward(
inputs_embeds = self.embed_dropout(inputs_embeds)
# Attention mask.
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py
index ad298c6d389048..f03f90183a59a5 100644
--- a/src/transformers/models/plbart/modeling_plbart.py
+++ b/src/transformers/models/plbart/modeling_plbart.py
@@ -23,7 +23,12 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
-from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -265,9 +270,8 @@ class PLBartEncoderLayer(nn.Module):
def __init__(self, config: PLBartConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
+ self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -332,7 +336,8 @@ def forward(
return outputs
-PLBART_ATTENTION_CLASSES = {"default": PLBartAttention}
+# TODO: Implement attention with SDPA for PLBart.
+PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART
@@ -341,8 +346,7 @@ def __init__(self, config: PLBartConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
+ self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -355,7 +359,7 @@ def __init__(self, config: PLBartConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = PLBART_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -670,6 +674,8 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] =
embed_dim,
)
self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
@@ -757,8 +763,13 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
+ elif self._use_sdpa and head_mask is None and not output_attentions:
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
@@ -846,6 +857,9 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] =
config.d_model,
)
self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
+
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
@@ -964,9 +978,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
@@ -975,8 +998,17 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py
index 57c74c8c42e2a6..76e088415ab88d 100755
--- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py
@@ -30,7 +30,12 @@
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
from .configuration_speech_to_text import Speech2TextConfig
@@ -326,7 +331,7 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-SPEECH_TO_TEXT_ATTENTION_CLASSES = {"default": Speech2TextAttention}
+SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT
@@ -334,9 +339,8 @@ class Speech2TextEncoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
+ self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -406,9 +410,8 @@ class Speech2TextDecoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
+ self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -421,7 +424,7 @@ def __init__(self, config: Speech2TextConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
index 1e3596f600fe6c..b6e86735c6a3d0 100644
--- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
+++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
@@ -32,7 +32,12 @@
)
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
-from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
from .configuration_time_series_transformer import TimeSeriesTransformerConfig
@@ -436,9 +441,8 @@ class TimeSeriesTransformerEncoderLayer(nn.Module):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
+ self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -503,7 +507,10 @@ def forward(
return outputs
-TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {"default": TimeSeriesTransformerAttention}
+# TODO: Implement attention with SDPA for TimeSeriesTransformer.
+TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {
+ "eager": TimeSeriesTransformerAttention,
+}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER
@@ -512,8 +519,7 @@ def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
+ self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -526,7 +532,7 @@ def __init__(self, config: TimeSeriesTransformerConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py
index 4f7f25cc826f1f..21c4c82b7f40e7 100644
--- a/src/transformers/models/whisper/modeling_whisper.py
+++ b/src/transformers/models/whisper/modeling_whisper.py
@@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
-from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -690,9 +690,111 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)
+class WhisperSdpaAttention(WhisperAttention):
+ # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ if output_attentions or layer_head_mask is not None:
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states,
+ key_value_states=key_value_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ query_states = self._shape(query_states, tgt_len, bsz)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
WHISPER_ATTENTION_CLASSES = {
- "default": WhisperAttention,
+ "eager": WhisperAttention,
"flash_attention_2": WhisperFlashAttention2,
+ "sdpa": WhisperSdpaAttention,
}
@@ -701,9 +803,8 @@ class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
+ self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -773,9 +874,8 @@ class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
- attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
- self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
+ self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -788,7 +888,7 @@ def __init__(self, config: WhisperConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = WHISPER_ATTENTION_CLASSES[attn_type](
+ self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -897,6 +997,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_sdpa = True
def _init_weights(self, module):
std = self.config.init_std
@@ -1227,6 +1328,8 @@ def __init__(self, config: WhisperConfig):
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -1336,9 +1439,14 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if getattr(self.config, "_flash_attn_2_enabled", False):
+ if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and head_mask is None and not output_attentions:
+ # output_attentions=True & head_mask can not be supported when using SDPA.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index eb21cbac2303e6..4dcca595a1dc61 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -107,6 +107,7 @@
is_torch_fp16_available_on_device,
is_torch_neuroncore_available,
is_torch_npu_available,
+ is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
@@ -440,6 +441,15 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
+def require_torch_sdpa(test_case):
+ """
+ Decorator marking a test that requires PyTorch's SDPA.
+
+ These tests are skipped when requirements are not met (torch version).
+ """
+ return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
+
+
def require_peft(test_case):
"""
Decorator marking a test that requires PEFT.
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 719f78af2aeaf9..e9e2f9e0403987 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -180,6 +180,7 @@
is_torch_mps_available,
is_torch_neuroncore_available,
is_torch_npu_available,
+ is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 2da0dbc891b863..bf7530e84f4b15 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -258,6 +258,19 @@ def get_torch_version():
return _torch_version
+def is_torch_sdpa_available():
+ if not is_torch_available():
+ return False
+ elif _torch_version == "N/A":
+ return False
+
+ # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
+ # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
+ # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
+ # NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
+ return version.parse(_torch_version) >= version.parse("2.1.1")
+
+
def is_torchvision_available():
return _torchvision_available
diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py
index 713fb6c3ee790a..1246fa561583a6 100644
--- a/tests/models/bark/test_modeling_bark.py
+++ b/tests/models/bark/test_modeling_bark.py
@@ -890,13 +890,11 @@ def test_flash_attn_2_inference(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
- model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
- )
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1]
@@ -949,12 +947,13 @@ def test_flash_attn_2_inference_padding_right(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
)
model.to(torch_device)
diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py
index 8194c4285916de..9ab9d01577a974 100644
--- a/tests/models/distilbert/test_modeling_distilbert.py
+++ b/tests/models/distilbert/test_modeling_distilbert.py
@@ -319,13 +319,11 @@ def test_flash_attn_2_inference(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
- model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
- )
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
@@ -373,12 +371,13 @@ def test_flash_attn_2_inference_padding_right(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
)
model.to(torch_device)
diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py
index 75204637bd0784..fa7ea2af816cb0 100644
--- a/tests/models/falcon/test_modeling_falcon.py
+++ b/tests/models/falcon/test_modeling_falcon.py
@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Falcon model. """
+import tempfile
import unittest
from parameterized import parameterized
@@ -26,7 +27,7 @@
is_torch_available,
set_seed,
)
-from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
+from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_sdpa, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -437,6 +438,76 @@ def test_model_rope_scaling(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
+ @require_torch_sdpa
+ @slow
+ def test_eager_matches_sdpa_generate(self):
+ max_new_tokens = 30
+
+ if len(self.all_generative_model_classes) == 0:
+ self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_sdpa:
+ self.skipTest(f"{model_class.__name__} does not support SDPA")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model_sdpa = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
+
+ model_eager = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ attn_implementation="eager",
+ ).to(torch_device)
+
+ self.assertTrue(model_eager.config._attn_implementation == "eager")
+
+ # NOTE: This check is disabled for Falcon as the non-SDPA/SDPA implementation is in the same class (legacy reason).
+ # for name, submodule in model_eager.named_modules():
+ # if "SdpaAttention" in submodule.__class__.__name__:
+ # raise ValueError("The eager model should not have SDPA attention layers")
+
+ # has_sdpa = False
+ # for name, submodule in model_sdpa.named_modules():
+ # if "SdpaAttention" in submodule.__class__.__name__:
+ # has_sdpa = True
+ # break
+ # if not has_sdpa:
+ # raise ValueError("The SDPA model should have SDPA attention layers")
+
+ # Just test that a large cache works as expected
+ res_eager = model_eager.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
+ )
+
+ res_sdpa = model_sdpa.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(res_eager, res_sdpa))
+
@require_torch
class FalconLanguageGenerationTest(unittest.TestCase):
diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py
index 1e2bb12ee4609e..28530c72194585 100644
--- a/tests/models/idefics/test_modeling_idefics.py
+++ b/tests/models/idefics/test_modeling_idefics.py
@@ -16,11 +16,14 @@
import unittest
+from parameterized import parameterized
+
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
TestCasePlus,
require_bitsandbytes,
require_torch,
+ require_torch_sdpa,
require_vision,
slow,
torch_device,
@@ -309,6 +312,12 @@ def prepare_config_and_inputs_for_common(self):
def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ @require_torch_sdpa
+ @slow
+ @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
+ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
+ self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
+
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
@@ -557,6 +566,12 @@ def test_model_from_pretrained(self):
model = IdeficsModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ @require_torch_sdpa
+ @slow
+ @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
+ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
+ self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
+
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py
index 26aecaeb1ad9b7..427f94f873cff2 100644
--- a/tests/models/llama/test_modeling_llama.py
+++ b/tests/models/llama/test_modeling_llama.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """
+import tempfile
import unittest
import pytest
@@ -26,6 +27,7 @@
require_torch,
require_torch_accelerator,
require_torch_gpu,
+ require_torch_sdpa,
slow,
torch_device,
)
@@ -411,7 +413,7 @@ def test_flash_attn_2_generate_padding_right(self):
output_native = tokenizer.batch_decode(output_native)
model = LlamaForCausalLM.from_pretrained(
- "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
+ "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -419,6 +421,85 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn
+ @require_torch_gpu
+ @slow
+ def test_use_flash_attention_2_true(self):
+ """
+ NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model = model_class(config)
+ model.save_pretrained(tmp_dir)
+
+ new_model = LlamaForCausalLM.from_pretrained(
+ tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
+ ).to("cuda")
+
+ self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
+
+ has_flash = False
+ for name, submodule in new_model.named_modules():
+ if "FlashAttention" in submodule.__class__.__name__:
+ has_flash = True
+ break
+ if not has_flash:
+ raise ValueError("The flash model should have flash attention layers")
+
+ @require_torch_sdpa
+ @slow
+ def test_eager_matches_sdpa_generate(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ max_new_tokens = 30
+
+ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+
+ model_sdpa = LlamaForCausalLM.from_pretrained(
+ "meta-llama/Llama-2-7b-hf",
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
+
+ model_eager = LlamaForCausalLM.from_pretrained(
+ "meta-llama/Llama-2-7b-hf",
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ attn_implementation="eager",
+ ).to(torch_device)
+
+ self.assertTrue(model_eager.config._attn_implementation == "eager")
+
+ for name, submodule in model_eager.named_modules():
+ if "SdpaAttention" in submodule.__class__.__name__:
+ raise ValueError("The eager model should not have SDPA attention layers")
+
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ if "SdpaAttention" in submodule.__class__.__name__:
+ has_sdpa = True
+ break
+ if not has_sdpa:
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
+ texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"]
+
+ for padding_side in ["left", "right"]:
+ tokenizer.padding_side = padding_side
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
+
+ res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
+
+ res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
+ self.assertTrue(torch.allclose(res_eager, res_sdpa))
+
@require_torch
class LlamaIntegrationTest(unittest.TestCase):
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index fcb1f2495aab8a..35a2341b4e69d6 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -387,9 +387,9 @@ def test_flash_attn_2_generate_padding_right(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
- ).to(torch_device)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
@@ -397,7 +397,10 @@ def test_flash_attn_2_generate_padding_right(self):
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
@@ -437,7 +440,7 @@ def test_flash_attn_2_generate_use_cache(self):
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
@@ -507,7 +510,7 @@ def test_model_7b_long_prompt(self):
"mistralai/Mistral-7B-v0.1",
device_map="auto",
load_in_4bit=True,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py
index 76bc5c2104306a..516dd1ee626e7f 100644
--- a/tests/models/phi/test_modeling_phi.py
+++ b/tests/models/phi/test_modeling_phi.py
@@ -389,7 +389,7 @@ def test_flash_attn_2_generate_padding_right(self):
output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained(
- "susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
+ "susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py
index 6f01cfdac29fc7..9de3b8ff2c21b6 100644
--- a/tests/models/whisper/test_modeling_whisper.py
+++ b/tests/models/whisper/test_modeling_whisper.py
@@ -891,12 +891,13 @@ def test_flash_attn_2_inference(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
)
model.to(torch_device)
@@ -936,11 +937,11 @@ def test_flash_attn_2_inference_padding_right(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
@@ -981,6 +982,7 @@ def _create_and_check_torchscript(self, config, inputs_dict):
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
+ configs_no_init._attn_implementation = "eager"
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
@@ -2337,13 +2339,20 @@ def test_encoder_outputs(self):
with torch.no_grad():
outputs = model(**inputs)[0]
- input_ids = inputs["input_features"]
+ encoder = model.encoder
+
+ encoder_inputs = {"input_features": inputs["input_features"]}
del inputs["input_features"]
- encoder = model.encoder
+ if "head_mask" in inputs:
+ encoder_inputs["head_mask"] = inputs["head_mask"]
+ if "attention_mask" in inputs:
+ encoder_inputs["attention_mask"] = inputs["attention_mask"]
+ if "output_attentions" in inputs:
+ encoder_inputs["output_attentions"] = inputs["output_attentions"]
with torch.no_grad():
- inputs["encoder_outputs"] = encoder(input_ids)
+ inputs["encoder_outputs"] = encoder(**encoder_inputs)
outputs_embeds = model(**inputs)[0]
self.assertTrue((outputs_embeds == outputs).all())
diff --git a/tests/test_configuration_utils.py b/tests/test_configuration_utils.py
index a6e9e6b0390abe..2ceb1695fa96cc 100644
--- a/tests/test_configuration_utils.py
+++ b/tests/test_configuration_utils.py
@@ -198,7 +198,14 @@ def test_config_common_kwargs_is_complete(self):
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(
- missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
+ missing_keys,
+ [
+ "is_encoder_decoder",
+ "_name_or_path",
+ "_commit_hash",
+ "_attn_implementation_internal",
+ "transformers_version",
+ ],
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 2d725e112056f3..f0e6c0f1fce37f 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import collections
import copy
import gc
@@ -28,6 +27,7 @@
from typing import Dict, List, Tuple
import numpy as np
+from parameterized import parameterized
from pytest import mark
import transformers
@@ -71,6 +71,7 @@
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
+ require_torch_sdpa,
slow,
torch_device,
)
@@ -776,102 +777,120 @@ def _create_and_check_torchscript(self, config, inputs_dict):
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
- model = model_class(config=configs_no_init)
- model.to(torch_device)
- model.eval()
- inputs = self._prepare_for_class(inputs_dict, model_class)
-
- main_input_name = model_class.main_input_name
+ for attn_implementation in ["eager", "sdpa"]:
+ if attn_implementation == "sdpa" and not model_class._supports_sdpa:
+ continue
- try:
- if model.config.is_encoder_decoder:
- model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
- main_input = inputs[main_input_name]
- attention_mask = inputs["attention_mask"]
- decoder_input_ids = inputs["decoder_input_ids"]
- decoder_attention_mask = inputs["decoder_attention_mask"]
- model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
- traced_model = torch.jit.trace(
- model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
- )
- elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
- input_ids = inputs["input_ids"]
- bbox = inputs["bbox"]
- image = inputs["image"].tensor
- model(input_ids, bbox, image)
- traced_model = torch.jit.trace(
- model, (input_ids, bbox, image), check_trace=False
- ) # when traced model is checked, an error is produced due to name mangling
- elif "bbox" in inputs: # Bros requires additional inputs (bbox)
- input_ids = inputs["input_ids"]
- bbox = inputs["bbox"]
- model(input_ids, bbox)
- traced_model = torch.jit.trace(
- model, (input_ids, bbox), check_trace=False
- ) # when traced model is checked, an error is produced due to name mangling
- else:
- main_input = inputs[main_input_name]
- model(main_input)
- traced_model = torch.jit.trace(model, main_input)
- except RuntimeError:
- self.fail("Couldn't trace module.")
+ configs_no_init._attn_implementation = attn_implementation
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class)
- with tempfile.TemporaryDirectory() as tmp_dir_name:
- pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+ main_input_name = model_class.main_input_name
try:
- torch.jit.save(traced_model, pt_file_name)
- except Exception:
- self.fail("Couldn't save module.")
-
- try:
- loaded_model = torch.jit.load(pt_file_name)
- except Exception:
- self.fail("Couldn't load module.")
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ main_input = inputs[main_input_name]
+ attention_mask = inputs["attention_mask"]
+ decoder_input_ids = inputs["decoder_input_ids"]
+ decoder_attention_mask = inputs["decoder_attention_mask"]
+ model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
+ traced_model = torch.jit.trace(
+ model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
+ )
+ elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
+ input_ids = inputs["input_ids"]
+ bbox = inputs["bbox"]
+ image = inputs["image"].tensor
+ model(input_ids, bbox, image)
+ traced_model = torch.jit.trace(
+ model, (input_ids, bbox, image), check_trace=False
+ ) # when traced model is checked, an error is produced due to name mangling
+ elif "bbox" in inputs: # Bros requires additional inputs (bbox)
+ input_ids = inputs["input_ids"]
+ bbox = inputs["bbox"]
+ model(input_ids, bbox)
+ traced_model = torch.jit.trace(
+ model, (input_ids, bbox), check_trace=False
+ ) # when traced model is checked, an error is produced due to name mangling
+ else:
+ main_input = inputs[main_input_name]
+
+ if model.config._attn_implementation == "sdpa":
+ trace_input = {main_input_name: main_input}
+
+ if "attention_mask" in inputs:
+ trace_input["attention_mask"] = inputs["attention_mask"]
+ else:
+ self.skipTest("testing SDPA without attention_mask is not supported")
+
+ model(main_input, attention_mask=inputs["attention_mask"])
+ # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
+ traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
+ else:
+ model(main_input)
+ traced_model = torch.jit.trace(model, (main_input,))
+ except RuntimeError:
+ self.fail("Couldn't trace module.")
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+
+ try:
+ torch.jit.save(traced_model, pt_file_name)
+ except Exception:
+ self.fail("Couldn't save module.")
+
+ try:
+ loaded_model = torch.jit.load(pt_file_name)
+ except Exception:
+ self.fail("Couldn't load module.")
- model.to(torch_device)
- model.eval()
+ model.to(torch_device)
+ model.eval()
- loaded_model.to(torch_device)
- loaded_model.eval()
+ loaded_model.to(torch_device)
+ loaded_model.eval()
- model_state_dict = model.state_dict()
- loaded_model_state_dict = loaded_model.state_dict()
+ model_state_dict = model.state_dict()
+ loaded_model_state_dict = loaded_model.state_dict()
- non_persistent_buffers = {}
- for key in loaded_model_state_dict.keys():
- if key not in model_state_dict.keys():
- non_persistent_buffers[key] = loaded_model_state_dict[key]
+ non_persistent_buffers = {}
+ for key in loaded_model_state_dict.keys():
+ if key not in model_state_dict.keys():
+ non_persistent_buffers[key] = loaded_model_state_dict[key]
- loaded_model_state_dict = {
- key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
- }
+ loaded_model_state_dict = {
+ key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
+ }
- self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
+ self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
- model_buffers = list(model.buffers())
- for non_persistent_buffer in non_persistent_buffers.values():
- found_buffer = False
- for i, model_buffer in enumerate(model_buffers):
- if torch.equal(non_persistent_buffer, model_buffer):
- found_buffer = True
- break
+ model_buffers = list(model.buffers())
+ for non_persistent_buffer in non_persistent_buffers.values():
+ found_buffer = False
+ for i, model_buffer in enumerate(model_buffers):
+ if torch.equal(non_persistent_buffer, model_buffer):
+ found_buffer = True
+ break
- self.assertTrue(found_buffer)
- model_buffers.pop(i)
+ self.assertTrue(found_buffer)
+ model_buffers.pop(i)
- models_equal = True
- for layer_name, p1 in model_state_dict.items():
- if layer_name in loaded_model_state_dict:
- p2 = loaded_model_state_dict[layer_name]
- if p1.data.ne(p2.data).sum() > 0:
- models_equal = False
+ models_equal = True
+ for layer_name, p1 in model_state_dict.items():
+ if layer_name in loaded_model_state_dict:
+ p2 = loaded_model_state_dict[layer_name]
+ if p1.data.ne(p2.data).sum() > 0:
+ models_equal = False
- self.assertTrue(models_equal)
+ self.assertTrue(models_equal)
- # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
- # (Even with this call, there are still memory leak by ~0.04MB)
- self.clear_torch_jit_class_registry()
+ # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
+ # (Even with this call, there are still memory leak by ~0.04MB)
+ self.clear_torch_jit_class_registry()
def test_torch_fx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -2832,8 +2851,6 @@ def test_model_is_small(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
- import torch
-
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@@ -2845,7 +2862,7 @@ def test_flash_attn_2_conversion(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
).to(torch_device)
for _, module in model.named_modules():
@@ -2859,8 +2876,6 @@ def test_flash_attn_2_conversion(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
- import torch
-
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@@ -2871,12 +2886,12 @@ def test_flash_attn_2_inference(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device)
@@ -2956,8 +2971,6 @@ def test_flash_attn_2_inference(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
- import torch
-
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@@ -2968,12 +2981,12 @@ def test_flash_attn_2_inference_padding_right(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device)
@@ -3049,8 +3062,6 @@ def test_flash_attn_2_inference_padding_right(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_left_padding(self):
- import torch
-
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@@ -3060,9 +3071,9 @@ def test_flash_attn_2_generate_left_padding(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
- ).to(torch_device)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
@@ -3078,7 +3089,10 @@ def test_flash_attn_2_generate_left_padding(self):
)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
@@ -3092,8 +3106,6 @@ def test_flash_attn_2_generate_left_padding(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
- import torch
-
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@@ -3103,9 +3115,9 @@ def test_flash_attn_2_generate_padding_right(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
- ).to(torch_device)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
@@ -3121,7 +3133,10 @@ def test_flash_attn_2_generate_padding_right(self):
)
model = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
@@ -3130,13 +3145,330 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
+ @require_torch_sdpa
+ @slow
+ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
+ if not self.all_model_classes[0]._supports_sdpa:
+ self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
+
+ if torch_device == "cpu" and torch_dtype == "float16":
+ self.skipTest("float16 not supported on cpu")
+
+ # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
+ if torch_dtype == "float16":
+ torch_dtype = torch.float16
+ elif torch_dtype == "bfloat16":
+ torch_dtype = torch.bfloat16
+ elif torch_dtype == "float32":
+ torch_dtype = torch.float32
+
+ atols = {
+ ("cpu", False, torch.float32): 1e-6,
+ ("cpu", False, torch.bfloat16): 1e-2,
+ ("cpu", True, torch.float32): 1e-6,
+ ("cpu", True, torch.bfloat16): 1e-2,
+ ("cuda", False, torch.float32): 1e-6,
+ ("cuda", False, torch.bfloat16): 1e-2,
+ ("cuda", False, torch.float16): 1e-3,
+ ("cuda", True, torch.float32): 1e-6,
+ ("cuda", True, torch.bfloat16): 1e-2,
+ ("cuda", True, torch.float16): 5e-3,
+ }
+ rtols = {
+ ("cpu", False, torch.float32): 1e-4,
+ ("cpu", False, torch.bfloat16): 1e-2,
+ ("cpu", True, torch.float32): 1e-4,
+ ("cpu", True, torch.bfloat16): 1e-2,
+ ("cuda", False, torch.float32): 1e-4,
+ ("cuda", False, torch.bfloat16): 1e-2,
+ ("cuda", False, torch.float16): 1e-3,
+ ("cuda", True, torch.float32): 1e-4,
+ ("cuda", True, torch.bfloat16): 3e-2,
+ ("cuda", True, torch.float16): 5e-3,
+ }
+
+ def get_mean_reldiff(failcase, x, ref, atol, rtol):
+ return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
+
+ for model_class in self.all_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ is_encoder_decoder = model.config.is_encoder_decoder
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
+ model_sdpa = model_sdpa.eval().to(torch_device)
+
+ self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
+
+ model_eager = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch_dtype,
+ attn_implementation="eager",
+ )
+ model_eager = model_eager.eval().to(torch_device)
+
+ self.assertTrue(model_eager.config._attn_implementation == "eager")
+
+ for name, submodule in model_eager.named_modules():
+ if "SdpaAttention" in submodule.__class__.__name__:
+ raise ValueError("The eager model should not have SDPA attention layers")
+
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ if "SdpaAttention" in submodule.__class__.__name__:
+ has_sdpa = True
+ break
+ if not has_sdpa and model_sdpa.config.model_type != "falcon":
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
+ # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
+ # but it would be nicer to have an efficient way to use parameterized.expand
+ fail_cases = []
+ for padding_side in ["left", "right"]:
+ for use_mask in [False, True]:
+ for batch_size in [1, 5]:
+ dummy_input = inputs_dict[model.main_input_name]
+
+ if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
+ dummy_input = dummy_input.to(torch_dtype)
+
+ dummy_input = dummy_input[:batch_size]
+ if dummy_input.shape[0] != batch_size:
+ if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
+ extension = torch.rand(
+ batch_size - dummy_input.shape[0],
+ *dummy_input.shape[1:],
+ dtype=torch_dtype,
+ device=torch_device,
+ )
+ dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
+ else:
+ extension = torch.randint(
+ high=5,
+ size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
+ dtype=dummy_input.dtype,
+ device=torch_device,
+ )
+ dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
+
+ if not use_mask:
+ dummy_attention_mask = None
+ else:
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+ if dummy_attention_mask is None:
+ if is_encoder_decoder:
+ seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
+ else:
+ seqlen = dummy_input.shape[-1]
+ dummy_attention_mask = (
+ torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
+ )
+
+ dummy_attention_mask = dummy_attention_mask[:batch_size]
+ if dummy_attention_mask.shape[0] != batch_size:
+ extension = torch.ones(
+ batch_size - dummy_attention_mask.shape[0],
+ *dummy_attention_mask.shape[1:],
+ dtype=dummy_attention_mask.dtype,
+ device=torch_device,
+ )
+ dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
+ dummy_attention_mask = dummy_attention_mask.to(torch_device)
+
+ dummy_attention_mask[:] = 1
+ if padding_side == "left":
+ dummy_attention_mask[-1, :-1] = 1
+ dummy_attention_mask[-1, -4:] = 0
+ elif padding_side == "right":
+ dummy_attention_mask[-1, 1:] = 1
+ dummy_attention_mask[-1, :3] = 0
+
+ for enable_kernels in [False, True]:
+ failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
+ if is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size]
+ if decoder_input_ids.shape[0] != batch_size:
+ extension = torch.ones(
+ batch_size - decoder_input_ids.shape[0],
+ *decoder_input_ids.shape[1:],
+ dtype=decoder_input_ids.dtype,
+ device=torch_device,
+ )
+ decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
+ decoder_input_ids = decoder_input_ids.to(torch_device)
+
+ # TODO: never an `attention_mask` arg here?
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ else:
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+
+ # Otherwise fails for e.g. WhisperEncoderModel
+ if "attention_mask" in inspect.signature(model_eager.forward).parameters:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ # TODO: test gradients as well (& for FA2 as well!)
+ with torch.no_grad():
+ with torch.backends.cuda.sdp_kernel(
+ enable_flash=enable_kernels,
+ enable_math=True,
+ enable_mem_efficient=enable_kernels,
+ ):
+ outputs_eager = model_eager(dummy_input, **other_inputs)
+ outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
+
+ logits_eager = (
+ outputs_eager.hidden_states[-1]
+ if not is_encoder_decoder
+ else outputs_eager.decoder_hidden_states[-1]
+ )
+ logits_sdpa = (
+ outputs_sdpa.hidden_states[-1]
+ if not is_encoder_decoder
+ else outputs_sdpa.decoder_hidden_states[-1]
+ )
+
+ if torch_device in ["cpu", "cuda"]:
+ atol = atols[torch_device, enable_kernels, torch_dtype]
+ rtol = rtols[torch_device, enable_kernels, torch_dtype]
+ else:
+ atol = 1e-7
+ rtol = 1e-4
+
+ # Masked tokens output slightly deviates - we don't mind that.
+ if use_mask:
+ if padding_side == "left":
+ sub_sdpa = logits_sdpa[:-1]
+ sub_eager = logits_eager[:-1]
+ if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
+ fail_cases.append(
+ get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
+ )
+
+ sub_sdpa = logits_sdpa[-1, :-4]
+ sub_eager = logits_eager[-1, :-4]
+ if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
+ fail_cases.append(
+ get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
+ )
+
+ # Testing the padding tokens is not really meaningful but anyway
+ # sub_sdpa = logits_sdpa[-1, -4:]
+ # sub_eager = logits_eager[-1, -4:]
+ # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
+ # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
+ elif padding_side == "right":
+ sub_sdpa = logits_sdpa[:-1]
+ sub_eager = logits_eager[:-1]
+ if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
+ fail_cases.append(
+ get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
+ )
+
+ sub_sdpa = logits_sdpa[-1, 3:]
+ sub_eager = logits_eager[-1, 3:]
+ if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
+ fail_cases.append(
+ get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
+ )
+
+ # Testing the padding tokens is not really meaningful but anyway
+ # sub_sdpa = logits_sdpa[-1, :3]
+ # sub_eager = logits_eager[-1, :3]
+ # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
+ # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
+
+ else:
+ if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
+ fail_cases.append(
+ get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
+ )
+
+ self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
+
+ @require_torch_sdpa
+ @slow
+ def test_eager_matches_sdpa_generate(self):
+ max_new_tokens = 30
+
+ if len(self.all_generative_model_classes) == 0:
+ self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_sdpa:
+ self.skipTest(f"{model_class.__name__} does not support SDPA")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model_sdpa = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
+
+ model_eager = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ attn_implementation="eager",
+ ).to(torch_device)
+
+ self.assertTrue(model_eager.config._attn_implementation == "eager")
+
+ for name, submodule in model_eager.named_modules():
+ if "SdpaAttention" in submodule.__class__.__name__:
+ raise ValueError("The eager model should not have SDPA attention layers")
+
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ if "SdpaAttention" in submodule.__class__.__name__:
+ has_sdpa = True
+ break
+ if not has_sdpa:
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
+ # Just test that a large cache works as expected
+ res_eager = model_eager.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
+ )
+
+ res_sdpa = model_sdpa.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(res_eager, res_sdpa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
- import torch
-
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
@@ -3163,7 +3495,7 @@ def test_flash_attn_2_generate_use_cache(self):
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
@@ -3182,8 +3514,6 @@ def test_flash_attn_2_generate_use_cache(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
- import torch
-
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@@ -3204,7 +3534,7 @@ def test_flash_attn_2_fp32_ln(self):
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
- use_flash_attention_2=True,
+ attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
load_in_4bit=True,
)
@@ -3282,8 +3612,6 @@ def test_flax_from_pt_safetensors(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_from_config(self):
- import torch
-
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@@ -3291,7 +3619,7 @@ def test_flash_attn_2_from_config(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes
fa2_model = AutoModelForCausalLM.from_config(
- config, use_flash_attention_2=True, torch_dtype=torch.bfloat16
+ config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
@@ -3313,7 +3641,7 @@ def test_flash_attn_2_from_config(self):
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
- self.assertFalse(getattr(model_from_pretrained.config, "_flash_attn_2_enabled", False))
+ self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
fa2_correctly_converted = False
diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py
index e1c37ec2687ed0..ddfaad5214dc50 100755
--- a/tests/test_modeling_utils.py
+++ b/tests/test_modeling_utils.py
@@ -60,7 +60,13 @@
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
-from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
+from transformers.utils.import_utils import (
+ is_flash_attn_2_available,
+ is_flax_available,
+ is_tf_available,
+ is_torch_sdpa_available,
+ is_torchdynamo_available,
+)
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@@ -1689,3 +1695,158 @@ def test_torch_compile_fullgraph(self):
res_compiled = compiled_model(mask, inputs_embeds)
self.assertTrue(torch.equal(res_non_compiled, res_compiled))
+
+ @require_torch
+ @slow
+ def test_unmask_unattended_left_padding(self):
+ attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)
+
+ expanded_mask = torch.Tensor(
+ [
+ [[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
+ [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
+ [[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
+ ]
+ ).to(torch.int64)
+
+ reference_output = torch.Tensor(
+ [
+ [[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
+ [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
+ [[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
+ ]
+ ).to(torch.int64)
+
+ result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)
+
+ self.assertTrue(torch.equal(result, reference_output))
+
+ attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)
+
+ attn_mask_converter = AttentionMaskConverter(is_causal=True)
+ past_key_values_length = 0
+ key_value_length = attention_mask.shape[-1] + past_key_values_length
+
+ expanded_mask = attn_mask_converter.to_4d(
+ attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
+ )
+
+ result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
+ min_inf = torch.finfo(torch.float32).min
+ reference_output = torch.Tensor(
+ [
+ [
+ [
+ [0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0],
+ [min_inf, min_inf, 0, min_inf, min_inf],
+ [min_inf, min_inf, 0, 0, min_inf],
+ [min_inf, min_inf, 0, 0, 0],
+ ]
+ ],
+ [
+ [
+ [0, min_inf, min_inf, min_inf, min_inf],
+ [0, 0, min_inf, min_inf, min_inf],
+ [0, 0, 0, min_inf, min_inf],
+ [0, 0, 0, 0, min_inf],
+ [0, 0, 0, 0, 0],
+ ]
+ ],
+ [
+ [
+ [0, 0, 0, 0, 0],
+ [min_inf, 0, min_inf, min_inf, min_inf],
+ [min_inf, 0, 0, min_inf, min_inf],
+ [min_inf, 0, 0, 0, min_inf],
+ [min_inf, 0, 0, 0, 0],
+ ]
+ ],
+ ]
+ )
+
+ self.assertTrue(torch.equal(reference_output, result))
+
+ @require_torch
+ @slow
+ def test_unmask_unattended_right_padding(self):
+ attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)
+
+ attn_mask_converter = AttentionMaskConverter(is_causal=True)
+ past_key_values_length = 0
+ key_value_length = attention_mask.shape[-1] + past_key_values_length
+
+ expanded_mask = attn_mask_converter.to_4d(
+ attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
+ )
+
+ result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
+
+ self.assertTrue(torch.equal(expanded_mask, result))
+
+ @require_torch
+ @slow
+ def test_unmask_unattended_random_mask(self):
+ attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)
+
+ attn_mask_converter = AttentionMaskConverter(is_causal=True)
+ past_key_values_length = 0
+ key_value_length = attention_mask.shape[-1] + past_key_values_length
+
+ expanded_mask = attn_mask_converter.to_4d(
+ attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
+ )
+
+ result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
+
+ self.assertTrue(torch.equal(expanded_mask, result))
+
+
+@require_torch
+class TestAttentionImplementation(unittest.TestCase):
+ def test_error_no_sdpa_available(self):
+ with self.assertRaises(ValueError) as cm:
+ _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")
+
+ self.assertTrue(
+ "does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
+ in str(cm.exception)
+ )
+
+ _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
+
+ def test_error_no_flash_available(self):
+ with self.assertRaises(ValueError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
+ )
+
+ self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
+
+ def test_error_wrong_attn_implementation(self):
+ with self.assertRaises(ValueError) as cm:
+ _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
+
+ self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))
+
+ def test_not_available_flash(self):
+ if is_flash_attn_2_available():
+ self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
+
+ with self.assertRaises(ImportError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
+ )
+
+ self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
+
+ def test_not_available_sdpa(self):
+ if is_torch_sdpa_available():
+ self.skipTest("This test requires torch<=2.0")
+
+ with self.assertRaises(ImportError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa"
+ )
+
+ self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
diff --git a/tests/utils/test_doc_samples.py b/tests/utils/test_doc_samples.py
index 84c5a4d2bf5008..953654537843ee 100644
--- a/tests/utils/test_doc_samples.py
+++ b/tests/utils/test_doc_samples.py
@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import doctest
import logging
import os
import unittest
+from glob import glob
from pathlib import Path
from typing import List, Union
@@ -27,6 +27,63 @@
logger = logging.getLogger()
+@require_torch
+class TestDocLists(unittest.TestCase):
+ def test_flash_support_list(self):
+ with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
+ doctext = f.read()
+
+ doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
+ doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
+
+ patterns = glob("./src/transformers/models/**/modeling_*.py")
+ patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
+ patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
+ patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
+ archs_supporting_fa2 = []
+ for filename in patterns:
+ with open(filename, "r") as f:
+ text = f.read()
+
+ if "_supports_flash_attn_2 = True" in text:
+ model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
+ archs_supporting_fa2.append(model_name)
+
+ for arch in archs_supporting_fa2:
+ if arch not in doctext:
+ raise ValueError(
+ f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
+ )
+
+ def test_sdpa_support_list(self):
+ with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
+ doctext = f.read()
+
+ doctext = doctext.split(
+ "For now, Transformers supports inference and training through SDPA for the following architectures:"
+ )[1]
+ doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
+
+ patterns = glob("./src/transformers/models/**/modeling_*.py")
+ patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
+ patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
+ patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
+ archs_supporting_sdpa = []
+ for filename in patterns:
+ with open(filename, "r") as f:
+ text = f.read()
+
+ if "_supports_sdpa = True" in text:
+ model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
+ archs_supporting_sdpa.append(model_name)
+
+ for arch in archs_supporting_sdpa:
+ if arch not in doctext:
+ raise ValueError(
+ f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
+ )
+
+
@unittest.skip("Temporarily disable the doc tests.")
@require_torch
@require_tf