Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention 2 support to Musicgen and Musicgen Melody #29939

Merged
merged 16 commits into from
Apr 2, 2024

Conversation

ylacombe
Copy link
Contributor

@ylacombe ylacombe commented Mar 28, 2024

What does this PR do?

Supersedes #27924

The attention tests all pass but there are no integration equivalence between the original attention models and the FA ones. I don't hear any difference in quality despite not being the same song, though.

cc @sanchit-gandhi and @amyeroberts, could you review please?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

@@ -239,3 +239,20 @@ def from_sub_models_config(
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate

@property
def _attn_implementation(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is one-to-one the same as in the PreTrainedConfig class:

def _attn_implementation(self):

Can we remove it from here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not if we want to keep the setter part!


MUSICGEN_ATTENTION_CLASSES = {
"eager": MusicgenAttention,
"flash_attention_2": MusicgenFlashAttention2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding sdpa in one go as well? Would enable you to showcase attention implementation through sdpa on free tier Colab T4 GPU (where FA2 is not available)

@@ -254,3 +252,20 @@ def from_sub_models_config(
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate

@property
def _attn_implementation(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

else outputs_fa.decoder_hidden_states[-1]
)

assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good enough for a generative audio model with FA2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've copied out the same tolerance threshold than any other models (regardless of modality) btw

@ylacombe
Copy link
Contributor Author

I've also added SDPA!

cc @amyeroberts or @ArthurZucker could you review when you have time?

@ylacombe ylacombe requested a review from ArthurZucker March 29, 2024 13:08
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Tests are ... huge, would be nice if you can use copied from, would help the review 😅

Comment on lines 1029 to 1030
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's only save self._attn_implementation please


return attn_output, None, past_key_value


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from can be used here as well!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouf! Thanks for the big PR and adding those tests!

@ylacombe ylacombe merged commit 0d04b1e into huggingface:main Apr 2, 2024
21 checks passed
ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* add FA2 to o.g Musicgen

* make style

* add FA2 support to Musicgen Melody

* add generation FA2 tests to o.g Musicgen

* make style and fix copies

* add Musicgen to FA2 docs + deprecate list

* add sdpa supports to Musicgen's

* make style and fix copies

* refactor attention implementation arguments

* add Copied from to sdpa tests

* add copied form in sdpa tests melody

* add copied for FA2 generation tests

* add FA2 inference copied from

* make style
itazap pushed a commit that referenced this pull request May 14, 2024
* add FA2 to o.g Musicgen

* make style

* add FA2 support to Musicgen Melody

* add generation FA2 tests to o.g Musicgen

* make style and fix copies

* add Musicgen to FA2 docs + deprecate list

* add sdpa supports to Musicgen's

* make style and fix copies

* refactor attention implementation arguments

* add Copied from to sdpa tests

* add copied form in sdpa tests melody

* add copied for FA2 generation tests

* add FA2 inference copied from

* make style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants