-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 #27940
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, looks good, would like @ArthurZucker to take a quick look before merging. Will cherry-pick this for the release.
@@ -1244,6 +1244,7 @@ def _autoset_attn_implementation( | |||
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. | |||
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). | |||
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) | |||
requested_attn_implementation = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be "default" instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the idea here is to check whether the user passed attn_implementation="eager"
, attn_implementation="sdpa"
or attn_implementation="sdpa"
explicitly when loading the model from from_pretrained
or from_config
.
In case attn_implementation
is explicitly set, we hard error if a dependency is missing (torch>=2.1.1, model does not support SDPA), otherwise we smoothly fall back on eager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) | ||
elif not hard_check_only: | ||
config = cls._check_and_enable_sdpa( | ||
config, hard_check_only=False if requested_attn_implementation is None else True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks better thanks
As per title.
On torch==2.0.1, these do pass
On torch==2.1.1, these do pass (#26572 (comment))
There was a bug where even though we manually request
attn_implementation="eager"
, we would still go into the SDPA controlflow and hard check that the requirements are fine. Which is not what we want.