-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FA2 and sdpa
support for SigLIP
#31499
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Nice! We can probably combine this with CLIP cc @sayakpaul For reference, there's a FA2 siglip implementation for IDEFICS2, but I'm not sure how much testing was done for the equivalence between the eager and FA2 classes |
@amyeroberts there are some discrepancies with the attention mask, I am digging deeper to the equivalence testing |
sdpa
support for SigLIP
c593ac3
to
a12367b
Compare
@amyeroberts @molbap please review if you have time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
- Main comment is about the attention selection, we should instead be instantiating the model components with
from_config
(or possibly_from_config
?) and passing in `attn_implementation=config.attn_implementation). - Could we extend this to CLIP and add both at the same time?
- There should be a section added to the model doc page + benchmarks showing some times for improvements for SDPA and FA2 e.g. like here for mistral.
@@ -543,6 +825,33 @@ def _init_weights(self, module): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
|
|||
@classmethod | |||
def _autoset_attn_implementation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c.f. #31203 (comment)
I let this slip in for IDEFICS2, but it never should have been included
@@ -55,6 +67,178 @@ | |||
from transformers import SiglipProcessor | |||
|
|||
|
|||
class SiglipModelTesterMixin(ModelTesterMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ydshieh for comments / opinions on this mixin structure within the model's testing file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, it's just to overwrite test_eager_matches_sdpa_inference
(which is a large block) in ModelTesterMixin)
.
So we don't really need to have this new class. However, there are 3 or more model test classes in this file. It's nice to have SiglipModelTesterMixin
and just overwrite with mini block like
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("pooler_output", "last_hidden_state"),
use_attnetion_mask_options=(False,),
)
@require_torch_gpu | ||
@mark.flash_attn_test | ||
@slow | ||
@is_flaky() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it flaky for siglip? Do we know why? I know we have this decorator for the common tests, but not for the model-specific implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't notice it's flaky for Siglip, however, I don't know how much is it hardware, cuda/fa2 version specific. So, decided to make it also flaky as in the initial common implementation. I will remove it to make it consistent with other model-specific tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to make it with Not sure we can change internal model components too, for example, change See implementation 669c537
P.S. I am looking at #30390 with similar questions 👀
I hope we can, I will check this! P.S. Given the work made in #30390, probably will be better to continue work on this, rather than merge both PRs
Addressed in d41955d |
Regarding the CLIP part: #30390. Feel free to cherry-pick commits if that help. I got stuck in some FLAX tests that I never got time to get resolved. |
Hi, regarding this part, even if an internal component is just class GemmaDecoderLayer(nn.Module):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
|
||
atols = { | ||
("cpu", False, torch.float32): 1e-5, | ||
("cpu", False, torch.bfloat16): 3e-2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maximum diff for bfloat16
is increased 1e-2
-> 3e-2
compared to the common model test.
@require_torch_gpu | ||
@mark.flash_attn_test | ||
@slow | ||
@is_flaky() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't notice it's flaky for Siglip, however, I don't know how much is it hardware, cuda/fa2 version specific. So, decided to make it also flaky as in the initial common implementation. I will remove it to make it consistent with other model-specific tests.
@require_torch_gpu | ||
@mark.flash_attn_test | ||
@slow | ||
@is_flaky() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
baf5b7b
to
23457a2
Compare
I changed the attention implementation propagation as follows (commit 23457a2):
# First, initialize the text and vision models with proper attention implementation
text_model = SiglipTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
vision_model = SiglipVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
# Second, get the text and vision submodules (for backward compatibility)
self.text_model = text_model.text_model
self.vision_model = vision_model.vision_model With this approach, the underlying modules will exhibit the same behavior for the The disadvantage of this method is that the @amyeroberts please let me know what you think. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a small review, then off so posting it: mostly typos/minor suggestions :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥 🔥 🔥 🔥 🔥 🔥
Thanks for adding and iterating on this! Re the double call to post_init
, it's not ideal, but I think should be fine as on the second call all the layers should be marked as initialized
@@ -786,7 +1069,7 @@ def forward( | |||
|
|||
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. | |||
# expand attention_mask | |||
if attention_mask is not None: | |||
if attention_mask is not None and not self._use_flash_attention_2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need special attention_mask preparation for the SDPA case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we don't need it. I additionally tested with the following code and with the same attention mask preparation outputs matched for both implementations eager
and sdpa
import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers import SiglipConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention, SiglipSdpaAttention
torch.manual_seed(235093093)
dtype = torch.float16
device = "cuda"
# Configure
config = SiglipConfig()
hidden_size = 6
num_attention_heads = 1
seq_len = 5
batch_size = 1
config.vision_config.hidden_size = hidden_size
config.vision_config.num_attention_heads = num_attention_heads
# Eager attention
attention = SiglipAttention(config.vision_config)
attention = attention.to(dtype).to(device)
# SDPA attention
attention_sdpa = SiglipSdpaAttention(config.vision_config)
attention_sdpa.load_state_dict(attention.state_dict())
attention_sdpa = attention_sdpa.to(dtype).to(device)
# Prepare inputs
dummy_input = torch.rand(
[batch_size, seq_len, hidden_size], dtype=dtype, device=device,
)
dummy_attention_mask = torch.ones(
[batch_size, seq_len], dtype=dtype, device=device,
)
# padding
dummy_attention_mask[:1, -2:] = 0
print("Dummy attention mask:\n", dummy_attention_mask)
# Prepare attention mask
dummy_attention_mask_eager = _prepare_4d_attention_mask(
dummy_attention_mask, dummy_input.dtype
) # 1, 1, 512, 512 -> batch_size, 1, seq_len, seq_len
# the same for SDPA
dummy_attention_mask_sdpa = dummy_attention_mask_eager
with torch.no_grad():
attn_output, attn_weights = attention(dummy_input, dummy_attention_mask_eager)
attn_output_sdpa, attn_weights_sdpa = attention_sdpa(dummy_input, dummy_attention_mask_sdpa)
print("\nEager:\n", attn_output)
print("\nSDPA:\n", attn_output_sdpa)
print("\nDiff:\n", attn_output - attn_output_sdpa)
diff_with_sdpa = torch.abs(attn_output - attn_output_sdpa).max()
print("\nDiff with SDPA:", diff_with_sdpa)
Dummy attention mask:
tensor([[1., 1., 1., 0., 0.]], device='cuda:0', dtype=torch.float16)
Eager:
tensor([[[ 0.0267, 0.3291, 0.3442, 0.6152, -0.1914, 0.1541],
[ 0.0174, 0.3289, 0.3398, 0.6138, -0.1968, 0.1560],
[ 0.0283, 0.3301, 0.3447, 0.6162, -0.1914, 0.1542],
[ 0.0176, 0.3289, 0.3401, 0.6138, -0.1968, 0.1559],
[ 0.0228, 0.3293, 0.3423, 0.6147, -0.1941, 0.1550]]],
device='cuda:0', dtype=torch.float16)
SDPA:
tensor([[[ 0.0267, 0.3291, 0.3442, 0.6152, -0.1914, 0.1541],
[ 0.0174, 0.3289, 0.3398, 0.6138, -0.1968, 0.1560],
[ 0.0283, 0.3301, 0.3447, 0.6162, -0.1914, 0.1542],
[ 0.0176, 0.3289, 0.3401, 0.6138, -0.1968, 0.1559],
[ 0.0228, 0.3293, 0.3423, 0.6147, -0.1941, 0.1550]]],
device='cuda:0', dtype=torch.float16)
Diff:
tensor([[[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
Diff with SDPA: tensor(0., device='cuda:0', dtype=torch.float16)
|
||
## Expected speedups | ||
|
||
Below is an expected speedup diagram that compares inference time between the native implementation in transformers using `google/siglip-so400m-patch14-384` checkpoint in `float16` precision and the Flash Attention 2 / SDPA version of the model using different batch sizes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
Thanks for the work! Which latest version on pypi would support this feature? |
Hi @lucasjinreal, it's going to be included in the next release, most probably 4.43.0. You can try it now by installing transformers from the source pip install -U git+https://github.com/huggingface/transformers.git |
I will waiting for the 4.43, my server unable to access network, any estimate time will 4.43 out? |
@lucasjinreal Just to confirm, you can't access github from your server? If it's just relating to internet, you would also need access for installing the latest release from pypi. We typically release on a monthly schedule. You can see the list of releases here. The next minor release will probably be in 2-3 weeks. |
thanks, will try update weeks later. |
What does this PR do?
Add flash attention 2 and sdpa (
torch.nn.functional.scaled_dot_product_attention
) attention implementations for SigLIP model.Fixes #31138
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.