Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FA2 and sdpa support for SigLIP #31499

Merged
merged 35 commits into from
Jul 8, 2024
Merged

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Jun 19, 2024

What does this PR do?

Add flash attention 2 and sdpa (torch.nn.functional.scaled_dot_product_attention) attention implementations for SigLIP model.

Fixes #31138

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qubvel qubvel marked this pull request as draft June 19, 2024 16:56
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts
Copy link
Collaborator

Nice! We can probably combine this with CLIP cc @sayakpaul

For reference, there's a FA2 siglip implementation for IDEFICS2, but I'm not sure how much testing was done for the equivalence between the eager and FA2 classes

@qubvel
Copy link
Member Author

qubvel commented Jun 20, 2024

@amyeroberts there are some discrepancies with the attention mask, I am digging deeper to the equivalence testing

@qubvel qubvel changed the title Add FA2 support for SigLIP Add FA2 and sdpa support for SigLIP Jun 20, 2024
@qubvel qubvel force-pushed the siglip-fa2-support branch from c593ac3 to a12367b Compare June 24, 2024 13:37
@qubvel
Copy link
Member Author

qubvel commented Jun 24, 2024

Running test locally

Flash Attention

RUN_SLOW=1 python -m pytest --verbose -m flash_attn_test \
    tests/models/siglip/test_modeling_siglip.py
Screenshot 2024-06-24 at 15 31 23

SDPA

RUN_SLOW=1 python -m pytest --verbose \
    tests/models/siglip/test_modeling_siglip.py -k "sdpa"
Screenshot 2024-06-24 at 15 31 06

@qubvel qubvel marked this pull request as ready for review June 24, 2024 14:55
@qubvel
Copy link
Member Author

qubvel commented Jun 24, 2024

@amyeroberts @molbap please review if you have time

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

  • Main comment is about the attention selection, we should instead be instantiating the model components with from_config (or possibly _from_config?) and passing in `attn_implementation=config.attn_implementation).
  • Could we extend this to CLIP and add both at the same time?
  • There should be a section added to the model doc page + benchmarks showing some times for improvements for SDPA and FA2 e.g. like here for mistral.

@@ -543,6 +825,33 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

@classmethod
def _autoset_attn_implementation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c.f. #31203 (comment)

I let this slip in for IDEFICS2, but it never should have been included

tests/models/siglip/test_modeling_siglip.py Outdated Show resolved Hide resolved
tests/models/siglip/test_modeling_siglip.py Outdated Show resolved Hide resolved
tests/models/siglip/test_modeling_siglip.py Outdated Show resolved Hide resolved
tests/models/siglip/test_modeling_siglip.py Outdated Show resolved Hide resolved
@@ -55,6 +67,178 @@
from transformers import SiglipProcessor


class SiglipModelTesterMixin(ModelTesterMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ydshieh for comments / opinions on this mixin structure within the model's testing file

Copy link
Collaborator

@ydshieh ydshieh Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, it's just to overwrite test_eager_matches_sdpa_inference (which is a large block) in ModelTesterMixin).

So we don't really need to have this new class. However, there are 3 or more model test classes in this file. It's nice to have SiglipModelTesterMixin and just overwrite with mini block like

    def test_eager_matches_sdpa_inference(self, torch_dtype: str):
        super().test_eager_matches_sdpa_inference(
            torch_dtype=torch_dtype,
            logit_keys=("pooler_output", "last_hidden_state"),
            use_attnetion_mask_options=(False,),
        )

@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it flaky for siglip? Do we know why? I know we have this decorator for the common tests, but not for the model-specific implementations

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice it's flaky for Siglip, however, I don't know how much is it hardware, cuda/fa2 version specific. So, decided to make it also flaky as in the initial common implementation. I will remove it to make it consistent with other model-specific tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh ydshieh self-assigned this Jun 24, 2024
@qubvel
Copy link
Member Author

qubvel commented Jun 25, 2024

@amyeroberts

Main comment is about the attention selection, we should instead be instantiating the model components with from_config (or possibly _from_config?) and passing in `attn_implementation=config.attn_implementation).

I tried to make it with _from_config, however, internal model components are not inherited from PretrainedModel, they are just pure nn.Module and don't have _from_config method. I changed the inheritance, but that led to other issues: models have to be included in docs and some fields have to be specified.

Not sure we can change internal model components too, for example, change SiglipVisionTransformer(nn.Module) to SiglipVisionModel(SiglipPreTrainedModel). This will lead to incompatible checkpoints.

See implementation 669c537

Do you have any thoughts on that?

P.S. I am looking at #30390 with similar questions 👀

Could we extend this to CLIP and add both at the same time?

I hope we can, I will check this!

P.S. Given the work made in #30390, probably will be better to continue work on this, rather than merge both PRs

There should be a section added to the model doc page + benchmarks showing some times for improvements for SDPA and FA2 e.g. like here for mistral.

Addressed in d41955d

speedup_plot

@sayakpaul
Copy link
Member

Regarding the CLIP part: #30390. Feel free to cherry-pick commits if that help. I got stuck in some FLAX tests that I never got time to get resolved.

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 25, 2024

I tried to make it with _from_config, however, internal model components are not inherited from PretrainedModel, they are just pure nn.Module and don't have _from_config method

Hi, regarding this part, even if an internal component is just nn.Module, we can still passing config to its __init__, just like what is done in GemmaDecoderLayer. Hope this could help.

class GemmaDecoderLayer(nn.Module):
    def __init__(self, config: GemmaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)


atols = {
("cpu", False, torch.float32): 1e-5,
("cpu", False, torch.bfloat16): 3e-2,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maximum diff for bfloat16 is increased 1e-2 -> 3e-2 compared to the common model test.

@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice it's flaky for Siglip, however, I don't know how much is it hardware, cuda/fa2 version specific. So, decided to make it also flaky as in the initial common implementation. I will remove it to make it consistent with other model-specific tests.

@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel qubvel force-pushed the siglip-fa2-support branch from baf5b7b to 23457a2 Compare June 26, 2024 10:07
@qubvel
Copy link
Member Author

qubvel commented Jun 26, 2024

I changed the attention implementation propagation as follows (commit 23457a2):

  1. Initialize the *Model class instead of the *Transformer module to utilize the _from_config method.
  2. Use only the submodule of the *Model to maintain the overall model structure and weights for backward compatibility.
# First, initialize the text and vision models with proper attention implementation
text_model = SiglipTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
vision_model = SiglipVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)

# Second, get the text and vision submodules (for backward compatibility)
self.text_model = text_model.text_model
self.vision_model = vision_model.vision_model

With this approach, the underlying modules will exhibit the same behavior for the attn_implementation setting.

The disadvantage of this method is that the post_init() method is called twice: once for each *Model and again for the parent model.

@amyeroberts please let me know what you think.

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a small review, then off so posting it: mostly typos/minor suggestions :)

tests/models/siglip/test_modeling_siglip.py Outdated Show resolved Hide resolved
tests/models/siglip/test_modeling_siglip.py Show resolved Hide resolved
tests/models/siglip/test_modeling_siglip.py Outdated Show resolved Hide resolved
tests/models/siglip/test_modeling_siglip.py Outdated Show resolved Hide resolved
tests/models/siglip/test_modeling_siglip.py Show resolved Hide resolved
src/transformers/models/siglip/modeling_siglip.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 🔥 🔥 🔥 🔥 🔥

Thanks for adding and iterating on this! Re the double call to post_init, it's not ideal, but I think should be fine as on the second call all the layers should be marked as initialized

@@ -786,7 +1069,7 @@ def forward(

# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
if attention_mask is not None:
if attention_mask is not None and not self._use_flash_attention_2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need special attention_mask preparation for the SDPA case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need it. I additionally tested with the following code and with the same attention mask preparation outputs matched for both implementations eager and sdpa

import torch

from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask

from transformers import SiglipConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention, SiglipSdpaAttention

torch.manual_seed(235093093)

dtype = torch.float16
device = "cuda"

# Configure
config = SiglipConfig()

hidden_size = 6
num_attention_heads = 1
seq_len = 5
batch_size = 1

config.vision_config.hidden_size = hidden_size
config.vision_config.num_attention_heads = num_attention_heads

# Eager attention
attention = SiglipAttention(config.vision_config)
attention = attention.to(dtype).to(device)

# SDPA attention
attention_sdpa = SiglipSdpaAttention(config.vision_config)
attention_sdpa.load_state_dict(attention.state_dict())
attention_sdpa = attention_sdpa.to(dtype).to(device)

# Prepare inputs
dummy_input = torch.rand(
    [batch_size, seq_len, hidden_size], dtype=dtype, device=device,
)
dummy_attention_mask = torch.ones(
    [batch_size, seq_len], dtype=dtype, device=device,
)

# padding
dummy_attention_mask[:1, -2:] = 0
print("Dummy attention mask:\n", dummy_attention_mask)

# Prepare attention mask
dummy_attention_mask_eager = _prepare_4d_attention_mask(
    dummy_attention_mask, dummy_input.dtype
)  # 1, 1, 512, 512  -> batch_size, 1, seq_len, seq_len

# the same for SDPA
dummy_attention_mask_sdpa = dummy_attention_mask_eager

with torch.no_grad():
    attn_output, attn_weights = attention(dummy_input, dummy_attention_mask_eager)
    attn_output_sdpa, attn_weights_sdpa = attention_sdpa(dummy_input, dummy_attention_mask_sdpa)

print("\nEager:\n", attn_output)
print("\nSDPA:\n", attn_output_sdpa)
print("\nDiff:\n", attn_output - attn_output_sdpa)

diff_with_sdpa = torch.abs(attn_output - attn_output_sdpa).max()
print("\nDiff with SDPA:", diff_with_sdpa)
Dummy attention mask:
 tensor([[1., 1., 1., 0., 0.]], device='cuda:0', dtype=torch.float16)

Eager:
 tensor([[[ 0.0267,  0.3291,  0.3442,  0.6152, -0.1914,  0.1541],
         [ 0.0174,  0.3289,  0.3398,  0.6138, -0.1968,  0.1560],
         [ 0.0283,  0.3301,  0.3447,  0.6162, -0.1914,  0.1542],
         [ 0.0176,  0.3289,  0.3401,  0.6138, -0.1968,  0.1559],
         [ 0.0228,  0.3293,  0.3423,  0.6147, -0.1941,  0.1550]]],
       device='cuda:0', dtype=torch.float16)

SDPA:
 tensor([[[ 0.0267,  0.3291,  0.3442,  0.6152, -0.1914,  0.1541],
         [ 0.0174,  0.3289,  0.3398,  0.6138, -0.1968,  0.1560],
         [ 0.0283,  0.3301,  0.3447,  0.6162, -0.1914,  0.1542],
         [ 0.0176,  0.3289,  0.3401,  0.6138, -0.1968,  0.1559],
         [ 0.0228,  0.3293,  0.3423,  0.6147, -0.1941,  0.1550]]],
       device='cuda:0', dtype=torch.float16)

Diff:
 tensor([[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)

Diff with SDPA: tensor(0., device='cuda:0', dtype=torch.float16)


## Expected speedups

Below is an expected speedup diagram that compares inference time between the native implementation in transformers using `google/siglip-so400m-patch14-384` checkpoint in `float16` precision and the Flash Attention 2 / SDPA version of the model using different batch sizes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@qubvel qubvel merged commit a177821 into huggingface:main Jul 8, 2024
23 checks passed
@lucasjinreal
Copy link

Thanks for the work!

Which latest version on pypi would support this feature?

@qubvel
Copy link
Member Author

qubvel commented Jul 8, 2024

Hi @lucasjinreal, it's going to be included in the next release, most probably 4.43.0. You can try it now by installing transformers from the source

pip install -U git+https://github.com/huggingface/transformers.git

@lucasjinreal
Copy link

I will waiting for the 4.43, my server unable to access network, any estimate time will 4.43 out?

@amyeroberts
Copy link
Collaborator

amyeroberts commented Jul 8, 2024

@lucasjinreal Just to confirm, you can't access github from your server? If it's just relating to internet, you would also need access for installing the latest release from pypi.

We typically release on a monthly schedule. You can see the list of releases here. The next minor release will probably be in 2-3 weeks.

@lucasjinreal
Copy link

thanks, will try update weeks later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add siglip flashattention support?
7 participants