Skip to content

Commit

Permalink
add sdpa for vit
Browse files Browse the repository at this point in the history
  • Loading branch information
lyaronskaya authored and Sebastien Ehrhardt committed Apr 27, 2024
1 parent 73014b5 commit d324cd0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
39 changes: 38 additions & 1 deletion src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,30 @@ def forward(
return outputs


class ViTSdpaSelfAttention(ViTSelfAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, head_mask,
self.attention_probs_dropout_prob if self.training else 0.0, is_causal=False, scale=None)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


class ViTSelfOutput(nn.Module):
"""
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
Expand Down Expand Up @@ -296,6 +320,12 @@ def forward(
return outputs


class ViTSdpaAttention(ViTAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention = ViTSdpaSelfAttention(config)


class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
Expand Down Expand Up @@ -327,14 +357,20 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


VIT_ATTENTION_CLASSES = {
"eager": ViTAttention,
"sdpa": ViTSdpaAttention,
}


class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""

def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTAttention(config)
self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -431,6 +467,7 @@ class ViTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
_supports_sdpa = True

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
2 changes: 2 additions & 0 deletions tests/models/vit/test_modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_torch_sdpa,
require_vision,
slow,
torch_device,
Expand Down Expand Up @@ -201,6 +202,7 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False

def setUp(self):
self.model_tester = ViTModelTester(self)
Expand Down

0 comments on commit d324cd0

Please sign in to comment.