-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 #27940
Merged
+10
−8
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1244,6 +1244,7 @@ def _autoset_attn_implementation( | |
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. | ||
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). | ||
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) | ||
requested_attn_implementation = None | ||
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: | ||
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: | ||
raise ValueError( | ||
|
@@ -1260,9 +1261,7 @@ def _autoset_attn_implementation( | |
raise ValueError(message + ".") | ||
|
||
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. | ||
hard_check_only = True | ||
else: | ||
hard_check_only = False | ||
requested_attn_implementation = config._attn_implementation_internal | ||
|
||
if use_flash_attention_2: | ||
logger.warning_once( | ||
|
@@ -1275,13 +1274,15 @@ def _autoset_attn_implementation( | |
config, | ||
torch_dtype=torch_dtype, | ||
device_map=device_map, | ||
hard_check_only=hard_check_only, | ||
hard_check_only=False, | ||
check_device_map=check_device_map, | ||
) | ||
elif cls._supports_sdpa or config._attn_implementation == "sdpa": | ||
elif requested_attn_implementation in [None, "sdpa"]: | ||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. | ||
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) | ||
elif not hard_check_only: | ||
config = cls._check_and_enable_sdpa( | ||
config, hard_check_only=False if requested_attn_implementation is None else True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks better thanks |
||
) | ||
else: | ||
config._attn_implementation = "eager" | ||
|
||
return config | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be "default" instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the idea here is to check whether the user passed
attn_implementation="eager"
,attn_implementation="sdpa"
orattn_implementation="sdpa"
explicitly when loading the model fromfrom_pretrained
orfrom_config
.In case
attn_implementation
is explicitly set, we hard error if a dependency is missing (torch>=2.1.1, model does not support SDPA), otherwise we smoothly fall back on eager.