-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flex Attention for Mistral along with refactoring #34845
Conversation
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: | ||
if ("SdpaAttention" in class_name or "SdpaSelfAttention" in class_name) or ( | ||
hasattr(submodule, "_uses_attention_functions") and submodule._uses_attention_functions | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not exactly sure how to handle this correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
if self._attn_implementation != "flash_attention_2": | ||
cache_kwargs["cache_position"] = cache_position |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this escape is required no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this is how it is handled in FlashAttention2
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@ArthurZucker What should be do about these failing tests? I think they are related to the sdpa tests where we might have |
Hi. Let's take a look on failing test(s) step by step. First, do you know why
|
Idefics2 uses mistral as the text model self.text_model = AutoModel.from_config(config.text_config) if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("text_config is None, using default text config")
text_config = CONFIG_MAPPING["mistral"](
max_position_embeddings=4096 * 8,
rms_norm_eps=1e-5,
# None in the original configuration_mistral, we set it to the unk_token_id
pad_token_id=0,
tie_word_embeddings=False,
) |
BTW let's make sure we rebase now that #34896 was merged! |
can you make sure the CIs are green? 🤗 |
Should I reset the default back to eager instead of flex because the eager matches sdpa fails for float32 when using flex. Or do we need to change the thresholds to ensure that flex remains the default while the tests are also green? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! We actually shipped this in #35235 ! 🤗 super sorry for the late notice
Thanks for informing. |
What does this PR do?
Towards #34809
Who can review?
@ArthurZucker