Skip to content

Commit

Permalink
add sdpa to vit-based models
Browse files Browse the repository at this point in the history
  • Loading branch information
lyaronskaya authored and Sebastien Ehrhardt committed Apr 27, 2024
1 parent cffbf39 commit eca6644
Show file tree
Hide file tree
Showing 11 changed files with 572 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,38 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST
class ASTSdpaSelfAttention(ASTSelfAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
class ASTSelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -231,6 +263,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST
class ASTSdpaAttention(ASTAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention = ASTSdpaSelfAttention(config)


# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
class ASTIntermediate(nn.Module):
def __init__(self, config: ASTConfig) -> None:
Expand Down Expand Up @@ -264,15 +303,21 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST
AST_ATTENTION_CLASSES = {
"eager": ASTAttention,
"sdpa": ASTSdpaAttention,
}


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
class ASTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""

def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ASTIntermediate(config)
self.output = ASTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -369,6 +414,7 @@ class ASTPreTrainedModel(PreTrainedModel):
base_model_prefix = "audio_spectrogram_transformer"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_sdpa = True

# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
Expand Down
50 changes: 48 additions & 2 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,38 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT
class DeiTSdpaSelfAttention(DeiTSelfAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
class DeiTSelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -252,6 +284,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT
class DeiTSdpaAttention(DeiTAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention = DeiTSdpaSelfAttention(config)


# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
class DeiTIntermediate(nn.Module):
def __init__(self, config: DeiTConfig) -> None:
Expand Down Expand Up @@ -285,15 +324,21 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT
DEIT_ATTENTION_CLASSES = {
"eager": DeiTAttention,
"sdpa": DeiTSdpaAttention,
}


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
class DeiTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""

def __init__(self, config: DeiTConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = DeiTIntermediate(config)
self.output = DeiTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -391,6 +436,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["DeiTLayer"]
_supports_sdpa = True

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
50 changes: 47 additions & 3 deletions src/transformers/models/videomae/modeling_videomae.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def forward(self, pixel_values, bool_masked_pos):

# add position embeddings
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()

# only keep visible patches
# ~bool_masked_pos means visible
if bool_masked_pos is not None:
Expand Down Expand Up @@ -271,6 +270,40 @@ def forward(
return outputs


class VideoMAESdpaSelfAttention(VideoMAESelfAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)

key_layer = self.transpose_for_scores(keys)
value_layer = self.transpose_for_scores(values)
query_layer = self.transpose_for_scores(queries)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
class VideoMAESelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -330,6 +363,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE
class VideoMAESdpaAttention(VideoMAEAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention = VideoMAESdpaSelfAttention(config)


# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
class VideoMAEIntermediate(nn.Module):
def __init__(self, config: VideoMAEConfig) -> None:
Expand Down Expand Up @@ -363,15 +403,18 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE
VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention}


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
class VideoMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""

def __init__(self, config: VideoMAEConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = VideoMAEIntermediate(config)
self.output = VideoMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -468,6 +511,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
base_model_prefix = "videomae"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
50 changes: 48 additions & 2 deletions src/transformers/models/vit_mae/modeling_vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,38 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE
class ViTMAESdpaSelfAttention(ViTMAESelfAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
class ViTMAESelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -432,6 +464,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE
class ViTMAESdpaAttention(ViTMAEAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention = ViTMAESdpaSelfAttention(config)


# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
class ViTMAEIntermediate(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
Expand Down Expand Up @@ -465,15 +504,21 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE
VITMAE_ATTENTION_CLASSES = {
"eager": ViTMAEAttention,
"sdpa": ViTMAESdpaAttention,
}


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""

def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTMAEIntermediate(config)
self.output = ViTMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -570,6 +615,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
Loading

0 comments on commit eca6644

Please sign in to comment.