From 1576d5824d895d8010d71e61f292f1b11248d1ac Mon Sep 17 00:00:00 2001 From: hyenal Date: Thu, 16 May 2024 10:56:11 +0100 Subject: [PATCH] add sdpa to ViT [follow up of #29325] (#30555) remove blank line (+1 squashed commit) Squashed commits: [24ccd2061] [run-slow]vit_msn,vision_encoder_decoder (+24 squashed commits) Squashed commits: [08bd27e7a] [run-slow]vit_msn,vision_encoder_decoder [ec96a8db3] [run-slow]vit_msn [ead817eca] fix vit msn multi gpu [d12cdc8fd] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos [3fdbfa88f] doc [a3ff33e4a] finish implementation [e20b7b7fb] Update test_modeling_common.py [e290c5810] Update test_modeling_flax_common.py [d3af86f46] comment [ff7dd32d8] more comments [59b137889] suggestion [7e2ba6d67] attn_implementation as attribute of the class [fe66ab71f] minor [38642b568] Apply suggestions from code review Accept comments Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [22cde7d52] Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [48e137cc6] Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [99f4c679f] Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [96cf20a6d] Update src/transformers/models/vit_msn/modeling_vit_msn.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [c59377d23] Update src/transformers/models/vit_mae/modeling_vit_mae.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [b70a47259] Update tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [00c84d216] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos [61f00ebb0] all tests are passing locally [e9e0b82b7] vision encoder/decoder [4d5076b56] test-vision (+20 squashed commits) Squashed commits: [d1add8db9] yolo [9fde65716] fix flax [986566c28] minor [ca2f21d1f] vit [3333efd7a] easy models change [ebfc21402] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos [b8b8603ed] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos [48ecc7e26] all tests are passing locally [bff7fc366] minor [62f88306f] fix yolo and text_encoder tests [121507555] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae [1064cae0a] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos [b7f52ff3a] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae [cffaa10dd] fix-copies [ef6c511c4] test vit hybrid [7d4ba8644] vit hybrid [66f919033] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae [1fcc0a031] fixes [cfde6eb21] fixup [e77df1ed3] all except yolo end encoder decoder (+17 squashed commits) Squashed commits: [602913e22] vit + vit_mae are working [547f6c4cc] RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/ passes [61a97dfa9] it s the complete opposite... [aefab37d4] fix more tests [71802a1b9] fix all torch tests [40b12eb58] encoder - decoder tests [941552b69] slow decorator where appropriate [14d055d80] has_attentions to yolo and msn [3381fa19f] add correct name [e261316a7] repo consistency [31c6d0c08] fixup [9d214276c] minor fix [11ed2e1b7] chore [eca6644c4] add sdpa to vit-based models [cffbf390b] make fix-copies result [6468319b0] fix style [d324cd02a] add sdpa for vit Co-authored-by: Liubov Yaronskaya --- .../audio-spectrogram-transformer.md | 28 +++++++++ docs/source/en/model_doc/deit.md | 28 +++++++++ docs/source/en/model_doc/videomae.md | 28 +++++++++ docs/source/en/model_doc/vit.md | 28 +++++++++ docs/source/en/model_doc/vit_hybrid.md | 28 +++++++++ docs/source/en/model_doc/vit_mae.md | 28 +++++++++ docs/source/en/model_doc/vit_msn.md | 28 +++++++++ docs/source/en/model_doc/yolos.md | 28 +++++++++ docs/source/en/perf_infer_gpu_one.md | 8 +++ .../modeling_audio_spectrogram_transformer.py | 50 +++++++++++++++- src/transformers/models/deit/modeling_deit.py | 50 +++++++++++++++- .../models/videomae/modeling_videomae.py | 50 +++++++++++++++- .../modeling_vision_encoder_decoder.py | 14 ++++- src/transformers/models/vit/modeling_vit.py | 46 ++++++++++++++- .../models/vit_hybrid/modeling_vit_hybrid.py | 48 +++++++++++++++- .../models/vit_mae/modeling_vit_mae.py | 57 +++++++++++++++++-- .../models/vit_msn/modeling_vit_msn.py | 49 +++++++++++++++- .../models/yolos/modeling_yolos.py | 47 ++++++++++++++- ..._modeling_audio_spectrogram_transformer.py | 3 + tests/models/deit/test_modeling_deit.py | 7 +++ tests/models/deit/test_modeling_tf_deit.py | 3 + .../models/videomae/test_modeling_videomae.py | 6 +- ...test_modeling_tf_vision_encoder_decoder.py | 4 +- tests/models/vit/test_modeling_flax_vit.py | 3 + tests/models/vit/test_modeling_tf_vit.py | 3 + tests/models/vit/test_modeling_vit.py | 7 +++ .../vit_hybrid/test_modeling_vit_hybrid.py | 3 + .../vit_mae/test_modeling_tf_vit_mae.py | 3 + tests/models/vit_mae/test_modeling_vit_mae.py | 8 ++- tests/models/vit_msn/test_modeling_vit_msn.py | 3 + tests/models/yolos/test_modeling_yolos.py | 3 + tests/test_modeling_common.py | 30 +++++++++- tests/test_modeling_flax_common.py | 4 +- utils/check_support_list.py | 2 +- 34 files changed, 709 insertions(+), 26 deletions(-) mode change 100644 => 100755 src/transformers/models/videomae/modeling_videomae.py diff --git a/docs/source/en/model_doc/audio-spectrogram-transformer.md b/docs/source/en/model_doc/audio-spectrogram-transformer.md index 3eac3781667eb4..d83c3bbb6cf2fe 100644 --- a/docs/source/en/model_doc/audio-spectrogram-transformer.md +++ b/docs/source/en/model_doc/audio-spectrogram-transformer.md @@ -43,6 +43,34 @@ the authors compute the stats for a downstream dataset. - Note that the AST needs a low learning rate (the authors use a 10 times smaller learning rate compared to their CNN model proposed in the [PSLA paper](https://arxiv.org/abs/2102.01243)) and converges quickly, so please search for a suitable learning rate and learning rate scheduler for your task. +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import ASTForAudioClassification +model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MIT/ast-finetuned-audioset-10-10-0.4593` model, we saw the following speedups during inference. + +| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | +|--------------|-------------------------------------------|-------------------------------------------|------------------------------| +| 1 | 27 | 6 | 4.5 | +| 2 | 12 | 6 | 2 | +| 4 | 21 | 8 | 2.62 | +| 8 | 40 | 14 | 2.86 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with the Audio Spectrogram Transformer. diff --git a/docs/source/en/model_doc/deit.md b/docs/source/en/model_doc/deit.md index 7d9918a45eeeb6..6a4e141facaeac 100644 --- a/docs/source/en/model_doc/deit.md +++ b/docs/source/en/model_doc/deit.md @@ -68,6 +68,34 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The Tenso *facebook/deit-base-patch16-384*. Note that one should use [`DeiTImageProcessor`] in order to prepare images for the model. +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import DeiTForImageClassification +model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/deit-base-distilled-patch16-224` model, we saw the following speedups during inference. + +| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | +|--------------|-------------------------------------------|-------------------------------------------|------------------------------| +| 1 | 8 | 6 | 1.33 | +| 2 | 9 | 6 | 1.5 | +| 4 | 9 | 6 | 1.5 | +| 8 | 8 | 6 | 1.33 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DeiT. diff --git a/docs/source/en/model_doc/videomae.md b/docs/source/en/model_doc/videomae.md index 75eb9617380c57..a785611185700d 100644 --- a/docs/source/en/model_doc/videomae.md +++ b/docs/source/en/model_doc/videomae.md @@ -33,6 +33,34 @@ alt="drawing" width="600"/> This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/MCG-NJU/VideoMAE). +## Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import VideoMAEForVideoClassification +model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MCG-NJU/videomae-base-finetuned-kinetics` model, we saw the following speedups during inference. + +| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | +|--------------|-------------------------------------------|-------------------------------------------|------------------------------| +| 1 | 37 | 10 | 3.7 | +| 2 | 24 | 18 | 1.33 | +| 4 | 43 | 32 | 1.34 | +| 8 | 84 | 60 | 1.4 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with VideoMAE. If diff --git a/docs/source/en/model_doc/vit.md b/docs/source/en/model_doc/vit.md index 25c3a6c8f537f4..b49cb821859f59 100644 --- a/docs/source/en/model_doc/vit.md +++ b/docs/source/en/model_doc/vit.md @@ -88,6 +88,34 @@ who already converted the weights from JAX to PyTorch. Credits go to him! language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant improvement of 2% to training from scratch, but still 4% behind supervised pre-training. +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import ViTForImageClassification +model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference. + +| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | +|--------------|-------------------------------------------|-------------------------------------------|------------------------------| +| 1 | 7 | 6 | 1.17 | +| 2 | 8 | 6 | 1.33 | +| 4 | 8 | 6 | 1.33 | +| 8 | 8 | 6 | 1.33 | + ## Resources Demo notebooks regarding inference as well as fine-tuning ViT on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer). diff --git a/docs/source/en/model_doc/vit_hybrid.md b/docs/source/en/model_doc/vit_hybrid.md index 52c0d35bc13538..ec98fc5e1ef8e0 100644 --- a/docs/source/en/model_doc/vit_hybrid.md +++ b/docs/source/en/model_doc/vit_hybrid.md @@ -39,6 +39,34 @@ substantially fewer computational resources to train.* This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be found [here](https://github.com/google-research/vision_transformer). +## Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import ViTHybridForImageClassification +model = ViTHybridForImageClassification.from_pretrained("google/vit-hybrid-base-bit-384", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-hybrid-base-bit-384` model, we saw the following speedups during inference. + +| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | +|--------------|-------------------------------------------|-------------------------------------------|------------------------------| +| 1 | 29 | 18 | 1.61 | +| 2 | 26 | 18 | 1.44 | +| 4 | 25 | 18 | 1.39 | +| 8 | 34 | 24 | 1.42 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT Hybrid. diff --git a/docs/source/en/model_doc/vit_mae.md b/docs/source/en/model_doc/vit_mae.md index 27d6d26816ae49..8d0a40c8a3e1be 100644 --- a/docs/source/en/model_doc/vit_mae.md +++ b/docs/source/en/model_doc/vit_mae.md @@ -52,6 +52,34 @@ consists of Transformer blocks) takes as input. Each mask token is a shared, lea sin/cos position embeddings are added both to the input of the encoder and the decoder. - For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/). +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import ViTMAEModel +model = ViTMAEModel.from_pretrained("facebook/vit-mae-base", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-mae-base` model, we saw the following speedups during inference. + +| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | +|--------------|-------------------------------------------|-------------------------------------------|------------------------------| +| 1 | 11 | 6 | 1.83 | +| 2 | 8 | 6 | 1.33 | +| 4 | 8 | 6 | 1.33 | +| 8 | 8 | 6 | 1.33 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTMAE. diff --git a/docs/source/en/model_doc/vit_msn.md b/docs/source/en/model_doc/vit_msn.md index 666b7dd0dfda83..e1210ce7f9dd9a 100644 --- a/docs/source/en/model_doc/vit_msn.md +++ b/docs/source/en/model_doc/vit_msn.md @@ -49,6 +49,34 @@ use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMS - MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K labels when fine-tuned. +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import ViTMSNForImageClassification +model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-base", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-msn-base` model, we saw the following speedups during inference. + +| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | +|--------------|-------------------------------------------|-------------------------------------------|------------------------------| +| 1 | 7 | 6 | 1.17 | +| 2 | 8 | 6 | 1.33 | +| 4 | 8 | 6 | 1.33 | +| 8 | 8 | 6 | 1.33 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT MSN. diff --git a/docs/source/en/model_doc/yolos.md b/docs/source/en/model_doc/yolos.md index 023256914f64c1..ebe249517fdf3b 100644 --- a/docs/source/en/model_doc/yolos.md +++ b/docs/source/en/model_doc/yolos.md @@ -32,6 +32,34 @@ alt="drawing" width="600"/> This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/hustvl/YOLOS). +## Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import AutoModelForObjectDetection +model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-base", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `hustvl/yolos-base` model, we saw the following speedups during inference. + +| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | +|--------------|-------------------------------------------|-------------------------------------------|------------------------------| +| 1 | 106 | 76 | 1.39 | +| 2 | 154 | 90 | 1.71 | +| 4 | 222 | 116 | 1.91 | +| 8 | 368 | 168 | 2.19 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with YOLOS. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 18fb16151185ed..41a5d09a0d2d35 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -192,10 +192,12 @@ FlashAttention is more memory efficient, meaning you can train on much larger se PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. For now, Transformers supports SDPA inference and training for the following architectures: +* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) +* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) @@ -216,12 +218,18 @@ For now, Transformers supports SDPA inference and training for the following arc * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) +* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel) +* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel) +* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel) +* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel) +* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell) * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) +* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 1d70e57c2fd128..523ab85f14f7cd 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -169,6 +169,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST +class ASTSdpaSelfAttention(ASTSelfAttention): + def __init__(self, config: ASTConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST class ASTSelfOutput(nn.Module): """ @@ -228,6 +260,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST +class ASTSdpaAttention(ASTAttention): + def __init__(self, config: ASTConfig) -> None: + super().__init__(config) + self.attention = ASTSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST class ASTIntermediate(nn.Module): def __init__(self, config: ASTConfig) -> None: @@ -261,7 +300,13 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST +AST_ATTENTION_CLASSES = { + "eager": ASTAttention, + "sdpa": ASTSdpaAttention, +} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST class ASTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -269,7 +314,7 @@ def __init__(self, config: ASTConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ASTAttention(config) + self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ASTIntermediate(config) self.output = ASTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -366,6 +411,7 @@ class ASTPreTrainedModel(PreTrainedModel): base_model_prefix = "audio_spectrogram_transformer" main_input_name = "input_values" supports_gradient_checkpointing = True + _supports_sdpa = True # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 2480b99586192f..fe811ecc4a70c9 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -190,6 +190,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT +class DeiTSdpaSelfAttention(DeiTSelfAttention): + def __init__(self, config: DeiTConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT class DeiTSelfOutput(nn.Module): """ @@ -249,6 +281,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT +class DeiTSdpaAttention(DeiTAttention): + def __init__(self, config: DeiTConfig) -> None: + super().__init__(config) + self.attention = DeiTSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT class DeiTIntermediate(nn.Module): def __init__(self, config: DeiTConfig) -> None: @@ -282,7 +321,13 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT +DEIT_ATTENTION_CLASSES = { + "eager": DeiTAttention, + "sdpa": DeiTSdpaAttention, +} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT class DeiTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -290,7 +335,7 @@ def __init__(self, config: DeiTConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = DeiTAttention(config) + self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = DeiTIntermediate(config) self.output = DeiTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -388,6 +433,7 @@ class DeiTPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["DeiTLayer"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py old mode 100644 new mode 100755 index 100bee54389569..05f74328b4c85f --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -134,7 +134,6 @@ def forward(self, pixel_values, bool_masked_pos): # add position embeddings embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach() - # only keep visible patches # ~bool_masked_pos means visible if bool_masked_pos is not None: @@ -268,6 +267,40 @@ def forward( return outputs +class VideoMAESdpaSelfAttention(VideoMAESelfAttention): + def __init__(self, config: VideoMAEConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None + keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias) + values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias) + queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias) + + key_layer = self.transpose_for_scores(keys) + value_layer = self.transpose_for_scores(values) + query_layer = self.transpose_for_scores(queries) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE class VideoMAESelfOutput(nn.Module): """ @@ -327,6 +360,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE +class VideoMAESdpaAttention(VideoMAEAttention): + def __init__(self, config: VideoMAEConfig) -> None: + super().__init__(config) + self.attention = VideoMAESdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE class VideoMAEIntermediate(nn.Module): def __init__(self, config: VideoMAEConfig) -> None: @@ -360,7 +400,10 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE +VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE class VideoMAELayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -368,7 +411,7 @@ def __init__(self, config: VideoMAEConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = VideoMAEAttention(config) + self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = VideoMAEIntermediate(config) self.output = VideoMAEOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -465,6 +508,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel): base_model_prefix = "videomae" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index fc72eb1cbdf831..1dbee5efccaf94 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -336,8 +336,20 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): del tf_model gc.collect() + attn_implementation = kwargs.get("attn_implementation", None) + kwargs_encoder_decoder = {} + if attn_implementation: + kwargs_encoder_decoder = { + "encoder_attn_implementation": attn_implementation, + "decoder_attn_implementation": attn_implementation, + } + model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( - encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True + encoder_dir, + decoder_dir, + encoder_from_tf=True, + decoder_from_tf=True, + **kwargs_encoder_decoder, ) # This is only for copying some specific attributes of this particular model. model.config = config diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 8aa43c5c43c500..dfda7bf731ba0b 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -236,6 +236,37 @@ def forward( return outputs +class ViTSdpaSelfAttention(ViTSelfAttention): + def __init__(self, config: ViTConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + class ViTSelfOutput(nn.Module): """ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the @@ -293,6 +324,12 @@ def forward( return outputs +class ViTSdpaAttention(ViTAttention): + def __init__(self, config: ViTConfig) -> None: + super().__init__(config) + self.attention = ViTSdpaSelfAttention(config) + + class ViTIntermediate(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() @@ -324,6 +361,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states +VIT_ATTENTION_CLASSES = { + "eager": ViTAttention, + "sdpa": ViTSdpaAttention, +} + + class ViTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -331,7 +374,7 @@ def __init__(self, config: ViTConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTAttention(config) + self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTIntermediate(config) self.output = ViTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -428,6 +471,7 @@ class ViTPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["ViTEmbeddings", "ViTLayer"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 20579e0d3db2cc..359b5e3fb9b08f 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -248,6 +248,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTHybrid +class ViTHybridSdpaSelfAttention(ViTHybridSelfAttention): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTHybrid class ViTHybridSelfOutput(nn.Module): """ @@ -307,6 +339,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTHybrid +class ViTHybridSdpaAttention(ViTHybridAttention): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__(config) + self.attention = ViTHybridSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTHybrid class ViTHybridIntermediate(nn.Module): def __init__(self, config: ViTHybridConfig) -> None: @@ -340,6 +379,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states +VIT_HYBRID_ATTENTION_CLASSES = { + "eager": ViTHybridAttention, + "sdpa": ViTHybridSdpaAttention, +} + + class ViTHybridLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -347,7 +392,7 @@ def __init__(self, config: ViTHybridConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTHybridAttention(config) + self.attention = VIT_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTHybridIntermediate(config) self.output = ViTHybridOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -447,6 +492,7 @@ class ViTHybridPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index b652c9e71f9106..f434efbe3e5af5 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -241,8 +241,8 @@ def random_masking(self, sequence, noise=None): noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] # sort noise for each sample - ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove - ids_restore = torch.argsort(ids_shuffle, dim=1) + ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] @@ -370,6 +370,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE +class ViTMAESdpaSelfAttention(ViTMAESelfAttention): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE class ViTMAESelfOutput(nn.Module): """ @@ -429,6 +461,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE +class ViTMAESdpaAttention(ViTMAEAttention): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__(config) + self.attention = ViTMAESdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE class ViTMAEIntermediate(nn.Module): def __init__(self, config: ViTMAEConfig) -> None: @@ -462,7 +501,13 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE +VITMAE_ATTENTION_CLASSES = { + "eager": ViTMAEAttention, + "sdpa": ViTMAESdpaAttention, +} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE class ViTMAELayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -470,7 +515,7 @@ def __init__(self, config: ViTMAEConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTMAEAttention(config) + self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTMAEIntermediate(config) self.output = ViTMAEOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -567,6 +612,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" @@ -764,7 +810,8 @@ def forward( # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token - x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + # unshuffle + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 0632738455d1ab..fe5a880027fc60 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -222,6 +222,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTMSN +class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN class ViTMSNSelfOutput(nn.Module): """ @@ -281,6 +313,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMSN +class ViTMSNSdpaAttention(ViTMSNAttention): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__(config) + self.attention = ViTMSNSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN class ViTMSNIntermediate(nn.Module): def __init__(self, config: ViTMSNConfig) -> None: @@ -314,7 +353,10 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN +VITMSN_ATTENTION_CLASSES = {"eager": ViTMSNAttention, "sdpa": ViTMSNSdpaAttention} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN class ViTMSNLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -322,7 +364,7 @@ def __init__(self, config: ViTMSNConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTMSNAttention(config) + self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTMSNIntermediate(config) self.output = ViTMSNOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -419,7 +461,8 @@ class ViTMSNPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True - _no_split_modules = ["ViTMSNAttention"] + _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"] + _supports_sdpa = True # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts. diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 9d6536b6c27258..faae349e6e7707 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -307,6 +307,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Yolos +class YolosSdpaSelfAttention(YolosSelfAttention): + def __init__(self, config: YolosConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos class YolosSelfOutput(nn.Module): """ @@ -366,6 +398,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Yolos +class YolosSdpaAttention(YolosAttention): + def __init__(self, config: YolosConfig) -> None: + super().__init__(config) + self.attention = YolosSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos class YolosIntermediate(nn.Module): def __init__(self, config: YolosConfig) -> None: @@ -399,7 +438,10 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos +YOLOS_ATTENTION_CLASSES = {"eager": YolosAttention, "sdpa": YolosSdpaAttention} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS class YolosLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -407,7 +449,7 @@ def __init__(self, config: YolosConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = YolosAttention(config) + self.attention = YOLOS_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = YolosIntermediate(config) self.output = YolosOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -531,6 +573,7 @@ class YolosPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = [] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py b/tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py index 564ca4d48c6a7f..ad10b99547f261 100644 --- a/tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py +++ b/tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py @@ -63,6 +63,7 @@ def __init__( scope=None, frequency_stride=2, time_stride=2, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -83,6 +84,7 @@ def __init__( self.scope = scope self.frequency_stride = frequency_stride self.time_stride = time_stride + self.attn_implementation = attn_implementation # in AST, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens) frequency_out_dimension = (self.num_mel_bins - self.patch_size) // self.frequency_stride + 1 @@ -117,6 +119,7 @@ def get_config(self): initializer_range=self.initializer_range, frequency_stride=self.frequency_stride, time_stride=self.time_stride, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, input_values, labels): diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index 9a54f16dab689f..adc7154a5072a5 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -80,6 +80,8 @@ def __init__( num_labels=3, scope=None, encoder_stride=2, + mask_ratio=0.5, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -99,10 +101,14 @@ def __init__( self.initializer_range = initializer_range self.scope = scope self.encoder_stride = encoder_stride + self.attn_implementation = attn_implementation # in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens) num_patches = (image_size // patch_size) ** 2 self.seq_length = num_patches + 2 + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -130,6 +136,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/deit/test_modeling_tf_deit.py b/tests/models/deit/test_modeling_tf_deit.py index 26980e84207d50..c26635cef6ddb9 100644 --- a/tests/models/deit/test_modeling_tf_deit.py +++ b/tests/models/deit/test_modeling_tf_deit.py @@ -71,6 +71,7 @@ def __init__( num_labels=3, scope=None, encoder_stride=2, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -90,6 +91,7 @@ def __init__( self.initializer_range = initializer_range self.scope = scope self.encoder_stride = encoder_stride + self.attn_implementation = attn_implementation # in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens) num_patches = (image_size // patch_size) ** 2 @@ -121,6 +123,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py index e5b1c6b78e40dd..425fe2bcd7e09b 100644 --- a/tests/models/videomae/test_modeling_videomae.py +++ b/tests/models/videomae/test_modeling_videomae.py @@ -70,6 +70,7 @@ def __init__( initializer_range=0.02, mask_ratio=0.9, scope=None, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -91,6 +92,7 @@ def __init__( self.initializer_range = initializer_range self.mask_ratio = mask_ratio self.scope = scope + self.attn_implementation = attn_implementation # in VideoMAE, the number of tokens equals num_frames/tubelet_size * num_patches per frame self.num_patches_per_frame = (image_size // patch_size) ** 2 @@ -132,6 +134,7 @@ def get_config(self): decoder_intermediate_size=self.intermediate_size, decoder_num_attention_heads=self.num_attention_heads, decoder_num_hidden_layers=self.num_hidden_layers, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): @@ -197,7 +200,8 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): # hence we define a single mask, which we then repeat for each example in the batch mask = torch.ones((self.model_tester.num_masks,)) mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))]) - bool_masked_pos = mask.expand(self.model_tester.batch_size, -1).bool() + batch_size = inputs_dict["pixel_values"].shape[0] + bool_masked_pos = mask.expand(batch_size, -1).bool() inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device) if return_labels: diff --git a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index d512ff25fe35ac..171f33d6802a4a 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -492,7 +492,9 @@ def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict): with tempfile.TemporaryDirectory() as tmpdirname: tf_model.save_pretrained(tmpdirname, safe_serialization=False) - pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True) + pt_model = VisionEncoderDecoderModel.from_pretrained( + tmpdirname, from_tf=True, attn_implementation=tf_model.config._attn_implementation + ) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) diff --git a/tests/models/vit/test_modeling_flax_vit.py b/tests/models/vit/test_modeling_flax_vit.py index af56f4717b888f..fb53caa3433ac2 100644 --- a/tests/models/vit/test_modeling_flax_vit.py +++ b/tests/models/vit/test_modeling_flax_vit.py @@ -49,6 +49,7 @@ def __init__( attention_probs_dropout_prob=0.1, type_sequence_label_size=10, initializer_range=0.02, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -66,6 +67,7 @@ def __init__( self.attention_probs_dropout_prob = attention_probs_dropout_prob self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range + self.attn_implementation = attn_implementation # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 @@ -87,6 +89,7 @@ def prepare_config_and_inputs(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + attn_implementation=self.attn_implementation, ) return config, pixel_values diff --git a/tests/models/vit/test_modeling_tf_vit.py b/tests/models/vit/test_modeling_tf_vit.py index dee2c8f18c171a..2c06b0bc609b7c 100644 --- a/tests/models/vit/test_modeling_tf_vit.py +++ b/tests/models/vit/test_modeling_tf_vit.py @@ -63,6 +63,7 @@ def __init__( initializer_range=0.02, num_labels=3, scope=None, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -81,6 +82,7 @@ def __init__( self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.scope = scope + self.attn_implementation = attn_implementation # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 @@ -111,6 +113,7 @@ def get_config(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index 7298543a563438..e150bd0be9539a 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -68,6 +68,8 @@ def __init__( initializer_range=0.02, scope=None, encoder_stride=2, + mask_ratio=0.5, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -87,10 +89,14 @@ def __init__( self.initializer_range = initializer_range self.scope = scope self.encoder_stride = encoder_stride + self.attn_implementation = attn_implementation # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 self.seq_length = num_patches + 1 + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -118,6 +124,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit_hybrid/test_modeling_vit_hybrid.py b/tests/models/vit_hybrid/test_modeling_vit_hybrid.py index d48a8853921649..b3d0040c22f360 100644 --- a/tests/models/vit_hybrid/test_modeling_vit_hybrid.py +++ b/tests/models/vit_hybrid/test_modeling_vit_hybrid.py @@ -58,6 +58,7 @@ def __init__( initializer_range=0.02, backbone_featmap_shape=[1, 16, 4, 4], scope=None, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -77,6 +78,7 @@ def __init__( self.initializer_range = initializer_range self.scope = scope self.backbone_featmap_shape = backbone_featmap_shape + self.attn_implementation = attn_implementation # in ViT hybrid, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) # the number of patches is based on the feature map of the backbone, which by default uses an output stride @@ -122,6 +124,7 @@ def get_config(self): backbone_featmap_shape=self.backbone_featmap_shape, backbone_config=backbone_config, backbone=None, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index 6a77e95102c969..4221d6bfd358af 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -72,6 +72,7 @@ def __init__( num_labels=3, mask_ratio=0.6, scope=None, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -91,6 +92,7 @@ def __init__( self.initializer_range = initializer_range self.mask_ratio = mask_ratio self.scope = scope + self.attn_implementation = attn_implementation # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above # (we add 1 for the [CLS] token) @@ -127,6 +129,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, mask_ratio=self.mask_ratio, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index ffb679d646ffda..6c981adeb82cd9 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -63,8 +63,9 @@ def __init__( type_sequence_label_size=10, initializer_range=0.02, num_labels=3, - mask_ratio=0.6, scope=None, + mask_ratio=0.5, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -84,11 +85,15 @@ def __init__( self.initializer_range = initializer_range self.mask_ratio = mask_ratio self.scope = scope + self.attn_implementation = attn_implementation # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above # (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1))) + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -120,6 +125,7 @@ def get_config(self): decoder_intermediate_size=self.intermediate_size, decoder_num_attention_heads=self.num_attention_heads, decoder_num_hidden_layers=self.num_hidden_layers, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit_msn/test_modeling_vit_msn.py b/tests/models/vit_msn/test_modeling_vit_msn.py index 5fe494c105cb62..be0857181e9e1b 100644 --- a/tests/models/vit_msn/test_modeling_vit_msn.py +++ b/tests/models/vit_msn/test_modeling_vit_msn.py @@ -59,6 +59,7 @@ def __init__( type_sequence_label_size=10, initializer_range=0.02, scope=None, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -77,6 +78,7 @@ def __init__( self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.scope = scope + self.attn_implementation = attn_implementation # in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 @@ -106,6 +108,7 @@ def get_config(self): hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, initializer_range=self.initializer_range, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py index 64a439f27a4e45..9c145388ffdb46 100644 --- a/tests/models/yolos/test_modeling_yolos.py +++ b/tests/models/yolos/test_modeling_yolos.py @@ -62,6 +62,7 @@ def __init__( scope=None, n_targets=8, num_detection_tokens=10, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -83,6 +84,7 @@ def __init__( self.scope = scope self.n_targets = n_targets self.num_detection_tokens = num_detection_tokens + self.attn_implementation = attn_implementation # we set the expected sequence length (which is used in several tests) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size) @@ -123,6 +125,7 @@ def get_config(self): initializer_range=self.initializer_range, num_detection_tokens=self.num_detection_tokens, num_labels=self.num_labels, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 64a8a348ff6b81..d6102dc2ef281f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2788,7 +2788,9 @@ def test_equivalence_flax_to_pt(self): with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) - pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + pt_model_loaded = model_class.from_pretrained( + tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation + ) # send pytorch model to the correct device pt_model_loaded.to(torch_device) @@ -3724,6 +3726,11 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) + # FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors. + # These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask. + # This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code. + # However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it. + deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters is_encoder_decoder = model.config.is_encoder_decoder @@ -3861,6 +3868,27 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): and "output_attentions" in inspect.signature(model_sdpa.forward).parameters ): processed_inputs["output_attentions"] = output_attentions + if not deactivate_mask and ( + "bool_masked_pos" in inspect.signature(model_eager.forward).parameters + ): + dummy_mask = torch.ones((self.model_tester.num_masks,)) + + # In case of additional token (like class) we define a custom `mask_length` + if hasattr(self.model_tester, "mask_length"): + mask_length = self.model_tester.mask_length - dummy_mask.size(0) + else: + mask_length = self.model_tester.seq_length - dummy_mask.size(0) + dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) + + if "noise" in inspect.signature(model_eager.forward).parameters: + np.random.seed(2) + num_patches = int( + (self.model_tester.image_size // self.model_tester.patch_size) ** 2 + ) + noise = np.random.uniform(size=(batch_size, num_patches)) + processed_inputs["noise"] = torch.from_numpy(noise) # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 22d6b241f0048c..c7d098be3ea8f2 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -371,7 +371,9 @@ def test_equivalence_flax_to_pt(self): with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) + pt_model_loaded = pt_model_class.from_pretrained( + tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation + ) # send pytorch model to the correct device pt_model_loaded.to(torch_device) diff --git a/utils/check_support_list.py b/utils/check_support_list.py index f6aaa2bb67dce4..3cb0b616022426 100644 --- a/utils/check_support_list.py +++ b/utils/check_support_list.py @@ -84,7 +84,7 @@ def check_sdpa_support_list(): archs_supporting_sdpa.append(model_name) for arch in archs_supporting_sdpa: - if arch not in doctext: + if arch not in doctext and arch not in doctext.replace("-", "_"): raise ValueError( f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation." )