diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 819fe5fcf288be..dd2ed9d695e73b 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -236,8 +236,6 @@ class PretrainedConfig(PushToHubMixin): This attribute is currently not being used during model loading time, but this may change in the future versions. But we can already start preparing for the future by saving the dtype with save_pretrained. - attn_implementation (`str`, *optional*): - The attention implementation to use in the model. Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. > TensorFlow specific parameters diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b3102a37d37f31..38dde4ec91e267 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2696,6 +2696,8 @@ def from_pretrained( [pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information. + attn_implementation (`str`, *optional*): + The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. > Parameters for big model inference diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index ce7884d2ef120e..98c0e851bcc22d 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -58,6 +58,8 @@ The model class to instantiate is selected based on the configuration class: List options + attn_implementation (`str`, *optional*): + The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. Examples: