diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 4ccdd1deaf4ca1..1441f4a5e6bf80 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -239,6 +239,30 @@ def forward( return outputs +class ViTSdpaSelfAttention(ViTSelfAttention): + def __init__(self, config: ViTConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, is_causal=False, scale=None) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + class ViTSelfOutput(nn.Module): """ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the @@ -296,6 +320,12 @@ def forward( return outputs +class ViTSdpaAttention(ViTAttention): + def __init__(self, config: ViTConfig) -> None: + super().__init__(config) + self.attention = ViTSdpaSelfAttention(config) + + class ViTIntermediate(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() @@ -327,6 +357,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states +VIT_ATTENTION_CLASSES = { + "eager": ViTAttention, + "sdpa": ViTSdpaAttention, +} + + class ViTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -334,7 +370,7 @@ def __init__(self, config: ViTConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTAttention(config) + self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTIntermediate(config) self.output = ViTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -431,6 +467,7 @@ class ViTPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["ViTEmbeddings", "ViTLayer"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index 7298543a563438..0a0f9da1b27f67 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -23,6 +23,7 @@ require_torch, require_torch_accelerator, require_torch_fp16, + require_torch_sdpa, require_vision, slow, torch_device, @@ -201,6 +202,7 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): test_pruning = False test_resize_embeddings = False test_head_masking = False + has_attentions = False def setUp(self): self.model_tester = ViTModelTester(self)