From c1316c3f60d2260474712404c50604363edd886a Mon Sep 17 00:00:00 2001
From: fxmarty <9808326+fxmarty@users.noreply.github.com>
Date: Thu, 22 Feb 2024 13:50:22 +0100
Subject: [PATCH 1/3] fix torch.export.export for llama
---
docs/source/en/perf_infer_gpu_one.md | 2 +-
docs/source/en/perf_train_gpu_one.md | 22 ++-----------------
src/transformers/modeling_attn_mask_utils.py | 18 ++++++++++-----
.../models/llama/modeling_llama.py | 16 ++++++++++----
4 files changed, 28 insertions(+), 30 deletions(-)
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 69512acd6a6c3f..b03460a7a0d15c 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -184,7 +184,7 @@ For now, Transformers supports SDPA inference and training for the following arc
-FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first.
+FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md
index 1d885ba03646c7..99a1771e8b5121 100644
--- a/docs/source/en/perf_train_gpu_one.md
+++ b/docs/source/en/perf_train_gpu_one.md
@@ -527,26 +527,8 @@ Most related papers and implementations are built around Tensorflow/TPUs:
And for Pytorch DeepSpeed has built one as well: [DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale](https://arxiv.org/abs/2201.05596), [Mixture of Experts](https://www.deepspeed.ai/tutorials/mixture-of-experts/) - blog posts: [1](https://www.microsoft.com/en-us/research/blog/deepspeed-powers-8x-larger-moe-model-training-with-high-performance/), [2](https://www.microsoft.com/en-us/research/publication/scalable-and-efficient-moe-training-for-multitask-multilingual-models/) and specific deployment with large transformer-based natural language generation models: [blog post](https://www.deepspeed.ai/2021/12/09/deepspeed-moe-nlg.html), [Megatron-Deepspeed branch](https://github.com/microsoft/Megatron-DeepSpeed/tree/moe-training).
-## Using PyTorch native attention and Flash Attention
+## Using PyTorch native scaled dot product attention
-PyTorch 2.0 released a native [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA),
-that allows using fused GPU kernels such as [memory-efficient attention](https://arxiv.org/abs/2112.05682) and [flash attention](https://arxiv.org/abs/2205.14135).
-
-After installing the [`optimum`](https://github.com/huggingface/optimum) package, the relevant internal modules can be
-replaced to use PyTorch's native attention with:
-
-```python
-model = model.to_bettertransformer()
-```
-
-Once converted, train the model as usual.
-
-
-
-The PyTorch-native `scaled_dot_product_attention` operator can only dispatch to Flash Attention if no `attention_mask` is provided.
-
-By default, in training mode, the BetterTransformer integration **drops the mask support and can only be used for training that does not require a padding mask for batched training**. This is the case, for example, during masked language modeling or causal language modeling. BetterTransformer is not suited for fine-tuning models on tasks that require a padding mask.
-
-
+PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. Please refer to [PyTorch scaled dot product attention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) for a list of supported models and more details.
Check out this [blogpost](https://pytorch.org/blog/out-of-the-box-acceleration/) to learn more about acceleration and memory-savings with SDPA.
diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py
index 67555239c758ae..1a2c0db7bb140c 100755
--- a/src/transformers/modeling_attn_mask_utils.py
+++ b/src/transformers/modeling_attn_mask_utils.py
@@ -349,8 +349,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
- # TODO: Fix this as well when using torchdynamo with fullgraph=True.
- is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(inputs_embeds, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
if attention_mask is not None:
# 4d mask is passed through
@@ -448,10 +452,14 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
batch_size, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length
- # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
- # TODO: Fix this as well when using torchdynamo with fullgraph=True.
- is_tracing = torch.jit.is_tracing()
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(mask, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
if torch.all(mask == 1):
if is_tracing:
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 8e494adefc2d73..1d41bf13710e62 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -1076,10 +1076,18 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask, torch.finfo(dtype).min
)
- if self.config._attn_implementation == "sdpa":
- is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
- if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
- causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None:
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(input_tensor, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
+ if not is_tracing and torch.any(attention_mask != 1):
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to(
dtype
)
From c2fdb495ccd8c34be4089182df5a5de0060fb48a Mon Sep 17 00:00:00 2001
From: fxmarty <9808326+fxmarty@users.noreply.github.com>
Date: Thu, 22 Feb 2024 13:53:42 +0100
Subject: [PATCH 2/3] do not change doc title
---
docs/source/en/perf_train_gpu_one.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md
index 99a1771e8b5121..df27f178616b91 100644
--- a/docs/source/en/perf_train_gpu_one.md
+++ b/docs/source/en/perf_train_gpu_one.md
@@ -527,7 +527,7 @@ Most related papers and implementations are built around Tensorflow/TPUs:
And for Pytorch DeepSpeed has built one as well: [DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale](https://arxiv.org/abs/2201.05596), [Mixture of Experts](https://www.deepspeed.ai/tutorials/mixture-of-experts/) - blog posts: [1](https://www.microsoft.com/en-us/research/blog/deepspeed-powers-8x-larger-moe-model-training-with-high-performance/), [2](https://www.microsoft.com/en-us/research/publication/scalable-and-efficient-moe-training-for-multitask-multilingual-models/) and specific deployment with large transformer-based natural language generation models: [blog post](https://www.deepspeed.ai/2021/12/09/deepspeed-moe-nlg.html), [Megatron-Deepspeed branch](https://github.com/microsoft/Megatron-DeepSpeed/tree/moe-training).
-## Using PyTorch native scaled dot product attention
+## Using PyTorch native attention and Flash Attention
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. Please refer to [PyTorch scaled dot product attention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) for a list of supported models and more details.
From 25dd52185cdb3f94644bedda1f9c496a62eb754e Mon Sep 17 00:00:00 2001
From: fxmarty <9808326+fxmarty@users.noreply.github.com>
Date: Thu, 22 Feb 2024 15:27:49 +0100
Subject: [PATCH 3/3] make fix copies
---
src/transformers/models/gemma/modeling_gemma.py | 16 ++++++++++++----
1 file changed, 12 insertions(+), 4 deletions(-)
diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py
index 165ef5a0545182..d5cfed296a903e 100644
--- a/src/transformers/models/gemma/modeling_gemma.py
+++ b/src/transformers/models/gemma/modeling_gemma.py
@@ -969,10 +969,18 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask, torch.finfo(dtype).min
)
- if self.config._attn_implementation == "sdpa":
- is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
- if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
- causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None:
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(input_tensor, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
+ if not is_tracing and torch.any(attention_mask != 1):
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to(
dtype
)