Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FA2 tests #29909

Merged
merged 2 commits into from
Apr 1, 2024
Merged

Fix FA2 tests #29909

merged 2 commits into from
Apr 1, 2024

Conversation

ylacombe
Copy link
Contributor

@ylacombe ylacombe commented Mar 27, 2024

What does this PR do?

#26572 introduced an artifact that avoid properly testing inference with Flash Attention 2, the model supposed to be loaded without Flash Attention 2 (as a reference to compare) was in fact using Flash Attention 2!

cc @fxmarty @ArthurZucker @amyeroberts

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ylacombe ylacombe requested a review from ArthurZucker March 28, 2024 07:02
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AH. That's a great catch. Thanks for it!

model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's update the name to test_flash_attn_2_inference_equivalence or something like that!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do!

On a side note, how to make sure that every model using FA2 still passes ? The tests are slow, so I'm not actually sure the CI is totally green ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to run the tests manually. You can select just the flash attention tests by doing something like:

RUN_SLOW=1 pytest tests/models -k "flash_attn" on a GPU setup

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot - thanks for fixing!

model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to run the tests manually. You can select just the flash attention tests by doing something like:

RUN_SLOW=1 pytest tests/models -k "flash_attn" on a GPU setup

@ylacombe
Copy link
Contributor Author

I've ran RUN_SLOW=1 pytest tests/models -k "flash_attn" as requested and got the following results. In particular, inference tests from QWen, Whisper and StableLM failed!

I'll open an issue to keep trace of the different failures. Should I still merge the PR in the meantime?

FAILED tests/models/bark/test_modeling_bark.py::BarkSemanticModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.bark.configuration_bark.BarkSemanticConfig'> for this kind of AutoModel: AutoModelForCausalLM.
FAILED tests/models/bark/test_modeling_bark.py::BarkCoarseModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.bark.configuration_bark.BarkCoarseConfig'> for this kind of AutoModel: AutoModelForCausalLM.
FAILED tests/models/bark/test_modeling_bark.py::BarkCoarseModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/gemma/test_modeling_gemma.py::GemmaModelTest::test_flash_attn_2_generate_padding_right - AssertionError: ValueError not raised
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeMHAModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/gpt_neo/test_modeling_gpt_neo.py::GPTNeoModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_flash_attn_2_inference_equivalence - AssertionError: assert False
FAILED tests/models/stablelm/test_modeling_stablelm.py::StableLmModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelTest::test_flash_attn_2_inference_equivalence_right_padding - AssertionError: assert False
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference_equivalence - AssertionError: assert False
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference_equivalence_right_padding - AssertionError: assert False

@amyeroberts
Copy link
Collaborator

@ylacombe Thanks for running and sharing the results! Merging depends on whether the same tests are failing on main, if they are, then merging is fine; if not, the tests will need to be fixed :)

@ylacombe
Copy link
Contributor Author

Testing this right now then !

@ylacombe
Copy link
Contributor Author

Well, the same tests fail except qwen2 and stablelm that are introduced by this PR, but this makes sense since the FA2 tests were'nt actually testing FA2

@ArthurZucker
Copy link
Collaborator

Feel free to mege!
FYI @ydshieh more failing tests incoming I am afraid 😨

@ylacombe ylacombe merged commit 569f6c7 into huggingface:main Apr 1, 2024
18 checks passed
@ylacombe ylacombe deleted the fix-fa2-tests branch April 1, 2024 08:20
@ArthurZucker ArthurZucker mentioned this pull request Apr 1, 2024
@ydshieh
Copy link
Collaborator

ydshieh commented Apr 2, 2024

😨

😨😨😨😨😨

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 2, 2024

@ylacombe

Thanks a lot ❤️ for the fix and great catch!

One nit: It would be really nice 🙏 if you can mention, in the PR description, a bit why the previous testing is done improperly. Something as simple as

the model supposed to be loaded without attn_implementation="flash_attention_2" (as a reference to compare) was using attn_implementation="flash_attention_2"

This way, it's super clear what the PR is doing even before diving into the changes.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 2, 2024

afaik many FA2 tests were already failing (they are not in the CI) due to diffs in logits

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 2, 2024

afaik many FA2 tests were already failing (they are not in the CI) due to diffs in logits

@fxmarty I think we or you (?) have run those tests before merging. Do you know why we have many failing FA2 tests? Or those many failing tests are only for newly added (many) models ..?

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 2, 2024

Oh, they are not run on T4 GPUs.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 2, 2024

@ydshieh When I used to run these tests locally (some months ago), it was because the diff tolerance was too low between eager/fa2. Some models (as whisper) somehow require a large diff tolerance

ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* fix FA2 tests

* refactor inference test name
itazap pushed a commit that referenced this pull request May 14, 2024
* fix FA2 tests

* refactor inference test name
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants