Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flex Attention for Mistral along with refactoring #34845

Closed
wants to merge 9 commits into from

Conversation

OmarManzoor
Copy link
Contributor

What does this PR do?

Towards #34809

  • Adds Flex Attention for Mistral
  • Does refactoring to enable the attention mechanisms using functions instead of classes

Who can review?

@ArthurZucker

if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if ("SdpaAttention" in class_name or "SdpaSelfAttention" in class_name) or (
hasattr(submodule, "_uses_attention_functions") and submodule._uses_attention_functions
):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not exactly sure how to handle this correctly.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Comment on lines +369 to +370
if self._attn_implementation != "flash_attention_2":
cache_kwargs["cache_position"] = cache_position
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this escape is required no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is how it is handled in FlashAttention2

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
@OmarManzoor
Copy link
Contributor Author

@ArthurZucker What should be do about these failing tests? I think they are related to the sdpa tests where we might have output_attentions equal to True.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 27, 2024

Hi. Let's take a look on failing test(s) step by step.

First, do you know why Idefics2 will have something MistralAttention? It's very strange. (see the test_torch job log)

FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ForConditionalGenerationModelTest::test_retain_grad_hidden_states_attentions - AttributeError: 'MistralAttention' object has no attribute 'scaling'

@OmarManzoor
Copy link
Contributor Author

Hi. Let's take a look on failing test(s) step by step.

First, do you know why Idefics2 will have something MistralAttention? It's very strange. (see the test_torch job log)

FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ForConditionalGenerationModelTest::test_retain_grad_hidden_states_attentions - AttributeError: 'MistralAttention' object has no attribute 'scaling'

Idefics2 uses mistral as the text model

self.text_model = AutoModel.from_config(config.text_config)
        if isinstance(text_config, dict):
            text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral"
            text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
        elif text_config is None:
            logger.info("text_config is None, using default text config")
            text_config = CONFIG_MAPPING["mistral"](
                max_position_embeddings=4096 * 8,
                rms_norm_eps=1e-5,
                # None in the original configuration_mistral, we set it to the unk_token_id
                pad_token_id=0,
                tie_word_embeddings=False,
            )

@ArthurZucker
Copy link
Collaborator

BTW let's make sure we rebase now that #34896 was merged!

@ArthurZucker
Copy link
Collaborator

can you make sure the CIs are green? 🤗

@OmarManzoor
Copy link
Contributor Author

can you make sure the CIs are green? 🤗

Should I reset the default back to eager instead of flex because the eager matches sdpa fails for float32 when using flex. Or do we need to change the thresholds to ensure that flex remains the default while the tests are also green?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! We actually shipped this in #35235 ! 🤗 super sorry for the late notice

@OmarManzoor
Copy link
Contributor Author

Hey! We actually shipped this in #35235 ! 🤗 super sorry for the late notice

Thanks for informing.

@OmarManzoor OmarManzoor deleted the mistral_flex branch December 23, 2024 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants