From 820290cbb5798e34b424c8f68c23dd2e57d3e1c7 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 11 Dec 2023 09:25:44 +0000 Subject: [PATCH] fix sdpa dispatch --- src/transformers/modeling_utils.py | 15 ++++++++------- tests/test_modeling_common.py | 3 ++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2588893d2575be..51fb21987f487a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1244,6 +1244,7 @@ def _autoset_attn_implementation( # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + requested_attn_implementation = None if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: raise ValueError( @@ -1260,9 +1261,7 @@ def _autoset_attn_implementation( raise ValueError(message + ".") # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. - hard_check_only = True - else: - hard_check_only = False + requested_attn_implementation = config._attn_implementation_internal if use_flash_attention_2: logger.warning_once( @@ -1275,13 +1274,15 @@ def _autoset_attn_implementation( config, torch_dtype=torch_dtype, device_map=device_map, - hard_check_only=hard_check_only, + hard_check_only=False, check_device_map=check_device_map, ) - elif cls._supports_sdpa or config._attn_implementation == "sdpa": + elif requested_attn_implementation in [None, "sdpa"]: # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. - config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) - elif not hard_check_only: + config = cls._check_and_enable_sdpa( + config, hard_check_only=False if requested_attn_implementation is None else True + ) + else: config._attn_implementation = "eager" return config diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f0e6c0f1fce37f..8293829b009d04 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -83,6 +83,7 @@ is_flax_available, is_tf_available, is_torch_fx_available, + is_torch_sdpa_available, ) from transformers.utils.generic import ModelOutput @@ -778,7 +779,7 @@ def _create_and_check_torchscript(self, config, inputs_dict): configs_no_init.torchscript = True for model_class in self.all_model_classes: for attn_implementation in ["eager", "sdpa"]: - if attn_implementation == "sdpa" and not model_class._supports_sdpa: + if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()): continue configs_no_init._attn_implementation = attn_implementation