Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F.scaled_dot_product_attention support #26572

Merged
merged 114 commits into from
Dec 8, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Oct 3, 2023

As per title, this PR proposes to support natively torch.nn.functional.scaled_dot_product_attention in transformers. I propose to enable SDPA by default if torch>=2.1.1 (released 15 Nov. 2023), for the reasons written in the PR. The support could then be extended using https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py.


The introduced _unmask_unattended is a workaround for pytorch/pytorch#110213.

It behaves as follow:

If attention_mask is

   [[0, 0, 1]
    [1, 1, 1]
    [0, 1, 1]]

and expanded_mask is (e.g. here left-padding case)

    [[[[0, 0, 0],
       [0, 0, 0],
       [0, 0, 1]]],
     [[[1, 0, 0],
       [1, 1, 0],
       [1, 1, 1]]],
     [[[0, 0, 0],
       [0, 1, 0],
       [0, 1, 1]]]]

then the modified expanded_mask will be

    [[[[1, 1, 1],   <-- modified
       [1, 1, 1],   <-- modified
       [0, 0, 1]]],
     [[[1, 0, 0],
       [1, 1, 0],
       [1, 1, 1]]],
     [[[1, 1, 1],   <-- modified
       [0, 1, 0],
       [0, 1, 1]]]]

Modifying as such the attention mask is fine given that we modify it only for pad tokens on the -2 dimension. Softmax is computed on the -1 dimension, and thus there is no change for the relevant non-padding tokens.

@fxmarty fxmarty changed the title F.scaled_dot_product preliminary support F.scaled_dot_product_attention preliminary support Oct 3, 2023
@fxmarty fxmarty marked this pull request as draft October 3, 2023 17:11
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for starting over the preliminary support, I really like the design of it!

I see there are a bunch of deduplicated code from the vanilla attention, my question would be why not add everything inside the forward pass of the vanilla attention? The only problem is that it might make the attention code quite hard to read.

On the other hand, in case we go for a standalone LlamaSDPAAttention module, it might make the modeling file of all models more bloated (FA-2, SDPA, ..). @ydshieh suggested offline that we could offload those modules in a new file to make the modeling file cleaner and nearly untouched.
I would personally advocate to add the SDPA support directly inside xxxAttention as the changes relative to it is only ~20 LoC, it would be surprising for users to see that the xxxAttention modules has suddenly changed to xxxSDPAAttention by just upgrading transformers with no other intervention.
I would like to hear opinions from others @LysandreJik @ArthurZucker @patrickvonplaten on this matter, and I will be happy to help you extend this PR on other archs and adding relevant tests

@fxmarty
Copy link
Contributor Author

fxmarty commented Oct 31, 2023

As #26792 was merged will get back to it this week, targeting next to next transformers release.

@fxmarty fxmarty changed the title F.scaled_dot_product_attention preliminary support F.scaled_dot_product_attention support Oct 31, 2023
@fxmarty fxmarty force-pushed the torch-sdpa-preliminary-support branch from 7bb4857 to dd646c1 Compare October 31, 2023 14:21
@fxmarty
Copy link
Contributor Author

fxmarty commented Dec 8, 2023

It is ready. Here is a summary of the relevant CI.

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/llama -s -vvvvv

Flacky (new):

FAILED tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_eager_matches_sdpa_inference_1_bfloat16 - AssertionError: False is not true : padding_side=left, use_mask=False, batch_size=5, enable_kernels=True: mean relative difference: 8.728e-03, torch ato...

Already failing on main:

FAILED tests/models/llama/test_modeling_llama.py::CodeLlamaIntegrationTest::test_model_7b_logits - AssertionError: Lists differ: ['<s>▁<PRE> def remove_non_ascii(s: str) -> st[893 chars]ID>'] != ['<s> <PRE> def remove_non_ascii(s: str) -> st[893 chars...
FAILED tests/models/llama/test_tokenization_llama.py::LlamaIntegrationTest::test_conversion - AssertionError: '{\n [964 chars]or": {\n    "type": "TemplateProcessing",\n   [1795198 chars]}\n}' != '{\n [964 chars]or": null,\n  "decoder": {\n    "t...

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/whisper -s -vvvvv

Already failing on main:

FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_large_generation_multilingual - FileNotFoundError: https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch - assert [' While Porashaggy sits there, a cooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room besi...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch_hard - assert " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany ch...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_single_batch - assert [" Because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, all poor ashaggy sits there, acco...

Flacky (on main):

FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_generate_left_padding - AssertionError: False is not true
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference - AssertionError: assert False
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference_padding_right - AssertionError: assert False

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/bart -s -vvvvv

Flacky on main:

AILED tests/models/bart/test_modeling_bart.py::BartStandaloneDecoderModelTest::test_cpu_offload - AssertionError: False is not true

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/falcon -s -vvvvv

Flacky (new):

FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_eager_matches_sdpa_inference_1_bfloat16 - AssertionError: False is not true : padding_side=left, use_mask=True, batch_size=5, enable_kernels=True: mean relative difference: 7.141e-03, torch atol...

Flacky (on main):

FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/idefics -s -vvvvv

all pass

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/bert -s -vvvvv

all pass

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/gpt2 -s -vvvvv

all pass

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/test_modeling_utils.py -s -vvvvv

Already failing on main:

FAILED tests/test_modeling_utils.py::ModelUtilsTest::test_legacy_load_from_url - huggingface_hub.utils._validators.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': 'https://huggingface.co/hf-intern...
FAILED tests/test_modeling_utils.py::ModelUtilsTest::test_load_from_one_file - huggingface_hub.utils._validators.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/tmp/tmp64wrpwyf'. Use `repo_typ...
FAILED tests/test_modeling_utils.py::ModelUtilsTest::test_model_from_pretrained - AssertionError: 7 != 8
ERROR tests/test_modeling_utils.py::ModelOnTheFlyConversionTester::test_safetensors_on_the_fly_conversion - ValueError: Cannot run tests as secret isn't setup.
ERROR tests/test_modeling_utils.py::ModelOnTheFlyConversionTester::test_safetensors_on_the_fly_conversion_gated - ValueError: Cannot run tests as secret isn't setup.
ERROR tests/test_modeling_utils.py::ModelOnTheFlyConversionTester::test_safetensors_on_the_fly_conversion_private - ValueError: Cannot run tests as secret isn't setup.
ERROR tests/test_modeling_utils.py::ModelOnTheFlyConversionTester::test_safetensors_on_the_fly_sharded_conversion - ValueError: Cannot run tests as secret isn't setup.
ERROR tests/test_modeling_utils.py::ModelOnTheFlyConversionTester::test_safetensors_on_the_fly_sharded_conversion_gated - ValueError: Cannot run tests as secret isn't setup.
ERROR tests/test_modeling_utils.py::ModelOnTheFlyConversionTester::test_safetensors_on_the_fly_sharded_conversion_private - ValueError: Cannot run tests as secret isn't setup.
ERROR tests/test_modeling_utils.py::ModelOnTheFlyConversionTester::test_safetensors_on_the_fly_specific_revision - ValueError: Cannot run tests as secret isn't setup.
ERROR tests/test_modeling_utils.py::ModelOnTheFlyConversionTester::test_safetensors_on_the_fly_wrong_user_opened_pr - ValueError: Cannot run tests as secret isn't setup.

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/ -s -vvvvv -k "flash or sdpa"

Flacky (new):

FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_eager_matches_sdpa_inference_1_bfloat16 - AssertionError: False is not true : padding_side=left, use_mask=False, batch_size=1, enable_kernels=True: mean relative difference: 7.660e-03, torch ato...

Already failing/flacky on main:

FAILED tests/models/bark/test_modeling_bark.py::BarkSemanticModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.bark.configuration_bark.BarkSemanticConfig'> for this kind of AutoModel: AutoMo...
FAILED tests/models/bark/test_modeling_bark.py::BarkSemanticModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/bark/test_modeling_bark.py::BarkCoarseModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.bark.configuration_bark.BarkCoarseConfig'> for this kind of AutoModel: AutoMode...
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_flash_attn_2_inference_padding_right - AssertionError: False is not true
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/gpt_neo/test_modeling_gpt_neo.py::GPTNeoModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/opt/test_modeling_opt.py::OPTModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference - AssertionError: assert False
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference_padding_right - AssertionError: assert False

if use_flash_attention_2:
config = cls._check_and_enable_flash_attn_2(config, torch_dtype)
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
config._attn_implementation = kwargs.pop("attn_implementation", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This overrides existing _attn_implementation value inside config. And sets it to None when attn_implementation is not passed in kwargs ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intended, users should not use _attn_implementation. Is there a case where you have no choice but to use it?

Copy link
Contributor

@BowenBao BowenBao Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fxmarty thanks for reply. For context it is the same background as this PR #28823 where I tried to unblock export in our benchmark pipeline.

I guess we misunderstood the error message, and tried to pass attn_implementation="eager" to config constructor instead of from_config call.

Regarding your comment though I'm not sure if that is the right behavior. attn_implementation is indeed documented in PretrainedConfig, and it is not respected if called in this way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenBao you are right, it is an issue in the documentation. This should not be exposed in the config.

from transformers import AutoModelForCausalLM, AutoConfig, LlamaForCausalLM

cfg = AutoConfig.from_pretrained("fxmarty/tiny-llama-fast-tokenizer")

cfg._attn_implementation = "eager"

model = LlamaForCausalLM(cfg)

works. It is true that there is no API exposed to the user for initializing with XxxForCausalLM(cfg) and selecting the attention implementation, apart from using this private attribute.

Any of:

model = AutoModel.from_config(cfg, attn_implementation="eager")
model = LlamaModel.from_pretrained("xxx", attn_implementation="eager")

work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants