diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 151ca765a216f8..3d17506b570dcc 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -191,10 +191,12 @@ FlashAttention is more memory efficient, meaning you can train on much larger se PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. For now, Transformers supports SDPA inference and training for the following architectures: +* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) +* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) @@ -214,12 +216,18 @@ For now, Transformers supports SDPA inference and training for the following arc * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) +* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel) +* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel) +* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel) +* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel) +* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell) * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) +* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 1d70e57c2fd128..523ab85f14f7cd 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -169,6 +169,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST +class ASTSdpaSelfAttention(ASTSelfAttention): + def __init__(self, config: ASTConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST class ASTSelfOutput(nn.Module): """ @@ -228,6 +260,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST +class ASTSdpaAttention(ASTAttention): + def __init__(self, config: ASTConfig) -> None: + super().__init__(config) + self.attention = ASTSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST class ASTIntermediate(nn.Module): def __init__(self, config: ASTConfig) -> None: @@ -261,7 +300,13 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST +AST_ATTENTION_CLASSES = { + "eager": ASTAttention, + "sdpa": ASTSdpaAttention, +} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST class ASTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -269,7 +314,7 @@ def __init__(self, config: ASTConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ASTAttention(config) + self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ASTIntermediate(config) self.output = ASTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -366,6 +411,7 @@ class ASTPreTrainedModel(PreTrainedModel): base_model_prefix = "audio_spectrogram_transformer" main_input_name = "input_values" supports_gradient_checkpointing = True + _supports_sdpa = True # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 2480b99586192f..fe811ecc4a70c9 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -190,6 +190,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT +class DeiTSdpaSelfAttention(DeiTSelfAttention): + def __init__(self, config: DeiTConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT class DeiTSelfOutput(nn.Module): """ @@ -249,6 +281,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT +class DeiTSdpaAttention(DeiTAttention): + def __init__(self, config: DeiTConfig) -> None: + super().__init__(config) + self.attention = DeiTSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT class DeiTIntermediate(nn.Module): def __init__(self, config: DeiTConfig) -> None: @@ -282,7 +321,13 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT +DEIT_ATTENTION_CLASSES = { + "eager": DeiTAttention, + "sdpa": DeiTSdpaAttention, +} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT class DeiTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -290,7 +335,7 @@ def __init__(self, config: DeiTConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = DeiTAttention(config) + self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = DeiTIntermediate(config) self.output = DeiTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -388,6 +433,7 @@ class DeiTPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["DeiTLayer"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py old mode 100644 new mode 100755 index 100bee54389569..05f74328b4c85f --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -134,7 +134,6 @@ def forward(self, pixel_values, bool_masked_pos): # add position embeddings embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach() - # only keep visible patches # ~bool_masked_pos means visible if bool_masked_pos is not None: @@ -268,6 +267,40 @@ def forward( return outputs +class VideoMAESdpaSelfAttention(VideoMAESelfAttention): + def __init__(self, config: VideoMAEConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None + keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias) + values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias) + queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias) + + key_layer = self.transpose_for_scores(keys) + value_layer = self.transpose_for_scores(values) + query_layer = self.transpose_for_scores(queries) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE class VideoMAESelfOutput(nn.Module): """ @@ -327,6 +360,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE +class VideoMAESdpaAttention(VideoMAEAttention): + def __init__(self, config: VideoMAEConfig) -> None: + super().__init__(config) + self.attention = VideoMAESdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE class VideoMAEIntermediate(nn.Module): def __init__(self, config: VideoMAEConfig) -> None: @@ -360,7 +400,10 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE +VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE class VideoMAELayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -368,7 +411,7 @@ def __init__(self, config: VideoMAEConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = VideoMAEAttention(config) + self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = VideoMAEIntermediate(config) self.output = VideoMAEOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -465,6 +508,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel): base_model_prefix = "videomae" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 8aa43c5c43c500..dfda7bf731ba0b 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -236,6 +236,37 @@ def forward( return outputs +class ViTSdpaSelfAttention(ViTSelfAttention): + def __init__(self, config: ViTConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + class ViTSelfOutput(nn.Module): """ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the @@ -293,6 +324,12 @@ def forward( return outputs +class ViTSdpaAttention(ViTAttention): + def __init__(self, config: ViTConfig) -> None: + super().__init__(config) + self.attention = ViTSdpaSelfAttention(config) + + class ViTIntermediate(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() @@ -324,6 +361,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states +VIT_ATTENTION_CLASSES = { + "eager": ViTAttention, + "sdpa": ViTSdpaAttention, +} + + class ViTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -331,7 +374,7 @@ def __init__(self, config: ViTConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTAttention(config) + self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTIntermediate(config) self.output = ViTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -428,6 +471,7 @@ class ViTPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["ViTEmbeddings", "ViTLayer"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 20579e0d3db2cc..359b5e3fb9b08f 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -248,6 +248,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTHybrid +class ViTHybridSdpaSelfAttention(ViTHybridSelfAttention): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTHybrid class ViTHybridSelfOutput(nn.Module): """ @@ -307,6 +339,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTHybrid +class ViTHybridSdpaAttention(ViTHybridAttention): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__(config) + self.attention = ViTHybridSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTHybrid class ViTHybridIntermediate(nn.Module): def __init__(self, config: ViTHybridConfig) -> None: @@ -340,6 +379,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states +VIT_HYBRID_ATTENTION_CLASSES = { + "eager": ViTHybridAttention, + "sdpa": ViTHybridSdpaAttention, +} + + class ViTHybridLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -347,7 +392,7 @@ def __init__(self, config: ViTHybridConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTHybridAttention(config) + self.attention = VIT_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTHybridIntermediate(config) self.output = ViTHybridOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -447,6 +492,7 @@ class ViTHybridPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index b652c9e71f9106..3d6ea3f15c2b8e 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -245,14 +245,14 @@ def random_masking(self, sequence, noise=None): ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset - ids_keep = ids_shuffle[:, :len_keep] + ids_keep = ids_shuffle[:, :len_keep].to(sequence.device) sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([batch_size, seq_length], device=sequence.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask - mask = torch.gather(mask, dim=1, index=ids_restore) + mask = torch.gather(mask, dim=1, index=ids_restore.to(sequence.device)) return sequence_unmasked, mask, ids_restore @@ -370,6 +370,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE +class ViTMAESdpaSelfAttention(ViTMAESelfAttention): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE class ViTMAESelfOutput(nn.Module): """ @@ -429,6 +461,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE +class ViTMAESdpaAttention(ViTMAEAttention): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__(config) + self.attention = ViTMAESdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE class ViTMAEIntermediate(nn.Module): def __init__(self, config: ViTMAEConfig) -> None: @@ -462,7 +501,13 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE +VITMAE_ATTENTION_CLASSES = { + "eager": ViTMAEAttention, + "sdpa": ViTMAESdpaAttention, +} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE class ViTMAELayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -470,7 +515,7 @@ def __init__(self, config: ViTMAEConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTMAEAttention(config) + self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTMAEIntermediate(config) self.output = ViTMAEOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -567,6 +612,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" @@ -764,7 +810,11 @@ def forward( # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token - x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device), + ) # unshuffle x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 0632738455d1ab..d684868a645507 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -28,7 +28,12 @@ from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_vit_msn import ViTMSNConfig @@ -75,7 +80,10 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: patch_window_width = width // self.config.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 - patch_window_height, patch_window_width = patch_window_height + 0.1, patch_window_width + 0.1 + patch_window_height, patch_window_width = ( + patch_window_height + 0.1, + patch_window_width + 0.1, + ) patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( @@ -222,6 +230,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTMSN +class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN class ViTMSNSelfOutput(nn.Module): """ @@ -281,6 +321,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMSN +class ViTMSNSdpaAttention(ViTMSNAttention): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__(config) + self.attention = ViTMSNSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN class ViTMSNIntermediate(nn.Module): def __init__(self, config: ViTMSNConfig) -> None: @@ -314,7 +361,10 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN +VITMSN_ATTENTION_CLASSES = {"eager": ViTMSNAttention, "sdpa": ViTMSNSdpaAttention} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN class ViTMSNLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -322,7 +372,7 @@ def __init__(self, config: ViTMSNConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ViTMSNAttention(config) + self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTMSNIntermediate(config) self.output = ViTMSNOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -420,6 +470,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["ViTMSNAttention"] + _supports_sdpa = True # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts. @@ -553,7 +604,9 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + pixel_values, + bool_masked_pos=bool_masked_pos, + interpolate_pos_encoding=interpolate_pos_encoding, ) encoder_outputs = self.encoder( diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 9d6536b6c27258..a2223eb706c4ea 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -171,9 +171,15 @@ def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor: patch_pos_embed = patch_pos_embed.view(batch_size, hidden_size, patch_height, patch_width) height, width = img_size - new_patch_heigth, new_patch_width = height // self.config.patch_size, width // self.config.patch_size + new_patch_heigth, new_patch_width = ( + height // self.config.patch_size, + width // self.config.patch_size, + ) patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, size=(new_patch_heigth, new_patch_width), mode="bicubic", align_corners=False + patch_pos_embed, + size=(new_patch_heigth, new_patch_width), + mode="bicubic", + align_corners=False, ) patch_pos_embed = patch_pos_embed.flatten(2).transpose(1, 2) scale_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed, det_pos_embed), dim=1) @@ -199,9 +205,15 @@ def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor: ) patch_pos_embed = patch_pos_embed.view(depth * batch_size, hidden_size, patch_height, patch_width) height, width = img_size - new_patch_height, new_patch_width = height // self.config.patch_size, width // self.config.patch_size + new_patch_height, new_patch_width = ( + height // self.config.patch_size, + width // self.config.patch_size, + ) patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, size=(new_patch_height, new_patch_width), mode="bicubic", align_corners=False + patch_pos_embed, + size=(new_patch_height, new_patch_width), + mode="bicubic", + align_corners=False, ) patch_pos_embed = ( patch_pos_embed.flatten(2) @@ -307,6 +319,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Yolos +class YolosSdpaSelfAttention(YolosSelfAttention): + def __init__(self, config: YolosConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos class YolosSelfOutput(nn.Module): """ @@ -366,6 +410,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Yolos +class YolosSdpaAttention(YolosAttention): + def __init__(self, config: YolosConfig) -> None: + super().__init__(config) + self.attention = YolosSdpaSelfAttention(config) + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos class YolosIntermediate(nn.Module): def __init__(self, config: YolosConfig) -> None: @@ -399,7 +450,10 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos +YOLOS_ATTENTION_CLASSES = {"eager": YolosAttention, "sdpa": YolosSdpaAttention} + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS class YolosLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -407,7 +461,7 @@ def __init__(self, config: YolosConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = YolosAttention(config) + self.attention = YOLOS_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = YolosIntermediate(config) self.output = YolosOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -531,6 +585,7 @@ class YolosPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = [] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" @@ -701,10 +756,16 @@ def __init__(self, config: YolosConfig): # Object detection heads # We add one for the "no object" class self.class_labels_classifier = YolosMLPPredictionHead( - input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=config.num_labels + 1, num_layers=3 + input_dim=config.hidden_size, + hidden_dim=config.hidden_size, + output_dim=config.num_labels + 1, + num_layers=3, ) self.bbox_predictor = YolosMLPPredictionHead( - input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=4, num_layers=3 + input_dim=config.hidden_size, + hidden_dim=config.hidden_size, + output_dim=4, + num_layers=3, ) # Initialize weights and apply final processing @@ -796,7 +857,9 @@ def forward( if labels is not None: # First: create the matcher matcher = YolosHungarianMatcher( - class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost + class_cost=self.config.class_cost, + bbox_cost=self.config.bbox_cost, + giou_cost=self.config.giou_cost, ) # Second: create the criterion losses = ["labels", "boxes", "cardinality"] diff --git a/tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py b/tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py index 564ca4d48c6a7f..9afad8adb9a665 100644 --- a/tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py +++ b/tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py @@ -117,6 +117,7 @@ def get_config(self): initializer_range=self.initializer_range, frequency_stride=self.frequency_stride, time_stride=self.time_stride, + attn_implementation="eager", ) def create_and_check_model(self, config, input_values, labels): diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index 9a54f16dab689f..39bec3ef392650 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -80,6 +80,7 @@ def __init__( num_labels=3, scope=None, encoder_stride=2, + mask_ratio=0.5, ): self.parent = parent self.batch_size = batch_size @@ -103,6 +104,9 @@ def __init__( # in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens) num_patches = (image_size // patch_size) ** 2 self.seq_length = num_patches + 2 + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -130,6 +134,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/deit/test_modeling_tf_deit.py b/tests/models/deit/test_modeling_tf_deit.py index 26980e84207d50..dfdbfcbf437dc2 100644 --- a/tests/models/deit/test_modeling_tf_deit.py +++ b/tests/models/deit/test_modeling_tf_deit.py @@ -121,6 +121,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py index e5b1c6b78e40dd..67c6a5ea0de3c2 100644 --- a/tests/models/videomae/test_modeling_videomae.py +++ b/tests/models/videomae/test_modeling_videomae.py @@ -132,6 +132,7 @@ def get_config(self): decoder_intermediate_size=self.intermediate_size, decoder_num_attention_heads=self.num_attention_heads, decoder_num_hidden_layers=self.num_hidden_layers, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): @@ -197,7 +198,8 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): # hence we define a single mask, which we then repeat for each example in the batch mask = torch.ones((self.model_tester.num_masks,)) mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))]) - bool_masked_pos = mask.expand(self.model_tester.batch_size, -1).bool() + batch_size = inputs_dict["pixel_values"].shape[0] + bool_masked_pos = mask.expand(batch_size, -1).bool() inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device) if return_labels: diff --git a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index d512ff25fe35ac..b346f2014451fc 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -26,6 +26,7 @@ from transformers import is_tf_available, is_torch_available, is_vision_available from transformers.testing_utils import ( + _run_slow_tests, is_pt_tf_cross_test, require_tf, require_torch, @@ -465,11 +466,15 @@ def check_pt_tf_equivalence(self, tf_model, pt_model, tf_inputs_dict): self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict): + if _run_slow_tests: + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) # Output all for aggressive testing encoder_decoder_config.output_hidden_states = True # All models tested in this file have attentions - encoder_decoder_config.output_attentions = True + encoder_decoder_config.output_attentions = _run_slow_tests pt_model = VisionEncoderDecoderModel(encoder_decoder_config) @@ -480,11 +485,17 @@ def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict): self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict): + # When taking a model from tf we are using the default attention + # mode (sdpa) so we are not expecting attention + config_output_attention = config.output_attentions + config.output_attentions = False + encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + # Output all for aggressive testing encoder_decoder_config.output_hidden_states = True # TODO: A generalizable way to determine this attribute - encoder_decoder_config.output_attentions = True + encoder_decoder_config.output_attentions = False tf_model = TFVisionEncoderDecoderModel(encoder_decoder_config) # Make sure model is built before saving @@ -495,6 +506,8 @@ def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict): pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) + # Revert mutable objet modification + config.output_attentions = config_output_attention def test_encoder_decoder_model(self): config_inputs_dict = self.prepare_config_and_inputs() @@ -554,9 +567,9 @@ def test_pt_tf_model_equivalence(self): # Output all for aggressive testing config.output_hidden_states = True decoder_config.output_hidden_states = True - # All models tested in this file have attentions - config.output_attentions = True - decoder_config.output_attentions = True + # All models tested in this file have attentions in slow mode + config.output_attentions = _run_slow_tests + decoder_config.output_attentions = _run_slow_tests tf_inputs_dict = config_inputs_dict # `encoder_hidden_states` is not used in model call/forward diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 3239b507a8172f..4ed9baadd5e48f 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -23,6 +23,7 @@ from transformers import DonutProcessor, NougatProcessor, TrOCRProcessor from transformers.testing_utils import ( + _run_slow_tests, require_levenshtein, require_nltk, require_sentencepiece, @@ -323,6 +324,7 @@ def test_save_and_load_from_encoder_decoder_pretrained(self): input_ids_dict = self.prepare_config_and_inputs() self.check_save_and_load_encoder_decoder_model(**input_ids_dict) + @slow def test_encoder_decoder_model_output_attentions(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_output_attentions(**input_ids_dict) @@ -457,6 +459,8 @@ def check_encoder_decoder_model_output_attentions( ) def get_encoder_decoder_model(self, config, decoder_config): + if _run_slow_tests: + config._attn_implementation = "eager" encoder_model = DeiTModel(config).eval() decoder_model = BertLMHeadModel(decoder_config).eval() return encoder_model, decoder_model @@ -522,6 +526,8 @@ def get_pretrained_model_and_inputs(self): return model, inputs def get_encoder_decoder_model(self, config, decoder_config): + if _run_slow_tests: + config._attn_implementation = "eager" encoder_model = ViTModel(config).eval() decoder_model = BertLMHeadModel(decoder_config).eval() return encoder_model, decoder_model @@ -650,6 +656,8 @@ def test_real_model_save_load_from_pretrained(self): @require_torch class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase): def get_encoder_decoder_model(self, config, decoder_config): + if _run_slow_tests: + config._attn_implementation = "eager" encoder_model = ViTModel(config).eval() decoder_model = TrOCRForCausalLM(decoder_config).eval() return encoder_model, decoder_model diff --git a/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py index 4a1ee2462e4f6e..1de8dd584f7f38 100644 --- a/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py +++ b/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py @@ -21,7 +21,13 @@ import numpy as np -from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device +from transformers.testing_utils import ( + is_pt_flax_cross_test, + require_torch, + require_vision, + slow, + torch_device, +) from transformers.utils import is_flax_available, is_torch_available, is_vision_available from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask diff --git a/tests/models/vit/test_modeling_flax_vit.py b/tests/models/vit/test_modeling_flax_vit.py index af56f4717b888f..4d3c59a7bdbae1 100644 --- a/tests/models/vit/test_modeling_flax_vit.py +++ b/tests/models/vit/test_modeling_flax_vit.py @@ -87,6 +87,7 @@ def prepare_config_and_inputs(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + attn_implementation="eager", ) return config, pixel_values diff --git a/tests/models/vit/test_modeling_tf_vit.py b/tests/models/vit/test_modeling_tf_vit.py index dee2c8f18c171a..153ddf2be47d8d 100644 --- a/tests/models/vit/test_modeling_tf_vit.py +++ b/tests/models/vit/test_modeling_tf_vit.py @@ -111,6 +111,7 @@ def get_config(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index 7298543a563438..b316f183c77410 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -68,6 +68,7 @@ def __init__( initializer_range=0.02, scope=None, encoder_stride=2, + mask_ratio=0.5, ): self.parent = parent self.batch_size = batch_size @@ -91,6 +92,9 @@ def __init__( # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 self.seq_length = num_patches + 1 + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -118,6 +122,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit_hybrid/test_modeling_vit_hybrid.py b/tests/models/vit_hybrid/test_modeling_vit_hybrid.py index d48a8853921649..8c37134e3f9f8a 100644 --- a/tests/models/vit_hybrid/test_modeling_vit_hybrid.py +++ b/tests/models/vit_hybrid/test_modeling_vit_hybrid.py @@ -122,6 +122,7 @@ def get_config(self): backbone_featmap_shape=self.backbone_featmap_shape, backbone_config=backbone_config, backbone=None, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index 6a77e95102c969..e4c703f18c551a 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -127,6 +127,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, mask_ratio=self.mask_ratio, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index ffb679d646ffda..4d5345dfa2051c 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -63,8 +63,8 @@ def __init__( type_sequence_label_size=10, initializer_range=0.02, num_labels=3, - mask_ratio=0.6, scope=None, + mask_ratio=0.5, ): self.parent = parent self.batch_size = batch_size @@ -89,6 +89,9 @@ def __init__( # (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1))) + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -120,6 +123,7 @@ def get_config(self): decoder_intermediate_size=self.intermediate_size, decoder_num_attention_heads=self.num_attention_heads, decoder_num_hidden_layers=self.num_hidden_layers, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/vit_msn/test_modeling_vit_msn.py b/tests/models/vit_msn/test_modeling_vit_msn.py index 5fe494c105cb62..395572b618465f 100644 --- a/tests/models/vit_msn/test_modeling_vit_msn.py +++ b/tests/models/vit_msn/test_modeling_vit_msn.py @@ -106,6 +106,7 @@ def get_config(self): hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, initializer_range=self.initializer_range, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py index 64a439f27a4e45..5fe2929bfc2fe2 100644 --- a/tests/models/yolos/test_modeling_yolos.py +++ b/tests/models/yolos/test_modeling_yolos.py @@ -123,6 +123,7 @@ def get_config(self): initializer_range=self.initializer_range, num_detection_tokens=self.num_detection_tokens, num_labels=self.num_labels, + attn_implementation="eager", ) def create_and_check_model(self, config, pixel_values, labels): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cd46934b5fcfe4..f55c7f9aeac555 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2788,7 +2788,9 @@ def test_equivalence_flax_to_pt(self): with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) - pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + pt_model_loaded = model_class.from_pretrained( + tmpdirname, from_flax=True, attn_implementation="eager" + ) # send pytorch model to the correct device pt_model_loaded.to(torch_device) @@ -3724,6 +3726,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) + # FIXME: we deactivate boolean mask because pretrained models + # will not load the mask token + if "use_mask_token" in inspect.signature(model_class).parameters: + deactivate_mask = True + else: + deactivate_mask = False is_encoder_decoder = model.config.is_encoder_decoder @@ -3840,6 +3848,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): "decoder_attention_mask": dummy_attention_mask, "output_hidden_states": True, } + else: processed_inputs = { model.main_input_name: dummy_input, @@ -3850,6 +3859,37 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): if "attention_mask" in inspect.signature(model_eager.forward).parameters: processed_inputs["attention_mask"] = dummy_attention_mask + if ( + "bool_masked_pos" in inspect.signature(model_eager.forward).parameters + ) and not deactivate_mask: + dummy_mask = torch.ones((self.model_tester.num_masks,)) + + # In case of additional token (like class) we define a custome `mask_length` + if hasattr(self.model_tester, "mask_length"): + dummy_mask = torch.cat( + [ + dummy_mask, + torch.zeros(self.model_tester.mask_length - dummy_mask.size(0)), + ] + ) + else: + dummy_mask = torch.cat( + [ + dummy_mask, + torch.zeros(self.model_tester.seq_length - dummy_mask.size(0)), + ] + ) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) + + if "noise" in inspect.signature(model_eager.forward).parameters: + np.random.seed(2) + num_patches = int( + (self.model_tester.image_size // self.model_tester.patch_size) ** 2 + ) + noise = np.random.uniform(size=(batch_size, num_patches)) + processed_inputs["noise"] = torch.from_numpy(noise) + # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): with torch.backends.cuda.sdp_kernel( diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 22d6b241f0048c..b9b23934bb9765 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -371,7 +371,9 @@ def test_equivalence_flax_to_pt(self): with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) + pt_model_loaded = pt_model_class.from_pretrained( + tmpdirname, from_flax=True, attn_implementation="eager" + ) # send pytorch model to the correct device pt_model_loaded.to(torch_device) diff --git a/utils/check_support_list.py b/utils/check_support_list.py index f6aaa2bb67dce4..3cb0b616022426 100644 --- a/utils/check_support_list.py +++ b/utils/check_support_list.py @@ -84,7 +84,7 @@ def check_sdpa_support_list(): archs_supporting_sdpa.append(model_name) for arch in archs_supporting_sdpa: - if arch not in doctext: + if arch not in doctext and arch not in doctext.replace("-", "_"): raise ValueError( f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation." )