From c1316c3f60d2260474712404c50604363edd886a Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:50:22 +0100 Subject: [PATCH 1/3] fix torch.export.export for llama --- docs/source/en/perf_infer_gpu_one.md | 2 +- docs/source/en/perf_train_gpu_one.md | 22 ++----------------- src/transformers/modeling_attn_mask_utils.py | 18 ++++++++++----- .../models/llama/modeling_llama.py | 16 ++++++++++---- 4 files changed, 28 insertions(+), 30 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 69512acd6a6c3f..b03460a7a0d15c 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -184,7 +184,7 @@ For now, Transformers supports SDPA inference and training for the following arc -FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. +FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models. diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md index 1d885ba03646c7..99a1771e8b5121 100644 --- a/docs/source/en/perf_train_gpu_one.md +++ b/docs/source/en/perf_train_gpu_one.md @@ -527,26 +527,8 @@ Most related papers and implementations are built around Tensorflow/TPUs: And for Pytorch DeepSpeed has built one as well: [DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale](https://arxiv.org/abs/2201.05596), [Mixture of Experts](https://www.deepspeed.ai/tutorials/mixture-of-experts/) - blog posts: [1](https://www.microsoft.com/en-us/research/blog/deepspeed-powers-8x-larger-moe-model-training-with-high-performance/), [2](https://www.microsoft.com/en-us/research/publication/scalable-and-efficient-moe-training-for-multitask-multilingual-models/) and specific deployment with large transformer-based natural language generation models: [blog post](https://www.deepspeed.ai/2021/12/09/deepspeed-moe-nlg.html), [Megatron-Deepspeed branch](https://github.com/microsoft/Megatron-DeepSpeed/tree/moe-training). -## Using PyTorch native attention and Flash Attention +## Using PyTorch native scaled dot product attention -PyTorch 2.0 released a native [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA), -that allows using fused GPU kernels such as [memory-efficient attention](https://arxiv.org/abs/2112.05682) and [flash attention](https://arxiv.org/abs/2205.14135). - -After installing the [`optimum`](https://github.com/huggingface/optimum) package, the relevant internal modules can be -replaced to use PyTorch's native attention with: - -```python -model = model.to_bettertransformer() -``` - -Once converted, train the model as usual. - - - -The PyTorch-native `scaled_dot_product_attention` operator can only dispatch to Flash Attention if no `attention_mask` is provided. - -By default, in training mode, the BetterTransformer integration **drops the mask support and can only be used for training that does not require a padding mask for batched training**. This is the case, for example, during masked language modeling or causal language modeling. BetterTransformer is not suited for fine-tuning models on tasks that require a padding mask. - - +PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. Please refer to [PyTorch scaled dot product attention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) for a list of supported models and more details. Check out this [blogpost](https://pytorch.org/blog/out-of-the-box-acceleration/) to learn more about acceleration and memory-savings with SDPA. diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 67555239c758ae..1a2c0db7bb140c 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -349,8 +349,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa( # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. - # TODO: Fix this as well when using torchdynamo with fullgraph=True. - is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) if attention_mask is not None: # 4d mask is passed through @@ -448,10 +452,14 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, batch_size, key_value_length = mask.shape tgt_len = tgt_len if tgt_len is not None else key_value_length - # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. - # TODO: Fix this as well when using torchdynamo with fullgraph=True. - is_tracing = torch.jit.is_tracing() + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(mask, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) if torch.all(mask == 1): if is_tracing: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8e494adefc2d73..1d41bf13710e62 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1076,10 +1076,18 @@ def _update_causal_mask(self, attention_mask, input_tensor): padding_mask, torch.finfo(dtype).min ) - if self.config._attn_implementation == "sdpa": - is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) - if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1): - causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to( + if self.config._attn_implementation == "sdpa" and attention_mask is not None: + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(input_tensor, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + if not is_tracing and torch.any(attention_mask != 1): + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to( dtype ) From c2fdb495ccd8c34be4089182df5a5de0060fb48a Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:53:42 +0100 Subject: [PATCH 2/3] do not change doc title --- docs/source/en/perf_train_gpu_one.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md index 99a1771e8b5121..df27f178616b91 100644 --- a/docs/source/en/perf_train_gpu_one.md +++ b/docs/source/en/perf_train_gpu_one.md @@ -527,7 +527,7 @@ Most related papers and implementations are built around Tensorflow/TPUs: And for Pytorch DeepSpeed has built one as well: [DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale](https://arxiv.org/abs/2201.05596), [Mixture of Experts](https://www.deepspeed.ai/tutorials/mixture-of-experts/) - blog posts: [1](https://www.microsoft.com/en-us/research/blog/deepspeed-powers-8x-larger-moe-model-training-with-high-performance/), [2](https://www.microsoft.com/en-us/research/publication/scalable-and-efficient-moe-training-for-multitask-multilingual-models/) and specific deployment with large transformer-based natural language generation models: [blog post](https://www.deepspeed.ai/2021/12/09/deepspeed-moe-nlg.html), [Megatron-Deepspeed branch](https://github.com/microsoft/Megatron-DeepSpeed/tree/moe-training). -## Using PyTorch native scaled dot product attention +## Using PyTorch native attention and Flash Attention PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. Please refer to [PyTorch scaled dot product attention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) for a list of supported models and more details. From 25dd52185cdb3f94644bedda1f9c496a62eb754e Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 22 Feb 2024 15:27:49 +0100 Subject: [PATCH 3/3] make fix copies --- src/transformers/models/gemma/modeling_gemma.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 165ef5a0545182..d5cfed296a903e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -969,10 +969,18 @@ def _update_causal_mask(self, attention_mask, input_tensor): padding_mask, torch.finfo(dtype).min ) - if self.config._attn_implementation == "sdpa": - is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) - if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1): - causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to( + if self.config._attn_implementation == "sdpa" and attention_mask is not None: + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(input_tensor, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + if not is_tracing and torch.any(attention_mask != 1): + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to( dtype )