From 74be54cdea58d08fcf4a974c31f2e6c6918b2497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Tue, 3 Oct 2023 16:03:54 +0200 Subject: [PATCH 001/100] add sdpa --- .../models/llama/modeling_llama.py | 123 +++++++++++++++++- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 7 + 3 files changed, 126 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3d9e317750ca75..e9a0d61b9518e2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_available, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -590,15 +591,127 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l ) +class LlamaSDPAAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + raise ValueError("output_attentions=True can not be supported with PyTorch SDPA.") + attn_weights = None + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + torch.nn.functional.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + torch.nn.functional.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + torch.nn.functional.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Note that Llama does not use dropout in the attention, hence the hard-coded + # dropout_p=0.0 independent of self.training. + # The batch_size = 1 case is handled separately to allow to dispatch on flash attention + # kernel. + if bsz == 1: + if query_states.shape[2] > 1: + is_causal = True + else: + is_causal = False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=is_causal + ) + else: + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # NOTE: As of PyTorch 2.1, this case can not dispatch to flash attention v2 due to the + # attention mask passed. Possible solution: use nested tensors. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = ( - LlamaAttention(config=config) - if not getattr(config, "_flash_attn_2_enabled", False) - else LlamaFlashAttention2(config=config) - ) + + if getattr(config, "_flash_attn_2_enabled", False): + self.self_attn = LlamaFlashAttention2(config=config) + elif is_torch_sdpa_available(): + self.self_attn = LlamaSDPAAttention(config=config) + else: + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index c5b80c617c85de..b2566f424b007b 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -175,6 +175,7 @@ is_torch_mps_available, is_torch_neuroncore_available, is_torch_npu_available, + is_torch_sdpa_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 837fb24af42a61..c10838e6e4aef5 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -256,6 +256,13 @@ def get_torch_version(): return _torch_version +def is_torch_sdpa_available(): + # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons: + # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259 + # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310 + return _torch_version >= version.parse("2.1") + + def is_torchvision_available(): return _torchvision_available From 9d14f0dea3d5941604a7694006d944ef1b7951d2 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 3 Oct 2023 16:32:21 +0000 Subject: [PATCH 002/100] wip --- src/transformers/generation/utils.py | 1 + .../models/llama/modeling_llama.py | 104 +++++++++++++++++- src/transformers/utils/import_utils.py | 2 +- 3 files changed, 102 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3b1bef6f040084..1403c6792d95b3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2450,6 +2450,7 @@ def greedy_search( # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + print("----- call self") # forward pass to get next token outputs = self( **model_inputs, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e9a0d61b9518e2..b4b53f29b1e54e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -96,6 +96,74 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) +def _inplace_unmask_padding(expanded_mask, attention_mask): + print("expanded_mask", expanded_mask.shape) + print("expanded_mask", expanded_mask) + + bsz = attention_mask.shape[0] + + # Get the index of the first non-zero value for every sample in the batch. + tmp = torch.arange(attention_mask.shape[1], 0, -1) + indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True) + + # Construct a matrix of shape (batch_size, maximum_left_padding_length) that is used + # in an aten::index_put to override expanded_mask values for pad tokens. + # For example, if attention_mask is + # [[1, 1, 1, 1, 1] + # [0, 0, 1, 1, 1] + # [0, 0, 0, 1, 1]] + # + # and the original expanded_mask is + # + # [[[[1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1]]], + # + # [[[0, 0, 1, 1, 1], + # [0, 0, 1, 1, 1], + # [0, 0, 1, 1, 1], + # [0, 0, 1, 1, 1], + # [0, 0, 1, 1, 1]]], + # + # [[[0, 0, 0, 1, 1], + # [0, 0, 0, 1, 1], + # [0, 0, 0, 1, 1], + # [0, 0, 0, 1, 1], + # [0, 0, 0, 1, 1]]]] + # + # then the modified expanded_mask will be + # + # [[[[1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1]]], + # + # [[[1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [0, 0, 1, 1, 1], + # [0, 0, 1, 1, 1], + # [0, 0, 1, 1, 1]]], + # + # [[[1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1], + # [0, 0, 0, 1, 1], + # [0, 0, 0, 1, 1]]]] + + max_len = torch.max(indices) + range_tensor = torch.arange(max_len).unsqueeze(0) + range_tensor = range_tensor.repeat(indices.size(0), 1) + range_tensor[range_tensor >= indices] = 0 + + print("range_tensor", range_tensor) + + expanded_mask[torch.arange(bsz).unsqueeze(1), 0, range_tensor] = 0 + + print("expanded_mask after", expanded_mask) + class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -681,11 +749,29 @@ def forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) + attention_mask = torch.max(attention_mask, torch.tensor(torch.finfo(key_states.dtype).min)) + + """ + if attention_mask.shape == (2, 1, 7, 7): + attention_mask[1, 0, 0, 0] = 0 + attention_mask[1, 0, 1, 0] = 0 + attention_mask[1, 0, 2, 0] = 0 + """ + # NOTE: As of PyTorch 2.1, this case can not dispatch to flash attention v2 due to the # attention mask passed. Possible solution: use nested tensors. - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + #print("attention_mask", attention_mask.shape) + #print("attention_mask", attention_mask) + #print("query_states", query_states.shape) + #print("query_states", query_states) + #print("key_states", key_states.shape) + #print("key_states", key_states) + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + #print("attn_output", attn_output.shape) + #print("attn_output", attn_output) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -707,10 +793,14 @@ def __init__(self, config: LlamaConfig): if getattr(config, "_flash_attn_2_enabled", False): self.self_attn = LlamaFlashAttention2(config=config) + #self.self_attn = LlamaSDPAAttention(config=config) + self.self_attn = LlamaAttention(config=config) + """ elif is_torch_sdpa_available(): self.self_attn = LlamaSDPAAttention(config=config) else: self.self_attn = LlamaAttention(config=config) + """ self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -931,10 +1021,16 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # does not support unattended sequences in the attention mask. Details: LINK + #if input_shape[-1] > 1 and is_torch_sdpa_available() and attention_mask.device.type == "cuda": + # _inplace_unmask_padding(expanded_attn_mask, attention_mask) + combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) - + print("combined_attention_mask", combined_attention_mask) return combined_attention_mask @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c10838e6e4aef5..9a355c404600bc 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -260,7 +260,7 @@ def is_torch_sdpa_available(): # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons: # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259 # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310 - return _torch_version >= version.parse("2.1") + return version.parse(_torch_version) >= version.parse("2.1") def is_torchvision_available(): From f803de3386956b66bef2ce230c0be746e8921b3b Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 3 Oct 2023 17:12:51 +0000 Subject: [PATCH 003/100] cleaning --- src/transformers/generation/utils.py | 1 - .../models/llama/modeling_llama.py | 70 ++----------------- 2 files changed, 4 insertions(+), 67 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1403c6792d95b3..3b1bef6f040084 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2450,7 +2450,6 @@ def greedy_search( # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - print("----- call self") # forward pass to get next token outputs = self( **model_inputs, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b4b53f29b1e54e..422b1788cbd829 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -96,74 +96,21 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -def _inplace_unmask_padding(expanded_mask, attention_mask): - print("expanded_mask", expanded_mask.shape) - print("expanded_mask", expanded_mask) +def _inplace_unmask_padding(expanded_mask, attention_mask): bsz = attention_mask.shape[0] # Get the index of the first non-zero value for every sample in the batch. tmp = torch.arange(attention_mask.shape[1], 0, -1) indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True) - # Construct a matrix of shape (batch_size, maximum_left_padding_length) that is used - # in an aten::index_put to override expanded_mask values for pad tokens. - # For example, if attention_mask is - # [[1, 1, 1, 1, 1] - # [0, 0, 1, 1, 1] - # [0, 0, 0, 1, 1]] - # - # and the original expanded_mask is - # - # [[[[1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1]]], - # - # [[[0, 0, 1, 1, 1], - # [0, 0, 1, 1, 1], - # [0, 0, 1, 1, 1], - # [0, 0, 1, 1, 1], - # [0, 0, 1, 1, 1]]], - # - # [[[0, 0, 0, 1, 1], - # [0, 0, 0, 1, 1], - # [0, 0, 0, 1, 1], - # [0, 0, 0, 1, 1], - # [0, 0, 0, 1, 1]]]] - # - # then the modified expanded_mask will be - # - # [[[[1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1]]], - # - # [[[1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [0, 0, 1, 1, 1], - # [0, 0, 1, 1, 1], - # [0, 0, 1, 1, 1]]], - # - # [[[1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1], - # [0, 0, 0, 1, 1], - # [0, 0, 0, 1, 1]]]] - max_len = torch.max(indices) range_tensor = torch.arange(max_len).unsqueeze(0) range_tensor = range_tensor.repeat(indices.size(0), 1) range_tensor[range_tensor >= indices] = 0 - print("range_tensor", range_tensor) - expanded_mask[torch.arange(bsz).unsqueeze(1), 0, range_tensor] = 0 - print("expanded_mask after", expanded_mask) - class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -760,18 +707,10 @@ def forward( # NOTE: As of PyTorch 2.1, this case can not dispatch to flash attention v2 due to the # attention mask passed. Possible solution: use nested tensors. - #print("attention_mask", attention_mask.shape) - #print("attention_mask", attention_mask) - #print("query_states", query_states.shape) - #print("query_states", query_states) - #print("key_states", key_states.shape) - #print("key_states", key_states) with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - #print("attn_output", attn_output.shape) - #print("attn_output", attn_output) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -793,7 +732,7 @@ def __init__(self, config: LlamaConfig): if getattr(config, "_flash_attn_2_enabled", False): self.self_attn = LlamaFlashAttention2(config=config) - #self.self_attn = LlamaSDPAAttention(config=config) + # self.self_attn = LlamaSDPAAttention(config=config) self.self_attn = LlamaAttention(config=config) """ elif is_torch_sdpa_available(): @@ -1024,13 +963,12 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # does not support unattended sequences in the attention mask. Details: LINK - #if input_shape[-1] > 1 and is_torch_sdpa_available() and attention_mask.device.type == "cuda": - # _inplace_unmask_padding(expanded_attn_mask, attention_mask) + if input_shape[-1] > 1 and is_torch_sdpa_available() and attention_mask.device.type == "cuda": + _inplace_unmask_padding(expanded_attn_mask, attention_mask) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) - print("combined_attention_mask", combined_attention_mask) return combined_attention_mask @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) From c0bcbfa41d4cb84512b577c784ae969040730bc4 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 3 Oct 2023 17:15:14 +0000 Subject: [PATCH 004/100] add ref --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 422b1788cbd829..356190c05b39d9 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -962,7 +962,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em ) # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # does not support unattended sequences in the attention mask. Details: LINK + # does not support unattended sequences in the attention mask. Details: https://github.com/huggingface/transformers/pull/26572 if input_shape[-1] > 1 and is_torch_sdpa_available() and attention_mask.device.type == "cuda": _inplace_unmask_padding(expanded_attn_mask, attention_mask) From 38332d767a935edd8d13bc5cbbe7423bfed08ce9 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 3 Oct 2023 17:16:30 +0000 Subject: [PATCH 005/100] yet more cleaning --- src/transformers/models/llama/modeling_llama.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 356190c05b39d9..167bc7ae47af64 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -698,19 +698,11 @@ def forward( attention_mask = torch.max(attention_mask, torch.tensor(torch.finfo(key_states.dtype).min)) - """ - if attention_mask.shape == (2, 1, 7, 7): - attention_mask[1, 0, 0, 0] = 0 - attention_mask[1, 0, 1, 0] = 0 - attention_mask[1, 0, 2, 0] = 0 - """ - # NOTE: As of PyTorch 2.1, this case can not dispatch to flash attention v2 due to the # attention mask passed. Possible solution: use nested tensors. - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) From 3b47502cc745c381ec56e1af5ba60afdf2674588 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 3 Oct 2023 17:17:12 +0000 Subject: [PATCH 006/100] and more :) --- src/transformers/models/llama/modeling_llama.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 167bc7ae47af64..99291fb30be3bb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -724,14 +724,10 @@ def __init__(self, config: LlamaConfig): if getattr(config, "_flash_attn_2_enabled", False): self.self_attn = LlamaFlashAttention2(config=config) - # self.self_attn = LlamaSDPAAttention(config=config) - self.self_attn = LlamaAttention(config=config) - """ elif is_torch_sdpa_available(): self.self_attn = LlamaSDPAAttention(config=config) else: self.self_attn = LlamaAttention(config=config) - """ self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) From 79c12a94fd4c11a29d0e4149b6c3fad39f41a06c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Tue, 31 Oct 2023 15:32:13 +0100 Subject: [PATCH 007/100] wip llama --- src/transformers/modeling_attn_mask_utils.py | 82 +++++++++++++++++++ .../models/llama/modeling_llama.py | 35 +++++++- 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 434b32ce7f899d..a767b520260866 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -13,6 +13,8 @@ # limitations under the License. from typing import List, Optional, Tuple, Union +from .utils import is_torch_sdpa_available + import torch @@ -114,6 +116,11 @@ def to_4d( ) expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if query_length > 1 and is_torch_sdpa_available() and attention_mask_2d.device.type == "cuda": + expanded_4d_mask = _unmask_unattended(expanded_4d_mask, attention_mask_2d, unmasked_value=0.0) + return expanded_4d_mask @staticmethod @@ -160,6 +167,81 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + @staticmethod + def _unmask_unattended(expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]): + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. Details: + https://github.com/pytorch/pytorch/issues/110213 + + expanded_mask is [bsz, 1, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + attention_mask is [bsz, src_seq_len]. + + If attention_mask is + ``` + [[0, 0, 1] + [1, 1, 1] + [0, 1, 1]] + ``` + and expanded_mask is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified expanded_mask will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # TODO: unmasked_value may be unnecessary following Patrick's refactor + + # Get the index of the first non-zero value for every sample in the batch. + # In the above example, indices = [[2], [0], [1]]] + tmp = torch.arange(attention_mask.shape[1], 0, -1) + indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True) + + # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the + # expanded mask will be completely unattended. + left_masked_rows = torch.where(indices > 0)[0] + + if left_masked_rows.shape[0] == 0: + return expanded_mask + indices = indices[left_masked_rows] + + max_len = torch.max(indices) + range_tensor = torch.arange(max_len).unsqueeze(0) + range_tensor = range_tensor.repeat(indices.size(0), 1) + + # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above. + range_tensor[range_tensor >= indices] = 0 + + mask_slice = (left_masked_rows.unsqueeze(1),) + + # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case + if expanded_mask.dim() == 4: + mask_slice += (0, range_tensor) + else: + mask_slice += (range_tensor,) + + expanded_mask[mask_slice] = unmasked_value + + return expanded_mask + def _prepare_4d_causal_attention_mask( attention_mask: Optional[torch.Tensor], diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 997043f65ee299..6ee2539aec4aa7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -603,6 +603,40 @@ class LlamaSDPAAttention(LlamaAttention): `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ + def set_sdpa_attention_mask( + self, + batch_size: int, + attention_mask: Optional[torch.Tensor], + padding_mask: Optional[torch.Tensor], + kv_seq_len: int, + query_length: int, + ): + """ + Prepares the correct argument to be used by torch.nn.functional.scaled_dot_product_attention. + We ignore the attention mask in some cases for batch_size = 1 to allow to dispatch to the flash attention + kernel. + NOTE: As of PyTorch 2.1, SDPA can not dispatch to flash attention in case an attention mask passed. Possible + solution: use nested tensors. + """ + # TODO: remove padding mask + if batch_size == 1 and padding_mask is None: + if query_length == 1: # For query_length == 1, causal attention and bi-directional attention are the same. + is_causal = False + attention_mask = None + elif kv_seq_len == query_length: + is_causal = True + attention_mask = None + else: + # Unfortunately, for query_length > 1, we can not generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We set is_causal=False in SDPA and rely on Transformers attention_mask instead. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + is_causal = False + elif attention_mask is not None: + is_causal = False + else: + is_causal = True + + return attention_mask, is_causal # Adapted from LlamaAttention.forward def forward( @@ -613,7 +647,6 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: raise ValueError("output_attentions=True can not be supported with PyTorch SDPA.") From dc929cd4c6c8da18d8f39908c9aee97bfdc45c55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Tue, 31 Oct 2023 16:56:25 +0100 Subject: [PATCH 008/100] working llama --- src/transformers/modeling_attn_mask_utils.py | 50 +++++++++-- .../models/llama/modeling_llama.py | 84 ++++++------------- 2 files changed, 69 insertions(+), 65 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index a767b520260866..6b645d360da420 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -116,11 +116,6 @@ def to_4d( ) expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if query_length > 1 and is_torch_sdpa_available() and attention_mask_2d.device.type == "cuda": - expanded_4d_mask = _unmask_unattended(expanded_4d_mask, attention_mask_2d, unmasked_value=0.0) - return expanded_4d_mask @staticmethod @@ -282,6 +277,51 @@ def _prepare_4d_causal_attention_mask( return attention_mask +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Prepares the correct attn_mask argument to be used by torch.nn.functional.scaled_dot_product_attention. + + We ignore the attention mask in some cases for batch_size = 1 to allow to dispatch to the flash attention + kernel. + + Note that as of PyTorch 2.1, SDPA can not dispatch to flash attention in case an attention mask is passed. A possible solution is to use nested tensors. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + query_length = input_shape[-1] + + if attention_mask is not None: + if batch_size == 1 and torch.all(attention_mask == 1): + if query_length == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + attention_mask = None + elif key_value_length == query_length: + attention_mask = None + else: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + is_causal = False + + if attention_mask is not None: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype + ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + expanded_4d_mask = AttentionMaskConverter._unmask_unattended(expanded_4d_mask, attention_mask, unmasked_value=0.0) + + return expanded_4d_mask + def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6ee2539aec4aa7..d1877541c6ee6d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -603,41 +603,6 @@ class LlamaSDPAAttention(LlamaAttention): `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ - def set_sdpa_attention_mask( - self, - batch_size: int, - attention_mask: Optional[torch.Tensor], - padding_mask: Optional[torch.Tensor], - kv_seq_len: int, - query_length: int, - ): - """ - Prepares the correct argument to be used by torch.nn.functional.scaled_dot_product_attention. - We ignore the attention mask in some cases for batch_size = 1 to allow to dispatch to the flash attention - kernel. - NOTE: As of PyTorch 2.1, SDPA can not dispatch to flash attention in case an attention mask passed. Possible - solution: use nested tensors. - """ - # TODO: remove padding mask - if batch_size == 1 and padding_mask is None: - if query_length == 1: # For query_length == 1, causal attention and bi-directional attention are the same. - is_causal = False - attention_mask = None - elif kv_seq_len == query_length: - is_causal = True - attention_mask = None - else: - # Unfortunately, for query_length > 1, we can not generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We set is_causal=False in SDPA and rely on Transformers attention_mask instead. - # Reference: https://github.com/pytorch/pytorch/issues/108108 - is_causal = False - elif attention_mask is not None: - is_causal = False - else: - is_causal = True - - return attention_mask, is_causal - # Adapted from LlamaAttention.forward def forward( self, @@ -701,32 +666,25 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + if attention_mask is not None: + is_causal = False + elif query_length == 1: + # causal attention and bi-directional attention are the same. + is_causal = False + else: + is_causal = True + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + # Note that Llama does not use dropout in the attention, hence the hard-coded # dropout_p=0.0 independent of self.training. - # The batch_size = 1 case is handled separately to allow to dispatch on flash attention - # kernel. - if bsz == 1: - if query_states.shape[2] > 1: - is_causal = True - else: - is_causal = False - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=is_causal - ) - else: - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attention_mask = torch.max(attention_mask, torch.tensor(torch.finfo(key_states.dtype).min)) - - # NOTE: As of PyTorch 2.1, this case can not dispatch to flash attention v2 due to the - # attention mask passed. Possible solution: use nested tensors. - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=is_causal + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -943,6 +901,7 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.use_sdpa = isinstance(self.layers[0], LlamaSDPAAttention) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1003,6 +962,11 @@ def forward( if getattr(self.config, "_flash_attn_2_enabled", False): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.use_sdpa: + # Alternatively, 4d mask or None is passed to the layers + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( From 17954ddc3440938cce32cde7229f95d58ae7f3a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Tue, 31 Oct 2023 17:11:56 +0100 Subject: [PATCH 009/100] add output_attentions=True support --- src/transformers/modeling_attn_mask_utils.py | 5 +++++ src/transformers/models/llama/modeling_llama.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 6b645d360da420..1bb9ea919fd971 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -284,6 +284,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( inputs_embeds: torch.Tensor, past_key_values_length: int, sliding_window: Optional[int] = None, + output_attentions: bool = False, ): """ Prepares the correct attn_mask argument to be used by torch.nn.functional.scaled_dot_product_attention. @@ -293,6 +294,10 @@ def _prepare_4d_causal_attention_mask_for_sdpa( Note that as of PyTorch 2.1, SDPA can not dispatch to flash attention in case an attention mask is passed. A possible solution is to use nested tensors. """ + if output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + return _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window) attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d1877541c6ee6d..c51189bc913ad6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -614,8 +614,16 @@ def forward( use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - raise ValueError("output_attentions=True can not be supported with PyTorch SDPA.") - attn_weights = None + # output_attentions=True can not be supported when using SDPA, falling back on the manual implementation. + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) bsz, q_len, _ = hidden_states.size() @@ -696,7 +704,7 @@ def forward( else: attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, None, past_key_value class LlamaDecoderLayer(nn.Module): @@ -965,7 +973,7 @@ def forward( elif self.use_sdpa: # Alternatively, 4d mask or None is passed to the layers attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, output_attentions=output_attentions ) else: # 4d mask is passed through the layers From f48f4fa9c2681407c6c4a4c839d16e52ef60c5e7 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 31 Oct 2023 18:59:30 +0000 Subject: [PATCH 010/100] bigcode sdpa support --- src/transformers/modeling_attn_mask_utils.py | 8 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 176 ++++++++++++++++-- .../models/llama/modeling_llama.py | 4 +- 3 files changed, 171 insertions(+), 17 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 1bb9ea919fd971..20e2851e675ef1 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -203,8 +203,6 @@ def _unmask_unattended(expanded_mask: torch.Tensor, attention_mask: torch.Tensor [0, 1, 1]]]] ``` """ - # TODO: unmasked_value may be unnecessary following Patrick's refactor - # Get the index of the first non-zero value for every sample in the batch. # In the above example, indices = [[2], [0], [1]]] tmp = torch.arange(attention_mask.shape[1], 0, -1) @@ -301,8 +299,9 @@ def _prepare_4d_causal_attention_mask_for_sdpa( attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length - query_length = input_shape[-1] + batch_size, query_length = input_shape + print("query_length", query_length) if attention_mask is not None: if batch_size == 1 and torch.all(attention_mask == 1): if query_length == 1: @@ -323,7 +322,8 @@ def _prepare_4d_causal_attention_mask_for_sdpa( # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - expanded_4d_mask = AttentionMaskConverter._unmask_unattended(expanded_4d_mask, attention_mask, unmasked_value=0.0) + if query_length > 1: + expanded_4d_mask = AttentionMaskConverter._unmask_unattended(expanded_4d_mask, attention_mask, unmasked_value=0.0) return expanded_4d_mask diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index fcbbfca5cedac7..4329787bcd4077 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -34,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_torch_sdpa_available, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig @@ -127,6 +129,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.scale_attention_softmax_in_fp32 = ( config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 ) + self.attn_pdrop = config.attn_pdrop if self.is_cross_attention: if self.multi_query: @@ -507,6 +510,138 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) +class GPTBigCodeSDPAAttention(GPTBigCodeAttention): + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + if head_mask is not None: + raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.") + + scale = None + if not self.scale_attn_weights: + scale = 1 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + kv_seq_len = key.shape[-2] + + if self.multi_query: + query_length = query_shape[1] + + # NOTE: Maybe there is better than this? + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + + # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. + key = key.unsqueeze(1) + value = value.unsqueeze(1) + + # Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention + # and flash attention (No available kernel. Aborting execution.) from the shapes + # query = [batch_size, num_heads, query_length, head_dim] + # key = [batch_size, 1, past_length, head_dim] + # value = [batch_size, 1, past_length, head_dim] + # which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy. + key = key.expand(-1, self.num_heads, -1, -1) + value = value.expand(-1, self.num_heads, -1, -1) + else: + query_length = query_shape[-1] + + dropout_p = self.attn_pdrop if self.training else 0.0 + + if attention_mask is not None: + is_causal = False + elif query_length == 1: + # causal attention and bi-directional attention are the same. + is_causal = False + else: + is_causal = True + + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + if self.multi_query: + # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) + sdpa_result = sdpa_result.transpose(1, 2) + + # Reshape is kind of expensive here, as it does a memory copy, + # but I did not manage to make away without it (logits do not match when using view) + # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) + sdpa_result = sdpa_result.reshape(query_shape) + + return sdpa_result, None + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + if not output_attentions and head_mask is None: + # Difference with the original implementation: there is no need to transpose the key here, + # as SDPA expects seq_length to be at index -2 for the key as well + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + else: + # output_attentions=True, head_mask not None can not be supported when using SDPA. + attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs + class GPTBigCodeMLP(nn.Module): def __init__(self, intermediate_size, config): @@ -533,21 +668,26 @@ def __init__(self, config, layer_idx=None): self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = ( - GPTBigCodeAttention(config, layer_idx=layer_idx) - if not getattr(config, "_flash_attn_2_enabled", False) - else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx) - ) + + if getattr(config, "_flash_attn_2_enabled", False): + self.attn = GPTBigCodeFlashAttention2(config, layer_idx=layer_idx) + elif is_torch_sdpa_available(): + self.attn = GPTBigCodeSDPAAttention(config, layer_idx=layer_idx) + else: + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = ( - GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) - if not getattr(config, "_flash_attn_2_enabled", False) - else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx) - ) + + if getattr(config, "_flash_attn_2_enabled", False): + self.attn = GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx) + elif is_torch_sdpa_available(): + self.attn = GPTBigCodeSDPAAttention(config, is_cross_attention=True, layer_idx=layer_idx) + else: + self.attn = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigCodeMLP(self.inner_dim, config) @@ -769,6 +909,8 @@ def __init__(self, config): self.gradient_checkpointing = False + self._use_sdpa = isinstance(self.h[0].attn, GPTBigCodeSDPAAttention) + # Initialize weights and apply final processing self.post_init() @@ -866,7 +1008,19 @@ def forward( # MQA models: (batch_size, query_length, n_heads, key_length) # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if self._use_sdpa and query_length > 1: + if self.multi_query: + # gpt_bigcode using MQA has the bad taste to use a causal mask with shape + # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. + self_attention_mask = self_attention_mask.transpose(1, 2) + + self_attention_mask = AttentionMaskConverter._unmask_unattended(self_attention_mask, attention_mask, unmasked_value=True) + + attention_mask = self_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c51189bc913ad6..f3a7da001f0869 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -909,7 +909,7 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.use_sdpa = isinstance(self.layers[0], LlamaSDPAAttention) + self._use_sdpa = isinstance(self.layers[0].self_attn, LlamaSDPAAttention) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -970,7 +970,7 @@ def forward( if getattr(self.config, "_flash_attn_2_enabled", False): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.use_sdpa: + elif self._use_sdpa: # Alternatively, 4d mask or None is passed to the layers attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, output_attentions=output_attentions From dfc47a5d7e2cf24612a68fe1e0b1a1b365a4319b Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 31 Oct 2023 19:12:58 +0000 Subject: [PATCH 011/100] fixes --- src/transformers/modeling_attn_mask_utils.py | 1 - .../models/gpt_bigcode/modeling_gpt_bigcode.py | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 20e2851e675ef1..1cadb7a23c5039 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -301,7 +301,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa( key_value_length = input_shape[-1] + past_key_values_length batch_size, query_length = input_shape - print("query_length", query_length) if attention_mask is not None: if batch_size == 1 and torch.all(attention_mask == 1): if query_length == 1: diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 4329787bcd4077..b00be2f7b98ba3 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -1012,13 +1012,23 @@ def forward( # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if self._use_sdpa and query_length > 1: + if self._use_sdpa: if self.multi_query: # gpt_bigcode using MQA has the bad taste to use a causal mask with shape # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. self_attention_mask = self_attention_mask.transpose(1, 2) - self_attention_mask = AttentionMaskConverter._unmask_unattended(self_attention_mask, attention_mask, unmasked_value=True) + if query_length > 1: + self_attention_mask = AttentionMaskConverter._unmask_unattended(self_attention_mask, attention_mask, unmasked_value=True) + + if head_mask is None and not output_attentions: + # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. + dtype = self.wte.weight.dtype + self_attention_mask = torch.where( + self_attention_mask, + torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), + torch.full([], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device), + ) attention_mask = self_attention_mask From eba83c10ab3810d461cc439ffa931e523db4b8df Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 Nov 2023 08:38:28 +0000 Subject: [PATCH 012/100] gpt-bigcode support, require torch>=2.1.1 --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 6 ++---- src/transformers/models/llama/modeling_llama.py | 1 - src/transformers/utils/import_utils.py | 3 ++- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b00be2f7b98ba3..ee0cf8d74ea5fb 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -546,8 +546,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): else: query_length = query_shape[-1] - dropout_p = self.attn_pdrop if self.training else 0.0 - if attention_mask is not None: is_causal = False elif query_length == 1: @@ -555,13 +553,13 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): is_causal = False else: is_causal = True - + sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, - dropout_p=dropout_p, + dropout_p=self.attn_pdrop if self.training else 0.0, is_causal=is_causal, scale=scale, ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f3a7da001f0869..d7f7b38da220eb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -622,7 +622,6 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 3701f88107d7cb..3801daf4791e33 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -263,7 +263,8 @@ def is_torch_sdpa_available(): # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons: # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259 # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310 - return version.parse(_torch_version) >= version.parse("2.1") + # NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577 + return version.parse(_torch_version) >= version.parse("2.1.1") def is_torchvision_available(): From 569353553cb1004e3292af948c16dc89681ec487 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 Nov 2023 11:46:16 +0000 Subject: [PATCH 013/100] add falcon support --- src/transformers/modeling_attn_mask_utils.py | 19 ++- .../models/falcon/modeling_falcon.py | 110 +++++++++++------- 2 files changed, 82 insertions(+), 47 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 1cadb7a23c5039..3a74980c36cf68 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -169,9 +169,11 @@ def _unmask_unattended(expanded_mask: torch.Tensor, attention_mask: torch.Tensor using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. Details: https://github.com/pytorch/pytorch/issues/110213 - expanded_mask is [bsz, 1, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + expanded_mask is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. attention_mask is [bsz, src_seq_len]. + The dimension num_masks is most often 1, but it can also be the number of heads in the case of alibi. + If attention_mask is ``` [[0, 0, 1] @@ -223,14 +225,19 @@ def _unmask_unattended(expanded_mask: torch.Tensor, attention_mask: torch.Tensor # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above. range_tensor[range_tensor >= indices] = 0 - mask_slice = (left_masked_rows.unsqueeze(1),) - # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case if expanded_mask.dim() == 4: - mask_slice += (0, range_tensor) + num_masks = expanded_mask.shape[1] + if num_masks == 1: + # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] + mask_slice = (left_masked_rows[:, None], 0, range_tensor) + else: + # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len] + mask_slice = (left_masked_rows[:, None, None], torch.arange(num_masks)[None, :, None], range_tensor[:, None, :]) else: - mask_slice += (range_tensor,) - + # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] + mask_slice = (left_masked_rows[:, None], range_tensor) + expanded_mask[mask_slice] = unmasked_value return expanded_mask diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 511c55a8488283..abf6bee1f35134 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -459,13 +459,6 @@ def forward( if alibi is None: if hasattr(F, "scaled_dot_product_attention") and not output_attentions: - # TODO: deprecate this once we add FA2 support in Falcon - logger.warning_once( - "The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the" - " future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call " - "`model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations." - ) - attn_output = F.scaled_dot_product_attention( query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False ) @@ -475,6 +468,7 @@ def forward( attention_scores /= math.sqrt(self.head_dim) attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) + # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). attn_output = attention_scores @ value_layer_ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) @@ -487,41 +481,50 @@ def forward( return output_tensor, present, attention_scores else: return output_tensor, present - else: - matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) + if hasattr(F, "scaled_dot_product_attention") and not output_attentions and head_mask is None: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer_, + key_layer_, + value_layer_, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + ) + context_layer = context_layer.transpose(1, 2) + context_layer = context_layer.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + output_tensor = self.dense(context_layer) + else: + matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16 or input_dtype == torch.bfloat16: - attention_scores = attention_scores.to(torch.float32) - # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by - # adding (alibi * self.inv_norm_factor) to attention_mask. I think this would be mathematically - # equivalent and more performant, but there might be a numerical difference. If you're reading this - # and you'd like to experiment and maybe file a PR, feel free! - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) - attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) - # [batch_size, num_heads, q_length, kv_length] - attention_probs = self.attention_dropout(attention_probs) + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) - if head_mask is not None: - attention_probs = attention_probs * head_mask + if head_mask is not None: + attention_probs = attention_probs * head_mask - # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) - # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) - # change view [batch_size, q_length, num_heads * head_dim] - context_layer = self._merge_heads(context_layer) + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) - output_tensor = self.dense(context_layer) + output_tensor = self.dense(context_layer) if output_attentions: return output_tensor, present, attention_probs @@ -1049,12 +1052,6 @@ def forward( else: past_key_values = self._convert_to_rw_cache(past_key_values) - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -1096,12 +1093,43 @@ def forward( if getattr(self.config, "_flash_attn_2_enabled", False): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif hasattr(F, "scaled_dot_product_attention") and not output_attentions: + if alibi is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, output_attentions=output_attentions + ) + elif head_mask is None: + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + + attention_mask_2d = attention_mask + # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + attention_mask = torch.masked_fill(alibi / math.sqrt(self.config.hidden_size // self.num_heads), attention_mask < -1, torch.finfo(alibi.dtype).min) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if seq_length > 1: + attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, attention_mask_2d, unmasked_value=0.0) + else: + # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From ca87380570079b493f9bbc35b527c54e0d2fe2d7 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:37:46 +0000 Subject: [PATCH 014/100] fix conflicts falcon --- src/transformers/models/falcon/modeling_falcon.py | 14 +++++++------- src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 0c566c2a9320a8..cfce2023806d1c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -472,9 +472,9 @@ def forward( else: if hasattr(F, "scaled_dot_product_attention") and not output_attentions and head_mask is None: context_layer = torch.nn.functional.scaled_dot_product_attention( - query_layer_, - key_layer_, - value_layer_, + query_layer, + key_layer, + value_layer, attn_mask=attention_mask, dropout_p=self.attention_dropout.p if self.training else 0.0, ) @@ -483,7 +483,7 @@ def forward( output_tensor = self.dense(context_layer) else: - matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) + matmul_result = query_layer @ key_layer.transpose(-1, -2) # change view to [batch_size, num_heads, q_length, kv_length] attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) @@ -493,7 +493,7 @@ def forward( # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` if input_dtype == torch.float16 or input_dtype == torch.bfloat16: attention_scores = attention_scores.to(torch.float32) - + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) attention_logits *= self.inv_norm_factor attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) @@ -507,13 +507,13 @@ def forward( attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) + context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1) # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) output_tensor = self.dense(context_layer) - + if output_attentions: return output_tensor, present, attention_probs else: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a8b70b5cf186db..c93f6b91030f0f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -701,7 +701,7 @@ def forward( if attention_mask is not None: is_causal = False - elif query_length == 1: + elif q_len == 1: # causal attention and bi-directional attention are the same. is_causal = False else: From 969dda94fdbab9e1abfdb03014014469d6f194e9 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:50:37 +0000 Subject: [PATCH 015/100] style --- src/transformers/modeling_attn_mask_utils.py | 41 +++++++++++-------- .../models/falcon/modeling_falcon.py | 28 +++++++++---- .../gpt_bigcode/modeling_gpt_bigcode.py | 25 +++++++---- .../models/llama/modeling_llama.py | 13 +++++- 4 files changed, 73 insertions(+), 34 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 3a74980c36cf68..ed88e69f24971e 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -13,8 +13,6 @@ # limitations under the License. from typing import List, Optional, Tuple, Union -from .utils import is_torch_sdpa_available - import torch @@ -163,14 +161,16 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) @staticmethod - def _unmask_unattended(expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]): + def _unmask_unattended( + expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] + ): """ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when - using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. Details: - https://github.com/pytorch/pytorch/issues/110213 + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 - expanded_mask is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. - attention_mask is [bsz, src_seq_len]. + expanded_mask is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. attention_mask + is [bsz, src_seq_len]. The dimension num_masks is most often 1, but it can also be the number of heads in the case of alibi. @@ -225,7 +225,7 @@ def _unmask_unattended(expanded_mask: torch.Tensor, attention_mask: torch.Tensor # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above. range_tensor[range_tensor >= indices] = 0 - # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case + # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case if expanded_mask.dim() == 4: num_masks = expanded_mask.shape[1] if num_masks == 1: @@ -233,11 +233,15 @@ def _unmask_unattended(expanded_mask: torch.Tensor, attention_mask: torch.Tensor mask_slice = (left_masked_rows[:, None], 0, range_tensor) else: # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len] - mask_slice = (left_masked_rows[:, None, None], torch.arange(num_masks)[None, :, None], range_tensor[:, None, :]) + mask_slice = ( + left_masked_rows[:, None, None], + torch.arange(num_masks)[None, :, None], + range_tensor[:, None, :], + ) else: # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] mask_slice = (left_masked_rows[:, None], range_tensor) - + expanded_mask[mask_slice] = unmasked_value return expanded_mask @@ -282,6 +286,7 @@ def _prepare_4d_causal_attention_mask( return attention_mask + # Adapted from _prepare_4d_causal_attention_mask def _prepare_4d_causal_attention_mask_for_sdpa( attention_mask: Optional[torch.Tensor], @@ -294,15 +299,17 @@ def _prepare_4d_causal_attention_mask_for_sdpa( """ Prepares the correct attn_mask argument to be used by torch.nn.functional.scaled_dot_product_attention. - We ignore the attention mask in some cases for batch_size = 1 to allow to dispatch to the flash attention - kernel. + We ignore the attention mask in some cases for batch_size = 1 to allow to dispatch to the flash attention kernel. - Note that as of PyTorch 2.1, SDPA can not dispatch to flash attention in case an attention mask is passed. A possible solution is to use nested tensors. + Note that as of PyTorch 2.1, SDPA can not dispatch to flash attention in case an attention mask is passed. A + possible solution is to use nested tensors. """ if output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - return _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window) + return _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window + ) attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length @@ -319,7 +326,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. # Reference: https://github.com/pytorch/pytorch/issues/108108 - is_causal = False + pass if attention_mask is not None: expanded_4d_mask = attn_mask_converter.to_4d( @@ -329,7 +336,9 @@ def _prepare_4d_causal_attention_mask_for_sdpa( # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 if query_length > 1: - expanded_4d_mask = AttentionMaskConverter._unmask_unattended(expanded_4d_mask, attention_mask, unmasked_value=0.0) + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, attention_mask, unmasked_value=0.0 + ) return expanded_4d_mask diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index cfce2023806d1c..5474f2b3140e5c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -24,7 +24,11 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, AttentionMaskConverter +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -493,7 +497,7 @@ def forward( # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` if input_dtype == torch.float16 or input_dtype == torch.bfloat16: attention_scores = attention_scores.to(torch.float32) - + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) attention_logits *= self.inv_norm_factor attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) @@ -513,7 +517,7 @@ def forward( context_layer = self._merge_heads(context_layer) output_tensor = self.dense(context_layer) - + if output_attentions: return output_tensor, present, attention_probs else: @@ -1036,7 +1040,11 @@ def forward( elif hasattr(F, "scaled_dot_product_attention") and not output_attentions: if alibi is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, output_attentions=output_attentions + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + output_attentions=output_attentions, ) elif head_mask is None: alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) @@ -1047,12 +1055,18 @@ def forward( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) - attention_mask = torch.masked_fill(alibi / math.sqrt(self.config.hidden_size // self.num_heads), attention_mask < -1, torch.finfo(alibi.dtype).min) - + attention_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + attention_mask < -1, + torch.finfo(alibi.dtype).min, + ) + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 if seq_length > 1: - attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, attention_mask_2d, unmasked_value=0.0) + attention_mask = AttentionMaskConverter._unmask_unattended( + attention_mask, attention_mask_2d, unmasked_value=0.0 + ) else: # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. attention_mask = _prepare_4d_causal_attention_mask( diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b42f1f5f96e93e..e7146121e11527 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -21,8 +21,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...activations import ACT2FN +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -510,10 +510,13 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + class GPTBigCodeSDPAAttention(GPTBigCodeAttention): def _attn(self, query, key, value, attention_mask=None, head_mask=None): if head_mask is not None: - raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.") + raise ValueError( + "PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository." + ) scale = None if not self.scale_attn_weights: @@ -523,7 +526,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # MHA models: (batch_size, num_heads, query_length, head_dim) query_shape = query.shape batch_size = query_shape[0] - kv_seq_len = key.shape[-2] + key.shape[-2] if self.multi_query: query_length = query_shape[1] @@ -666,7 +669,7 @@ def __init__(self, config, layer_idx=None): self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - + if getattr(config, "_flash_attn_2_enabled", False): self.attn = GPTBigCodeFlashAttention2(config, layer_idx=layer_idx) elif is_torch_sdpa_available(): @@ -678,7 +681,7 @@ def __init__(self, config, layer_idx=None): if config.add_cross_attention: if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - + if getattr(config, "_flash_attn_2_enabled", False): self.attn = GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx) elif is_torch_sdpa_available(): @@ -1017,17 +1020,21 @@ def forward( self_attention_mask = self_attention_mask.transpose(1, 2) if query_length > 1: - self_attention_mask = AttentionMaskConverter._unmask_unattended(self_attention_mask, attention_mask, unmasked_value=True) - + self_attention_mask = AttentionMaskConverter._unmask_unattended( + self_attention_mask, attention_mask, unmasked_value=True + ) + if head_mask is None and not output_attentions: # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. dtype = self.wte.weight.dtype self_attention_mask = torch.where( self_attention_mask, torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), - torch.full([], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device), + torch.full( + [], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device + ), ) - + attention_mask = self_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c93f6b91030f0f..ed6a411c778e6d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,7 +29,11 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -629,6 +633,7 @@ class LlamaSDPAAttention(LlamaAttention): `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ + # Adapted from LlamaAttention.forward def forward( self, @@ -998,7 +1003,11 @@ def forward( elif self._use_sdpa: # Alternatively, 4d mask or None is passed to the layers attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, output_attentions=output_attentions + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + output_attentions=output_attentions, ) else: # 4d mask is passed through the layers From 06766ecf84bc61b3f886d1f9d945f0a7faae0a64 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 Nov 2023 14:16:38 +0000 Subject: [PATCH 016/100] fix attention_mask definition --- src/transformers/modeling_attn_mask_utils.py | 35 +++++ src/transformers/models/bart/modeling_bart.py | 144 ++++++++++++++++-- 2 files changed, 168 insertions(+), 11 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index ed88e69f24971e..62d42e39a15bbf 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -339,6 +339,8 @@ def _prepare_4d_causal_attention_mask_for_sdpa( expanded_4d_mask = AttentionMaskConverter._unmask_unattended( expanded_4d_mask, attention_mask, unmasked_value=0.0 ) + else: + expanded_4d_mask = None return expanded_4d_mask @@ -358,6 +360,39 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: """ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) +def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, output_attentions: bool = False): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + if output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + batch_size, key_value_length = mask.shape + tgt_len = tgt_len if tgt_len is not None else key_value_length + if batch_size == 1 and torch.all(mask == 1): + if query_length == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + return None + elif key_value_length == tgt_len: + return None + else: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + else: + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) def _create_4d_causal_attention_mask( input_shape: Union[torch.Size, Tuple, List], diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index efca985f67842e..22a848a76968b2 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -17,7 +17,7 @@ import math import warnings from typing import List, Optional, Tuple, Union - +from enum import Enum import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -44,6 +44,7 @@ is_flash_attn_2_available, logging, replace_return_docstrings, + is_torch_sdpa_available, ) from .configuration_bart import BartConfig @@ -488,20 +489,117 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) +class BartSDPAAttention(BartAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + return super().forward(hidden_states, key_value_states=key_value_states, past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, **kwargs) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + key_states = key_states + value_states = value_states + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=self.is_causal and attention_mask is not None, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value BART_ATTENTION_CLASSES = { - "default": BartAttention, + "eager": BartAttention, + "sdpa": BartSDPAAttention, "flash_attention_2": BartFlashAttention2, } +class BartAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = BART_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = BartAttentionType.flash_attention_2 + elif is_torch_sdpa_available(): + self._attn_type = BartAttentionType.sdpa + else: + self._attn_type = BartAttentionType.eager + + print("attn_type", self._attn_type) + + self.self_attn = BART_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -571,8 +669,14 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = BART_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = BartAttentionType.flash_attention_2 + elif is_torch_sdpa_available(): + self._attn_type = BartAttentionType.sdpa + else: + self._attn_type = BartAttentionType.eager + + self.self_attn = BART_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -585,7 +689,7 @@ def __init__(self, config: BartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = BART_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -945,6 +1049,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No embed_dim, ) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self._attn_type = self.layers[0]._attn_type self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -1032,8 +1137,11 @@ def forward( # expand attention_mask if attention_mask is not None: - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._attn_type == BartAttentionType.flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None + elif self._attn_type == BartAttentionType.sdpa: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype, output_attentions=output_attentions) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) @@ -1120,6 +1228,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.d_model, ) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._attn_type = self.layers[0]._attn_type self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1238,9 +1347,17 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._attn_type == BartAttentionType.flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_type == BartAttentionType.sdpa: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + output_attentions=output_attentions, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( @@ -1249,8 +1366,13 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._attn_type == BartAttentionType.flash_attention_2: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self._attn_type == BartAttentionType.sdpa: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], output_attentions=output_attentions + ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( From 5c648d45bbbea2ebd651afa59a03357611c3468e Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 Nov 2023 15:15:38 +0000 Subject: [PATCH 017/100] remove output_attentions from attnmaskconverter --- src/transformers/modeling_attn_mask_utils.py | 19 ++----- src/transformers/models/bart/modeling_bart.py | 55 ++++++++++++++----- .../models/falcon/modeling_falcon.py | 3 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 27 ++++----- .../models/llama/modeling_llama.py | 6 +- 5 files changed, 66 insertions(+), 44 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 62d42e39a15bbf..ea6c90922c3152 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -294,7 +294,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa( inputs_embeds: torch.Tensor, past_key_values_length: int, sliding_window: Optional[int] = None, - output_attentions: bool = False, ): """ Prepares the correct attn_mask argument to be used by torch.nn.functional.scaled_dot_product_attention. @@ -304,12 +303,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa( Note that as of PyTorch 2.1, SDPA can not dispatch to flash attention in case an attention mask is passed. A possible solution is to use nested tensors. """ - if output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - return _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window - ) attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length @@ -360,7 +353,8 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: """ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) -def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, output_attentions: bool = False): + +def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)` @@ -373,15 +367,13 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len (`int`): The target length or query length the created mask shall have. """ - if output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. batch_size, key_value_length = mask.shape tgt_len = tgt_len if tgt_len is not None else key_value_length if batch_size == 1 and torch.all(mask == 1): - if query_length == 1: + if tgt_len == 1: # For query_length == 1, causal attention and bi-directional attention are the same. return None elif key_value_length == tgt_len: @@ -394,6 +386,7 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, else: return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + def _create_4d_causal_attention_mask( input_shape: Union[torch.Size, Tuple, List], dtype: torch.dtype, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 22a848a76968b2..8b7f1afdaa4ee1 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -16,8 +16,9 @@ import copy import math import warnings -from typing import List, Optional, Tuple, Union from enum import Enum +from typing import List, Optional, Tuple, Union + import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -25,7 +26,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -42,9 +48,9 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_torch_sdpa_available, logging, replace_return_docstrings, - is_torch_sdpa_available, ) from .configuration_bart import BartConfig @@ -489,6 +495,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + class BartSDPAAttention(BartAttention): def forward( self, @@ -502,7 +509,15 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: - return super().forward(hidden_states, key_value_states=key_value_states, past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, **kwargs) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + **kwargs, + ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder @@ -516,7 +531,11 @@ def forward( # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning - if is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]: + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] @@ -574,17 +593,20 @@ def forward( return attn_output, None, past_key_value + BART_ATTENTION_CLASSES = { "eager": BartAttention, "sdpa": BartSDPAAttention, "flash_attention_2": BartFlashAttention2, } + class BartAttentionType(str, Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" + class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() @@ -597,8 +619,6 @@ def __init__(self, config: BartConfig): else: self._attn_type = BartAttentionType.eager - print("attn_type", self._attn_type) - self.self_attn = BART_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, @@ -675,7 +695,7 @@ def __init__(self, config: BartConfig): self._attn_type = BartAttentionType.sdpa else: self._attn_type = BartAttentionType.eager - + self.self_attn = BART_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, @@ -1139,9 +1159,11 @@ def forward( if attention_mask is not None: if self._attn_type == BartAttentionType.flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None - elif self._attn_type == BartAttentionType.sdpa: + elif self._attn_type == BartAttentionType.sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype, output_attentions=output_attentions) + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) @@ -1350,13 +1372,14 @@ def forward( if self._attn_type == BartAttentionType.flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_type == BartAttentionType.sdpa: + elif self._attn_type == BartAttentionType.sdpa and not output_attentions and cross_attn_head_mask is None: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, inputs_embeds, past_key_values_length, - output_attentions=output_attentions, ) else: # 4d mask is passed through the layers @@ -1368,10 +1391,14 @@ def forward( if encoder_hidden_states is not None and encoder_attention_mask is not None: if self._attn_type == BartAttentionType.flash_attention_2: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._attn_type == BartAttentionType.sdpa: + elif self._attn_type == BartAttentionType.sdpa and cross_attn_head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], output_attentions=output_attentions + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 5474f2b3140e5c..d210d189a0b23c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1038,13 +1038,14 @@ def forward( # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif hasattr(F, "scaled_dot_product_attention") and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. if alibi is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, - output_attentions=output_attentions, ) elif head_mask is None: alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index e7146121e11527..eb5d8aedc3e5e9 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -1011,29 +1011,30 @@ def forward( # MHA models: (batch_size, n_heads, query_length, key_length) self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if self._use_sdpa: + if self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. if self.multi_query: # gpt_bigcode using MQA has the bad taste to use a causal mask with shape # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. self_attention_mask = self_attention_mask.transpose(1, 2) if query_length > 1: + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 self_attention_mask = AttentionMaskConverter._unmask_unattended( self_attention_mask, attention_mask, unmasked_value=True ) - if head_mask is None and not output_attentions: - # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. - dtype = self.wte.weight.dtype - self_attention_mask = torch.where( - self_attention_mask, - torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), - torch.full( - [], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device - ), - ) + # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. + dtype = self.wte.weight.dtype + self_attention_mask = torch.where( + self_attention_mask, + torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), + torch.full( + [], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device + ), + ) attention_mask = self_attention_mask diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ed6a411c778e6d..83dd6ebd27dfd8 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1000,14 +1000,14 @@ def forward( if getattr(self.config, "_flash_attn_2_enabled", False): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa: - # Alternatively, 4d mask or None is passed to the layers + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, - output_attentions=output_attentions, ) else: # 4d mask is passed through the layers From 674bff43a775dbed7f90b5db14e68a08efa557e8 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:35:01 +0000 Subject: [PATCH 018/100] support whisper without removing any Copied from statement --- src/transformers/modeling_attn_mask_utils.py | 2 +- src/transformers/modeling_utils.py | 20 +++ src/transformers/models/bart/modeling_bart.py | 3 +- .../models/blenderbot/modeling_blenderbot.py | 32 +++- .../modeling_blenderbot_small.py | 36 ++++- .../models/falcon/modeling_falcon.py | 35 +++-- .../gpt_bigcode/modeling_gpt_bigcode.py | 13 +- .../models/llama/modeling_llama.py | 18 +-- .../models/m2m_100/modeling_m2m_100.py | 32 +++- .../models/marian/modeling_marian.py | 33 +++- .../models/mbart/modeling_mbart.py | 30 +++- .../models/pegasus/modeling_pegasus.py | 32 +++- .../models/plbart/modeling_plbart.py | 74 +++++++-- .../speech_to_text/modeling_speech_to_text.py | 39 ++++- .../modeling_time_series_transformer.py | 44 +++++- .../models/whisper/modeling_whisper.py | 143 +++++++++++++++++- 16 files changed, 485 insertions(+), 101 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index ea6c90922c3152..4f5e71789b6910 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -309,7 +309,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( batch_size, query_length = input_shape if attention_mask is not None: - if batch_size == 1 and torch.all(attention_mask == 1): + if torch.all(attention_mask == 1): if query_length == 1: # For query_length == 1, causal attention and bi-directional attention are the same. attention_mask = None diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fcb51e6a56be2a..7e338ae5eeec4a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -79,6 +79,7 @@ is_peft_available, is_remote_url, is_safetensors_available, + is_torch_sdpa_available, is_torch_tpu_available, logging, replace_return_docstrings, @@ -1323,6 +1324,22 @@ def _check_and_enable_flash_attn_2( config._flash_attn_2_enabled = True return config + @classmethod + def _check_and_enable_sdpa(cls, config) -> PretrainedConfig: + """ + Enables the use of SDPA natively in Transformers if supported by the model, and if BetterTransformer is not + being used. + """ + if not cls._supports_sdpa: + return config + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + if _is_bettertransformer: + return config + + config._sdpa_enabled = True + return config + def enable_input_require_grads(self): """ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping @@ -3231,6 +3248,9 @@ def from_pretrained( if use_flash_attention_2: config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map) + elif is_torch_sdpa_available(): + # use_flash_attention_2 takes priority. + config = cls._check_and_enable_sdpa(config) with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 8b7f1afdaa4ee1..87f544a7149ba4 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -614,7 +614,7 @@ def __init__(self, config: BartConfig): if getattr(config, "_flash_attn_2_enabled", False): self._attn_type = BartAttentionType.flash_attention_2 - elif is_torch_sdpa_available(): + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = BartAttentionType.sdpa else: self._attn_type = BartAttentionType.eager @@ -843,6 +843,7 @@ class BartPreTrainedModel(PreTrainedModel): _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index f49f90f794fc94..2bd51a25c94127 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -19,6 +19,7 @@ import math import os import warnings +from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -40,6 +41,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -252,7 +254,14 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention} +BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Blenderbot +class BlenderbotAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT @@ -260,9 +269,15 @@ class BlenderbotEncoderLayer(nn.Module): def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = BlenderbotAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = BlenderbotAttentionType.sdpa + else: + self._attn_type = BlenderbotAttentionType.eager + + self.self_attn = BLENDERBOT_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -332,9 +347,14 @@ class BlenderbotDecoderLayer(nn.Module): def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = BlenderbotAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = BlenderbotAttentionType.sdpa + else: + self._attn_type = BlenderbotAttentionType.eager - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( + self.self_attn = BLENDERBOT_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -347,7 +367,7 @@ def __init__(self, config: BlenderbotConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 292b5a8c6e8bf6..b3868435cbdff7 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -17,6 +17,7 @@ import copy import math +from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -38,6 +39,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -254,9 +256,15 @@ class BlenderbotSmallEncoderLayer(nn.Module): def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = BlenderbotSmallAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = BlenderbotSmallAttentionType.sdpa + else: + self._attn_type = BlenderbotSmallAttentionType.eager + + self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -321,7 +329,17 @@ def forward( return outputs -BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention} +# TODO: Implement attention with SDPA for TimeSeriesTransformer. +BLENDERBOT_SMALL_ATTENTION_CLASSES = { + "eager": BlenderbotSmallAttention, +} + + +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->BlenderbotSmall +class BlenderbotSmallAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL @@ -330,8 +348,14 @@ def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = BlenderbotSmallAttentionType.flash_attention_2 + elif is_torch_sdpa_available(): + self._attn_type = BlenderbotSmallAttentionType.sdpa + else: + self._attn_type = BlenderbotSmallAttentionType.eager + + self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -344,7 +368,7 @@ def __init__(self, config: BlenderbotSmallConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d210d189a0b23c..b1172d22958ed4 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -451,7 +451,12 @@ def forward( ) attn_output = F.scaled_dot_product_attention( - query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + is_causal=self.is_causal and attention_mask is None, ) attention_scores = None else: @@ -481,6 +486,7 @@ def forward( value_layer, attn_mask=attention_mask, dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None, ) context_layer = context_layer.transpose(1, 2) context_layer = context_layer.reshape(batch_size, query_length, self.num_heads * self.head_dim) @@ -905,6 +911,7 @@ class FalconPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["FalconDecoderLayer"] _supports_flash_attn_2 = True + _supports_sdpa = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -1056,18 +1063,22 @@ def forward( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) - attention_mask = torch.masked_fill( - alibi / math.sqrt(self.config.hidden_size // self.num_heads), - attention_mask < -1, - torch.finfo(alibi.dtype).min, - ) - - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if seq_length > 1: - attention_mask = AttentionMaskConverter._unmask_unattended( - attention_mask, attention_mask_2d, unmasked_value=0.0 + # We take care to integrate alibi bias in the attention_mask here. + if attention_mask is None: + attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) + else: + attention_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + attention_mask < -1, + torch.finfo(alibi.dtype).min, ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if seq_length > 1: + attention_mask = AttentionMaskConverter._unmask_unattended( + attention_mask, attention_mask_2d, unmasked_value=0.0 + ) else: # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. attention_mask = _prepare_4d_causal_attention_mask( diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index eb5d8aedc3e5e9..d03ea37081ee1d 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -549,21 +549,13 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): else: query_length = query_shape[-1] - if attention_mask is not None: - is_causal = False - elif query_length == 1: - # causal attention and bi-directional attention are the same. - is_causal = False - else: - is_causal = True - sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=self.attn_pdrop if self.training else 0.0, - is_causal=is_causal, + is_causal=self.is_causal and attention_mask is None, scale=scale, ) @@ -672,7 +664,7 @@ def __init__(self, config, layer_idx=None): if getattr(config, "_flash_attn_2_enabled", False): self.attn = GPTBigCodeFlashAttention2(config, layer_idx=layer_idx) - elif is_torch_sdpa_available(): + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self.attn = GPTBigCodeSDPAAttention(config, layer_idx=layer_idx) else: self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) @@ -769,6 +761,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTBigCodeBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 83dd6ebd27dfd8..aaa82d24e5f72d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -704,14 +704,6 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if attention_mask is not None: - is_causal = False - elif q_len == 1: - # causal attention and bi-directional attention are the same. - is_causal = False - else: - is_causal = True - if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( @@ -721,7 +713,12 @@ def forward( # Note that Llama does not use dropout in the attention, hence the hard-coded # dropout_p=0.0 independent of self.training. attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=is_causal + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=self.is_causal and attention_mask is None, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -744,7 +741,7 @@ def __init__(self, config: LlamaConfig): if getattr(config, "_flash_attn_2_enabled", False): self.self_attn = LlamaFlashAttention2(config=config) - elif is_torch_sdpa_available(): + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self.self_attn = LlamaSDPAAttention(config=config) else: self.self_attn = LlamaAttention(config=config) @@ -843,6 +840,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index c05948540f7865..232a497610d382 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -16,6 +16,7 @@ import math +from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -37,6 +38,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -325,9 +327,15 @@ class M2M100EncoderLayer(nn.Module): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = M2M100_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = M2M100AttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = M2M100AttentionType.sdpa + else: + self._attn_type = M2M100AttentionType.eager + + self.self_attn = M2M100_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -392,7 +400,14 @@ def forward( return outputs -M2M100_ATTENTION_CLASSES = {"default": M2M100Attention} +M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention} + + +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->M2M100 +class M2M100AttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 @@ -400,9 +415,14 @@ class M2M100DecoderLayer(nn.Module): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = M2M100AttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = M2M100AttentionType.sdpa + else: + self._attn_type = M2M100AttentionType.eager - self.self_attn = M2M100_ATTENTION_CLASSES[attn_type]( + self.self_attn = M2M100_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -415,7 +435,7 @@ def __init__(self, config: M2M100Config): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = M2M100_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = M2M100_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index cabf0c68f8b62b..dda2a87d0a7bc4 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -17,6 +17,7 @@ import copy import math +from enum import Enum from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -39,6 +40,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -272,9 +274,15 @@ class MarianEncoderLayer(nn.Module): def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = MarianAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = MarianAttentionType.sdpa + else: + self._attn_type = MarianAttentionType.eager + + self.self_attn = MARIAN_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -339,7 +347,14 @@ def forward( return outputs -MARIAN_ATTENTION_CLASSES = {"default": MarianAttention} +MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Marian +class MarianAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN @@ -348,8 +363,14 @@ def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = MarianAttentionType.flash_attention_2 + elif is_torch_sdpa_available(): + self._attn_type = MarianAttentionType.sdpa + else: + self._attn_type = MarianAttentionType.eager + + self.self_attn = MARIAN_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -362,7 +383,7 @@ def __init__(self, config: MarianConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MARIAN_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = MARIAN_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 97fdf9ed87998b..bbe691a1bde440 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -15,6 +15,7 @@ """ PyTorch MBART model.""" import copy import math +from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -41,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -490,13 +492,26 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query } +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->MBart +class MBartAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" + + class MBartEncoderLayer(nn.Module): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = MBART_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = MBartAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = MBartAttentionType.sdpa + else: + self._attn_type = MBartAttentionType.eager + + self.self_attn = MBART_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -565,9 +580,14 @@ class MBartDecoderLayer(nn.Module): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = MBartAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = MBartAttentionType.sdpa + else: + self._attn_type = MBartAttentionType.eager - self.self_attn = MBART_ATTENTION_CLASSES[attn_type]( + self.self_attn = MBART_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -580,7 +600,7 @@ def __init__(self, config: MBartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MBART_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = MBART_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 18af4d518a899b..65298fb80e5098 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -16,6 +16,7 @@ import copy import math +from enum import Enum from typing import List, Optional, Tuple, Union import numpy as np @@ -38,6 +39,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -267,7 +269,14 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -PEGASUS_ATTENTION_CLASSES = {"default": PegasusAttention} +PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Pegasus +class PegasusAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS @@ -275,9 +284,15 @@ class PegasusEncoderLayer(nn.Module): def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = PegasusAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = PegasusAttentionType.sdpa + else: + self._attn_type = PegasusAttentionType.eager + + self.self_attn = PEGASUS_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -347,9 +362,14 @@ class PegasusDecoderLayer(nn.Module): def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = PegasusAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = PegasusAttentionType.sdpa + else: + self._attn_type = PegasusAttentionType.eager - self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type]( + self.self_attn = PEGASUS_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -362,7 +382,7 @@ def __init__(self, config: PegasusConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PEGASUS_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = PEGASUS_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index ad298c6d389048..f9578b2ed5a58c 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -15,6 +15,7 @@ """ PyTorch PLBART model.""" import copy import math +from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -23,7 +24,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -38,6 +44,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -265,9 +272,15 @@ class PLBartEncoderLayer(nn.Module): def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = PLBART_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = PLBartAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = PLBartAttentionType.sdpa + else: + self._attn_type = PLBartAttentionType.eager + + self.self_attn = PLBART_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -332,7 +345,15 @@ def forward( return outputs -PLBART_ATTENTION_CLASSES = {"default": PLBartAttention} +# TODO: Implement attention with SDPA for PLBart. +PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention, "sdpa": PLBartAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->PLBart +class PLBartAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART @@ -341,8 +362,14 @@ def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = PLBART_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = PLBartAttentionType.flash_attention_2 + elif is_torch_sdpa_available(): + self._attn_type = PLBartAttentionType.sdpa + else: + self._attn_type = PLBartAttentionType.eager + + self.self_attn = PLBART_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -355,7 +382,7 @@ def __init__(self, config: PLBartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PLBART_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = PLBART_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -670,6 +697,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = embed_dim, ) self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self._attn_type = self.layers[0]._attn_type self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -757,8 +785,13 @@ def forward( # expand attention_mask if attention_mask is not None: - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._attn_type == PLBartAttentionType.flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None + elif self._attn_type == PLBartAttentionType.sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) @@ -846,6 +879,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = config.d_model, ) self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._attn_type = self.layers[0]._attn_type self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -964,9 +998,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._attn_type == PLBartAttentionType.flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_type == PLBartAttentionType.sdpa and not output_attentions and cross_attn_head_mask is None: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( @@ -975,8 +1018,19 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._attn_type == PLBartAttentionType.flash_attention_2: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif ( + self._attn_type == PLBartAttentionType.sdpa and cross_attn_head_mask is None and not output_attentions + ): + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 57c74c8c42e2a6..bb795f2c9f429b 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -15,6 +15,7 @@ """ PyTorch Speech2Text model.""" import math +from enum import Enum from typing import Optional, Tuple, Union import torch @@ -30,7 +31,13 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_sdpa_available, + logging, + replace_return_docstrings, +) from .configuration_speech_to_text import Speech2TextConfig @@ -326,7 +333,14 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -SPEECH_TO_TEXT_ATTENTION_CLASSES = {"default": Speech2TextAttention} +SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Speech2Text +class Speech2TextAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT @@ -334,9 +348,15 @@ class Speech2TextEncoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = Speech2TextAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = Speech2TextAttentionType.sdpa + else: + self._attn_type = Speech2TextAttentionType.eager + + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -406,9 +426,14 @@ class Speech2TextDecoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = Speech2TextAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = Speech2TextAttentionType.sdpa + else: + self._attn_type = Speech2TextAttentionType.eager - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type]( + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -421,7 +446,7 @@ def __init__(self, config: Speech2TextConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 904c02b4f04308..c3c2b554d7bea0 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -15,6 +15,7 @@ # limitations under the License. """ PyTorch Time Series Transformer model.""" +from enum import Enum from typing import List, Optional, Tuple, Union import numpy as np @@ -32,7 +33,13 @@ ) from ...modeling_utils import PreTrainedModel from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_sdpa_available, + logging, + replace_return_docstrings, +) from .configuration_time_series_transformer import TimeSeriesTransformerConfig @@ -434,9 +441,15 @@ class TimeSeriesTransformerEncoderLayer(nn.Module): def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = TimeSeriesTransformerAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = TimeSeriesTransformerAttentionType.sdpa + else: + self._attn_type = TimeSeriesTransformerAttentionType.eager + + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -501,7 +514,18 @@ def forward( return outputs -TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {"default": TimeSeriesTransformerAttention} +# TODO: Implement attention with SDPA for TimeSeriesTransformer. +TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = { + "eager": TimeSeriesTransformerAttention, + "sdpa": TimeSeriesTransformerAttention, +} + + +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->TimeSeriesTransformer +class TimeSeriesTransformerAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER @@ -510,8 +534,14 @@ def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = TimeSeriesTransformerAttentionType.flash_attention_2 + elif is_torch_sdpa_available(): + self._attn_type = TimeSeriesTransformerAttentionType.sdpa + else: + self._attn_type = TimeSeriesTransformerAttentionType.eager + + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -524,7 +554,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index a107adf74e169a..10eca9913d3f05 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -15,6 +15,7 @@ """ PyTorch Whisper model.""" import math +from enum import Enum from typing import Optional, Tuple, Union import numpy as np @@ -26,7 +27,7 @@ from ...activations import ACT2FN from ...generation.logits_process import WhisperTimeStampLogitsProcessor -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -40,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -672,20 +674,133 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +class WhisperSDPAAttention(WhisperAttention): + # Copied from transformers.models.bart.modeling_bart.BartSDPAAttention.forward with BART->whisper + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + **kwargs, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + key_states = key_states + value_states = value_states + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=self.is_causal and attention_mask is not None, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + WHISPER_ATTENTION_CLASSES = { - "default": WhisperAttention, + "eager": WhisperAttention, "flash_attention_2": WhisperFlashAttention2, + "sdpa": WhisperSDPAAttention, } +# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Whisper +class WhisperAttentionType(str, Enum): + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" + + # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type]( + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = WhisperAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = WhisperAttentionType.sdpa + else: + self._attn_type = WhisperAttentionType.eager + + self.self_attn = WHISPER_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -755,9 +870,14 @@ class WhisperDecoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + if getattr(config, "_flash_attn_2_enabled", False): + self._attn_type = WhisperAttentionType.flash_attention_2 + elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + self._attn_type = WhisperAttentionType.sdpa + else: + self._attn_type = WhisperAttentionType.eager - self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type]( + self.self_attn = WHISPER_ATTENTION_CLASSES[self._attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -770,7 +890,7 @@ def __init__(self, config: WhisperConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WHISPER_ATTENTION_CLASSES[attn_type]( + self.encoder_attn = WHISPER_ATTENTION_CLASSES[self._attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -879,6 +999,7 @@ class WhisperPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): std = self.config.init_std @@ -1202,6 +1323,7 @@ def __init__(self, config: WhisperConfig): self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._attn_type = self.layers[0]._attn_type self.layer_norm = nn.LayerNorm(config.d_model) @@ -1311,9 +1433,14 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._attn_type == WhisperAttentionType.flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_type == WhisperAttentionType.sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( From dd89c3c588ccec76231931130c1ee2335ed5f874 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:47:51 +0000 Subject: [PATCH 019/100] fix mbart default to eager renaming --- src/transformers/models/mbart/modeling_mbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index bbe691a1bde440..fc000eb589d287 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -487,7 +487,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query MBART_ATTENTION_CLASSES = { - "default": MBartAttention, + "eager": MBartAttention, "flash_attention_2": MBartFlashAttention2, } From f31c7b3ed69eb8fbf2bb55f2bf68184bfcba9722 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 6 Nov 2023 15:15:23 +0000 Subject: [PATCH 020/100] fix typo in falcon --- src/transformers/models/falcon/modeling_falcon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index b1172d22958ed4..159417f3450aeb 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1064,7 +1064,7 @@ def forward( ) # We take care to integrate alibi bias in the attention_mask here. - if attention_mask is None: + if attention_mask_2d is None: attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) else: attention_mask = torch.masked_fill( From 280c07883858b207875f11ce2f9bd5bbdc0a1c71 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 8 Nov 2023 13:52:09 +0000 Subject: [PATCH 021/100] fix is_causal in SDPA --- src/transformers/modeling_attn_mask_utils.py | 7 ++----- src/transformers/models/bart/modeling_bart.py | 3 ++- src/transformers/models/falcon/modeling_falcon.py | 12 +++--------- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 3 ++- src/transformers/models/llama/modeling_llama.py | 6 +++--- src/transformers/models/whisper/modeling_whisper.py | 3 ++- 6 files changed, 14 insertions(+), 20 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 4f5e71789b6910..6d898c1d50aec4 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -48,7 +48,7 @@ def to_causal_4d( key_value_length: int, dtype: torch.dtype = torch.float32, device: Union[torch.device, "str"] = "cpu", - ) -> torch.Tensor: + ) -> Optional[torch.Tensor]: """ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative bias to upper right hand triangular matrix (causal mask). @@ -367,9 +367,6 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len (`int`): The target length or query length the created mask shall have. """ - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - batch_size, key_value_length = mask.shape tgt_len = tgt_len if tgt_len is not None else key_value_length if batch_size == 1 and torch.all(mask == 1): @@ -393,7 +390,7 @@ def _create_4d_causal_attention_mask( device: torch.device, past_key_values_length: int = 0, sliding_window: Optional[int] = None, -): +) -> Optional[torch.Tensor]: """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 87f544a7149ba4..f5f519819cc653 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -574,7 +574,8 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - is_causal=self.is_causal and attention_mask is not None, + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal=self.is_causal and attention_mask is None and tgt_len > 1, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 159417f3450aeb..2278462b5e8719 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -443,20 +443,14 @@ def forward( if alibi is None: if hasattr(F, "scaled_dot_product_attention") and not output_attentions: - # TODO: deprecate this once we add FA2 support in Falcon - logger.warning_once( - "The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the" - " future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call " - "`model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations." - ) - attn_output = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask, 0.0, - is_causal=self.is_causal and attention_mask is None, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, ) attention_scores = None else: @@ -486,7 +480,7 @@ def forward( value_layer, attn_mask=attention_mask, dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None, + is_causal=self.is_causal and attention_mask is None and query_length > 1, ) context_layer = context_layer.transpose(1, 2) context_layer = context_layer.reshape(batch_size, query_length, self.num_heads * self.head_dim) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d03ea37081ee1d..031f983cc09395 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -555,7 +555,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): value, attn_mask=attention_mask, dropout_p=self.attn_pdrop if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, scale=scale, ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index aaa82d24e5f72d..f8dac0f5e02248 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -710,15 +710,15 @@ def forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) - # Note that Llama does not use dropout in the attention, hence the hard-coded - # dropout_p=0.0 independent of self.training. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, + # Llama does not use dropout in the attention, hence the hard-coded dropout_p=0.0 independent of self.training. dropout_p=0.0, - is_causal=self.is_causal and attention_mask is None, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 10eca9913d3f05..efb3bccefd8169 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -753,7 +753,8 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - is_causal=self.is_causal and attention_mask is not None, + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal=self.is_causal and attention_mask is None and tgt_len > 1, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): From e41ecfa29d712da98e32f333664cf8dc9cdd8077 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 08:40:44 +0000 Subject: [PATCH 022/100] check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained --- src/transformers/configuration_utils.py | 7 +++++++ src/transformers/models/bart/modeling_bart.py | 2 +- .../models/blenderbot/modeling_blenderbot.py | 3 ++- .../blenderbot_small/modeling_blenderbot_small.py | 3 ++- src/transformers/models/falcon/modeling_falcon.py | 10 +++++----- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/m2m_100/modeling_m2m_100.py | 3 ++- src/transformers/models/marian/modeling_marian.py | 3 ++- src/transformers/models/mbart/modeling_mbart.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 11 ++++++----- src/transformers/models/pegasus/modeling_pegasus.py | 3 ++- src/transformers/models/plbart/modeling_plbart.py | 3 ++- .../models/speech_to_text/modeling_speech_to_text.py | 3 ++- .../modeling_time_series_transformer.py | 3 ++- src/transformers/models/whisper/modeling_whisper.py | 2 +- 15 files changed, 38 insertions(+), 22 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index d4f59020434eb4..cc2fa55e9f7ceb 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -860,8 +860,11 @@ def to_diff_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(serializable_config_dict) + # TODO: This is to be refactored with e.g. an attribute model.attn_type (or attn_implementation) or model.config.attn_type. if "_flash_attn_2_enabled" in serializable_config_dict: del serializable_config_dict["_flash_attn_2_enabled"] + if "_sdpa_enabled" in serializable_config_dict: + del serializable_config_dict["_sdpa_enabled"] return serializable_config_dict @@ -879,8 +882,12 @@ def to_dict(self) -> Dict[str, Any]: del output["_auto_class"] if "_commit_hash" in output: del output["_commit_hash"] + + # TODO: This is to be refactored with e.g. an attribute model.attn_type (or attn_implementation) or model.config.attn_type. if "_flash_attn_2_enabled" in output: del output["_flash_attn_2_enabled"] + if "_sdpa_enabled" in output: + del output["_sdpa_enabled"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index f5f519819cc653..d84266dac84895 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -613,7 +613,7 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = BartAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = BartAttentionType.sdpa diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 2bd51a25c94127..2eeceaa278c019 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -41,6 +41,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_torch_sdpa_available, logging, replace_return_docstrings, @@ -270,7 +271,7 @@ def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = BlenderbotAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = BlenderbotAttentionType.sdpa diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index b3868435cbdff7..9465e9fbf81c60 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -39,6 +39,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_torch_sdpa_available, logging, replace_return_docstrings, @@ -257,7 +258,7 @@ def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = BlenderbotSmallAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = BlenderbotSmallAttentionType.sdpa diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 2278462b5e8719..51890fb7e7eea9 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -733,11 +733,11 @@ def __init__(self, config: FalconConfig): hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = ( - FalconAttention(config) - if not getattr(config, "_flash_attn_2_enabled", False) - else FalconFlashAttention2(config) - ) + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): + self.self_attention = FalconFlashAttention2(config) + else: + self.self_attention = FalconAttention(config) + self.mlp = FalconMLP(config) self.hidden_dropout = config.hidden_dropout self.config = config diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f8dac0f5e02248..6122b13b3f49f2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -739,7 +739,7 @@ def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self.self_attn = LlamaFlashAttention2(config=config) elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self.self_attn = LlamaSDPAAttention(config=config) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 232a497610d382..044795f45f1da7 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -38,6 +38,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_torch_sdpa_available, logging, replace_return_docstrings, @@ -328,7 +329,7 @@ def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = M2M100AttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = M2M100AttentionType.sdpa diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index dda2a87d0a7bc4..ee1a62c001b189 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -40,6 +40,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_torch_sdpa_available, logging, replace_return_docstrings, @@ -275,7 +276,7 @@ def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = MarianAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = MarianAttentionType.sdpa diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index fc000eb589d287..594988727aaae6 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -504,7 +504,7 @@ def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = MBartAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = MBartAttentionType.sdpa diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f1699a8f48adf3..c66150b6b8600f 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -581,11 +581,12 @@ class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = ( - MistralAttention(config=config) - if not getattr(config, "_flash_attn_2_enabled", False) - else MistralFlashAttention2(config) - ) + + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): + self.self_attn = MistralFlashAttention2(config) + else: + self.self_attn = MistralAttention(config=config) + self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 65298fb80e5098..d05232858bde1d 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -39,6 +39,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_torch_sdpa_available, logging, replace_return_docstrings, @@ -285,7 +286,7 @@ def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = PegasusAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = PegasusAttentionType.sdpa diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index f9578b2ed5a58c..73762724d55b0f 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -44,6 +44,7 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_torch_sdpa_available, logging, replace_return_docstrings, @@ -273,7 +274,7 @@ def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = PLBartAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = PLBartAttentionType.sdpa diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index bb795f2c9f429b..da5d842c50dcdf 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -34,6 +34,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_torch_sdpa_available, logging, replace_return_docstrings, @@ -349,7 +350,7 @@ def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = Speech2TextAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = Speech2TextAttentionType.sdpa diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index c3c2b554d7bea0..224360bca055cf 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -36,6 +36,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_torch_sdpa_available, logging, replace_return_docstrings, @@ -442,7 +443,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = TimeSeriesTransformerAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = TimeSeriesTransformerAttentionType.sdpa diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index efb3bccefd8169..f664c95f5e23ad 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -794,7 +794,7 @@ def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): self._attn_type = WhisperAttentionType.flash_attention_2 elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): self._attn_type = WhisperAttentionType.sdpa From 6f7964db55dd131879a3d3b0aca67b8f6ba8e3ee Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 08:59:16 +0000 Subject: [PATCH 023/100] add warnings when falling back on the manual implementation --- src/transformers/modeling_utils.py | 3 +++ src/transformers/models/bart/modeling_bart.py | 2 ++ src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 4 +++- src/transformers/models/llama/modeling_llama.py | 3 ++- src/transformers/models/whisper/modeling_whisper.py | 3 ++- 5 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1260aea76f7652..de2d0ed10612e8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1124,6 +1124,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Flash Attention 2 support _supports_flash_attn_2 = False + # SDPA support + _supports_sdpa = False + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index d84266dac84895..356d786f7d0a3a 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -509,6 +509,8 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once("BartModel is using BartSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True or layer_head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards.") return super().forward( hidden_states, key_value_states=key_value_states, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 880e015b976762..474096fd0e58bf 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -499,6 +499,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class GPTBigCodeSDPAAttention(GPTBigCodeAttention): def _attn(self, query, key, value, attention_mask=None, head_mask=None): if head_mask is not None: + # The super dispatch is done in the forward. raise ValueError( "PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository." ) @@ -604,7 +605,8 @@ def forward( # as SDPA expects seq_length to be at index -2 for the key as well attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) else: - # output_attentions=True, head_mask not None can not be supported when using SDPA. + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once("GPTBigCodeModel is using GPTBigCodeSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True and head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards.") attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2219e6ff10d035..5e0351d366b861 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -645,7 +645,8 @@ def forward( use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - # output_attentions=True can not be supported when using SDPA, falling back on the manual implementation. + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once("LlamaModel is using LlamaSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards.") return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 91d7ca328dd574..7f851a0fd82a4f 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -675,7 +675,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class WhisperSDPAAttention(WhisperAttention): - # Copied from transformers.models.bart.modeling_bart.BartSDPAAttention.forward with BART->whisper + # Copied from transformers.models.bart.modeling_bart.BartSDPAAttention.forward with BART->whisper, Bart->Whisper def forward( self, hidden_states: torch.Tensor, @@ -688,6 +688,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: + logger.warning_once("WhisperModel is using WhisperSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True or layer_head_mask not None. Falling back to the manual attention implementation, but manually specifying the manual implementation will be required from Transformers version v5.0.0 onwards.") return super().forward( hidden_states, key_value_states=key_value_states, From 0e38a95ba3d800505302d309705c2d93c0f85552 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:10:37 +0000 Subject: [PATCH 024/100] precise doc --- docs/source/en/llm_tutorial_optimization.md | 2 +- docs/source/en/perf_infer_gpu_one.md | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/source/en/llm_tutorial_optimization.md b/docs/source/en/llm_tutorial_optimization.md index 497e624820d4ed..6ef4ec29ff51ec 100644 --- a/docs/source/en/llm_tutorial_optimization.md +++ b/docs/source/en/llm_tutorial_optimization.md @@ -441,7 +441,7 @@ flush() ``` For comparison, let's run the same function, but enable Flash Attention instead. -To do so, we convert the model to [BetterTransformers](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is based on Flash Attention. +To do so, we convert the model to [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is able to use Flash Attention. ```python model.to_bettertransformer() diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index ba339c1a3068fa..3ac29044ee510f 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -128,6 +128,13 @@ FlashAttention is more memory efficient, meaning you can train on much larger se ## BetterTransformer + + +Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers. + + + + Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post. @@ -156,11 +163,13 @@ model = model.reverse_bettertransformer() model.save_pretrained("saved_model") ``` -### FlashAttention +### FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention + +PyTorch's `torch.nn.functional.scaled_dot_product_attention` (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is being added natively in Transformers, and you can check whether your model is using SDPA with the attribute `model.config.attn_implementation`. -SDPA can also call FlashAttention kernels under the hood. FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it. +Note that FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it. -To enable FlashAttention or to check whether it is available in a given setting (hardware, problem size), use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager: +By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether it is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager: ```diff import torch From 1bd07aa192014d10e3c319d13fea52f7c2d6b98b Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:32:54 +0000 Subject: [PATCH 025/100] wip replace _flash_attn_enabled by config.attn_implementation --- src/transformers/configuration_utils.py | 24 +++++----- src/transformers/modeling_utils.py | 29 ++++++++++-- src/transformers/models/bark/modeling_bark.py | 19 ++++---- src/transformers/models/bart/modeling_bart.py | 44 +++++++---------- .../models/blenderbot/modeling_blenderbot.py | 23 ++------- .../modeling_blenderbot_small.py | 26 ++-------- .../models/distilbert/modeling_distilbert.py | 13 ++--- .../models/falcon/modeling_falcon.py | 5 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 27 +++++------ .../models/gpt_neo/modeling_gpt_neo.py | 5 +- .../models/llama/modeling_llama.py | 9 ++-- .../models/m2m_100/modeling_m2m_100.py | 23 ++------- .../models/marian/modeling_marian.py | 26 ++-------- .../models/mbart/modeling_mbart.py | 36 ++++---------- .../models/mistral/modeling_mistral.py | 8 ++-- .../models/pegasus/modeling_pegasus.py | 29 ++---------- .../models/plbart/modeling_plbart.py | 47 +++++++------------ .../speech_to_text/modeling_speech_to_text.py | 23 ++------- .../modeling_time_series_transformer.py | 26 ++-------- .../models/whisper/modeling_whisper.py | 40 +++++----------- 20 files changed, 163 insertions(+), 319 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 072f1d007f098d..cada5a60467ecd 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -422,6 +422,18 @@ def num_labels(self, num_labels: int): self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) + @property + def attn_implementation(self): + return self._attn_implementation + + @attn_implementation.setter + def attn_implementation(self, value): + if hasattr(self, "attn_implementation_set") and self.attn_implementation_set: + raise NotImplementedError("Modifying the attention implementation through this attribute is currently not implemented.") + self.attn_implementation_set = True + + self._attn_implementation = value + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the @@ -861,12 +873,6 @@ def to_diff_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(serializable_config_dict) - # TODO: This is to be refactored with e.g. an attribute model.attn_type (or attn_implementation) or model.config.attn_type. - if "_flash_attn_2_enabled" in serializable_config_dict: - del serializable_config_dict["_flash_attn_2_enabled"] - if "_sdpa_enabled" in serializable_config_dict: - del serializable_config_dict["_sdpa_enabled"] - return serializable_config_dict def to_dict(self) -> Dict[str, Any]: @@ -884,12 +890,6 @@ def to_dict(self) -> Dict[str, Any]: if "_commit_hash" in output: del output["_commit_hash"] - # TODO: This is to be refactored with e.g. an attribute model.attn_type (or attn_implementation) or model.config.attn_type. - if "_flash_attn_2_enabled" in output: - del output["_flash_attn_2_enabled"] - if "_sdpa_enabled" in output: - del output["_sdpa_enabled"] - # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index de2d0ed10612e8..87397e011fc0a7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1155,6 +1155,27 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): self.warnings_issued = {} self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + # TODO: This is TEMPORARY and need to be discussed + if config.attn_implementation == "flash_attention_2": + if not self._supports_flash_attn_2: + raise ValueError(f'Passed config.attn_implementation == "flash_attention_2" but {self.__class__.__name__} does not support Flash Attention yet.') + + if not is_flash_attn_2_available(): + raise ImportError( + "Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" + " installing it. Make sure to have at least the version 2.1.0" + ) + + if config.attn_implementation == "sdpa": + if not self._supports_sdpa: + raise ValueError(f'Passed config.attn_implementation == "sdpa" but {self.__class__.__name__} does not support SDPA yet.') + + if not is_torch_sdpa_available(): + raise ImportError( + "SDPA is not available. Please use torch>=2.1.1 in order to use SDPA." + ) + + def post_init(self): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's @@ -1265,7 +1286,7 @@ def _check_and_enable_flash_attn_2( The method checks if the current setup is compatible with Flash Attention as it requires the model to be in half precision and not ran on CPU. - If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model + If all checks pass, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module """ if not cls._supports_flash_attn_2: @@ -1325,7 +1346,7 @@ def _check_and_enable_flash_attn_2( "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." ) - config._flash_attn_2_enabled = True + config.attn_implementation = "flash_attention_2" return config @classmethod @@ -1341,7 +1362,7 @@ def _check_and_enable_sdpa(cls, config) -> PretrainedConfig: if _is_bettertransformer: return config - config._sdpa_enabled = True + config.attn_implementation = "sdpa" return config def enable_input_require_grads(self): @@ -3260,6 +3281,8 @@ def from_pretrained( elif is_torch_sdpa_available(): # use_flash_attention_2 takes priority. config = cls._check_and_enable_sdpa(config) + else: + config.attn_implementation = "eager" with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index f8b9eab5d397b3..83178a724c6594 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -373,7 +373,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query BARK_ATTENTION_CLASSES = { - "default": BarkSelfAttention, + "eager": BarkSelfAttention, "flash_attention_2": BarkSelfFlashAttention2, } @@ -420,8 +420,7 @@ def __init__(self, config, is_causal=False): self.layernorm_1 = nn.LayerNorm(config.hidden_size) self.layernorm_2 = nn.LayerNorm(config.hidden_size) - attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal) + self.attn = BARK_ATTENTION_CLASSES[config.attn_implementation](config, is_causal=is_causal) self.mlp = BarkMLP(config) @@ -654,6 +653,7 @@ def __init__(self, config): self.drop = nn.Dropout(config.dropout) self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)]) + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias) @@ -789,7 +789,7 @@ def forward( if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None else: attention_mask = attention_mask.view(batch_size, -1) @@ -1249,6 +1249,7 @@ def __init__(self, config): self.drop = nn.Dropout(config.dropout) self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)]) + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" self.layernorm_final = nn.LayerNorm(config.hidden_size) @@ -1418,7 +1419,7 @@ def forward( if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None else: # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] @@ -1876,12 +1877,12 @@ def _check_and_enable_flash_attn_2( The method checks if the current setup is compatible with Flash Attention as it requires the model to be in half precision and not ran on CPU. - If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model + If all checks pass, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module """ config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map) - config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) - config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) - config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) + config.semantic_config.attn_implementation = config.attn_implementation + config.coarse_acoustics_config.attn_implementation = config.attn_implementation + config.fine_acoustics_config.attn_implementation = config.attn_implementation return config diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 356d786f7d0a3a..7d49f383f583c4 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -48,7 +48,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -604,7 +603,7 @@ def forward( } -class BartAttentionType(str, Enum): +class BartAttentionType(Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" @@ -615,20 +614,13 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = BartAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = BartAttentionType.sdpa - else: - self._attn_type = BartAttentionType.eager - - self.self_attn = BART_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = BART_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -692,14 +684,7 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = BartAttentionType.flash_attention_2 - elif is_torch_sdpa_available(): - self._attn_type = BartAttentionType.sdpa - else: - self._attn_type = BartAttentionType.eager - - self.self_attn = BART_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = BART_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -712,7 +697,7 @@ def __init__(self, config: BartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = BART_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -1073,7 +1058,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No embed_dim, ) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._attn_type = self.layers[0]._attn_type + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_sdpa = config.attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -1161,9 +1147,9 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._attn_type == BartAttentionType.flash_attention_2: + if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None - elif self._attn_type == BartAttentionType.sdpa and head_mask is None and not output_attentions: + elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1254,7 +1240,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.d_model, ) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._attn_type = self.layers[0]._attn_type + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_sdpa = config.attn_implementation == "sdpa" + self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1373,10 +1361,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - if self._attn_type == BartAttentionType.flash_attention_2: + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_type == BartAttentionType.sdpa and not output_attentions and cross_attn_head_mask is None: + elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -1393,9 +1381,9 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._attn_type == BartAttentionType.flash_attention_2: + if self._use_flash_attention_2: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._attn_type == BartAttentionType.sdpa and cross_attn_head_mask is None and not output_attentions: + elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index ad9d73b9638ccf..d6b2b83a373dfe 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -41,8 +41,6 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -259,7 +257,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Blenderbot -class BlenderbotAttentionType(str, Enum): +class BlenderbotAttentionType(Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" @@ -271,14 +269,7 @@ def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = BlenderbotAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = BlenderbotAttentionType.sdpa - else: - self._attn_type = BlenderbotAttentionType.eager - - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -348,14 +339,8 @@ class BlenderbotDecoderLayer(nn.Module): def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = BlenderbotAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = BlenderbotAttentionType.sdpa - else: - self._attn_type = BlenderbotAttentionType.eager - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -368,7 +353,7 @@ def __init__(self, config: BlenderbotConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 08dc22785dd5ad..af3e7c55782b8e 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -39,8 +39,6 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -258,20 +256,13 @@ def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = BlenderbotSmallAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = BlenderbotSmallAttentionType.sdpa - else: - self._attn_type = BlenderbotSmallAttentionType.eager - - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -337,7 +328,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->BlenderbotSmall -class BlenderbotSmallAttentionType(str, Enum): +class BlenderbotSmallAttentionType(Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" @@ -349,14 +340,7 @@ def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = BlenderbotSmallAttentionType.flash_attention_2 - elif is_torch_sdpa_available(): - self._attn_type = BlenderbotSmallAttentionType.sdpa - else: - self._attn_type = BlenderbotSmallAttentionType.eager - - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -369,7 +353,7 @@ def __init__(self, config: BlenderbotSmallConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 144fde42e0bf00..08d8d6cebfb8cd 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -463,11 +463,11 @@ def __init__(self, config: PretrainedConfig): if config.dim % config.n_heads != 0: raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") - self.attention = ( - MultiHeadSelfAttention(config) - if not getattr(config, "_flash_attn_2_enabled", False) - else DistilBertFlashAttention2(config) - ) + if config.attn_implementation == "flash_attention_2": + self.attention = DistilBertFlashAttention2(config) + else: + self.attention = MultiHeadSelfAttention(config) + self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.ffn = FFN(config) @@ -687,6 +687,7 @@ def __init__(self, config: PretrainedConfig): self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" # Initialize weights and apply final processing self.post_init() @@ -792,7 +793,7 @@ def forward( embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: if attention_mask is None: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 51890fb7e7eea9..fcf92b6499c0c4 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -733,7 +733,7 @@ def __init__(self, config: FalconConfig): hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and config.attn_implementation == "flash_attention_2": self.self_attention = FalconFlashAttention2(config) else: self.self_attention = FalconAttention(config) @@ -944,6 +944,7 @@ def __init__(self, config: FalconConfig): # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -1035,7 +1036,7 @@ def forward( ) position_ids = position_ids.unsqueeze(0) - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif hasattr(F, "scaled_dot_product_attention") and not output_attentions: diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 474096fd0e58bf..6f136eebcb3254 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -35,7 +35,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_torch_sdpa_available, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig @@ -642,6 +641,12 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states +GPTBIGCODE_ATTENTION_CLASSES = { + "eager": GPTBigCodeAttention, + "flash_attention_2": GPTBigCodeFlashAttention2, + "sdpa": GPTBigCodeSDPAAttention, +} + class GPTBigCodeBlock(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() @@ -650,24 +655,15 @@ def __init__(self, config, layer_idx=None): self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - if getattr(config, "_flash_attn_2_enabled", False): - self.attn = GPTBigCodeFlashAttention2(config, layer_idx=layer_idx) - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self.attn = GPTBigCodeSDPAAttention(config, layer_idx=layer_idx) - else: - self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) + self.attn = GPTBIGCODE_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - if getattr(config, "_flash_attn_2_enabled", False): - self.attn = GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx) - elif is_torch_sdpa_available(): - self.attn = GPTBigCodeSDPAAttention(config, is_cross_attention=True, layer_idx=layer_idx) - else: - self.attn = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) + self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config.attn_implementation](config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -891,7 +887,8 @@ def __init__(self, config): self.gradient_checkpointing = False - self._use_sdpa = isinstance(self.h[0].attn, GPTBigCodeSDPAAttention) + self._use_sdpa = config.attn_implementation == "sdpa" + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" # Initialize weights and apply final processing self.post_init() @@ -973,7 +970,7 @@ def forward( key_length = past_length + query_length self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None encoder_attention_mask = ( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 368920f3769c6f..a7955f09b8e8d6 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -477,7 +477,7 @@ def __init__(self, config, layer_id=0): if self.attention_type in ["global", "local"]: self.attention = ( GPTNeoSelfAttention(config, self.attention_type) - if not getattr(config, "_flash_attn_2_enabled", False) + if config.attn_implementation == "eager" else GPTNeoFlashAttention2(config, self.attention_type) ) else: @@ -698,6 +698,7 @@ def __init__(self, config): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.drop = nn.Dropout(float(config.embed_dropout)) self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)]) + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.gradient_checkpointing = False @@ -775,7 +776,7 @@ def forward( hidden_states = inputs_embeds + position_embeds # Attention mask. - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5e0351d366b861..4616792ca2c18a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -740,9 +740,9 @@ def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): + if is_flash_attn_2_available() and config.attn_implementation == "flash_attention_2": self.self_attn = LlamaFlashAttention2(config=config) - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): + elif is_torch_sdpa_available() and config.attn_implementation == "sdpa": self.self_attn = LlamaSDPAAttention(config=config) else: self.self_attn = LlamaAttention(config=config) @@ -938,7 +938,8 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self._use_sdpa = isinstance(self.layers[0].self_attn, LlamaSDPAAttention) + self._use_sdpa = config.attn_implementation == "sdpa" + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -996,7 +997,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions: diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 044795f45f1da7..8491f45480d602 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -38,8 +38,6 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -329,14 +327,7 @@ def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = M2M100AttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = M2M100AttentionType.sdpa - else: - self._attn_type = M2M100AttentionType.eager - - self.self_attn = M2M100_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = M2M100_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -405,7 +396,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->M2M100 -class M2M100AttentionType(str, Enum): +class M2M100AttentionType(Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" @@ -416,14 +407,8 @@ class M2M100DecoderLayer(nn.Module): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = M2M100AttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = M2M100AttentionType.sdpa - else: - self._attn_type = M2M100AttentionType.eager - self.self_attn = M2M100_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = M2M100_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -436,7 +421,7 @@ def __init__(self, config: M2M100Config): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = M2M100_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = M2M100_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index ee1a62c001b189..6016e8c6e954ca 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -40,8 +40,6 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -276,20 +274,13 @@ def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = MarianAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = MarianAttentionType.sdpa - else: - self._attn_type = MarianAttentionType.eager - - self.self_attn = MARIAN_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = MARIAN_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -352,7 +343,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Marian -class MarianAttentionType(str, Enum): +class MarianAttentionType(Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" @@ -364,14 +355,7 @@ def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = MarianAttentionType.flash_attention_2 - elif is_torch_sdpa_available(): - self._attn_type = MarianAttentionType.sdpa - else: - self._attn_type = MarianAttentionType.eager - - self.self_attn = MARIAN_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = MARIAN_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -384,7 +368,7 @@ def __init__(self, config: MarianConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MARIAN_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = MARIAN_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 594988727aaae6..00d03dbade07d6 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -15,7 +15,6 @@ """ PyTorch MBART model.""" import copy import math -from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -42,7 +41,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -492,26 +490,12 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query } -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->MBart -class MBartAttentionType(str, Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - class MBartEncoderLayer(nn.Module): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = MBartAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = MBartAttentionType.sdpa - else: - self._attn_type = MBartAttentionType.eager - - self.self_attn = MBART_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = MBART_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -580,14 +564,8 @@ class MBartDecoderLayer(nn.Module): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = MBartAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = MBartAttentionType.sdpa - else: - self._attn_type = MBartAttentionType.eager - self.self_attn = MBART_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = MBART_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -600,7 +578,7 @@ def __init__(self, config: MBartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MBART_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = MBART_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -939,6 +917,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N embed_dim, ) self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1027,7 +1006,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1116,6 +1095,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.d_model, ) self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1235,7 +1215,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1246,7 +1226,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f9756abfcf9e0c..056cd68d0a65fe 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -580,7 +580,7 @@ def __init__(self, config: MistralConfig): super().__init__() self.hidden_size = config.hidden_size - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): + if config.attn_implementation == "flash_attention_2": self.self_attn = MistralFlashAttention2(config) else: self.self_attn = MistralAttention(config=config) @@ -773,6 +773,7 @@ def __init__(self, config: MistralConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -837,8 +838,7 @@ def forward( if ( attention_mask is not None - and hasattr(self.config, "_flash_attn_2_enabled") - and self.config._flash_attn_2_enabled + and self._use_flash_attention_2 and past_key_values is not None ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size @@ -849,7 +849,7 @@ def forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index d05232858bde1d..57d1adfa6c211f 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -16,7 +16,6 @@ import copy import math -from enum import Enum from typing import List, Optional, Tuple, Union import numpy as np @@ -39,8 +38,6 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -273,27 +270,13 @@ def forward( PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention} -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Pegasus -class PegasusAttentionType(str, Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS class PegasusEncoderLayer(nn.Module): def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = PegasusAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = PegasusAttentionType.sdpa - else: - self._attn_type = PegasusAttentionType.eager - - self.self_attn = PEGASUS_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = PEGASUS_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -363,14 +346,8 @@ class PegasusDecoderLayer(nn.Module): def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = PegasusAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = PegasusAttentionType.sdpa - else: - self._attn_type = PegasusAttentionType.eager - self.self_attn = PEGASUS_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = PEGASUS_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -383,7 +360,7 @@ def __init__(self, config: PegasusConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PEGASUS_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 73762724d55b0f..ce9df7641b80e4 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -44,8 +44,6 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -274,20 +272,13 @@ def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = PLBartAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = PLBartAttentionType.sdpa - else: - self._attn_type = PLBartAttentionType.eager - - self.self_attn = PLBART_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = PLBART_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -351,7 +342,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->PLBart -class PLBartAttentionType(str, Enum): +class PLBartAttentionType(Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" @@ -363,14 +354,7 @@ def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = PLBartAttentionType.flash_attention_2 - elif is_torch_sdpa_available(): - self._attn_type = PLBartAttentionType.sdpa - else: - self._attn_type = PLBartAttentionType.eager - - self.self_attn = PLBART_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = PLBART_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -383,7 +367,7 @@ def __init__(self, config: PLBartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PLBART_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = PLBART_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -698,7 +682,8 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = embed_dim, ) self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._attn_type = self.layers[0]._attn_type + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_sdpa = config.attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -786,9 +771,9 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._attn_type == PLBartAttentionType.flash_attention_2: + if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None - elif self._attn_type == PLBartAttentionType.sdpa and head_mask is None and not output_attentions: + elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -880,7 +865,9 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = config.d_model, ) self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._attn_type = self.layers[0]._attn_type + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_sdpa = config.attn_implementation == "sdpa" + self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -999,10 +986,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - if self._attn_type == PLBartAttentionType.flash_attention_2: + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_type == PLBartAttentionType.sdpa and not output_attentions and cross_attn_head_mask is None: + elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -1019,11 +1006,9 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._attn_type == PLBartAttentionType.flash_attention_2: + if self._use_flash_attention_2: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif ( - self._attn_type == PLBartAttentionType.sdpa and cross_attn_head_mask is None and not output_attentions - ): + elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index da5d842c50dcdf..6dbf94382c323d 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -34,8 +34,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -338,7 +336,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Speech2Text -class Speech2TextAttentionType(str, Enum): +class Speech2TextAttentionType(Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" @@ -350,14 +348,7 @@ def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = Speech2TextAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = Speech2TextAttentionType.sdpa - else: - self._attn_type = Speech2TextAttentionType.eager - - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -427,14 +418,8 @@ class Speech2TextDecoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = Speech2TextAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = Speech2TextAttentionType.sdpa - else: - self._attn_type = Speech2TextAttentionType.eager - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -447,7 +432,7 @@ def __init__(self, config: Speech2TextConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 224360bca055cf..279615b63ea528 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -36,8 +36,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -443,20 +441,13 @@ def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = TimeSeriesTransformerAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = TimeSeriesTransformerAttentionType.sdpa - else: - self._attn_type = TimeSeriesTransformerAttentionType.eager - - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -523,7 +514,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->TimeSeriesTransformer -class TimeSeriesTransformerAttentionType(str, Enum): +class TimeSeriesTransformerAttentionType(Enum): eager = "eager" sdpa = "sdpa" flash_attention_2 = "flash_attention_2" @@ -535,14 +526,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = TimeSeriesTransformerAttentionType.flash_attention_2 - elif is_torch_sdpa_available(): - self._attn_type = TimeSeriesTransformerAttentionType.sdpa - else: - self._attn_type = TimeSeriesTransformerAttentionType.eager - - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -555,7 +539,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 7f851a0fd82a4f..5a1e731f09155e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -15,7 +15,6 @@ """ PyTorch Whisper model.""" import math -from enum import Enum from typing import Optional, Tuple, Union import numpy as np @@ -41,7 +40,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -688,7 +686,10 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: - logger.warning_once("WhisperModel is using WhisperSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True or layer_head_mask not None. Falling back to the manual attention implementation, but manually specifying the manual implementation will be required from Transformers version v5.0.0 onwards.") + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "WhisperModel is using WhisperSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True or layer_head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + ) return super().forward( hidden_states, key_value_states=key_value_states, @@ -782,27 +783,13 @@ def forward( } -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Whisper -class WhisperAttentionType(str, Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - if is_flash_attn_2_available() and getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = WhisperAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = WhisperAttentionType.sdpa - else: - self._attn_type = WhisperAttentionType.eager - - self.self_attn = WHISPER_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = WHISPER_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -872,14 +859,8 @@ class WhisperDecoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self._attn_type = WhisperAttentionType.flash_attention_2 - elif is_torch_sdpa_available() and getattr(config, "_sdpa_enabled", False): - self._attn_type = WhisperAttentionType.sdpa - else: - self._attn_type = WhisperAttentionType.eager - self.self_attn = WHISPER_ATTENTION_CLASSES[self._attn_type]( + self.self_attn = WHISPER_ATTENTION_CLASSES[config.attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -892,7 +873,7 @@ def __init__(self, config: WhisperConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WHISPER_ATTENTION_CLASSES[self._attn_type]( + self.encoder_attn = WHISPER_ATTENTION_CLASSES[config.attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -1325,7 +1306,8 @@ def __init__(self, config: WhisperConfig): self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._attn_type = self.layers[0]._attn_type + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_sdpa = config.attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) @@ -1435,10 +1417,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._attn_type == WhisperAttentionType.flash_attention_2: + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_type == WhisperAttentionType.sdpa and head_mask is None and not output_attentions: + elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, inputs_embeds, past_key_values_length From feae821d6b3e733ace8c36a319e32ae0164398cf Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:16:03 +0000 Subject: [PATCH 026/100] fix typo --- src/transformers/configuration_utils.py | 5 ++++- src/transformers/modeling_utils.py | 2 +- src/transformers/models/bart/modeling_bart.py | 2 +- .../models/blenderbot_small/modeling_blenderbot_small.py | 2 +- src/transformers/models/marian/modeling_marian.py | 2 +- src/transformers/models/plbart/modeling_plbart.py | 2 +- .../modeling_time_series_transformer.py | 2 +- 7 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index cada5a60467ecd..6e1ca87449cc19 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -424,7 +424,10 @@ def num_labels(self, num_labels: int): @property def attn_implementation(self): - return self._attn_implementation + if not hasattr(self, "_attn_implementation"): + return "eager" + else: + return self._attn_implementation @attn_implementation.setter def attn_implementation(self, value): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 87397e011fc0a7..429e031ff4911f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1155,7 +1155,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): self.warnings_issued = {} self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None - # TODO: This is TEMPORARY and need to be discussed + # TODO: This is TEMPORARY and need to be discussed, should it rather be in XXXPreTrainedModel __init__? if config.attn_implementation == "flash_attention_2": if not self._supports_flash_attn_2: raise ValueError(f'Passed config.attn_implementation == "flash_attention_2" but {self.__class__.__name__} does not support Flash Attention yet.') diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7d49f383f583c4..9db78ddd075638 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -620,7 +620,7 @@ def __init__(self, config: BartConfig): dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index af3e7c55782b8e..7ef465f571a9de 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -262,7 +262,7 @@ def __init__(self, config: BlenderbotSmallConfig): dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 6016e8c6e954ca..43d6b4f1cd0918 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -280,7 +280,7 @@ def __init__(self, config: MarianConfig): dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index ce9df7641b80e4..cfd9833faa7bb2 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -278,7 +278,7 @@ def __init__(self, config: PLBartConfig): dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 279615b63ea528..ce7bc90a55b5db 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -447,7 +447,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): dropout=config.attention_dropout, config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(config.attn_implementation) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout From 2032e648965b091a4fa408a62cca7eaf5783f888 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:30:00 +0000 Subject: [PATCH 027/100] add tests --- src/transformers/modeling_utils.py | 6 +- .../models/falcon/modeling_falcon.py | 2 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 2 +- .../models/llama/modeling_llama.py | 5 +- src/transformers/testing_utils.py | 8 + tests/models/llama/test_modeling_llama.py | 45 +++++ tests/test_modeling_common.py | 179 +++++++++++++++++- 7 files changed, 232 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 429e031ff4911f..4c9a7902a2055c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2605,6 +2605,10 @@ def from_pretrained( adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + # TODO: remove this one once we have a proper config.attn_implementation setter. + # This only temporary for testing. SDPA should otherwise should be transparent to the user. + _use_sdpa = kwargs.pop("_use_sdpa", True) + if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -3278,7 +3282,7 @@ def from_pretrained( if use_flash_attention_2: config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map) - elif is_torch_sdpa_available(): + elif _use_sdpa and is_torch_sdpa_available(): # use_flash_attention_2 takes priority. config = cls._check_and_enable_sdpa(config) else: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index fcf92b6499c0c4..c633b776e5a16f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -733,7 +733,7 @@ def __init__(self, config: FalconConfig): hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - if is_flash_attn_2_available() and config.attn_implementation == "flash_attention_2": + if config.attn_implementation == "flash_attention_2": self.self_attention = FalconFlashAttention2(config) else: self.self_attention = FalconAttention(config) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 6f136eebcb3254..5de388ba48d4a4 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -997,7 +997,7 @@ def forward( # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. self_attention_mask = self_attention_mask.transpose(1, 2) - if query_length > 1: + if query_length > 1 and attention_mask is not None: # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 self_attention_mask = AttentionMaskConverter._unmask_unattended( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4616792ca2c18a..18da0451de7ab8 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -41,7 +41,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_torch_sdpa_available, logging, replace_return_docstrings, ) @@ -740,9 +739,9 @@ def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - if is_flash_attn_2_available() and config.attn_implementation == "flash_attention_2": + if config.attn_implementation == "flash_attention_2": self.self_attn = LlamaFlashAttention2(config=config) - elif is_torch_sdpa_available() and config.attn_implementation == "sdpa": + elif config.attn_implementation == "sdpa": self.self_attn = LlamaSDPAAttention(config=config) else: self.self_attn = LlamaAttention(config=config) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index eb21cbac2303e6..1a3511836ae676 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -107,6 +107,7 @@ is_torch_fp16_available_on_device, is_torch_neuroncore_available, is_torch_npu_available, + is_torch_sdpa_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, @@ -439,6 +440,13 @@ def require_flash_attn(test_case): """ return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) +def require_torch_sdpa(test_case): + """ + Decorator marking a test that requires PyTorch's SDPA. + + These tests are skipped when requirements are not met (torch version). + """ + return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case) def require_peft(test_case): """ diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 21fb4f44d2b8d0..d66fa6bd860ce0 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -26,6 +26,7 @@ require_torch, require_torch_accelerator, require_torch_gpu, + require_torch_sdpa, slow, torch_device, ) @@ -418,6 +419,50 @@ def test_flash_attn_2_generate_padding_right(self): self.assertListEqual(output_native, output_fa_2) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + import torch + + max_new_tokens = 30 + + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + model_sdpa = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + # TODO: replace the _use_sdpa=False by `model.config.attn_implementation = "eager"` once the setter is implemented. + model_eager = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + _use_sdpa=False, + ).to(torch_device) + + for name, submodule in model_eager.named_modules(): + if "SDPA" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"] + + for padding_side in ["left", "right"]: + tokenizer.padding_side = padding_side + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) + + res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False + ) + self.assertTrue(torch.allclose(res_eager, res_sdpa)) @require_torch class LlamaIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0edc23c7af2048..d9988ffaf13c46 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import collections import copy import gc @@ -28,6 +27,7 @@ from typing import Dict, List, Tuple import numpy as np +from parameterized import parameterized from pytest import mark import transformers @@ -70,6 +70,7 @@ require_torch, require_torch_gpu, require_torch_multi_gpu, + require_torch_sdpa, slow, torch_device, ) @@ -2834,7 +2835,7 @@ def test_flash_attn_2_conversion(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") model = model_class(config) @@ -2859,7 +2860,7 @@ def test_flash_attn_2_inference(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -2956,7 +2957,7 @@ def test_flash_attn_2_inference_padding_right(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3049,7 +3050,7 @@ def test_flash_attn_2_generate_left_padding(self): for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3092,7 +3093,7 @@ def test_flash_attn_2_generate_padding_right(self): for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3108,7 +3109,7 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input = dummy_input.to(torch.float16) dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # make sure we do left padding + # make sure we do right padding dummy_attention_mask[:, :-1] = 1 dummy_attention_mask[:, -1:] = 0 @@ -3126,6 +3127,166 @@ def test_flash_attn_2_generate_padding_right(self): self.assertTrue(torch.equal(out, out_fa)) + @require_torch_sdpa + @slow + @parameterized.expand([("left",), ("right",)]) + def test_eager_matches_sdpa_inference(self, padding_side: str): + import torch + + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, + ) + model_sdpa.to(torch_device) + + # TODO: replace the _use_sdpa=False by `model.config.attn_implementation = "eager"` once the setter is implemented. + model_eager = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, _use_sdpa=False, + ) + model_eager.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + if padding_side == "left": + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + elif padding_side == "right": + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + if is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + + outputs_eager = model_eager(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_sdpa = model_sdpa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs_eager = model_eager(dummy_input, output_hidden_states=True) + outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True) + + logits = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + if is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs_eager = model_eager(dummy_input, **other_inputs) + outputs_sdpa = model_sdpa(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs_eager = model_eager(dummy_input, **other_inputs) + outputs_sdpa = model_sdpa(dummy_input, **other_inputs) + + logits_eager = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_sdpa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_sdpa[1:], logits_eager[1:], atol=4e-2, rtol=4e-2) + + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + import torch + + max_new_tokens = 30 + + # TODO: Implement a test for SDPA simply testing forward, not generate. + if len(self.all_generative_model_classes) == 0: + self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + # TODO: replace the _use_sdpa=False by `model.config.attn_implementation = "eager"` once the setter is implemented. + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + _use_sdpa=False, + ).to(torch_device) + + for name, submodule in model_eager.named_modules(): + if "SDPA" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -3137,7 +3298,7 @@ def test_flash_attn_2_generate_use_cache(self): for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -3178,7 +3339,7 @@ def test_flash_attn_2_fp32_ln(self): for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) From d98c2f99bd4564075184d09eb98b5511e83b5795 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:49:06 +0000 Subject: [PATCH 028/100] style --- src/transformers/configuration_utils.py | 4 +++- src/transformers/modeling_utils.py | 13 +++++++------ src/transformers/models/bart/modeling_bart.py | 4 +++- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 9 +++++++-- src/transformers/models/llama/modeling_llama.py | 4 +++- .../models/mistral/modeling_mistral.py | 6 +----- src/transformers/testing_utils.py | 2 ++ tests/models/llama/test_modeling_llama.py | 7 +++---- tests/test_modeling_common.py | 15 +++++++++++---- 9 files changed, 40 insertions(+), 24 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 6e1ca87449cc19..11ca827006e2e4 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -432,7 +432,9 @@ def attn_implementation(self): @attn_implementation.setter def attn_implementation(self, value): if hasattr(self, "attn_implementation_set") and self.attn_implementation_set: - raise NotImplementedError("Modifying the attention implementation through this attribute is currently not implemented.") + raise NotImplementedError( + "Modifying the attention implementation through this attribute is currently not implemented." + ) self.attn_implementation_set = True self._attn_implementation = value diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4c9a7902a2055c..b68ecdcbed271d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1158,7 +1158,9 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): # TODO: This is TEMPORARY and need to be discussed, should it rather be in XXXPreTrainedModel __init__? if config.attn_implementation == "flash_attention_2": if not self._supports_flash_attn_2: - raise ValueError(f'Passed config.attn_implementation == "flash_attention_2" but {self.__class__.__name__} does not support Flash Attention yet.') + raise ValueError( + f'Passed config.attn_implementation == "flash_attention_2" but {self.__class__.__name__} does not support Flash Attention yet.' + ) if not is_flash_attn_2_available(): raise ImportError( @@ -1168,13 +1170,12 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): if config.attn_implementation == "sdpa": if not self._supports_sdpa: - raise ValueError(f'Passed config.attn_implementation == "sdpa" but {self.__class__.__name__} does not support SDPA yet.') - - if not is_torch_sdpa_available(): - raise ImportError( - "SDPA is not available. Please use torch>=2.1.1 in order to use SDPA." + raise ValueError( + f'Passed config.attn_implementation == "sdpa" but {self.__class__.__name__} does not support SDPA yet.' ) + if not is_torch_sdpa_available(): + raise ImportError("SDPA is not available. Please use torch>=2.1.1 in order to use SDPA.") def post_init(self): """ diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 9db78ddd075638..2b9e120d9afb3f 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -509,7 +509,9 @@ def forward( """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once("BartModel is using BartSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True or layer_head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards.") + logger.warning_once( + "BartModel is using BartSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True or layer_head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + ) return super().forward( hidden_states, key_value_states=key_value_states, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 5de388ba48d4a4..8d67f339e68010 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -605,7 +605,9 @@ def forward( attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) else: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once("GPTBigCodeModel is using GPTBigCodeSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True and head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards.") + logger.warning_once( + "GPTBigCodeModel is using GPTBigCodeSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True and head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + ) attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: @@ -647,6 +649,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl "sdpa": GPTBigCodeSDPAAttention, } + class GPTBigCodeBlock(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() @@ -663,7 +666,9 @@ def __init__(self, config, layer_idx=None): if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config.attn_implementation](config, is_cross_attention=True, layer_idx=layer_idx) + self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config.attn_implementation]( + config, is_cross_attention=True, layer_idx=layer_idx + ) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 18da0451de7ab8..d05c9c1d149aa3 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -645,7 +645,9 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once("LlamaModel is using LlamaSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards.") + logger.warning_once( + "LlamaModel is using LlamaSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 056cd68d0a65fe..99456e98bad2f3 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -836,11 +836,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if ( - attention_mask is not None - and self._use_flash_attention_2 - and past_key_values is not None - ): + if attention_mask is not None and self._use_flash_attention_2 and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1a3511836ae676..4dcca595a1dc61 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -440,6 +440,7 @@ def require_flash_attn(test_case): """ return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) + def require_torch_sdpa(test_case): """ Decorator marking a test that requires PyTorch's SDPA. @@ -448,6 +449,7 @@ def require_torch_sdpa(test_case): """ return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case) + def require_peft(test_case): """ Decorator marking a test that requires PEFT. diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index d66fa6bd860ce0..630f8856133328 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -457,13 +457,12 @@ def test_eager_matches_sdpa_generate(self): inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) - res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False - ) + res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False - ) + res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) self.assertTrue(torch.allclose(res_eager, res_sdpa)) + @require_torch class LlamaIntegrationTest(unittest.TestCase): @unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d9988ffaf13c46..fb83ac2e151d09 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3145,13 +3145,16 @@ def test_eager_matches_sdpa_inference(self, padding_side: str): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_sdpa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, + tmpdirname, + torch_dtype=torch.bfloat16, ) model_sdpa.to(torch_device) # TODO: replace the _use_sdpa=False by `model.config.attn_implementation = "eager"` once the setter is implemented. model_eager = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, _use_sdpa=False, + tmpdirname, + torch_dtype=torch.bfloat16, + _use_sdpa=False, ) model_eager.to(torch_device) @@ -3173,8 +3176,12 @@ def test_eager_matches_sdpa_inference(self, padding_side: str): if is_encoder_decoder: decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - outputs_eager = model_eager(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - outputs_sdpa = model_sdpa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_eager = model_eager( + dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True + ) + outputs_sdpa = model_sdpa( + dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True + ) else: outputs_eager = model_eager(dummy_input, output_hidden_states=True) outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True) From ab59f9d6897567bce99de70ede786dea4ecf16d9 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:09:19 +0000 Subject: [PATCH 029/100] add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace --- src/transformers/modeling_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b68ecdcbed271d..5b1b4c2513e559 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import copy import functools import gc import importlib.metadata @@ -2822,6 +2823,9 @@ def from_pretrained( else: model_kwargs = kwargs + # We do not want to modify inplace the PretrainedConfig passed to from_pretrained. + config = copy.deepcopy(config) + quantizer = None quantization_method_from_config = None if hasattr(config, "quantization_config"): From 98a3825b7d333114e48b85905efb27fb347e4171 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:42:09 +0000 Subject: [PATCH 030/100] obey to config.attn_implementation if a config is passed in from_pretrained --- src/transformers/configuration_utils.py | 5 +++ src/transformers/modeling_utils.py | 43 +++++++++++++------ src/transformers/models/bark/modeling_bark.py | 8 +++- tests/models/llama/test_modeling_llama.py | 8 +++- tests/test_modeling_common.py | 16 +++++-- 5 files changed, 60 insertions(+), 20 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 11ca827006e2e4..54e5ba7485ef96 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -878,6 +878,9 @@ def to_diff_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(serializable_config_dict) + if "attention_implementation" in serializable_config_dict: + del serializable_config_dict["attention_implementation"] + return serializable_config_dict def to_dict(self) -> Dict[str, Any]: @@ -894,6 +897,8 @@ def to_dict(self) -> Dict[str, Any]: del output["_auto_class"] if "_commit_hash" in output: del output["_commit_hash"] + if "attention_implementation" in output: + del output["attention_implementation"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5b1b4c2513e559..1d8d7f6c5fb7f7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1275,7 +1275,11 @@ def can_generate(cls) -> bool: @classmethod def _check_and_enable_flash_attn_2( - cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + enable: bool = True, ) -> PretrainedConfig: """ If you don't know about Flash Attention, check out the official repository of flash attention: @@ -1348,11 +1352,12 @@ def _check_and_enable_flash_attn_2( "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." ) - config.attn_implementation = "flash_attention_2" + if enable: + config.attn_implementation = "flash_attention_2" return config @classmethod - def _check_and_enable_sdpa(cls, config) -> PretrainedConfig: + def _check_and_enable_sdpa(cls, config, enable: bool = True) -> PretrainedConfig: """ Enables the use of SDPA natively in Transformers if supported by the model, and if BetterTransformer is not being used. @@ -1364,7 +1369,8 @@ def _check_and_enable_sdpa(cls, config) -> PretrainedConfig: if _is_bettertransformer: return config - config.attn_implementation = "sdpa" + if enable: + config.attn_implementation = "sdpa" return config def enable_input_require_grads(self): @@ -2607,10 +2613,6 @@ def from_pretrained( adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) - # TODO: remove this one once we have a proper config.attn_implementation setter. - # This only temporary for testing. SDPA should otherwise should be transparent to the user. - _use_sdpa = kwargs.pop("_use_sdpa", True) - if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -3285,12 +3287,29 @@ def from_pretrained( elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: init_contexts.append(init_empty_weights()) + if ( + hasattr(config, "attn_implementation") + and config.attn_implementation != "flash_attention_2" + and use_flash_attention_2 + ): + raise ValueError( + f"Both config.attn_implementation ({config.attn_implementation}) and use_flash_attention_2=True were passed to from_pretrained and are incompatible." + ) + + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config. + if hasattr(config, "attn_implementation"): + auto_dispatch_attention = False + else: + auto_dispatch_attention = True + if use_flash_attention_2: - config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map) - elif _use_sdpa and is_torch_sdpa_available(): + config = cls._check_and_enable_flash_attn_2( + config, torch_dtype=torch_dtype, device_map=device_map, enable=auto_dispatch_attention + ) + elif is_torch_sdpa_available(): # use_flash_attention_2 takes priority. - config = cls._check_and_enable_sdpa(config) - else: + config = cls._check_and_enable_sdpa(config, enable=auto_dispatch_attention) + elif not hasattr(config, "attn_implementation"): config.attn_implementation = "eager" with ContextManagers(init_contexts): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 83178a724c6594..65195c0c7df981 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -1860,7 +1860,11 @@ def generate( @classmethod def _check_and_enable_flash_attn_2( - cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + enable: bool = True, ): """ `_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model @@ -1880,7 +1884,7 @@ def _check_and_enable_flash_attn_2( If all checks pass, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module """ - config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map) + config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map, enable=enable) config.semantic_config.attn_implementation = config.attn_implementation config.coarse_acoustics_config.attn_implementation = config.attn_implementation diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 630f8856133328..710cc5ecaaafcd 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -15,6 +15,7 @@ """ Testing suite for the PyTorch LLaMA model. """ +import copy import unittest import pytest @@ -430,6 +431,7 @@ def test_eager_matches_sdpa_generate(self): max_new_tokens = 30 tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf") model_sdpa = LlamaForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", @@ -437,12 +439,14 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, ).to(torch_device) - # TODO: replace the _use_sdpa=False by `model.config.attn_implementation = "eager"` once the setter is implemented. + # Force using the eager implementation. + cfg = copy.deepcopy(config) + cfg.attn_implementation = "eager" model_eager = LlamaForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", + config=cfg, torch_dtype=torch.float16, low_cpu_mem_usage=True, - _use_sdpa=False, ).to(torch_device) for name, submodule in model_eager.named_modules(): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fb83ac2e151d09..a3210524e4bc10 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3150,14 +3150,20 @@ def test_eager_matches_sdpa_inference(self, padding_side: str): ) model_sdpa.to(torch_device) - # TODO: replace the _use_sdpa=False by `model.config.attn_implementation = "eager"` once the setter is implemented. + # Force using the eager implementation. + cfg = copy.deepcopy(config) + cfg.attn_implementation = "eager" model_eager = model_class.from_pretrained( tmpdirname, + config=cfg, torch_dtype=torch.bfloat16, - _use_sdpa=False, ) model_eager.to(torch_device) + for name, submodule in model_eager.named_modules(): + if "SDPA" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + dummy_input = inputs_dict[model.main_input_name][:1] if dummy_input.dtype in [torch.float32, torch.float16]: dummy_input = dummy_input.to(torch.bfloat16) @@ -3271,12 +3277,14 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, ).to(torch_device) - # TODO: replace the _use_sdpa=False by `model.config.attn_implementation = "eager"` once the setter is implemented. + # Force using the eager implementation. + cfg = copy.deepcopy(config) + cfg.attn_implementation = "eager" model_eager = model_class.from_pretrained( tmpdirname, + config=cfg, torch_dtype=torch.float16, low_cpu_mem_usage=True, - _use_sdpa=False, ).to(torch_device) for name, submodule in model_eager.named_modules(): From 098a62e10f3ad0491b0bc057b39e0700039223c2 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 18:24:26 +0000 Subject: [PATCH 031/100] fix is_torch_sdpa_available when torch is not installed --- src/transformers/utils/import_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index cf523c0ec4b2ae..23a539ebaf8686 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -262,6 +262,11 @@ def get_torch_version(): def is_torch_sdpa_available(): + if not is_torch_available(): + return False + elif _torch_version == "N/A": + return False + # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons: # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259 # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310 From b960912852c6a3fd16cff8dd3cc13aecf22b6c1f Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 Nov 2023 18:40:50 +0000 Subject: [PATCH 032/100] remove dead code --- src/transformers/models/bart/modeling_bart.py | 7 ------- .../models/blenderbot/modeling_blenderbot.py | 8 -------- .../blenderbot_small/modeling_blenderbot_small.py | 8 -------- src/transformers/models/m2m_100/modeling_m2m_100.py | 8 -------- src/transformers/models/marian/modeling_marian.py | 8 -------- src/transformers/models/plbart/modeling_plbart.py | 10 +--------- .../models/speech_to_text/modeling_speech_to_text.py | 8 -------- .../modeling_time_series_transformer.py | 9 --------- 8 files changed, 1 insertion(+), 65 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 2b9e120d9afb3f..07b095790d83c7 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -16,7 +16,6 @@ import copy import math import warnings -from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -605,12 +604,6 @@ def forward( } -class BartAttentionType(Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index d6b2b83a373dfe..1ebb89eab4023d 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -19,7 +19,6 @@ import math import os import warnings -from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -256,13 +255,6 @@ def forward( BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention} -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Blenderbot -class BlenderbotAttentionType(Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT class BlenderbotEncoderLayer(nn.Module): def __init__(self, config: BlenderbotConfig): diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 7ef465f571a9de..18a677f99e2100 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -17,7 +17,6 @@ import copy import math -from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -327,13 +326,6 @@ def forward( } -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->BlenderbotSmall -class BlenderbotSmallAttentionType(Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallDecoderLayer(nn.Module): def __init__(self, config: BlenderbotSmallConfig): diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 8491f45480d602..697fbfad48635e 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -16,7 +16,6 @@ import math -from enum import Enum from typing import List, Optional, Tuple, Union import torch @@ -395,13 +394,6 @@ def forward( M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention} -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->M2M100 -class M2M100AttentionType(Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 class M2M100DecoderLayer(nn.Module): def __init__(self, config: M2M100Config): diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 43d6b4f1cd0918..4d976749520134 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -17,7 +17,6 @@ import copy import math -from enum import Enum from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -342,13 +341,6 @@ def forward( MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention} -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Marian -class MarianAttentionType(Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN class MarianDecoderLayer(nn.Module): def __init__(self, config: MarianConfig): diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index cfd9833faa7bb2..3b52cf363ab1fc 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -15,7 +15,6 @@ """ PyTorch PLBART model.""" import copy import math -from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -338,14 +337,7 @@ def forward( # TODO: Implement attention with SDPA for PLBart. -PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention, "sdpa": PLBartAttention} - - -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->PLBart -class PLBartAttentionType(Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" +PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention} # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 6dbf94382c323d..c08dc7e51af67d 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -15,7 +15,6 @@ """ PyTorch Speech2Text model.""" import math -from enum import Enum from typing import Optional, Tuple, Union import torch @@ -335,13 +334,6 @@ def forward( SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention} -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->Speech2Text -class Speech2TextAttentionType(Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT class Speech2TextEncoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index ce7bc90a55b5db..6afd95f00ced09 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -15,7 +15,6 @@ # limitations under the License. """ PyTorch Time Series Transformer model.""" -from enum import Enum from typing import List, Optional, Tuple, Union import numpy as np @@ -509,17 +508,9 @@ def forward( # TODO: Implement attention with SDPA for TimeSeriesTransformer. TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = { "eager": TimeSeriesTransformerAttention, - "sdpa": TimeSeriesTransformerAttention, } -# Copied from transformers.models.bart.modeling_bart.BartAttentionType with Bart->TimeSeriesTransformer -class TimeSeriesTransformerAttentionType(Enum): - eager = "eager" - sdpa = "sdpa" - flash_attention_2 = "flash_attention_2" - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerDecoderLayer(nn.Module): def __init__(self, config: TimeSeriesTransformerConfig): From f1df402e679e3b589b2cd61297c97de0047d5eb3 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 02:10:03 +0900 Subject: [PATCH 033/100] Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_attn_mask_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index d1344f507b7a1a..a1bba57e899141 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -193,7 +193,8 @@ def _unmask_unattended( using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. Details: https://github.com/pytorch/pytorch/issues/110213 - expanded_mask is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. attention_mask + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. The dimension num_masks is most often 1, but it can also be the number of heads in the case of alibi. From f49c2a314ddebc251319e2cf85798b6b317cbdd9 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 02:11:09 +0900 Subject: [PATCH 034/100] Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_attn_mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index a1bba57e899141..79700b915344f8 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -321,7 +321,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( sliding_window: Optional[int] = None, ): """ - Prepares the correct attn_mask argument to be used by torch.nn.functional.scaled_dot_product_attention. + Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. We ignore the attention mask in some cases for batch_size = 1 to allow to dispatch to the flash attention kernel. From 3a22d8dfd95e17bc52a92273c34d76bd10248864 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 02:12:26 +0900 Subject: [PATCH 035/100] Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_attn_mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 79700b915344f8..629fce4d5b0245 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -325,7 +325,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( We ignore the attention mask in some cases for batch_size = 1 to allow to dispatch to the flash attention kernel. - Note that as of PyTorch 2.1, SDPA can not dispatch to flash attention in case an attention mask is passed. A + Note that as of PyTorch 2.1, SDPA cannot dispatch to flash attention if an attention mask is passed. A possible solution is to use nested tensors. """ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) From f0fa993ab9cff94985f16a945b60a223c2735337 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 02:12:46 +0900 Subject: [PATCH 036/100] Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_attn_mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 629fce4d5b0245..8655d5d7d73dec 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -341,7 +341,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( elif key_value_length == query_length: attention_mask = None else: - # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. # Reference: https://github.com/pytorch/pytorch/issues/108108 pass From f0840401e7f8c6d4363d82606a9ef0d48946301e Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 02:13:18 +0900 Subject: [PATCH 037/100] Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_attn_mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 8655d5d7d73dec..ea3ae5c014806c 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -342,7 +342,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( attention_mask = None else: # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. # Reference: https://github.com/pytorch/pytorch/issues/108108 pass From 885bbe42fc519648b8d1c851d30c8a8c9274a1f5 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 02:16:56 +0900 Subject: [PATCH 038/100] Update src/transformers/models/bart/modeling_bart.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/bart/modeling_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 07b095790d83c7..33022337e08103 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1145,7 +1145,7 @@ def forward( if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, and we fall back on + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) From 4dd55234b23bdce38bf27452e188b3f41510ca10 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:24:59 +0000 Subject: [PATCH 039/100] remove duplicate pretraining_tp code --- .../models/llama/modeling_llama.py | 79 ++++++++----------- 1 file changed, 32 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6d550f6b3cd2de..13c8bb49902c61 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -294,6 +294,7 @@ def __init__(self, config: LlamaConfig): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True + self.pretraining_tp = config.pretraining_tp if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -337,6 +338,25 @@ def _init_rope(self): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def compute_qkv_pretraining_tp(self, hidden_states: torch.Tensor): + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + return query_states, key_states, value_states + def forward( self, hidden_states: torch.Tensor, @@ -354,23 +374,8 @@ def forward( bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - + if self.pretraining_tp > 1: + query_states, key_states, value_states = self.compute_qkv_pretraining_tp(hidden_states) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -426,10 +431,10 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) @@ -662,28 +667,8 @@ def forward( bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - torch.nn.functional.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - torch.nn.functional.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - torch.nn.functional.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) + if self.pretraining_tp > 1: + query_states, key_states, value_states = self.compute_qkv_pretraining_tp(hidden_states) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -729,10 +714,10 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) From 349c99bb8a17a1c33fbb6cdd7a4f434b2d76feb3 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:28:28 +0000 Subject: [PATCH 040/100] add dropout in llama --- src/transformers/models/llama/modeling_llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 13c8bb49902c61..3f41f4b8ef167e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -504,7 +504,7 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - dropout_rate = 0.0 if not self.training else self.attention_dropout + dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -705,8 +705,7 @@ def forward( key_states, value_states, attn_mask=attention_mask, - # Llama does not use dropout in the attention, hence the hard-coded dropout_p=0.0 independent of self.training. - dropout_p=0.0, + dropout_p=self.attention_dropout if self.training else 0.0, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=self.is_causal and attention_mask is None and q_len > 1, ) From 5e56014dc56bc86b23a4fb01c1186734d70fb17c Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:32:00 +0000 Subject: [PATCH 041/100] precise comment on attn_mask --- src/transformers/modeling_attn_mask_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index ea3ae5c014806c..c2e6d2bf5f4c21 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -323,10 +323,9 @@ def _prepare_4d_causal_attention_mask_for_sdpa( """ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. - We ignore the attention mask in some cases for batch_size = 1 to allow to dispatch to the flash attention kernel. - - Note that as of PyTorch 2.1, SDPA cannot dispatch to flash attention if an attention mask is passed. A - possible solution is to use nested tensors. + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). """ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) From 951f70e61512057d70c767d31efdde2eed17fa11 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:35:40 +0000 Subject: [PATCH 042/100] add fmt: off for _unmask_unattended docstring --- src/transformers/modeling_attn_mask_utils.py | 49 ++++++++++---------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index c2e6d2bf5f4c21..eed662a72d9cb1 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -188,48 +188,49 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] def _unmask_unattended( expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] ): + # fmt: off """ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. Details: https://github.com/pytorch/pytorch/issues/110213 `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. - `attention_mask` - is [bsz, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. The dimension num_masks is most often 1, but it can also be the number of heads in the case of alibi. - If attention_mask is + For example, if `attention_mask` is ``` [[0, 0, 1] - [1, 1, 1] - [0, 1, 1]] + [1, 1, 1] + [0, 1, 1]] ``` - and expanded_mask is (e.g. here left-padding case) + and `expanded_mask` is (e.g. here left-padding case) ``` [[[[0, 0, 0], - [0, 0, 0], - [0, 0, 1]]], - [[[1, 0, 0], - [1, 1, 0], - [1, 1, 1]]], - [[[0, 0, 0], - [0, 1, 0], - [0, 1, 1]]]] + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] ``` - then the modified expanded_mask will be + then the modified `expanded_mask` will be ``` [[[[1, 1, 1], <-- modified - [1, 1, 1], <-- modified - [0, 0, 1]]], - [[[1, 0, 0], - [1, 1, 0], - [1, 1, 1]]], - [[[1, 1, 1], <-- modified - [0, 1, 0], - [0, 1, 1]]]] + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] ``` """ + # Get the index of the first non-zero value for every sample in the batch. # In the above example, indices = [[2], [0], [1]]] tmp = torch.arange(attention_mask.shape[1], 0, -1) @@ -323,7 +324,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( """ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. - In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). """ From c4e207e8b97d3ac0fa9e9be46a4a9716c470e7a8 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:40:04 +0000 Subject: [PATCH 043/100] precise num_masks comment --- src/transformers/modeling_attn_mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index eed662a72d9cb1..dcc2e86a612540 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -197,7 +197,7 @@ def _unmask_unattended( `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. `attention_mask` is [bsz, src_seq_len]. - The dimension num_masks is most often 1, but it can also be the number of heads in the case of alibi. + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. For example, if `attention_mask` is ``` From e752d93f0e17da585aacb03bae46f452d1ba7d84 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:45:11 +0000 Subject: [PATCH 044/100] nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion --- .../models/llama/modeling_llama.py | 61 ++++++++----------- 1 file changed, 24 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3f41f4b8ef167e..5addb0058ecc55 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -338,25 +338,6 @@ def _init_rope(self): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def compute_qkv_pretraining_tp(self, hidden_states: torch.Tensor): - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - return query_states, key_states, value_states - def forward( self, hidden_states: torch.Tensor, @@ -374,8 +355,22 @@ def forward( bsz, q_len, _ = hidden_states.size() - if self.pretraining_tp > 1: - query_states, key_states, value_states = self.compute_qkv_pretraining_tp(hidden_states) + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -431,10 +426,10 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) @@ -667,12 +662,9 @@ def forward( bsz, q_len, _ = hidden_states.size() - if self.pretraining_tp > 1: - query_states, key_states, value_states = self.compute_qkv_pretraining_tp(hidden_states) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -713,12 +705,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value From a072c5d7330df1828628cef1473256234e68e88f Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 11:37:29 +0000 Subject: [PATCH 045/100] cleanup modeling_utils --- src/transformers/configuration_utils.py | 27 +++-- src/transformers/modeling_utils.py | 147 +++++++++++++----------- 2 files changed, 90 insertions(+), 84 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 54e5ba7485ef96..b224a397a079e9 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -236,6 +236,8 @@ class PretrainedConfig(PushToHubMixin): This attribute is currently not being used during model loading time, but this may change in the future versions. But we can already start preparing for the future by saving the dtype with save_pretrained. + attn_implementation (`str`, *optional*): + The attention implementation to use in the model. Can be any of "eager" (manual implementation of the attention), "sdpa" (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or "flash_attention_2" (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. > TensorFlow specific parameters @@ -374,6 +376,9 @@ def __init__(self, **kwargs): # Config hash self._commit_hash = kwargs.pop("_commit_hash", None) + # Attention implementation to use, if relevant. + self._attn_implementation = kwargs.pop("attn_implementation", None) + # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) @@ -424,19 +429,13 @@ def num_labels(self, num_labels: int): @property def attn_implementation(self): - if not hasattr(self, "_attn_implementation"): - return "eager" - else: - return self._attn_implementation + return self._attn_implementation @attn_implementation.setter def attn_implementation(self, value): - if hasattr(self, "attn_implementation_set") and self.attn_implementation_set: - raise NotImplementedError( - "Modifying the attention implementation through this attribute is currently not implemented." - ) - self.attn_implementation_set = True - + # No specific check is implemented here, as we want to allow syntax as `config.attn_implementation = "flash_attention_2"` before the model + # loading. + # Modifying this property alone on an already loaded model (model.config) has no impact, `model.use_attn_implementation("flash_attention_2")` should be used instead. self._attn_implementation = value def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): @@ -878,8 +877,8 @@ def to_diff_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(serializable_config_dict) - if "attention_implementation" in serializable_config_dict: - del serializable_config_dict["attention_implementation"] + if "_attn_implementation" in serializable_config_dict: + del serializable_config_dict["_attn_implementation"] return serializable_config_dict @@ -897,8 +896,8 @@ def to_dict(self) -> Dict[str, Any]: del output["_auto_class"] if "_commit_hash" in output: del output["_commit_hash"] - if "attention_implementation" in output: - del output["attention_implementation"] + if "_attn_implementation" in output: + del output["_attn_implementation"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a918f2f2b76379..c308cd56741d0a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1152,32 +1152,12 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): ) # Save config and origin of the pretrained weights if given in model self.config = config + self.config = self._autoset_attn_implementation(self.config, torch_dtype=torch.get_default_dtype(), check_device_map=False) + self.name_or_path = config.name_or_path self.warnings_issued = {} self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None - # TODO: This is TEMPORARY and need to be discussed, should it rather be in XXXPreTrainedModel __init__? - if config.attn_implementation == "flash_attention_2": - if not self._supports_flash_attn_2: - raise ValueError( - f'Passed config.attn_implementation == "flash_attention_2" but {self.__class__.__name__} does not support Flash Attention yet.' - ) - - if not is_flash_attn_2_available(): - raise ImportError( - "Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" - " installing it. Make sure to have at least the version 2.1.0" - ) - - if config.attn_implementation == "sdpa": - if not self._supports_sdpa: - raise ValueError( - f'Passed config.attn_implementation == "sdpa" but {self.__class__.__name__} does not support SDPA yet.' - ) - - if not is_torch_sdpa_available(): - raise ImportError("SDPA is not available. Please use torch>=2.1.1 in order to use SDPA.") - def post_init(self): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's @@ -1211,8 +1191,7 @@ def _from_config(cls, config, **kwargs): if torch_dtype is not None: dtype_orig = cls._set_default_torch_dtype(torch_dtype) - if use_flash_attention_2: - config = cls._check_and_enable_flash_attn_2(config, torch_dtype) + config = cls._autoset_attn_implementation(config, use_flash_attention_2=use_flash_attention_2, check_device_map=False) if is_deepspeed_zero3_enabled(): import deepspeed @@ -1231,6 +1210,61 @@ def _from_config(cls, config, **kwargs): return model + def use_attn_implementation(self, attn_implementation: str): + """ + Specifies the attention implementation to use in the model. + + Args: + attn_implementation (`str`): + The attention implementation to use. Can be any of "eager" (manual implementation of the attention), "sdpa" (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or "flash_attention_2" (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). + """ + # TODO: Implement it. An implementation could be to define `self._eager_attn_class = XXXAttention`, `self._sdpa_attn_class = XXXSdpaAttention`, `self._flash_attn_class = XXXFlashAttention2` in the __init__ of XXXPreTrainedModel, and leverage those attributes here to replace the correct submodules. + raise NotImplementedError("model.use_attn_implementation is currently not implemented.") + + if attn_implementation == "sdpa": + self.config = self._check_and_enable_sdpa(self.config, enable=True) + elif attn_implementation == "flash_attention_2": + # TODO: define torch_dtype properly + torch_dtype = None + self.config = self._check_and_enable_flash_attn_2( + self.config, torch_dtype=torch_dtype, device_map=getattr(self, "hf_device_map", None), enable=True + ) + + @classmethod + def _autoset_attn_implementation(cls, config, use_flash_attention_2: Optional[bool] = None, torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True): + """ + Automatically checks and dispatches to a default attention implementation. In order of priority: + 1. An implementation specified in `config.attn_implementation`. + 2. If specified, flash attention through use_flash_attention_2=True. + 3. SDPA implementation, if available and supported by the model type. + 4. Manual implementation otherwise. + """ + config = copy.deepcopy(config) # We do not want to modify the config inplace. + + if config.attn_implementation is None: + auto_dispatch_attention = True + else: + if (config.attn_implementation != "flash_attention_2" and use_flash_attention_2): + raise ValueError( + f'Both config.attn_implementation ("{config.attn_implementation}") and use_flash_attention_2=True are used, and are incompatible.' + ) + + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config. + auto_dispatch_attention = False + + if use_flash_attention_2: + cls._check_and_enable_flash_attn_2( + config, torch_dtype=torch_dtype, device_map=device_map, enable=auto_dispatch_attention, check_device_map=check_device_map, + ) + elif is_torch_sdpa_available() and cls._supports_sdpa: + # use_flash_attention_2 takes priority over SDPA. + config = cls._check_and_enable_sdpa(config, enable=auto_dispatch_attention) + elif auto_dispatch_attention: + config.attn_implementation = "eager" + + return config + @classmethod def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: """ @@ -1285,25 +1319,17 @@ def _check_and_enable_flash_attn_2( config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, enable: bool = True, ) -> PretrainedConfig: """ - If you don't know about Flash Attention, check out the official repository of flash attention: - https://github.com/Dao-AILab/flash-attention - - For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this - specific section of the documentation to learn more about it: - https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models + Checks the availability of Flash Attention 2 and compatibility with the current model. - The method checks if the current setup is compatible with Flash Attention as it requires the model to be in - half precision and not ran on CPU. - - If all checks pass, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model - can initialize the correct attention module + If all checks pass and `enable` is True, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. """ if not cls._supports_flash_attn_2: raise ValueError( - "The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to " + f"{cls.__name__} does not support Flash Attention 2.0 yet. Please open an issue on GitHub to " "request support for this architecture: https://github.com/huggingface/transformers/issues/new" ) @@ -1337,20 +1363,23 @@ def _check_and_enable_flash_attn_2( " unexpected behaviour." ) - if device_map is None: + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": if torch.cuda.is_available(): logger.warning( - "You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU" + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU" " after initializing it on CPU with `model.to('cuda')`." ) else: raise ValueError( - "You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. " + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. " "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " "or initialising the model on CPU and then moving it to GPU." ) elif ( - device_map is not None + check_device_map + and device_map is not None and isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()) ): @@ -1365,11 +1394,15 @@ def _check_and_enable_flash_attn_2( @classmethod def _check_and_enable_sdpa(cls, config, enable: bool = True) -> PretrainedConfig: """ - Enables the use of SDPA natively in Transformers if supported by the model, and if BetterTransformer is not - being used. + Checks the availability of SDPA for a given model. + + If all checks pass and `enable` is True, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. """ if not cls._supports_sdpa: - return config + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention. Please open an issue on GitHub to " + "request support for this architecture: https://github.com/huggingface/transformers/issues/new" + ) _is_bettertransformer = getattr(cls, "use_bettertransformer", False) if _is_bettertransformer: @@ -2852,9 +2885,6 @@ def from_pretrained( else: model_kwargs = kwargs - # We do not want to modify inplace the PretrainedConfig passed to from_pretrained. - config = copy.deepcopy(config) - quantizer = None quantization_method_from_config = None if hasattr(config, "quantization_config"): @@ -3314,30 +3344,7 @@ def from_pretrained( elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: init_contexts.append(init_empty_weights()) - if ( - hasattr(config, "attn_implementation") - and config.attn_implementation != "flash_attention_2" - and use_flash_attention_2 - ): - raise ValueError( - f"Both config.attn_implementation ({config.attn_implementation}) and use_flash_attention_2=True were passed to from_pretrained and are incompatible." - ) - - # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config. - if hasattr(config, "attn_implementation"): - auto_dispatch_attention = False - else: - auto_dispatch_attention = True - - if use_flash_attention_2: - config = cls._check_and_enable_flash_attn_2( - config, torch_dtype=torch_dtype, device_map=device_map, enable=auto_dispatch_attention - ) - elif is_torch_sdpa_available(): - # use_flash_attention_2 takes priority. - config = cls._check_and_enable_sdpa(config, enable=auto_dispatch_attention) - elif not hasattr(config, "attn_implementation"): - config.attn_implementation = "eager" + config = cls._autoset_attn_implementation(config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map) with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) From f7009735f8704c2dccd20049aea5d8bf58e6ce3a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 13:07:04 +0000 Subject: [PATCH 046/100] backward compatibility --- src/transformers/configuration_utils.py | 11 +++++++++-- src/transformers/modeling_utils.py | 10 ++++++---- tests/test_configuration_utils.py | 2 +- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index b224a397a079e9..41fd86ae25901e 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -429,12 +429,19 @@ def num_labels(self, num_labels: int): @property def attn_implementation(self): - return self._attn_implementation + if hasattr(self, "_attn_implementation"): + if self._attn_implementation is None: + # `config.attn_implementation` should never be None, for backward compatibility. + return "eager" + else: + return self._attn_implementation + else: + return None @attn_implementation.setter def attn_implementation(self, value): # No specific check is implemented here, as we want to allow syntax as `config.attn_implementation = "flash_attention_2"` before the model - # loading. + # loading. Proper implementation/external library availability checks are done at load time. # Modifying this property alone on an already loaded model (model.config) has no impact, `model.use_attn_implementation("flash_attention_2")` should be used instead. self._attn_implementation = value diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c308cd56741d0a..1688fe671b52c6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1151,8 +1151,8 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" ) # Save config and origin of the pretrained weights if given in model + config = self._autoset_attn_implementation(config, torch_dtype=torch.get_default_dtype(), check_device_map=False) self.config = config - self.config = self._autoset_attn_implementation(self.config, torch_dtype=torch.get_default_dtype(), check_device_map=False) self.name_or_path = config.name_or_path self.warnings_issued = {} @@ -1191,6 +1191,7 @@ def _from_config(cls, config, **kwargs): if torch_dtype is not None: dtype_orig = cls._set_default_torch_dtype(torch_dtype) + config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. config = cls._autoset_attn_implementation(config, use_flash_attention_2=use_flash_attention_2, check_device_map=False) if is_deepspeed_zero3_enabled(): @@ -1240,9 +1241,9 @@ def _autoset_attn_implementation(cls, config, use_flash_attention_2: Optional[bo 3. SDPA implementation, if available and supported by the model type. 4. Manual implementation otherwise. """ - config = copy.deepcopy(config) # We do not want to modify the config inplace. - - if config.attn_implementation is None: + # Here we use config._attn_implementation to check whether the attention implementation was explicitely set by the user. + # The property `PretrainedConfig.attn_implementation` is never `None`, for backward compatibility. + if config._attn_implementation is None: auto_dispatch_attention = True else: if (config.attn_implementation != "flash_attention_2" and use_flash_attention_2): @@ -3344,6 +3345,7 @@ def from_pretrained( elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: init_contexts.append(init_empty_weights()) + config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. config = cls._autoset_attn_implementation(config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map) with ContextManagers(init_contexts): diff --git a/tests/test_configuration_utils.py b/tests/test_configuration_utils.py index a6e9e6b0390abe..8286bc4a36df95 100644 --- a/tests/test_configuration_utils.py +++ b/tests/test_configuration_utils.py @@ -198,7 +198,7 @@ def test_config_common_kwargs_is_complete(self): missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs] # If this part of the test fails, you have arguments to addin config_common_kwargs above. self.assertListEqual( - missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"] + missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "_attn_implementation", "transformers_version"] ) keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)] if len(keys_with_defaults) > 0: From e2677643b0dfb615bc69a99160c1b91f9ff5b9a4 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 13:19:15 +0000 Subject: [PATCH 047/100] fix style as requested --- src/transformers/models/bart/modeling_bart.py | 6 +++--- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 6 +++--- .../models/llama/modeling_llama.py | 18 ++++++++++-------- .../models/whisper/modeling_whisper.py | 8 ++++---- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 33022337e08103..c7a3e03bded42f 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -494,7 +494,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -class BartSDPAAttention(BartAttention): +class BartSdpaAttention(BartAttention): def forward( self, hidden_states: torch.Tensor, @@ -509,7 +509,7 @@ def forward( if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "BartModel is using BartSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True or layer_head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." ) return super().forward( hidden_states, @@ -599,7 +599,7 @@ def forward( BART_ATTENTION_CLASSES = { "eager": BartAttention, - "sdpa": BartSDPAAttention, + "sdpa": BartSdpaAttention, "flash_attention_2": BartFlashAttention2, } diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 8d67f339e68010..c620ec7c1be3cf 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -495,7 +495,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -class GPTBigCodeSDPAAttention(GPTBigCodeAttention): +class GPTBigCodeSdpaAttention(GPTBigCodeAttention): def _attn(self, query, key, value, attention_mask=None, head_mask=None): if head_mask is not None: # The super dispatch is done in the forward. @@ -606,7 +606,7 @@ def forward( else: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "GPTBigCodeModel is using GPTBigCodeSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True and head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." ) attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) @@ -646,7 +646,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl GPTBIGCODE_ATTENTION_CLASSES = { "eager": GPTBigCodeAttention, "flash_attention_2": GPTBigCodeFlashAttention2, - "sdpa": GPTBigCodeSDPAAttention, + "sdpa": GPTBigCodeSdpaAttention, } diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5addb0058ecc55..aa1383ef965820 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -629,7 +629,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -class LlamaSDPAAttention(LlamaAttention): +class LlamaSdpaAttention(LlamaAttention): """ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to @@ -649,7 +649,7 @@ def forward( if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "LlamaModel is using LlamaSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." ) return super().forward( hidden_states=hidden_states, @@ -710,17 +710,19 @@ def forward( return attn_output, None, past_key_value +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - if config.attn_implementation == "flash_attention_2": - self.self_attn = LlamaFlashAttention2(config=config) - elif config.attn_implementation == "sdpa": - self.self_attn = LlamaSDPAAttention(config=config) - else: - self.self_attn = LlamaAttention(config=config) + self.self_attn = LLAMA_ATTENTION_CLASSES[config.attn_implementation](config=config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 5a1e731f09155e..8cab83ceac911a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -672,8 +672,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -class WhisperSDPAAttention(WhisperAttention): - # Copied from transformers.models.bart.modeling_bart.BartSDPAAttention.forward with BART->whisper, Bart->Whisper +class WhisperSdpaAttention(WhisperAttention): + # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper def forward( self, hidden_states: torch.Tensor, @@ -688,7 +688,7 @@ def forward( if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "WhisperModel is using WhisperSDPAAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True or layer_head_mask not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + "WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." ) return super().forward( hidden_states, @@ -779,7 +779,7 @@ def forward( WHISPER_ATTENTION_CLASSES = { "eager": WhisperAttention, "flash_attention_2": WhisperFlashAttention2, - "sdpa": WhisperSDPAAttention, + "sdpa": WhisperSdpaAttention, } From d044d8189259d3ec998ef6c78a51d74a0824bf55 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 13:27:09 +0000 Subject: [PATCH 048/100] style --- src/transformers/modeling_utils.py | 30 ++++++++++++++----- .../models/falcon/modeling_falcon.py | 24 +++++++-------- tests/test_configuration_utils.py | 3 +- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1688fe671b52c6..bd6fbae58d69b8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1151,7 +1151,9 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" ) # Save config and origin of the pretrained weights if given in model - config = self._autoset_attn_implementation(config, torch_dtype=torch.get_default_dtype(), check_device_map=False) + config = self._autoset_attn_implementation( + config, torch_dtype=torch.get_default_dtype(), check_device_map=False + ) self.config = config self.name_or_path = config.name_or_path @@ -1192,7 +1194,9 @@ def _from_config(cls, config, **kwargs): dtype_orig = cls._set_default_torch_dtype(torch_dtype) config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. - config = cls._autoset_attn_implementation(config, use_flash_attention_2=use_flash_attention_2, check_device_map=False) + config = cls._autoset_attn_implementation( + config, use_flash_attention_2=use_flash_attention_2, check_device_map=False + ) if is_deepspeed_zero3_enabled(): import deepspeed @@ -1232,8 +1236,14 @@ def use_attn_implementation(self, attn_implementation: str): ) @classmethod - def _autoset_attn_implementation(cls, config, use_flash_attention_2: Optional[bool] = None, torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True): + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: Optional[bool] = None, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): """ Automatically checks and dispatches to a default attention implementation. In order of priority: 1. An implementation specified in `config.attn_implementation`. @@ -1246,7 +1256,7 @@ def _autoset_attn_implementation(cls, config, use_flash_attention_2: Optional[bo if config._attn_implementation is None: auto_dispatch_attention = True else: - if (config.attn_implementation != "flash_attention_2" and use_flash_attention_2): + if config.attn_implementation != "flash_attention_2" and use_flash_attention_2: raise ValueError( f'Both config.attn_implementation ("{config.attn_implementation}") and use_flash_attention_2=True are used, and are incompatible.' ) @@ -1256,7 +1266,11 @@ def _autoset_attn_implementation(cls, config, use_flash_attention_2: Optional[bo if use_flash_attention_2: cls._check_and_enable_flash_attn_2( - config, torch_dtype=torch_dtype, device_map=device_map, enable=auto_dispatch_attention, check_device_map=check_device_map, + config, + torch_dtype=torch_dtype, + device_map=device_map, + enable=auto_dispatch_attention, + check_device_map=check_device_map, ) elif is_torch_sdpa_available() and cls._supports_sdpa: # use_flash_attention_2 takes priority over SDPA. @@ -3346,7 +3360,9 @@ def from_pretrained( init_contexts.append(init_empty_weights()) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. - config = cls._autoset_attn_implementation(config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map) + config = cls._autoset_attn_implementation( + config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map + ) with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index f5b75695df2a2b..1046da1ca87551 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -465,16 +465,16 @@ def forward( attn_output = attn_output.permute(0, 2, 1, 3) attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - output_tensor = self.dense(attn_output) + attn_output = self.dense(attn_output) if output_attentions: - return output_tensor, present, attention_scores + return attn_output, present, attention_scores else: - return output_tensor, present + return attn_output, present else: if hasattr(F, "scaled_dot_product_attention") and not output_attentions and head_mask is None: - context_layer = torch.nn.functional.scaled_dot_product_attention( + attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, @@ -482,10 +482,10 @@ def forward( dropout_p=self.attention_dropout.p if self.training else 0.0, is_causal=self.is_causal and attention_mask is None and query_length > 1, ) - context_layer = context_layer.transpose(1, 2) - context_layer = context_layer.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - output_tensor = self.dense(context_layer) + attn_output = self.dense(attn_output) else: matmul_result = query_layer @ key_layer.transpose(-1, -2) @@ -511,17 +511,17 @@ def forward( attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1) + attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) # change view [batch_size, q_length, num_heads * head_dim] - context_layer = self._merge_heads(context_layer) + attn_output = self._merge_heads(attn_output) - output_tensor = self.dense(context_layer) + attn_output = self.dense(attn_output) if output_attentions: - return output_tensor, present, attention_probs + return attn_output, present, attention_probs else: - return output_tensor, present + return attn_output, present class FalconFlashAttention2(FalconAttention): diff --git a/tests/test_configuration_utils.py b/tests/test_configuration_utils.py index 8286bc4a36df95..df521efb180ac8 100644 --- a/tests/test_configuration_utils.py +++ b/tests/test_configuration_utils.py @@ -198,7 +198,8 @@ def test_config_common_kwargs_is_complete(self): missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs] # If this part of the test fails, you have arguments to addin config_common_kwargs above. self.assertListEqual( - missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "_attn_implementation", "transformers_version"] + missing_keys, + ["is_encoder_decoder", "_name_or_path", "_commit_hash", "_attn_implementation", "transformers_version"], ) keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)] if len(keys_with_defaults) > 0: From a9e760626499729349a1eeb6ff9c478e49ce3882 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:25:11 +0000 Subject: [PATCH 049/100] improve documentation --- docs/source/en/perf_infer_gpu_one.md | 24 ++++++++- .../models/idefics/modeling_idefics.py | 1 + tests/utils/test_doc_samples.py | 52 ++++++++++++++++++- 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 3ac29044ee510f..6069f926ff4f27 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -36,7 +36,19 @@ FlashAttention-2 is experimental and may change considerably in future versions. 1. additionally parallelizing the attention computation over sequence length 2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them -FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. +FlashAttention-2 is currently supported for the following architectures: +* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) +* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) +* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) +* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) +* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel) +* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) +* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) +* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) +* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) +* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) + +You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. Before you begin, make sure you have FlashAttention-2 installed (see the [installation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) guide for more details about prerequisites): @@ -165,7 +177,15 @@ model.save_pretrained("saved_model") ### FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention -PyTorch's `torch.nn.functional.scaled_dot_product_attention` (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is being added natively in Transformers, and you can check whether your model is using SDPA with the attribute `model.config.attn_implementation`. +PyTorch's `torch.nn.functional.scaled_dot_product_attention` (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available. + +For now, Transformers supports inference and training through SDPA for the following architectures: +* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) +* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) +* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) +* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) +* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) +* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) Note that FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it. diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 46672f2e26b723..6209d831a35895 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -960,6 +960,7 @@ class IdeficsPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of Idefics isn't meant for training from scratch - only diff --git a/tests/utils/test_doc_samples.py b/tests/utils/test_doc_samples.py index 84c5a4d2bf5008..46455aa81b89e3 100644 --- a/tests/utils/test_doc_samples.py +++ b/tests/utils/test_doc_samples.py @@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import doctest import logging import os import unittest +from glob import glob from pathlib import Path from typing import List, Union @@ -26,6 +26,56 @@ logger = logging.getLogger() +@require_torch +class TestDocLists(unittest.TestCase): + def test_flash_support_list(self): + with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f: + doctext = f.read() + + doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1] + doctext = doctext.split("You can request to add FlashAttention-2 support")[0] + + patterns = glob("./src/transformers/models/**/modeling_*.py") + patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py") + patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py") + patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax)) + archs_supporting_fa2 = [] + for filename in patterns: + with open(filename, "r") as f: + text = f.read() + + if "_supports_flash_attn_2 = True" in text: + model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "") + archs_supporting_fa2.append(model_name) + + for arch in archs_supporting_fa2: + if arch not in doctext: + raise ValueError(f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation.") + + def test_sdpa_support_list(self): + with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f: + doctext = f.read() + + doctext = doctext.split("For now, Transformers supports inference and training through SDPA for the following architectures:")[1] + doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0] + + patterns = glob("./src/transformers/models/**/modeling_*.py") + patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py") + patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py") + patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax)) + archs_supporting_sdpa = [] + for filename in patterns: + with open(filename, "r") as f: + text = f.read() + + if "_supports_sdpa = True" in text: + model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "") + archs_supporting_sdpa.append(model_name) + + for arch in archs_supporting_sdpa: + if arch not in doctext: + raise ValueError(f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation.") + @unittest.skip("Temporarily disable the doc tests.") @require_torch From 1727210af5d691cbb86cff90301c7d218f2a632d Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:31:08 +0000 Subject: [PATCH 050/100] test pass --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bd6fbae58d69b8..53f917f62a839b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1253,7 +1253,7 @@ def _autoset_attn_implementation( """ # Here we use config._attn_implementation to check whether the attention implementation was explicitely set by the user. # The property `PretrainedConfig.attn_implementation` is never `None`, for backward compatibility. - if config._attn_implementation is None: + if not hasattr(config, "_attn_implementation") or config._attn_implementation is None: auto_dispatch_attention = True else: if config.attn_implementation != "flash_attention_2" and use_flash_attention_2: From ae866808f50f5b9a1c3a473cea4bff5b8032b4e6 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:31:29 +0000 Subject: [PATCH 051/100] style --- tests/utils/test_doc_samples.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_doc_samples.py b/tests/utils/test_doc_samples.py index 46455aa81b89e3..953654537843ee 100644 --- a/tests/utils/test_doc_samples.py +++ b/tests/utils/test_doc_samples.py @@ -26,6 +26,7 @@ logger = logging.getLogger() + @require_torch class TestDocLists(unittest.TestCase): def test_flash_support_list(self): @@ -50,13 +51,17 @@ def test_flash_support_list(self): for arch in archs_supporting_fa2: if arch not in doctext: - raise ValueError(f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation.") + raise ValueError( + f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation." + ) def test_sdpa_support_list(self): with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f: doctext = f.read() - doctext = doctext.split("For now, Transformers supports inference and training through SDPA for the following architectures:")[1] + doctext = doctext.split( + "For now, Transformers supports inference and training through SDPA for the following architectures:" + )[1] doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0] patterns = glob("./src/transformers/models/**/modeling_*.py") @@ -74,7 +79,9 @@ def test_sdpa_support_list(self): for arch in archs_supporting_sdpa: if arch not in doctext: - raise ValueError(f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation.") + raise ValueError( + f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation." + ) @unittest.skip("Temporarily disable the doc tests.") From 5706ecb650cfcc93919e2e6023d5297b232db512 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 15:24:47 +0000 Subject: [PATCH 052/100] add _unmask_unattended tests --- src/transformers/modeling_attn_mask_utils.py | 4 +- tests/test_modeling_attn_mask_utils.py | 134 +++++++++++++++++++ 2 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 tests/test_modeling_attn_mask_utils.py diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index dcc2e86a612540..facf77b2ed9149 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -201,8 +201,8 @@ def _unmask_unattended( For example, if `attention_mask` is ``` - [[0, 0, 1] - [1, 1, 1] + [[0, 0, 1], + [1, 1, 1], [0, 1, 1]] ``` and `expanded_mask` is (e.g. here left-padding case) diff --git a/tests/test_modeling_attn_mask_utils.py b/tests/test_modeling_attn_mask_utils.py new file mode 100644 index 00000000000000..b60d2d5a69ddec --- /dev/null +++ b/tests/test_modeling_attn_mask_utils.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright 2019 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from transformers.testing_utils import ( + require_torch, + slow, +) +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + +class TestAttnMaskConverter(unittest.TestCase): + @require_torch + @slow + def test_unmask_unattended_left_padding(self): + attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64) + + expanded_mask = torch.Tensor( + [ + [[[0, 0, 0], [0, 0, 0], [0, 0, 1]]], + [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], + [[[0, 0, 0], [0, 1, 0], [0, 1, 1]]], + ] + ).to(torch.int64) + + reference_output = torch.Tensor( + [ + [[[1, 1, 1], [1, 1, 1], [0, 0, 1]]], + [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], + [[[1, 1, 1], [0, 1, 0], [0, 1, 1]]], + ] + ).to(torch.int64) + + result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1) + + self.assertTrue(torch.equal(result, reference_output)) + + attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64) + + attn_mask_converter = AttentionMaskConverter(is_causal=True) + past_key_values_length = 0 + key_value_length = attention_mask.shape[-1] + past_key_values_length + + expanded_mask = attn_mask_converter.to_4d( + attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + ) + + result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) + min_inf = torch.finfo(torch.float32).min + reference_output = torch.Tensor( + [ + [ + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [min_inf, min_inf, 0, min_inf, min_inf], + [min_inf, min_inf, 0, 0, min_inf], + [min_inf, min_inf, 0, 0, 0], + ] + ], + [ + [ + [0, min_inf, min_inf, min_inf, min_inf], + [0, 0, min_inf, min_inf, min_inf], + [0, 0, 0, min_inf, min_inf], + [0, 0, 0, 0, min_inf], + [0, 0, 0, 0, 0], + ] + ], + [ + [ + [0, 0, 0, 0, 0], + [min_inf, 0, min_inf, min_inf, min_inf], + [min_inf, 0, 0, min_inf, min_inf], + [min_inf, 0, 0, 0, min_inf], + [min_inf, 0, 0, 0, 0], + ] + ], + ] + ) + + self.assertTrue(torch.equal(reference_output, result)) + + @require_torch + @slow + def test_unmask_unattended_right_padding(self): + attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64) + + attn_mask_converter = AttentionMaskConverter(is_causal=True) + past_key_values_length = 0 + key_value_length = attention_mask.shape[-1] + past_key_values_length + + expanded_mask = attn_mask_converter.to_4d( + attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + ) + + result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) + + self.assertTrue(torch.equal(expanded_mask, result)) + + @require_torch + @slow + def test_unmask_unattended_random_mask(self): + attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64) + + attn_mask_converter = AttentionMaskConverter(is_causal=True) + past_key_values_length = 0 + key_value_length = attention_mask.shape[-1] + past_key_values_length + + expanded_mask = attn_mask_converter.to_4d( + attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + ) + + result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) + + self.assertTrue(torch.equal(expanded_mask, result)) From d2326e2c6d4f864078b5143310e279e3061bcfd0 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 16:07:23 +0000 Subject: [PATCH 053/100] skip meaningless tests for idefics --- tests/models/idefics/test_modeling_idefics.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 1e2bb12ee4609e..dbacf8b2868fb6 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -16,11 +16,14 @@ import unittest +from parameterized import parameterized + from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available from transformers.testing_utils import ( TestCasePlus, require_bitsandbytes, require_torch, + require_torch_sdpa, require_vision, slow, torch_device, @@ -309,6 +312,12 @@ def prepare_config_and_inputs_for_common(self): def prepare_pixel_values(self): return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + @require_torch_sdpa + @slow + @parameterized.expand([("left",), ("right",)]) + def test_eager_matches_sdpa_inference(self, padding_side: str): + self.skipTest("Idefics has a hard requirement on SDPA, skipping this test") + @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @require_torch @@ -557,6 +566,12 @@ def test_model_from_pretrained(self): model = IdeficsModel.from_pretrained(model_name) self.assertIsNotNone(model) + @require_torch_sdpa + @slow + @parameterized.expand([("left",), ("right",)]) + def test_eager_matches_sdpa_inference(self, padding_side: str): + self.skipTest("Idefics has a hard requirement on SDPA, skipping this test") + @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @require_torch From c0f849e15e92f714d65f1904cb5509843e6027b9 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 16:42:36 +0000 Subject: [PATCH 054/100] hard_check SDPA requirements when specifically requested --- src/transformers/modeling_utils.py | 55 +++++++++++-------- src/transformers/models/bark/modeling_bark.py | 8 ++- .../models/falcon/modeling_falcon.py | 17 +++++- .../models/idefics/modeling_idefics.py | 12 ++++ .../models/plbart/modeling_plbart.py | 2 +- 5 files changed, 67 insertions(+), 27 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 53f917f62a839b..82537b96cf0ad9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1227,12 +1227,15 @@ def use_attn_implementation(self, attn_implementation: str): raise NotImplementedError("model.use_attn_implementation is currently not implemented.") if attn_implementation == "sdpa": - self.config = self._check_and_enable_sdpa(self.config, enable=True) + self.config = self._check_and_enable_sdpa(self.config, hard_check_only=False) elif attn_implementation == "flash_attention_2": # TODO: define torch_dtype properly torch_dtype = None self.config = self._check_and_enable_flash_attn_2( - self.config, torch_dtype=torch_dtype, device_map=getattr(self, "hf_device_map", None), enable=True + self.config, + torch_dtype=torch_dtype, + device_map=getattr(self, "hf_device_map", None), + hard_check_only=False, ) @classmethod @@ -1253,29 +1256,29 @@ def _autoset_attn_implementation( """ # Here we use config._attn_implementation to check whether the attention implementation was explicitely set by the user. # The property `PretrainedConfig.attn_implementation` is never `None`, for backward compatibility. - if not hasattr(config, "_attn_implementation") or config._attn_implementation is None: - auto_dispatch_attention = True - else: + if hasattr(config, "_attn_implementation") or config._attn_implementation is not None: if config.attn_implementation != "flash_attention_2" and use_flash_attention_2: raise ValueError( f'Both config.attn_implementation ("{config.attn_implementation}") and use_flash_attention_2=True are used, and are incompatible.' ) - # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config. - auto_dispatch_attention = False + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. + hard_check_only = True + else: + hard_check_only = False if use_flash_attention_2: cls._check_and_enable_flash_attn_2( config, torch_dtype=torch_dtype, device_map=device_map, - enable=auto_dispatch_attention, + hard_check_only=hard_check_only, check_device_map=check_device_map, ) - elif is_torch_sdpa_available() and cls._supports_sdpa: + elif cls._supports_sdpa: # use_flash_attention_2 takes priority over SDPA. - config = cls._check_and_enable_sdpa(config, enable=auto_dispatch_attention) - elif auto_dispatch_attention: + config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) + elif not hard_check_only: config.attn_implementation = "eager" return config @@ -1335,12 +1338,12 @@ def _check_and_enable_flash_attn_2( torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True, - enable: bool = True, + hard_check_only: bool = False, ) -> PretrainedConfig: """ Checks the availability of Flash Attention 2 and compatibility with the current model. - If all checks pass and `enable` is True, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. """ if not cls._supports_flash_attn_2: raise ValueError( @@ -1402,28 +1405,36 @@ def _check_and_enable_flash_attn_2( "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." ) - if enable: + if not hard_check_only: config.attn_implementation = "flash_attention_2" return config @classmethod - def _check_and_enable_sdpa(cls, config, enable: bool = True) -> PretrainedConfig: + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: """ Checks the availability of SDPA for a given model. - If all checks pass and `enable` is True, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. """ - if not cls._supports_sdpa: - raise ValueError( - f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention. Please open an issue on GitHub to " - "request support for this architecture: https://github.com/huggingface/transformers/issues/new" - ) + if hard_check_only: + if not cls._supports_sdpa: + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please open an issue on GitHub to " + "request support for this architecture: https://github.com/huggingface/transformers/issues/new" + ) + if not is_torch_sdpa_available(): + raise ImportError( + "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1." + ) + + if not is_torch_sdpa_available() or not cls._supports_sdpa: + return config _is_bettertransformer = getattr(cls, "use_bettertransformer", False) if _is_bettertransformer: return config - if enable: + if not hard_check_only: config.attn_implementation = "sdpa" return config diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 65195c0c7df981..9425bb0b018f61 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -1864,7 +1864,7 @@ def _check_and_enable_flash_attn_2( config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, - enable: bool = True, + hard_check_only: bool = False, ): """ `_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model @@ -1881,10 +1881,12 @@ def _check_and_enable_flash_attn_2( The method checks if the current setup is compatible with Flash Attention as it requires the model to be in half precision and not ran on CPU. - If all checks pass, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module """ - config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map, enable=enable) + config = super()._check_and_enable_flash_attn_2( + config, torch_dtype, device_map, hard_check_only=hard_check_only + ) config.semantic_config.attn_implementation = config.attn_implementation config.coarse_acoustics_config.attn_implementation = config.attn_implementation diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 1046da1ca87551..e2247b95bde88a 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -16,7 +16,7 @@ import math import warnings -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -47,6 +47,9 @@ from .configuration_falcon import FalconConfig +if TYPE_CHECKING: + from ...configuration_utils import PretrainedConfig + if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -932,6 +935,18 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) + # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig": + # We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1). + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + if _is_bettertransformer: + return config + + if not hard_check_only: + config.attn_implementation = "sdpa" + return config + @add_start_docstrings( "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.", diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 6209d831a35895..57975f67e6f2b1 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -976,6 +976,18 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + # We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1). + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + if _is_bettertransformer: + return config + + if not hard_check_only: + config.attn_implementation = "sdpa" + return config + LLAMA_INPUTS_DOCSTRING = r""" Args: diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 3b52cf363ab1fc..adf4a7c96acde7 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -766,7 +766,7 @@ def forward( if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, and we fall back on + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) From 0fa8de00ab8137cbb07f5e887642cecc9ae4df7d Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:04:20 +0000 Subject: [PATCH 055/100] standardize the use if XXX_ATTENTION_CLASSES --- docs/source/en/perf_infer_gpu_one.md | 2 +- src/transformers/configuration_utils.py | 2 +- src/transformers/modeling_utils.py | 2 +- .../models/distilbert/modeling_distilbert.py | 12 +++++++----- src/transformers/models/falcon/modeling_falcon.py | 12 +++++++----- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 12 +++++++----- src/transformers/models/mistral/modeling_mistral.py | 11 +++++++---- 8 files changed, 32 insertions(+), 23 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 6069f926ff4f27..153e50c46f8118 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -189,7 +189,7 @@ For now, Transformers supports inference and training through SDPA for the follo Note that FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it. -By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether it is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager: +By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager: ```diff import torch diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 41fd86ae25901e..25f8ab32de916a 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -237,7 +237,7 @@ class PretrainedConfig(PushToHubMixin): This attribute is currently not being used during model loading time, but this may change in the future versions. But we can already start preparing for the future by saving the dtype with save_pretrained. attn_implementation (`str`, *optional*): - The attention implementation to use in the model. Can be any of "eager" (manual implementation of the attention), "sdpa" (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or "flash_attention_2" (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. + The attention implementation to use in the model. Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. > TensorFlow specific parameters diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 82537b96cf0ad9..e68fdf1e6aefb2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1223,7 +1223,7 @@ def use_attn_implementation(self, attn_implementation: str): attn_implementation (`str`): The attention implementation to use. Can be any of "eager" (manual implementation of the attention), "sdpa" (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or "flash_attention_2" (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). """ - # TODO: Implement it. An implementation could be to define `self._eager_attn_class = XXXAttention`, `self._sdpa_attn_class = XXXSdpaAttention`, `self._flash_attn_class = XXXFlashAttention2` in the __init__ of XXXPreTrainedModel, and leverage those attributes here to replace the correct submodules. + # TODO: Implement it. An implementation could be to define `self._eager_attn_class = XXXAttention`, `self._sdpa_attn_class = XXXSdpaAttention`, `self._flash_attn_class = XXXFlashAttention2` in the __init__ of XXXPreTrainedModel, and leverage those attributes here to replace the correct submodules. Alternatively, the `XXX_ATTENTION_CLASSES` in each modeling file can be leveraged. raise NotImplementedError("model.use_attn_implementation is currently not implemented.") if attn_implementation == "sdpa": diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 08d8d6cebfb8cd..df75a01ef5b3ca 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -455,6 +455,12 @@ def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: return x +DISTILBERT_ATTENTION_CLASSES = { + "eager": MultiHeadSelfAttention, + "flash_attention_2": DistilBertFlashAttention2, +} + + class TransformerBlock(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() @@ -463,11 +469,7 @@ def __init__(self, config: PretrainedConfig): if config.dim % config.n_heads != 0: raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") - if config.attn_implementation == "flash_attention_2": - self.attention = DistilBertFlashAttention2(config) - else: - self.attention = MultiHeadSelfAttention(config) - + self.attention = DISTILBERT_ATTENTION_CLASSES[config.attn_implementation](config) self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.ffn = FFN(config) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e2247b95bde88a..322db4898be41e 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -736,17 +736,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +FALCON_ATTENTION_CLASSES = { + "eager": FalconAttention, + "flash_attention_2": FalconFlashAttention2, +} + + class FalconDecoderLayer(nn.Module): def __init__(self, config: FalconConfig): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - if config.attn_implementation == "flash_attention_2": - self.self_attention = FalconFlashAttention2(config) - else: - self.self_attention = FalconAttention(config) - + self.self_attention = FALCON_ATTENTION_CLASSES[config.attn_implementation](config) self.mlp = FalconMLP(config) self.hidden_dropout = config.hidden_dropout self.config = config diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index c620ec7c1be3cf..2499de70cec450 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -351,7 +351,7 @@ def forward( key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) - attn_dropout = self.config.attn_pdrop if self.training else 0.0 + attn_dropout = self.attn_pdrop if self.training else 0.0 softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype upcast = query.dtype != softmax_dtype diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7b155c185668ae..014caa2ce29e8b 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -471,6 +471,12 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +GPT_NEO_ATTENTION_CLASSES = { + "eager": GPTNeoSelfAttention, + "flash_attention_2": GPTNeoFlashAttention2, +} + + class GPTNeoAttention(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -479,11 +485,7 @@ def __init__(self, config, layer_id=0): self.attention_type = self.attention_layers[layer_id] if self.attention_type in ["global", "local"]: - self.attention = ( - GPTNeoSelfAttention(config, self.attention_type) - if config.attn_implementation == "eager" - else GPTNeoFlashAttention2(config, self.attention_type) - ) + self.attention = GPT_NEO_ATTENTION_CLASSES[config.attn_implementation](config, self.attention_type) else: raise NotImplementedError( "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: " diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ff65607d902ddd..4da5ffa53f1ccc 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -575,15 +575,18 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +MISTRAL_ATTENTION_CLASSES = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, +} + + class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig): super().__init__() self.hidden_size = config.hidden_size - if config.attn_implementation == "flash_attention_2": - self.self_attn = MistralFlashAttention2(config) - else: - self.self_attn = MistralAttention(config=config) + self.self_attn = MISTRAL_ATTENTION_CLASSES[config.attn_implementation](config) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) From 637e473da517d7ae2286622ee5d5b57a3063ddd9 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:48:59 +0000 Subject: [PATCH 056/100] fix SDPA bug with mem-efficient backend on CUDA when using fp32 --- src/transformers/models/idefics/modeling_idefics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 57975f67e6f2b1..70ed4192dc74ad 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -29,7 +29,7 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ModelOutput from ...modeling_utils import PretrainedConfig from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -1253,7 +1253,7 @@ def forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - attention_mask = _prepare_4d_causal_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) From 55ec3253e7a0545fb7a289d68d5b78aca054bd19 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:51:52 +0000 Subject: [PATCH 057/100] fix test --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e68fdf1e6aefb2..45b0cae996e32e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1256,7 +1256,7 @@ def _autoset_attn_implementation( """ # Here we use config._attn_implementation to check whether the attention implementation was explicitely set by the user. # The property `PretrainedConfig.attn_implementation` is never `None`, for backward compatibility. - if hasattr(config, "_attn_implementation") or config._attn_implementation is not None: + if hasattr(config, "_attn_implementation") and config._attn_implementation is not None: if config.attn_implementation != "flash_attention_2" and use_flash_attention_2: raise ValueError( f'Both config.attn_implementation ("{config.attn_implementation}") and use_flash_attention_2=True are used, and are incompatible.' From 33ef389a0e4d98d49225b7953807e963b017938a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 22 Nov 2023 18:05:35 +0000 Subject: [PATCH 058/100] rely on SDPA is_causal parameter to handle the causal mask in some cases --- src/transformers/models/idefics/modeling_idefics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 70ed4192dc74ad..4421637c2e3030 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -578,6 +578,7 @@ def __init__( self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.dropout = dropout + self.is_causal = True if (self.head_dim * num_heads) != self.hidden_size: raise ValueError( @@ -693,6 +694,8 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, ) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): From 2e6bc3e0fc2ce4e5fd964eba8d4dd742ad3f010e Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Nov 2023 09:39:01 +0000 Subject: [PATCH 059/100] fix FALCON_ATTENTION_CLASSES --- src/transformers/models/falcon/modeling_falcon.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 322db4898be41e..77364712c53f78 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -738,6 +738,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: FALCON_ATTENTION_CLASSES = { "eager": FalconAttention, + "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA "flash_attention_2": FalconFlashAttention2, } From 5913dee1acf35665c1e88b9674f405933c10afdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Nov 2023 11:31:15 +0100 Subject: [PATCH 060/100] remove _flash_attn_2_enabled occurences --- src/transformers/models/mistral/modeling_mistral.py | 7 +------ src/transformers/models/opt/modeling_opt.py | 11 +++++++---- tests/test_modeling_common.py | 2 +- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4da5ffa53f1ccc..cac8058d53d417 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -839,12 +839,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if ( - attention_mask is not None - and hasattr(self.config, "_flash_attn_2_enabled") - and self.config._flash_attn_2_enabled - and use_cache - ): + if attention_mask is not None and self._use_flash_attention_2 and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 2192f327bc49f9..7701b597f44d05 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -475,15 +475,18 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +OPT_ATTENTION_CLASSES = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, +} + + class OPTDecoderLayer(nn.Module): def __init__(self, config: OPTConfig): super().__init__() self.embed_dim = config.hidden_size - if not getattr(config, "_flash_attn_2_enabled", False): - self.self_attn = OPTAttention(config=config, is_decoder=True) - else: - self.self_attn = OptFlashAttention2(config=config, is_decoder=True) + self.self_attn = OPT_ATTENTION_CLASSES[config.attn_implementation](config=config, is_decoder=True) self.do_layer_norm_before = config.do_layer_norm_before self.dropout = config.dropout diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9793cf95bc593b..61c91c9389d58d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3486,7 +3486,7 @@ def test_flash_attn_2_from_config(self): model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname) - self.assertFalse(getattr(model_from_pretrained.config, "_flash_attn_2_enabled", False)) + self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2") fa2_correctly_converted = False From 11ab3ae71f014721a36fdc30343b2f1f88bfe7e6 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Nov 2023 11:45:30 +0000 Subject: [PATCH 061/100] fix test --- src/transformers/models/bart/modeling_bart.py | 2 -- src/transformers/models/whisper/modeling_whisper.py | 2 -- tests/models/whisper/test_modeling_whisper.py | 13 ++++++++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c7a3e03bded42f..133124031a24fb 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -503,7 +503,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: @@ -518,7 +517,6 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, - **kwargs, ) # if key_value_states are provided this layer is used as a cross-attention layer diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 2bf7f65fdaee09..24822948b9a229 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -683,7 +683,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: @@ -698,7 +697,6 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, - **kwargs, ) # if key_value_states are provided this layer is used as a cross-attention layer diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f77d81d76e52be..c98caf1bf0bb6d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2336,13 +2336,20 @@ def test_encoder_outputs(self): with torch.no_grad(): outputs = model(**inputs)[0] - input_ids = inputs["input_features"] + encoder = model.encoder + + encoder_inputs = {"input_features": inputs["input_features"]} del inputs["input_features"] - encoder = model.encoder + if "head_mask" in inputs: + encoder_inputs["head_mask"] = inputs["head_mask"] + if "attention_mask" in inputs: + encoder_inputs["attention_mask"] = inputs["attention_mask"] + if "output_attentions" in inputs: + encoder_inputs["output_attentions"] = inputs["output_attentions"] with torch.no_grad(): - inputs["encoder_outputs"] = encoder(input_ids) + inputs["encoder_outputs"] = encoder(**encoder_inputs) outputs_embeds = model(**inputs)[0] self.assertTrue((outputs_embeds == outputs).all()) From b74894de9458f9b62a4fc3325f9fdaa59e2d1cda Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Nov 2023 12:54:49 +0000 Subject: [PATCH 062/100] add OPT to the list of supported flash models --- docs/source/en/perf_infer_gpu_one.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 153e50c46f8118..c2d2c8e76b4172 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -46,6 +46,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) +* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. From 4ff1057460678a555d21de79a22e4cea02f6ee8a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Nov 2023 16:39:56 +0000 Subject: [PATCH 063/100] improve test --- tests/models/llama/test_modeling_llama.py | 23 ++++++++----- tests/test_modeling_common.py | 42 +++++++++++++++-------- 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 5c3c83a2001fc8..cfe668ea6f246f 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -15,7 +15,6 @@ """ Testing suite for the PyTorch LLaMA model. """ -import copy import unittest import pytest @@ -428,12 +427,9 @@ def test_eager_matches_sdpa_generate(self): """ Overwritting the common test as the test is flaky on tiny models """ - import torch - max_new_tokens = 30 tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf") model_sdpa = LlamaForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", @@ -441,20 +437,29 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, ).to(torch_device) - # Force using the eager implementation. - cfg = copy.deepcopy(config) - cfg.attn_implementation = "eager" + self.assertTrue(model_sdpa.config.attn_implementation == "sdpa") + model_eager = LlamaForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", - config=cfg, torch_dtype=torch.float16, low_cpu_mem_usage=True, + attn_implementation="eager", ).to(torch_device) + self.assertTrue(model_eager.config.attn_implementation == "eager") + for name, submodule in model_eager.named_modules(): - if "SDPA" in submodule.__class__.__name__: + if "SdpaAttention" in submodule.__class__.__name__: raise ValueError("The eager model should not have SDPA attention layers") + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"] for padding_side in ["left", "right"]: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 61c91c9389d58d..acc730c538ef77 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3148,23 +3148,28 @@ def test_eager_matches_sdpa_inference(self, padding_side: str): model_sdpa = model_class.from_pretrained( tmpdirname, torch_dtype=torch.bfloat16, - ) - model_sdpa.to(torch_device) + ).to(torch_device) + self.assertTrue(model_sdpa.config.attn_implementation == "sdpa") - # Force using the eager implementation. - cfg = copy.deepcopy(config) - cfg.attn_implementation = "eager" model_eager = model_class.from_pretrained( tmpdirname, - config=cfg, torch_dtype=torch.bfloat16, - ) - model_eager.to(torch_device) + attn_implementation="eager", + ).to(torch_device) + self.assertTrue(model_eager.config.attn_implementation == "eager") for name, submodule in model_eager.named_modules(): - if "SDPA" in submodule.__class__.__name__: + if "SdpaAttention" in submodule.__class__.__name__: raise ValueError("The eager model should not have SDPA attention layers") + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + dummy_input = inputs_dict[model.main_input_name][:1] if dummy_input.dtype in [torch.float32, torch.float16]: dummy_input = dummy_input.to(torch.bfloat16) @@ -3278,20 +3283,29 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, ).to(torch_device) - # Force using the eager implementation. - cfg = copy.deepcopy(config) - cfg.attn_implementation = "eager" + self.assertTrue(model_sdpa.attn_implementation == "sdpa") + model_eager = model_class.from_pretrained( tmpdirname, - config=cfg, torch_dtype=torch.float16, low_cpu_mem_usage=True, + attn_implementation="eager", ).to(torch_device) + self.assertTrue(model_eager.attn_implementation == "eager") + for name, submodule in model_eager.named_modules(): - if "SDPA" in submodule.__class__.__name__: + if "SdpaAttention" in submodule.__class__.__name__: raise ValueError("The eager model should not have SDPA attention layers") + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + # Just test that a large cache works as expected res_eager = model_eager.generate( dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False From 8bd6c812cbc7a09e011a6e22ff04c7bd726d24b1 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Nov 2023 13:58:31 +0000 Subject: [PATCH 064/100] properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test --- tests/models/idefics/test_modeling_idefics.py | 8 +- tests/test_modeling_common.py | 358 ++++++++++++------ 2 files changed, 255 insertions(+), 111 deletions(-) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index dbacf8b2868fb6..28530c72194585 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -314,8 +314,8 @@ def prepare_pixel_values(self): @require_torch_sdpa @slow - @parameterized.expand([("left",), ("right",)]) - def test_eager_matches_sdpa_inference(self, padding_side: str): + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + def test_eager_matches_sdpa_inference(self, torch_dtype: str): self.skipTest("Idefics has a hard requirement on SDPA, skipping this test") @@ -568,8 +568,8 @@ def test_model_from_pretrained(self): @require_torch_sdpa @slow - @parameterized.expand([("left",), ("right",)]) - def test_eager_matches_sdpa_inference(self, padding_side: str): + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + def test_eager_matches_sdpa_inference(self, torch_dtype: str): self.skipTest("Idefics has a hard requirement on SDPA, skipping this test") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index acc730c538ef77..1c7e1e73d9413d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2830,8 +2830,6 @@ def test_model_is_small(self): @mark.flash_attn_test @slow def test_flash_attn_2_conversion(self): - import torch - config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -2857,8 +2855,6 @@ def test_flash_attn_2_conversion(self): @mark.flash_attn_test @slow def test_flash_attn_2_inference(self): - import torch - for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -2954,8 +2950,6 @@ def test_flash_attn_2_inference(self): @mark.flash_attn_test @slow def test_flash_attn_2_inference_padding_right(self): - import torch - for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3047,8 +3041,6 @@ def test_flash_attn_2_inference_padding_right(self): @mark.flash_attn_test @slow def test_flash_attn_2_generate_left_padding(self): - import torch - for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3090,8 +3082,6 @@ def test_flash_attn_2_generate_left_padding(self): @mark.flash_attn_test @slow def test_flash_attn_2_generate_padding_right(self): - import torch - for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3130,14 +3120,23 @@ def test_flash_attn_2_generate_padding_right(self): @require_torch_sdpa @slow - @parameterized.expand([("left",), ("right",)]) - def test_eager_matches_sdpa_inference(self, padding_side: str): - import torch + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_device == "cpu" and torch_dtype == "float16": + self.skipTest("float16 not supported on cpu") + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 for model_class in self.all_model_classes: - if not model_class._supports_sdpa: - self.skipTest(f"{model_class.__name__} does not support SDPA") - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3145,17 +3144,25 @@ def test_eager_matches_sdpa_inference(self, padding_side: str): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model_sdpa = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.bfloat16, - ).to(torch_device) + model_sdpa = ( + model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + ) + .to(torch_device) + .eval() + ) self.assertTrue(model_sdpa.config.attn_implementation == "sdpa") - model_eager = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.bfloat16, - attn_implementation="eager", - ).to(torch_device) + model_eager = ( + model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + .to(torch_device) + .eval() + ) self.assertTrue(model_eager.config.attn_implementation == "eager") for name, submodule in model_eager.named_modules(): @@ -3170,86 +3177,229 @@ def test_eager_matches_sdpa_inference(self, padding_side: str): if not has_sdpa and model_sdpa.config.model_type != "falcon": raise ValueError("The SDPA model should have SDPA attention layers") - dummy_input = inputs_dict[model.main_input_name][:1] - if dummy_input.dtype in [torch.float32, torch.float16]: - dummy_input = dummy_input.to(torch.bfloat16) - - dummy_attention_mask = inputs_dict.get("attention_mask", None) - - if dummy_attention_mask is not None: - dummy_attention_mask = dummy_attention_mask[:1] - if padding_side == "left": - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - elif padding_side == "right": - dummy_attention_mask[:, 1:] = 1 - dummy_attention_mask[:, :1] = 0 - - if is_encoder_decoder: - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - - outputs_eager = model_eager( - dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True - ) - outputs_sdpa = model_sdpa( - dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True - ) - else: - outputs_eager = model_eager(dummy_input, output_hidden_states=True) - outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True) - - logits = ( - outputs_eager.hidden_states[-1] - if not is_encoder_decoder - else outputs_eager.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_sdpa.hidden_states[-1] - if not is_encoder_decoder - else outputs_sdpa.decoder_hidden_states[-1] - ) - - assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) - - if is_encoder_decoder: - other_inputs = { - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask - - outputs_eager = model_eager(dummy_input, **other_inputs) - outputs_sdpa = model_sdpa(dummy_input, **other_inputs) - else: - other_inputs = { - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask - - outputs_eager = model_eager(dummy_input, **other_inputs) - outputs_sdpa = model_sdpa(dummy_input, **other_inputs) - - logits_eager = ( - outputs_eager.hidden_states[-1] - if not is_encoder_decoder - else outputs_eager.decoder_hidden_states[-1] - ) - logits_sdpa = ( - outputs_sdpa.hidden_states[-1] - if not is_encoder_decoder - else outputs_sdpa.decoder_hidden_states[-1] - ) - - assert torch.allclose(logits_sdpa[1:], logits_eager[1:], atol=4e-2, rtol=4e-2) + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for batch_size in [1, 5]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + if dummy_input.shape[0] != batch_size: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = torch.cat( + ( + dummy_input, + torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ), + ), + dim=0, + ).to(torch_device) + else: + dummy_input = torch.cat( + ( + dummy_input, + torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ), + ), + dim=0, + ).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + # print("was none!!!") + if is_encoder_decoder: + seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + else: + seqlen = dummy_input.shape[-1] + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + # print("dummy_attention_mask here", dummy_attention_mask.shape) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + dummy_attention_mask = torch.cat( + ( + dummy_attention_mask, + torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ), + ), + dim=0, + ).to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :-1] = 1 + dummy_attention_mask[-1, -4:] = 0 + elif padding_side == "right": + dummy_attention_mask[-1, 1:] = 1 + dummy_attention_mask[-1, :3] = 0 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" + # print("model_class", model_class) + # print("is_encoder_decoder", is_encoder_decoder) + if is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size] + if decoder_input_ids.shape[0] != batch_size: + decoder_input_ids = torch.cat( + ( + decoder_input_ids, + torch.ones( + batch_size - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ), + ), + dim=0, + ).to(torch_device) + + # TODO: never an `attention_mask` arg here? + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + else: + other_inputs = { + "output_hidden_states": True, + } + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + other_inputs["attention_mask"] = dummy_attention_mask + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + outputs_eager = model_eager(dummy_input, **other_inputs) + outputs_sdpa = model_sdpa(dummy_input, **other_inputs) + + logits_eager = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_sdpa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) + + if torch_device == "cpu": + if torch_dtype == torch.float32: + atol = 1e-6 + rtol = 1e-4 + elif torch_dtype == torch.bfloat16: + atol = 1e-2 + rtol = 1e-2 + elif torch_device == "cuda": + if not enable_kernels: + if torch_dtype == torch.float32: + atol = 1e-6 + rtol = 1e-4 + elif torch_dtype == torch.bfloat16: + atol = 1e-2 + rtol = 1e-2 + elif torch_dtype == torch.float16: + atol = 1e-3 + rtol = 1e-3 + else: + if torch_dtype == torch.float32: + atol = 1e-6 + rtol = 1e-4 + elif torch_dtype == torch.bfloat16: + atol = 1e-2 + rtol = 3e-2 + elif torch_dtype == torch.float16: + atol = 5e-3 + rtol = 5e-3 + else: + atol = 1e-7 + rtol = 1e-4 + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean reldiff: {((x - ref).abs() / (ref.abs() + 1e-12)).mean()}, torch atol = {atol}, torch rtol = {rtol}" + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + if padding_side == "left": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, :-4] + sub_eager = logits_eager[-1, :-4] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, -4:] + # sub_eager = logits_eager[-1, -4:] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + elif padding_side == "right": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, 3:] + sub_eager = logits_eager[-1, 3:] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, :3] + # sub_eager = logits_eager[-1, :3] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + + else: + if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self): - import torch - max_new_tokens = 30 # TODO: Implement a test for SDPA simply testing forward, not generate. @@ -3322,8 +3472,6 @@ def test_eager_matches_sdpa_generate(self): @mark.flash_attn_test @slow def test_flash_attn_2_generate_use_cache(self): - import torch - max_new_tokens = 30 for model_class in self.all_generative_model_classes: @@ -3369,8 +3517,6 @@ def test_flash_attn_2_generate_use_cache(self): @mark.flash_attn_test @slow def test_flash_attn_2_fp32_ln(self): - import torch - for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3469,8 +3615,6 @@ def test_flax_from_pt_safetensors(self): @mark.flash_attn_test @slow def test_flash_attn_2_from_config(self): - import torch - for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") From a11c114596c900df645136c594cef203200ca083 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Nov 2023 14:01:24 +0000 Subject: [PATCH 065/100] remove remaining _flash_attn_2_enabled occurence --- src/transformers/models/opt/modeling_opt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 7701b597f44d05..2c76044a7be522 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -719,6 +719,7 @@ def __init__(self, config: OPTConfig): self.final_layer_norm = None self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -817,7 +818,7 @@ def forward( mask_seq_length = past_key_values_length + seq_length # embed positions - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = ( From b5593a16a913d745bc558734512445b2238b9283 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Sat, 25 Nov 2023 00:21:43 +0900 Subject: [PATCH 066/100] Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6973e183c0ce38..a3503c2aa15813 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1250,7 +1250,7 @@ def _autoset_attn_implementation( """ Automatically checks and dispatches to a default attention implementation. In order of priority: 1. An implementation specified in `config.attn_implementation`. - 2. If specified, flash attention through use_flash_attention_2=True. + 2. If use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) 3. SDPA implementation, if available and supported by the model type. 4. Manual implementation otherwise. """ From 1bc983a44590589203ec586e92e6c97c18fac52d Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Sat, 25 Nov 2023 00:22:05 +0900 Subject: [PATCH 067/100] Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a3503c2aa15813..6c1d23cb1c5304 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1252,7 +1252,7 @@ def _autoset_attn_implementation( 1. An implementation specified in `config.attn_implementation`. 2. If use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) 3. SDPA implementation, if available and supported by the model type. - 4. Manual implementation otherwise. + 4. The default model's implementation otherwise (`LlamaAttention` for example) . """ # Here we use config._attn_implementation to check whether the attention implementation was explicitely set by the user. # The property `PretrainedConfig.attn_implementation` is never `None`, for backward compatibility. From 316b448de3912d104cb26d6f54186df1a41e0392 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Sat, 25 Nov 2023 00:26:56 +0900 Subject: [PATCH 068/100] Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6c1d23cb1c5304..95adf352723f7f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1251,7 +1251,7 @@ def _autoset_attn_implementation( Automatically checks and dispatches to a default attention implementation. In order of priority: 1. An implementation specified in `config.attn_implementation`. 2. If use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) - 3. SDPA implementation, if available and supported by the model type. + 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) 4. The default model's implementation otherwise (`LlamaAttention` for example) . """ # Here we use config._attn_implementation to check whether the attention implementation was explicitely set by the user. From 52178ba7de9764996d9faf52b0077700e5ddf161 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Sat, 25 Nov 2023 00:27:39 +0900 Subject: [PATCH 069/100] Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_attn_mask_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index facf77b2ed9149..e0067ec33a62bf 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -230,6 +230,7 @@ def _unmask_unattended( [0, 1, 1]]]] ``` """ + # fmt: on # Get the index of the first non-zero value for every sample in the batch. # In the above example, indices = [[2], [0], [1]]] From 231e354f055bde3dbd81ffc70c0073bfdf6bcd66 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Sat, 25 Nov 2023 00:28:14 +0900 Subject: [PATCH 070/100] Update docs/source/en/perf_infer_gpu_one.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/perf_infer_gpu_one.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index c2d2c8e76b4172..6544e2800d83c4 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -188,7 +188,7 @@ For now, Transformers supports inference and training through SDPA for the follo * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) -Note that FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it. +Note that FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type before using it. By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager: From f907b3fe7af9295514e6be17f39745d9e98e5a1c Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Nov 2023 15:36:03 +0000 Subject: [PATCH 071/100] remove use_attn_implementation --- src/transformers/modeling_utils.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 95adf352723f7f..bf08b4e6aedaec 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1215,29 +1215,6 @@ def _from_config(cls, config, **kwargs): return model - def use_attn_implementation(self, attn_implementation: str): - """ - Specifies the attention implementation to use in the model. - - Args: - attn_implementation (`str`): - The attention implementation to use. Can be any of "eager" (manual implementation of the attention), "sdpa" (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or "flash_attention_2" (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). - """ - # TODO: Implement it. An implementation could be to define `self._eager_attn_class = XXXAttention`, `self._sdpa_attn_class = XXXSdpaAttention`, `self._flash_attn_class = XXXFlashAttention2` in the __init__ of XXXPreTrainedModel, and leverage those attributes here to replace the correct submodules. Alternatively, the `XXX_ATTENTION_CLASSES` in each modeling file can be leveraged. - raise NotImplementedError("model.use_attn_implementation is currently not implemented.") - - if attn_implementation == "sdpa": - self.config = self._check_and_enable_sdpa(self.config, hard_check_only=False) - elif attn_implementation == "flash_attention_2": - # TODO: define torch_dtype properly - torch_dtype = None - self.config = self._check_and_enable_flash_attn_2( - self.config, - torch_dtype=torch_dtype, - device_map=getattr(self, "hf_device_map", None), - hard_check_only=False, - ) - @classmethod def _autoset_attn_implementation( cls, From 0e9e9f2b2c52b9e8dc78732cf1c1d4a31d6f5d2e Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Nov 2023 15:41:34 +0000 Subject: [PATCH 072/100] fix docstring & slight bug --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bf08b4e6aedaec..51f8972f499bcb 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1236,11 +1236,11 @@ def _autoset_attn_implementation( if hasattr(config, "_attn_implementation") and config._attn_implementation is not None: if config.attn_implementation != "flash_attention_2" and use_flash_attention_2: raise ValueError( - f'Both config.attn_implementation ("{config.attn_implementation}") and use_flash_attention_2=True are used, and are incompatible.' + f'Both config.attn_implementation ("{config.attn_implementation}"), with `config.attn_implementation != "flash_attention_2"`. This is not compatible, we recommend you to just use `config.attn_implementation = "flash_attention_2"` or to pass the argument `attn_implementation="flash_attention_2"`.' ) # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. - hard_check_only = True + hard_check_only = not use_flash_attention_2 else: hard_check_only = False From 5c77b944749268b11f30c01a010d84fd692d670a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 4 Dec 2023 18:18:30 +0000 Subject: [PATCH 073/100] make attn_implementation internal (_attn_implementation) --- src/transformers/configuration_utils.py | 33 ++++++++++--------- src/transformers/modeling_utils.py | 20 +++++------ src/transformers/models/bark/modeling_bark.py | 12 +++---- src/transformers/models/bart/modeling_bart.py | 16 ++++----- .../models/blenderbot/modeling_blenderbot.py | 6 ++-- .../modeling_blenderbot_small.py | 6 ++-- .../models/distilbert/modeling_distilbert.py | 4 +-- .../models/falcon/modeling_falcon.py | 6 ++-- .../gpt_bigcode/modeling_gpt_bigcode.py | 10 +++--- .../models/gpt_neo/modeling_gpt_neo.py | 4 +-- .../models/idefics/modeling_idefics.py | 2 +- .../models/llama/modeling_llama.py | 10 +++--- .../models/m2m_100/modeling_m2m_100.py | 6 ++-- .../models/marian/modeling_marian.py | 6 ++-- .../models/mbart/modeling_mbart.py | 10 +++--- .../models/mistral/modeling_mistral.py | 4 +-- src/transformers/models/opt/modeling_opt.py | 4 +-- .../models/pegasus/modeling_pegasus.py | 6 ++-- .../models/plbart/modeling_plbart.py | 14 ++++---- .../speech_to_text/modeling_speech_to_text.py | 6 ++-- .../modeling_time_series_transformer.py | 6 ++-- .../models/whisper/modeling_whisper.py | 12 +++---- tests/deepspeed/test_deepspeed.py | 2 ++ tests/models/llama/test_modeling_llama.py | 4 +-- tests/test_configuration_utils.py | 2 +- tests/test_modeling_common.py | 14 ++++---- 26 files changed, 114 insertions(+), 111 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 25f8ab32de916a..a8e4c599b40042 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -377,7 +377,7 @@ def __init__(self, **kwargs): self._commit_hash = kwargs.pop("_commit_hash", None) # Attention implementation to use, if relevant. - self._attn_implementation = kwargs.pop("attn_implementation", None) + self._attn_implementation_internal = kwargs.pop("attn_implementation", None) # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) @@ -428,22 +428,20 @@ def num_labels(self, num_labels: int): self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) @property - def attn_implementation(self): - if hasattr(self, "_attn_implementation"): - if self._attn_implementation is None: + def _attn_implementation(self): + # This property is made private for now (as it may not be changed alone and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + if hasattr(self, "_attn_implementation_internal"): + if self._attn_implementation_internal is None: # `config.attn_implementation` should never be None, for backward compatibility. return "eager" else: - return self._attn_implementation + return self._attn_implementation_internal else: - return None + return "eager" - @attn_implementation.setter - def attn_implementation(self, value): - # No specific check is implemented here, as we want to allow syntax as `config.attn_implementation = "flash_attention_2"` before the model - # loading. Proper implementation/external library availability checks are done at load time. - # Modifying this property alone on an already loaded model (model.config) has no impact, `model.use_attn_implementation("flash_attention_2")` should be used instead. - self._attn_implementation = value + @_attn_implementation.setter + def _attn_implementation(self, value): + self._attn_implementation_internal = value def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ @@ -770,6 +768,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": if "_commit_hash" in kwargs and "_commit_hash" in config_dict: kwargs["_commit_hash"] = config_dict["_commit_hash"] + # We remove it from kwargs so that it does not appear in `return_unused_kwargs`. + config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None) + config = cls(**config_dict) if hasattr(config, "pruned_heads"): @@ -884,8 +885,8 @@ def to_diff_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(serializable_config_dict) - if "_attn_implementation" in serializable_config_dict: - del serializable_config_dict["_attn_implementation"] + if "_attn_implementation_internal" in serializable_config_dict: + del serializable_config_dict["_attn_implementation_internal"] return serializable_config_dict @@ -903,8 +904,8 @@ def to_dict(self) -> Dict[str, Any]: del output["_auto_class"] if "_commit_hash" in output: del output["_commit_hash"] - if "_attn_implementation" in output: - del output["_attn_implementation"] + if "_attn_implementation_internal" in output: + del output["_attn_implementation_internal"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3be68b91c00765..ce8c7f66d97b4d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1230,17 +1230,17 @@ def _autoset_attn_implementation( ): """ Automatically checks and dispatches to a default attention implementation. In order of priority: - 1. An implementation specified in `config.attn_implementation`. + 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). 2. If use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) 4. The default model's implementation otherwise (`LlamaAttention` for example) . """ - # Here we use config._attn_implementation to check whether the attention implementation was explicitely set by the user. - # The property `PretrainedConfig.attn_implementation` is never `None`, for backward compatibility. - if hasattr(config, "_attn_implementation") and config._attn_implementation is not None: - if config.attn_implementation != "flash_attention_2" and use_flash_attention_2: + # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. + # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility. + if hasattr(config, "_attn_implementation") and config._attn_implementation_internal is not None: + if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: raise ValueError( - f'Both config.attn_implementation ("{config.attn_implementation}"), with `config.attn_implementation != "flash_attention_2"`. This is not compatible, we recommend you to just use `config.attn_implementation = "flash_attention_2"` or to pass the argument `attn_implementation="flash_attention_2"`.' + f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible. We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. @@ -1260,7 +1260,7 @@ def _autoset_attn_implementation( # use_flash_attention_2 takes priority over SDPA. config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) elif not hard_check_only: - config.attn_implementation = "eager" + config._attn_implementation = "eager" return config @@ -1401,7 +1401,7 @@ def _check_and_enable_flash_attn_2( "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." ) if not hard_check_only: - config.attn_implementation = "flash_attention_2" + config._attn_implementation = "flash_attention_2" return config @classmethod @@ -1409,7 +1409,7 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra """ Checks the availability of SDPA for a given model. - If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. """ if hard_check_only: if not cls._supports_sdpa: @@ -1430,7 +1430,7 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra return config if not hard_check_only: - config.attn_implementation = "sdpa" + config._attn_implementation = "sdpa" return config def enable_input_require_grads(self): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 1e7549ff297632..06ed527d97905f 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -436,7 +436,7 @@ def __init__(self, config, is_causal=False): self.layernorm_1 = nn.LayerNorm(config.hidden_size) self.layernorm_2 = nn.LayerNorm(config.hidden_size) - self.attn = BARK_ATTENTION_CLASSES[config.attn_implementation](config, is_causal=is_causal) + self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal) self.mlp = BarkMLP(config) @@ -669,7 +669,7 @@ def __init__(self, config): self.drop = nn.Dropout(config.dropout) self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias) @@ -1265,7 +1265,7 @@ def __init__(self, config): self.drop = nn.Dropout(config.dropout) self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layernorm_final = nn.LayerNorm(config.hidden_size) @@ -1904,7 +1904,7 @@ def _check_and_enable_flash_attn_2( config, torch_dtype, device_map, hard_check_only=hard_check_only ) - config.semantic_config.attn_implementation = config.attn_implementation - config.coarse_acoustics_config.attn_implementation = config.attn_implementation - config.fine_acoustics_config.attn_implementation = config.attn_implementation + config.semantic_config._attn_implementation = config._attn_implementation + config.coarse_acoustics_config._attn_implementation = config._attn_implementation + config.fine_acoustics_config._attn_implementation = config._attn_implementation return config diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index a2f1eb34532c0c..beb161ec989e5a 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -522,7 +522,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." ) @@ -623,7 +623,7 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -693,7 +693,7 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -706,7 +706,7 @@ def __init__(self, config: BartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -1067,8 +1067,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No embed_dim, ) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" - self._use_sdpa = config.attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -1249,8 +1249,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.d_model, ) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" - self._use_sdpa = config.attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1ebb89eab4023d..4512c3b503a4be 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -261,7 +261,7 @@ def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -332,7 +332,7 @@ def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -345,7 +345,7 @@ def __init__(self, config: BlenderbotConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 18a677f99e2100..dc4fa30b875ef2 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -255,7 +255,7 @@ def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -332,7 +332,7 @@ def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -345,7 +345,7 @@ def __init__(self, config: BlenderbotSmallConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index d2e9279e14842a..6e38ee84e98f6c 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -485,7 +485,7 @@ def __init__(self, config: PretrainedConfig): if config.dim % config.n_heads != 0: raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") - self.attention = DISTILBERT_ATTENTION_CLASSES[config.attn_implementation](config) + self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config) self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.ffn = FFN(config) @@ -705,7 +705,7 @@ def __init__(self, config: PretrainedConfig): self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 596dfbc3fe82d5..d93d4cc17cc9f9 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -765,7 +765,7 @@ def __init__(self, config: FalconConfig): hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FALCON_ATTENTION_CLASSES[config.attn_implementation](config) + self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config) self.mlp = FalconMLP(config) self.hidden_dropout = config.hidden_dropout self.config = config @@ -963,7 +963,7 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "Pretr return config if not hard_check_only: - config.attn_implementation = "sdpa" + config._attn_implementation = "sdpa" return config @@ -984,7 +984,7 @@ def __init__(self, config: FalconConfig): # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 3692198783e9df..6092f15e232088 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -620,7 +620,7 @@ def forward( # as SDPA expects seq_length to be at index -2 for the key as well attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) else: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." ) @@ -674,7 +674,7 @@ def __init__(self, config, layer_idx=None): self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBIGCODE_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx=layer_idx) + self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -682,7 +682,7 @@ def __init__(self, config, layer_idx=None): if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config.attn_implementation]( + self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation]( config, is_cross_attention=True, layer_idx=layer_idx ) @@ -908,8 +908,8 @@ def __init__(self, config): self.gradient_checkpointing = False - self._use_sdpa = config.attn_implementation == "sdpa" - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 8799e93f61e63a..a6a73dbb8cfd9f 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -501,7 +501,7 @@ def __init__(self, config, layer_id=0): self.attention_type = self.attention_layers[layer_id] if self.attention_type in ["global", "local"]: - self.attention = GPT_NEO_ATTENTION_CLASSES[config.attn_implementation](config, self.attention_type) + self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](config, self.attention_type) else: raise NotImplementedError( "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: " @@ -720,7 +720,7 @@ def __init__(self, config): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.drop = nn.Dropout(float(config.embed_dropout)) self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.gradient_checkpointing = False diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 4421637c2e3030..1e44d2e4a0cc46 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -988,7 +988,7 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra return config if not hard_check_only: - config.attn_implementation = "sdpa" + config._attn_implementation = "sdpa" return config diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 78d20b05d9ca2a..d77c5fc0b41771 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -664,7 +664,7 @@ def forward( use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." ) @@ -739,7 +739,9 @@ def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config.attn_implementation](config=config) + print("config._attn_implementation", config._attn_implementation) + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -932,8 +934,8 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self._use_sdpa = config.attn_implementation == "sdpa" - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 697fbfad48635e..656c526536c563 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -326,7 +326,7 @@ def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - self.self_attn = M2M100_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -400,7 +400,7 @@ def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - self.self_attn = M2M100_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -413,7 +413,7 @@ def __init__(self, config: M2M100Config): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = M2M100_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 4d976749520134..d52a060d4723c8 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -273,7 +273,7 @@ def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MARIAN_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -347,7 +347,7 @@ def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MARIAN_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -360,7 +360,7 @@ def __init__(self, config: MarianConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MARIAN_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 318d7a42550067..3d25d75b3ef28f 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -511,7 +511,7 @@ def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MBART_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -581,7 +581,7 @@ def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MBART_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -594,7 +594,7 @@ def __init__(self, config: MBartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MBART_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -933,7 +933,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N embed_dim, ) self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1111,7 +1111,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.d_model, ) self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 912f1325bb8630..d153c6b7edb91d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -602,7 +602,7 @@ def __init__(self, config: MistralConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MISTRAL_ATTENTION_CLASSES[config.attn_implementation](config) + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -792,7 +792,7 @@ def __init__(self, config: MistralConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 3a36d585c2cf02..9b22a51abc77aa 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -502,7 +502,7 @@ def __init__(self, config: OPTConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = OPT_ATTENTION_CLASSES[config.attn_implementation](config=config, is_decoder=True) + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True) self.do_layer_norm_before = config.do_layer_norm_before self.dropout = config.dropout @@ -735,7 +735,7 @@ def __init__(self, config: OPTConfig): self.final_layer_norm = None self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.gradient_checkpointing = False # Initialize weights and apply final processing diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 57d1adfa6c211f..1a75c43e58e0ee 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -276,7 +276,7 @@ def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PEGASUS_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -347,7 +347,7 @@ def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PEGASUS_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -360,7 +360,7 @@ def __init__(self, config: PegasusConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index adf4a7c96acde7..f03f90183a59a5 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -271,7 +271,7 @@ def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PLBART_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -346,7 +346,7 @@ def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PLBART_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -359,7 +359,7 @@ def __init__(self, config: PLBartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PLBART_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -674,8 +674,8 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = embed_dim, ) self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" - self._use_sdpa = config.attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -857,8 +857,8 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = config.d_model, ) self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" - self._use_sdpa = config.attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index c08dc7e51af67d..76e088415ab88d 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -340,7 +340,7 @@ def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -411,7 +411,7 @@ def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -424,7 +424,7 @@ def __init__(self, config: Speech2TextConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 48187cd3f55d26..b6e86735c6a3d0 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -442,7 +442,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -519,7 +519,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -532,7 +532,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 92d8a638e29ad8..3cfe0c900e2779 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -702,7 +702,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( "WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." ) @@ -804,7 +804,7 @@ def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = WHISPER_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -875,7 +875,7 @@ def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = WHISPER_ATTENTION_CLASSES[config.attn_implementation]( + self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -888,7 +888,7 @@ def __init__(self, config: WhisperConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WHISPER_ATTENTION_CLASSES[config.attn_implementation]( + self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -1328,8 +1328,8 @@ def __init__(self, config: WhisperConfig): self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2" - self._use_sdpa = config.attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 2352cf522f29a7..71630aae81b37a 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -562,6 +562,8 @@ def test_gradient_accumulation(self, stage, dtype): self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5) # see the note above how to get identical loss on a small bs + print("no_grad_accum_loss", no_grad_accum_loss) + print("yes_grad_accum_loss", yes_grad_accum_loss) self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2) def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index cfe668ea6f246f..03a12ebb7f87f9 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -437,7 +437,7 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, ).to(torch_device) - self.assertTrue(model_sdpa.config.attn_implementation == "sdpa") + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") model_eager = LlamaForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", @@ -446,7 +446,7 @@ def test_eager_matches_sdpa_generate(self): attn_implementation="eager", ).to(torch_device) - self.assertTrue(model_eager.config.attn_implementation == "eager") + self.assertTrue(model_eager.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): if "SdpaAttention" in submodule.__class__.__name__: diff --git a/tests/test_configuration_utils.py b/tests/test_configuration_utils.py index df521efb180ac8..800197728d1234 100644 --- a/tests/test_configuration_utils.py +++ b/tests/test_configuration_utils.py @@ -199,7 +199,7 @@ def test_config_common_kwargs_is_complete(self): # If this part of the test fails, you have arguments to addin config_common_kwargs above. self.assertListEqual( missing_keys, - ["is_encoder_decoder", "_name_or_path", "_commit_hash", "_attn_implementation", "transformers_version"], + ["is_encoder_decoder", "_name_or_path", "_commit_hash", "__attn_implementation", "transformers_version"], ) keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)] if len(keys_with_defaults) > 0: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 44d2eeda300539..bc2c778b4858a8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3157,8 +3157,9 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): .to(torch_device) .eval() ) - self.assertTrue(model_sdpa.config.attn_implementation == "sdpa") + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + print("----- load model eager") model_eager = ( model_class.from_pretrained( tmpdirname, @@ -3168,7 +3169,8 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): .to(torch_device) .eval() ) - self.assertTrue(model_eager.config.attn_implementation == "eager") + self.assertTrue(model_eager.config._attn_implementation == "eager") + break for name, submodule in model_eager.named_modules(): if "SdpaAttention" in submodule.__class__.__name__: @@ -3227,7 +3229,6 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): else: dummy_attention_mask = inputs_dict.get("attention_mask", None) if dummy_attention_mask is None: - # print("was none!!!") if is_encoder_decoder: seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] else: @@ -3235,7 +3236,6 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): dummy_attention_mask = ( torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) ) - # print("dummy_attention_mask here", dummy_attention_mask.shape) dummy_attention_mask = dummy_attention_mask[:batch_size] if dummy_attention_mask.shape[0] != batch_size: @@ -3262,8 +3262,6 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): for enable_kernels in [False, True]: failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" - # print("model_class", model_class) - # print("is_encoder_decoder", is_encoder_decoder) if is_encoder_decoder: decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size] if decoder_input_ids.shape[0] != batch_size: @@ -3438,7 +3436,7 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, ).to(torch_device) - self.assertTrue(model_sdpa.attn_implementation == "sdpa") + self.assertTrue(model_sdpa._attn_implementation == "sdpa") model_eager = model_class.from_pretrained( tmpdirname, @@ -3447,7 +3445,7 @@ def test_eager_matches_sdpa_generate(self): attn_implementation="eager", ).to(torch_device) - self.assertTrue(model_eager.attn_implementation == "eager") + self.assertTrue(model_eager._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): if "SdpaAttention" in submodule.__class__.__name__: From cd9e2096f637e9e5b122c0ea65f30d3fcacbf5c3 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 4 Dec 2023 18:22:07 +0000 Subject: [PATCH 074/100] typos --- src/transformers/models/llama/modeling_llama.py | 2 -- tests/deepspeed/test_deepspeed.py | 2 -- tests/test_modeling_common.py | 2 -- 3 files changed, 6 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d77c5fc0b41771..e992a408dbcc2f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -739,8 +739,6 @@ def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - print("config._attn_implementation", config._attn_implementation) - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config) self.mlp = LlamaMLP(config) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 71630aae81b37a..2352cf522f29a7 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -562,8 +562,6 @@ def test_gradient_accumulation(self, stage, dtype): self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5) # see the note above how to get identical loss on a small bs - print("no_grad_accum_loss", no_grad_accum_loss) - print("yes_grad_accum_loss", yes_grad_accum_loss) self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2) def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bc2c778b4858a8..e5921507feece9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3159,7 +3159,6 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): ) self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - print("----- load model eager") model_eager = ( model_class.from_pretrained( tmpdirname, @@ -3170,7 +3169,6 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): .eval() ) self.assertTrue(model_eager.config._attn_implementation == "eager") - break for name, submodule in model_eager.named_modules(): if "SdpaAttention" in submodule.__class__.__name__: From e475f25e27735d9d80e47e667923df2c9f52b93d Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 5 Dec 2023 11:05:35 +0000 Subject: [PATCH 075/100] fix tests --- src/transformers/modeling_utils.py | 5 +++-- tests/test_configuration_utils.py | 8 +++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ce8c7f66d97b4d..13d608f50f1823 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1236,8 +1236,9 @@ def _autoset_attn_implementation( 4. The default model's implementation otherwise (`LlamaAttention` for example) . """ # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. - # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility. - if hasattr(config, "_attn_implementation") and config._attn_implementation_internal is not None: + # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). + # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: raise ValueError( f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible. We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' diff --git a/tests/test_configuration_utils.py b/tests/test_configuration_utils.py index 800197728d1234..2ceb1695fa96cc 100644 --- a/tests/test_configuration_utils.py +++ b/tests/test_configuration_utils.py @@ -199,7 +199,13 @@ def test_config_common_kwargs_is_complete(self): # If this part of the test fails, you have arguments to addin config_common_kwargs above. self.assertListEqual( missing_keys, - ["is_encoder_decoder", "_name_or_path", "_commit_hash", "__attn_implementation", "transformers_version"], + [ + "is_encoder_decoder", + "_name_or_path", + "_commit_hash", + "_attn_implementation_internal", + "transformers_version", + ], ) keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)] if len(keys_with_defaults) > 0: From 48a6bfc3b227ca807f1827e0fa5bfde56188584d Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:09:10 +0000 Subject: [PATCH 076/100] deprecate use_flash_attention_2=True --- docs/source/en/model_doc/bark.md | 6 +- docs/source/en/model_doc/distilbert.md | 2 +- docs/source/en/model_doc/gpt_bigcode.md | 2 +- docs/source/en/model_doc/gpt_neo.md | 2 +- docs/source/en/model_doc/mistral.md | 2 +- docs/source/en/model_doc/opt.md | 2 +- docs/source/en/perf_infer_gpu_one.md | 10 +- docs/source/en/quantization.md | 2 +- docs/source/ja/perf_infer_gpu_one.md | 10 +- docs/source/zh/main_classes/quantization.md | 4 +- src/transformers/modeling_utils.py | 17 ++- .../models/llama/modeling_llama.py | 120 ++---------------- tests/models/bark/test_modeling_bark.py | 11 +- .../distilbert/test_modeling_distilbert.py | 11 +- tests/models/llama/test_modeling_llama.py | 2 +- tests/models/mistral/test_modeling_mistral.py | 15 ++- tests/models/whisper/test_modeling_whisper.py | 9 +- tests/test_modeling_common.py | 38 +++--- 18 files changed, 89 insertions(+), 176 deletions(-) diff --git a/docs/source/en/model_doc/bark.md b/docs/source/en/model_doc/bark.md index 2160159bd783a3..7c02e4be701187 100644 --- a/docs/source/en/model_doc/bark.md +++ b/docs/source/en/model_doc/bark.md @@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation ##### Usage -To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference: +To load a model using Flash Attention 2, we can pass the `attn_implementation="flash_attention_2"` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference: ```python -model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) +model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device) ``` ##### Performance comparison @@ -114,7 +114,7 @@ import torch device = "cuda" if torch.cuda.is_available() else "cpu" # load in fp16 and use Flash Attention 2 -model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) +model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device) # enable CPU offload model.enable_cpu_offload() diff --git a/docs/source/en/model_doc/distilbert.md b/docs/source/en/model_doc/distilbert.md index 233a182a553fa6..bd39260d3ca492 100644 --- a/docs/source/en/model_doc/distilbert.md +++ b/docs/source/en/model_doc/distilbert.md @@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: >>> device = "cuda" # the device to load the model onto >>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased') ->>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2") >>> text = "Replace me by any text you'd like." diff --git a/docs/source/en/model_doc/gpt_bigcode.md b/docs/source/en/model_doc/gpt_bigcode.md index 0f3bc72d03a55f..b3cb078e2a140c 100644 --- a/docs/source/en/model_doc/gpt_bigcode.md +++ b/docs/source/en/model_doc/gpt_bigcode.md @@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> device = "cuda" # the device to load the model onto ->>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2") >>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder") >>> prompt = "def hello_world():" diff --git a/docs/source/en/model_doc/gpt_neo.md b/docs/source/en/model_doc/gpt_neo.md index 96b6a8c96fe71a..3c7858c998207e 100644 --- a/docs/source/en/model_doc/gpt_neo.md +++ b/docs/source/en/model_doc/gpt_neo.md @@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> device = "cuda" # the device to load the model onto ->>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2") >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") >>> prompt = "def hello_world():" diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md index 8e37bc2caf888d..8e4d75ef2382c3 100644 --- a/docs/source/en/model_doc/mistral.md +++ b/docs/source/en/model_doc/mistral.md @@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> device = "cuda" # the device to load the model onto ->>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") >>> prompt = "My favourite condiment is" diff --git a/docs/source/en/model_doc/opt.md b/docs/source/en/model_doc/opt.md index 3da7b22fab747d..1b02b888994ecf 100644 --- a/docs/source/en/model_doc/opt.md +++ b/docs/source/en/model_doc/opt.md @@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: >>> from transformers import OPTForCausalLM, GPT2Tokenizer >>> device = "cuda" # the device to load the model onto ->>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2") >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") >>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the " diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index de6267eb975802..95d084a1459350 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -55,7 +55,7 @@ Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs. -To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]: +To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]: ```python import torch @@ -67,13 +67,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", ) ``` FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2. + +Note that `use_flash_attention_2=True` can also be used to enable Flash Attention 2, but is deprecated in favor of `attn_implementation="flash_attention_2"`. @@ -90,14 +92,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, load_in_8bit=True, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", ) # load in 4bit model = AutoModelForCausalLM.from_pretrained( model_id, load_in_4bit=True, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", ) ``` diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index 60903e36ad5968..ce55339bef349e 100644 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -82,7 +82,7 @@ AWQ quantization can also be combined with [FlashAttention-2](perf_infer_gpu_one ```py from transformers import AutoModelForCausalLM, AutoTokenizer -model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0") +model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0") ``` ## AutoGPTQ diff --git a/docs/source/ja/perf_infer_gpu_one.md b/docs/source/ja/perf_infer_gpu_one.md index d6a18a6f3e2047..6d7466e022220a 100644 --- a/docs/source/ja/perf_infer_gpu_one.md +++ b/docs/source/ja/perf_infer_gpu_one.md @@ -44,7 +44,7 @@ Flash Attention 2は、モデルのdtypeが`fp16`または`bf16`の場合にの ### Quick usage -モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`use_flash_attention_2`を追加します。 +モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`attn_implementation="flash_attention_2"`を追加します。 ```python @@ -57,7 +57,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", ) ``` @@ -114,7 +114,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, load_in_8bit=True, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", ) ``` @@ -132,7 +132,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, load_in_4bit=True, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", ) ``` @@ -151,7 +151,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, load_in_4bit=True, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", ) lora_config = LoraConfig( diff --git a/docs/source/zh/main_classes/quantization.md b/docs/source/zh/main_classes/quantization.md index 0a2c1eb4039c36..3c7e4d9212a1d0 100644 --- a/docs/source/zh/main_classes/quantization.md +++ b/docs/source/zh/main_classes/quantization.md @@ -66,12 +66,12 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda:0") ### 结合 AWQ 和 Flash Attention -您可以将AWQ量化与Flash Attention结合起来,得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`use_flash_attention_2=True`参数。 +您可以将AWQ量化与Flash Attention结合起来,得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`attn_implementation="flash_attention_2"`参数。 ```python from transformers import AutoModelForCausalLM, AutoTokenizer -model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0") +model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0") ``` ### 基准测试 diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 13d608f50f1823..df035d1ab5df95 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1186,8 +1186,6 @@ def _from_config(cls, config, **kwargs): Args: torch_dtype (`torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. - use_flash_attention_2 (`bool`, *optional*): - Whether to load the model with Flash Attention 2 modules. """ torch_dtype = kwargs.pop("torch_dtype", None) use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) @@ -1198,6 +1196,7 @@ def _from_config(cls, config, **kwargs): dtype_orig = cls._set_default_torch_dtype(torch_dtype) config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. + config._attn_implementation = kwargs.pop("attn_implementation", None) config = cls._autoset_attn_implementation( config, use_flash_attention_2=use_flash_attention_2, check_device_map=False ) @@ -1223,7 +1222,7 @@ def _from_config(cls, config, **kwargs): def _autoset_attn_implementation( cls, config, - use_flash_attention_2: Optional[bool] = None, + use_flash_attention_2: bool = False, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True, @@ -1231,7 +1230,7 @@ def _autoset_attn_implementation( """ Automatically checks and dispatches to a default attention implementation. In order of priority: 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). - 2. If use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) + 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) 4. The default model's implementation otherwise (`LlamaAttention` for example) . """ @@ -1245,11 +1244,17 @@ def _autoset_attn_implementation( ) # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. - hard_check_only = not use_flash_attention_2 + hard_check_only = True else: hard_check_only = False if use_flash_attention_2: + logger.warning_once( + 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + config._attn_implementation = "flash_attention_2" + + if config._attn_implementation == "flash_attention_2": cls._check_and_enable_flash_attn_2( config, torch_dtype=torch_dtype, @@ -1258,7 +1263,7 @@ def _autoset_attn_implementation( check_device_map=check_device_map, ) elif cls._supports_sdpa: - # use_flash_attention_2 takes priority over SDPA. + # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) elif not hard_check_only: config._attn_implementation = "eager" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e992a408dbcc2f..583fc8064e859f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,11 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 @@ -295,7 +291,6 @@ def __init__(self, config: LlamaConfig): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - self.pretraining_tp = config.pretraining_tp if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -372,6 +367,7 @@ def forward( value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) + else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -510,7 +506,7 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - dropout_rate = self.attention_dropout if self.training else 0.0 + dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -646,101 +642,15 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -class LlamaSdpaAttention(LlamaAttention): - """ - Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from LlamaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -LLAMA_ATTENTION_CLASSES = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, -} - - class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config) - + self.self_attn = ( + LlamaAttention(config=config) + if not getattr(config, "_flash_attn_2_enabled", False) + else LlamaFlashAttention2(config=config) + ) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -835,7 +745,6 @@ class LlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_sdpa = True def _init_weights(self, module): std = self.config.initializer_range @@ -932,8 +841,6 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self._use_sdpa = config._attn_implementation == "sdpa" - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -991,18 +898,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: + if getattr(self.config, "_flash_attn_2_enabled", False): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 713fb6c3ee790a..1246fa561583a6 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -890,13 +890,11 @@ def test_flash_attn_2_inference(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False - ) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) dummy_input = inputs_dict["input_ids"][:1] @@ -949,12 +947,13 @@ def test_flash_attn_2_inference_padding_right(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + tmpdirname, + torch_dtype=torch.bfloat16, ) model.to(torch_device) diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index 8194c4285916de..9ab9d01577a974 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -319,13 +319,11 @@ def test_flash_attn_2_inference(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False - ) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] @@ -373,12 +371,13 @@ def test_flash_attn_2_inference_padding_right(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + tmpdirname, + torch_dtype=torch.bfloat16, ) model.to(torch_device) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 03a12ebb7f87f9..664c3e8a16cace 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -413,7 +413,7 @@ def test_flash_attn_2_generate_padding_right(self): output_native = tokenizer.batch_decode(output_native) model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True + "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2" ) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index dba013b2057499..838a94d2dbe9f9 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -387,9 +387,9 @@ def test_flash_attn_2_generate_padding_right(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True - ).to(torch_device) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) @@ -397,7 +397,10 @@ def test_flash_attn_2_generate_padding_right(self): model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, ).to(torch_device) with self.assertRaises(ValueError): @@ -437,7 +440,7 @@ def test_flash_attn_2_generate_use_cache(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", low_cpu_mem_usage=True, ).to(torch_device) @@ -507,7 +510,7 @@ def test_model_7b_long_prompt(self): "mistralai/Mistral-7B-v0.1", device_map="auto", load_in_4bit=True, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", ) input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f767165230aac8..64c3649f8d6524 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -891,12 +891,13 @@ def test_flash_attn_2_inference(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + tmpdirname, + torch_dtype=torch.bfloat16, ) model.to(torch_device) @@ -936,11 +937,11 @@ def test_flash_attn_2_inference_padding_right(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16) model.to(torch_device) dummy_input = inputs_dict[model.main_input_name][:1] diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e5921507feece9..096f5937f223e6 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2846,7 +2846,7 @@ def test_flash_attn_2_conversion(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" ).to(torch_device) for _, module in model.named_modules(): @@ -2870,12 +2870,12 @@ def test_flash_attn_2_inference(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model.to(torch_device) @@ -2965,12 +2965,12 @@ def test_flash_attn_2_inference_padding_right(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model.to(torch_device) @@ -3055,9 +3055,9 @@ def test_flash_attn_2_generate_left_padding(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True - ).to(torch_device) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) dummy_input = inputs_dict[model.main_input_name] if dummy_input.dtype in [torch.float32, torch.bfloat16]: @@ -3073,7 +3073,10 @@ def test_flash_attn_2_generate_left_padding(self): ) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, ).to(torch_device) out_fa = model.generate( @@ -3096,9 +3099,9 @@ def test_flash_attn_2_generate_padding_right(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True - ).to(torch_device) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) dummy_input = inputs_dict[model.main_input_name] if dummy_input.dtype in [torch.float32, torch.bfloat16]: @@ -3114,7 +3117,10 @@ def test_flash_attn_2_generate_padding_right(self): ) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, ).to(torch_device) out_fa = model.generate( @@ -3499,7 +3505,7 @@ def test_flash_attn_2_generate_use_cache(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", low_cpu_mem_usage=True, ).to(torch_device) @@ -3538,7 +3544,7 @@ def test_flash_attn_2_fp32_ln(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", low_cpu_mem_usage=True, load_in_4bit=True, ) @@ -3623,7 +3629,7 @@ def test_flash_attn_2_from_config(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() # TODO: to change it in the future with other relevant auto classes fa2_model = AutoModelForCausalLM.from_config( - config, use_flash_attention_2=True, torch_dtype=torch.bfloat16 + config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16 ).to(torch_device) dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) From 8e7f8b5a0e822357b05127f16bf57dfd9a72b5a6 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 6 Dec 2023 17:47:29 +0000 Subject: [PATCH 077/100] fix test --- tests/test_modeling_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 096f5937f223e6..153ad0f25cf232 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3440,7 +3440,7 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, ).to(torch_device) - self.assertTrue(model_sdpa._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") model_eager = model_class.from_pretrained( tmpdirname, @@ -3449,7 +3449,7 @@ def test_eager_matches_sdpa_generate(self): attn_implementation="eager", ).to(torch_device) - self.assertTrue(model_eager._attn_implementation == "eager") + self.assertTrue(model_eager.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): if "SdpaAttention" in submodule.__class__.__name__: From 7a85efcd7b3feab5f8a10714063ef2d2f76a5365 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 6 Dec 2023 17:54:15 +0000 Subject: [PATCH 078/100] add back llama that was removed by mistake --- .../models/llama/modeling_llama.py | 119 ++++++++++++++++-- 1 file changed, 110 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 583fc8064e859f..dbca74f0791063 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,7 +29,11 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 @@ -367,7 +371,6 @@ def forward( value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) - else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -506,7 +509,7 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - dropout_rate = 0.0 if not self.training else self.attention_dropout + dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -642,15 +645,101 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = ( - LlamaAttention(config=config) - if not getattr(config, "_flash_attn_2_enabled", False) - else LlamaFlashAttention2(config=config) - ) + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config) + self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -745,6 +834,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): std = self.config.initializer_range @@ -841,6 +931,8 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -898,9 +990,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( From 3649553a683ada8b412328b5a1f66494b75850be Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 6 Dec 2023 18:41:34 +0000 Subject: [PATCH 079/100] fix tests --- tests/models/falcon/test_modeling_falcon.py | 73 ++++++++++++++++++++- tests/test_modeling_common.py | 1 - 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 75204637bd0784..fa7ea2af816cb0 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -15,6 +15,7 @@ """ Testing suite for the PyTorch Falcon model. """ +import tempfile import unittest from parameterized import parameterized @@ -26,7 +27,7 @@ is_torch_available, set_seed, ) -from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device +from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_sdpa, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -437,6 +438,76 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + max_new_tokens = 30 + + if len(self.all_generative_model_classes) == 0: + self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + # NOTE: This check is disabled for Falcon as the non-SDPA/SDPA implementation is in the same class (legacy reason). + # for name, submodule in model_eager.named_modules(): + # if "SdpaAttention" in submodule.__class__.__name__: + # raise ValueError("The eager model should not have SDPA attention layers") + + # has_sdpa = False + # for name, submodule in model_sdpa.named_modules(): + # if "SdpaAttention" in submodule.__class__.__name__: + # has_sdpa = True + # break + # if not has_sdpa: + # raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + @require_torch class FalconLanguageGenerationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 153ad0f25cf232..50ad369ea5fc8d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3409,7 +3409,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): def test_eager_matches_sdpa_generate(self): max_new_tokens = 30 - # TODO: Implement a test for SDPA simply testing forward, not generate. if len(self.all_generative_model_classes) == 0: self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test") From c1b87b8287cc4aa76c005a2e07f6b3a4ca4c3178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 11:08:39 +0100 Subject: [PATCH 080/100] remove _flash_attn_2_enabled occurences bis --- docs/source/en/perf_infer_gpu_one.md | 3 +++ .../models/gpt_neox/modeling_gpt_neox.py | 15 +++++++++------ src/transformers/models/llava/modeling_llava.py | 3 +-- src/transformers/models/phi/modeling_phi.py | 15 +++++++++------ tests/models/phi/test_modeling_phi.py | 2 +- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 95d084a1459350..6a31c41b6fa5d2 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -42,11 +42,14 @@ FlashAttention-2 is currently supported for the following architectures: * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel) +* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) +* [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) +* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 30feda146eabbf..d1c10f58d9d67a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -658,6 +658,12 @@ def forward(self, hidden_states): return hidden_states +GPT_NEOX_ATTENTION_CLASSES = { + "eager": GPTNeoXAttention, + "flash_attention_2": GPTNeoXFlashAttention2, +} + + class GPTNeoXLayer(nn.Module): def __init__(self, config): super().__init__() @@ -666,11 +672,7 @@ def __init__(self, config): self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_dropout = nn.Dropout(config.hidden_dropout) self.post_mlp_dropout = nn.Dropout(config.hidden_dropout) - self.attention = ( - GPTNeoXAttention(config) - if not getattr(config, "_flash_attn_2_enabled", False) - else GPTNeoXFlashAttention2(config) - ) + self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config) self.mlp = GPTNeoXMLP(config) def forward( @@ -785,6 +787,7 @@ def __init__(self, config): self.emb_dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.gradient_checkpointing = False @@ -861,7 +864,7 @@ def forward( if attention_mask is not None: assert batch_size > 0, "batch_size has to be defined and > 0" attention_mask = attention_mask.view(batch_size, -1) - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None else: # We create a 3D attention mask from a 2D tensor mask. diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 88e5b62121c746..0e29f536d10223 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -231,9 +231,8 @@ def __init__(self, config: LlavaConfig): self.multi_modal_projector = LlavaMultiModalProjector(config) self.vocab_size = config.vocab_size - use_flash_attention_2 = getattr(config, "_flash_attn_2_enabled", False) self.language_model = AutoModelForCausalLM.from_config( - config.text_config, use_flash_attention_2=use_flash_attention_2 + config.text_config, attn_implementation=config._attn_implementation ) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 44be9c749f0e80..9b14819615a0c6 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -602,14 +602,16 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +PHI_ATTENTION_CLASSES = { + "eager": PhiAttention, + "flash_attention_2": PhiFlashAttention2, +} + + class PhiDecoderLayer(nn.Module): def __init__(self, config: PhiConfig): super().__init__() - self.self_attn = ( - PhiAttention(config=config) - if not getattr(config, "_flash_attn_2_enabled", False) - else PhiFlashAttention2(config=config) - ) + self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config) self.mlp = PhiMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.resid_dropout = nn.Dropout(config.resid_pdrop) @@ -794,6 +796,7 @@ def __init__(self, config: PhiConfig): self.embed_dropout = nn.Dropout(config.embd_pdrop) self.layers = nn.ModuleList([PhiDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -856,7 +859,7 @@ def forward( inputs_embeds = self.embed_dropout(inputs_embeds) # Attention mask. - if getattr(self.config, "_flash_attn_2_enabled", False): + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index 76bc5c2104306a..516dd1ee626e7f 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -389,7 +389,7 @@ def test_flash_attn_2_generate_padding_right(self): output_native = tokenizer.batch_decode(output_native) model = PhiForCausalLM.from_pretrained( - "susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True + "susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2" ) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) From 8950b60bf16401fa9e1656a83cf77c7d49c2ec85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 11:29:12 +0100 Subject: [PATCH 081/100] add check & test that passed attn_implementation is valid --- src/transformers/modeling_utils.py | 10 +++++++++- tests/test_modeling_utils.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 15287bde7cdae0..28bd17525b5ff9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1247,6 +1247,14 @@ def _autoset_attn_implementation( f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible. We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) + if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]: + message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' + if cls._supports_flash_attn_2: + message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + raise ValueError(message + ".") + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. hard_check_only = True else: @@ -1266,7 +1274,7 @@ def _autoset_attn_implementation( hard_check_only=hard_check_only, check_device_map=check_device_map, ) - elif cls._supports_sdpa: + elif cls._supports_sdpa or config._attn_implementation == "sdpa": # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) elif not hard_check_only: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index fd4297de1c1fa6..08022ebd4d1b6a 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -50,6 +50,7 @@ require_torch, require_torch_accelerator, require_torch_multi_accelerator, + require_torch_sdpa, require_usr_bin_time, slow, torch_device, @@ -1621,3 +1622,31 @@ def test_causal_mask_sliding(self): self.check_to_causal(mask_converter, q_len=3, kv_len=7) # non auto-regressive case self.check_to_causal(mask_converter, q_len=7, kv_len=7) + + +@require_torch_sdpa +class TestAttentionImplementation(unittest.TestCase): + def test_error_no_sdpa_available(self): + with self.assertRaises(ValueError) as cm: + _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa") + + self.assertTrue( + "does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention" + in str(cm.exception) + ) + + _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel") + + def test_error_no_flash_available(self): + with self.assertRaises(ValueError) as cm: + _ = AutoModel.from_pretrained( + "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2" + ) + + self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception)) + + def test_error_wrong_attn_implementation(self): + with self.assertRaises(ValueError) as cm: + _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo") + + self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception)) From 18c2678e974854f39e53a700682cf93a3f1d74d6 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 11:04:40 +0000 Subject: [PATCH 082/100] fix falcon torchscript export --- src/transformers/models/falcon/modeling_falcon.py | 10 ++++++---- tests/models/whisper/test_modeling_whisper.py | 1 + tests/test_modeling_common.py | 1 + 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d93d4cc17cc9f9..33a52a8231ff01 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -285,6 +285,7 @@ def __init__(self, config: FalconConfig): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True + self._use_sdpa = config._attn_implementation == "sdpa" if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( @@ -446,7 +447,7 @@ def forward( present = None if alibi is None: - if hasattr(F, "scaled_dot_product_attention") and not output_attentions: + if self._use_sdpa and not output_attentions: attn_output = F.scaled_dot_product_attention( query_layer, key_layer, @@ -477,8 +478,8 @@ def forward( return attn_output, present else: - if hasattr(F, "scaled_dot_product_attention") and not output_attentions and head_mask is None: - attn_output = torch.nn.functional.scaled_dot_product_attention( + if self._use_sdpa and not output_attentions and head_mask is None: + attn_output = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, @@ -985,6 +986,7 @@ def __init__(self, config: FalconConfig): # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -1079,7 +1081,7 @@ def forward( if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif hasattr(F, "scaled_dot_product_attention") and not output_attentions: + elif self._use_sdpa and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. if alibi is None: diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 64c3649f8d6524..9de3b8ff2c21b6 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -982,6 +982,7 @@ def _create_and_check_torchscript(self, config, inputs_dict): configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.torchscript = True + configs_no_init._attn_implementation = "eager" for model_class in self.all_model_classes: model = model_class(config=configs_no_init) model.to(torch_device) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 50ad369ea5fc8d..82c4fb6814551c 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -778,6 +778,7 @@ def _create_and_check_torchscript(self, config, inputs_dict): configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.torchscript = True + configs_no_init._attn_implementation = "eager" for model_class in self.all_model_classes: model = model_class(config=configs_no_init) model.to(torch_device) From d96e0d2db5bddfd653184e9e4c0785b4ce3cd690 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 11:47:06 +0000 Subject: [PATCH 083/100] fix device of mask in tests --- tests/models/llama/test_modeling_llama.py | 2 +- tests/models/mistral/test_modeling_mistral.py | 2 +- tests/models/persimmon/test_modeling_persimmon.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 1de49bf19eebb1..26aecaeb1ad9b7 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index e59a1726b0735d..fcb1f2495aab8a 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -107,7 +107,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index e25a31b3ff4f0d..864db992772772 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: From 9e133c9cfdca64cd8be7d1d5445883ad7eb806b1 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 12:52:09 +0000 Subject: [PATCH 084/100] add tip about torch.jit.trace and move bt doc below sdpa --- docs/source/en/perf_infer_gpu_one.md | 84 +++++++++++++++------------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 6a31c41b6fa5d2..ef58b879ef5f71 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -142,46 +142,9 @@ FlashAttention is more memory efficient, meaning you can train on much larger se -## BetterTransformer - - - -Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers. - - - - - - -Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post. - - - -BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are: - -1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps -2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors - -BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood. - -Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation). - -Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method: - -```python -model = model.to_bettertransformer() -``` - -You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling: - -```py -model = model.reverse_bettertransformer() -model.save_pretrained("saved_model") -``` - -### FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention +## FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention -PyTorch's `torch.nn.functional.scaled_dot_product_attention` (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available. +PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available. For now, Transformers supports inference and training through SDPA for the following architectures: * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) @@ -222,6 +185,49 @@ RuntimeError: No available kernel. Aborting execution. pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 ``` + + +As of Transformers 4.36, attention modules using `torch.nn.functional.scaled_dot_product_attention` do not support tracing through [`torch.jit.trace`](https://pytorch.org/docs/stable/generated/torch.jit.trace.html). Please load your model with the argument `attn_implementation="eager"` in [`~PreTrainedModel.from_pretrained`] in order to export to TorchScript through `torch.jit.trace`. + + + +## BetterTransformer + + + +Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers. + + + + + + +Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post. + + + +BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are: + +1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps +2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors + +BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood. + +Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation). + +Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method: + +```python +model = model.to_bettertransformer() +``` + +You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling: + +```py +model = model.reverse_bettertransformer() +model.save_pretrained("saved_model") +``` + ## bitsandbytes bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory. From 76a1e17da3b71742ee62d510a95b04457f4a179b Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 13:02:58 +0000 Subject: [PATCH 085/100] fix parameterized.expand order --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 82c4fb6814551c..16f26fc1e3e2b0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3130,9 +3130,9 @@ def test_flash_attn_2_generate_padding_right(self): self.assertTrue(torch.allclose(out, out_fa)) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) def test_eager_matches_sdpa_inference(self, torch_dtype: str): if not self.all_model_classes[0]._supports_sdpa: self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") From 65aeba652db5d7730986037f48abf554a1c8acd5 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 13:07:03 +0000 Subject: [PATCH 086/100] move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there --- tests/test_modeling_attn_mask_utils.py | 134 ------------------------- tests/test_modeling_utils.py | 105 +++++++++++++++++++ 2 files changed, 105 insertions(+), 134 deletions(-) delete mode 100644 tests/test_modeling_attn_mask_utils.py diff --git a/tests/test_modeling_attn_mask_utils.py b/tests/test_modeling_attn_mask_utils.py deleted file mode 100644 index b60d2d5a69ddec..00000000000000 --- a/tests/test_modeling_attn_mask_utils.py +++ /dev/null @@ -1,134 +0,0 @@ -# coding=utf-8 -# Copyright 2019 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from transformers.testing_utils import ( - require_torch, - slow, -) -from transformers.utils import is_torch_available - - -if is_torch_available(): - import torch - - from transformers.modeling_attn_mask_utils import AttentionMaskConverter - - -class TestAttnMaskConverter(unittest.TestCase): - @require_torch - @slow - def test_unmask_unattended_left_padding(self): - attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64) - - expanded_mask = torch.Tensor( - [ - [[[0, 0, 0], [0, 0, 0], [0, 0, 1]]], - [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], - [[[0, 0, 0], [0, 1, 0], [0, 1, 1]]], - ] - ).to(torch.int64) - - reference_output = torch.Tensor( - [ - [[[1, 1, 1], [1, 1, 1], [0, 0, 1]]], - [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], - [[[1, 1, 1], [0, 1, 0], [0, 1, 1]]], - ] - ).to(torch.int64) - - result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1) - - self.assertTrue(torch.equal(result, reference_output)) - - attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64) - - attn_mask_converter = AttentionMaskConverter(is_causal=True) - past_key_values_length = 0 - key_value_length = attention_mask.shape[-1] + past_key_values_length - - expanded_mask = attn_mask_converter.to_4d( - attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 - ) - - result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) - min_inf = torch.finfo(torch.float32).min - reference_output = torch.Tensor( - [ - [ - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [min_inf, min_inf, 0, min_inf, min_inf], - [min_inf, min_inf, 0, 0, min_inf], - [min_inf, min_inf, 0, 0, 0], - ] - ], - [ - [ - [0, min_inf, min_inf, min_inf, min_inf], - [0, 0, min_inf, min_inf, min_inf], - [0, 0, 0, min_inf, min_inf], - [0, 0, 0, 0, min_inf], - [0, 0, 0, 0, 0], - ] - ], - [ - [ - [0, 0, 0, 0, 0], - [min_inf, 0, min_inf, min_inf, min_inf], - [min_inf, 0, 0, min_inf, min_inf], - [min_inf, 0, 0, 0, min_inf], - [min_inf, 0, 0, 0, 0], - ] - ], - ] - ) - - self.assertTrue(torch.equal(reference_output, result)) - - @require_torch - @slow - def test_unmask_unattended_right_padding(self): - attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64) - - attn_mask_converter = AttentionMaskConverter(is_causal=True) - past_key_values_length = 0 - key_value_length = attention_mask.shape[-1] + past_key_values_length - - expanded_mask = attn_mask_converter.to_4d( - attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 - ) - - result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) - - self.assertTrue(torch.equal(expanded_mask, result)) - - @require_torch - @slow - def test_unmask_unattended_random_mask(self): - attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64) - - attn_mask_converter = AttentionMaskConverter(is_causal=True) - past_key_values_length = 0 - key_value_length = attention_mask.shape[-1] + past_key_values_length - - expanded_mask = attn_mask_converter.to_4d( - attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 - ) - - result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) - - self.assertTrue(torch.equal(expanded_mask, result)) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 08022ebd4d1b6a..c193b8020a6faa 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1623,6 +1623,111 @@ def test_causal_mask_sliding(self): # non auto-regressive case self.check_to_causal(mask_converter, q_len=7, kv_len=7) + @require_torch + @slow + def test_unmask_unattended_left_padding(self): + attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64) + + expanded_mask = torch.Tensor( + [ + [[[0, 0, 0], [0, 0, 0], [0, 0, 1]]], + [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], + [[[0, 0, 0], [0, 1, 0], [0, 1, 1]]], + ] + ).to(torch.int64) + + reference_output = torch.Tensor( + [ + [[[1, 1, 1], [1, 1, 1], [0, 0, 1]]], + [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], + [[[1, 1, 1], [0, 1, 0], [0, 1, 1]]], + ] + ).to(torch.int64) + + result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1) + + self.assertTrue(torch.equal(result, reference_output)) + + attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64) + + attn_mask_converter = AttentionMaskConverter(is_causal=True) + past_key_values_length = 0 + key_value_length = attention_mask.shape[-1] + past_key_values_length + + expanded_mask = attn_mask_converter.to_4d( + attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + ) + + result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) + min_inf = torch.finfo(torch.float32).min + reference_output = torch.Tensor( + [ + [ + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [min_inf, min_inf, 0, min_inf, min_inf], + [min_inf, min_inf, 0, 0, min_inf], + [min_inf, min_inf, 0, 0, 0], + ] + ], + [ + [ + [0, min_inf, min_inf, min_inf, min_inf], + [0, 0, min_inf, min_inf, min_inf], + [0, 0, 0, min_inf, min_inf], + [0, 0, 0, 0, min_inf], + [0, 0, 0, 0, 0], + ] + ], + [ + [ + [0, 0, 0, 0, 0], + [min_inf, 0, min_inf, min_inf, min_inf], + [min_inf, 0, 0, min_inf, min_inf], + [min_inf, 0, 0, 0, min_inf], + [min_inf, 0, 0, 0, 0], + ] + ], + ] + ) + + self.assertTrue(torch.equal(reference_output, result)) + + @require_torch + @slow + def test_unmask_unattended_right_padding(self): + attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64) + + attn_mask_converter = AttentionMaskConverter(is_causal=True) + past_key_values_length = 0 + key_value_length = attention_mask.shape[-1] + past_key_values_length + + expanded_mask = attn_mask_converter.to_4d( + attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + ) + + result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) + + self.assertTrue(torch.equal(expanded_mask, result)) + + @require_torch + @slow + def test_unmask_unattended_random_mask(self): + attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64) + + attn_mask_converter = AttentionMaskConverter(is_causal=True) + past_key_values_length = 0 + key_value_length = attention_mask.shape[-1] + past_key_values_length + + expanded_mask = attn_mask_converter.to_4d( + attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + ) + + result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) + + self.assertTrue(torch.equal(expanded_mask, result)) + @require_torch_sdpa class TestAttentionImplementation(unittest.TestCase): From 546dd51f72bf65d3a8e44e0e82ab25936ae25c96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 09:26:54 +0100 Subject: [PATCH 087/100] update sdpaattention class with the new cache --- src/transformers/models/llama/modeling_llama.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 885f563ccbeea0..1a7262eb4f380f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -668,7 +668,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -698,16 +698,13 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) From 2045915a2b0b30e81a433cfd43e859776132c4dd Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 17:38:10 +0900 Subject: [PATCH 088/100] Update src/transformers/configuration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index a8e4c599b40042..5c419ee0fc7cf3 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -429,7 +429,7 @@ def num_labels(self, num_labels: int): @property def _attn_implementation(self): - # This property is made private for now (as it may not be changed alone and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) if hasattr(self, "_attn_implementation_internal"): if self._attn_implementation_internal is None: # `config.attn_implementation` should never be None, for backward compatibility. From eb1188371ffed1485b28a0b24a4ee799a605e2aa Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 17:41:46 +0900 Subject: [PATCH 089/100] Update src/transformers/models/bark/modeling_bark.py --- src/transformers/models/bark/modeling_bark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 06ed527d97905f..703886d500ba12 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -1897,7 +1897,7 @@ def _check_and_enable_flash_attn_2( The method checks if the current setup is compatible with Flash Attention as it requires the model to be in half precision and not ran on CPU. - If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module """ config = super()._check_and_enable_flash_attn_2( From 920686edda363f86fb9e2bdded1481a600915cd0 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 09:46:09 +0000 Subject: [PATCH 090/100] address review comments --- src/transformers/modeling_utils.py | 17 +- src/transformers/models/bart/modeling_bart.py | 3 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 3 +- .../models/llama/modeling_llama.py | 3 +- .../models/whisper/modeling_whisper.py | 3 +- tests/test_modeling_common.py | 165 ++++++++---------- tests/test_modeling_utils.py | 32 +++- 7 files changed, 114 insertions(+), 112 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 63e61e3e06af1a..2588893d2575be 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1247,7 +1247,8 @@ def _autoset_attn_implementation( if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: raise ValueError( - f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible. We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' + f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.' + ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]: @@ -1354,15 +1355,14 @@ def _check_and_enable_flash_attn_2( ) if not is_flash_attn_2_available(): - flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) - preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." - if torch.version.cuda: - if importlib.util.find_spec("flash_attn") is None: - raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") - flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + if importlib.util.find_spec("flash_attn") is None: + raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + + flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + if torch.version.cuda: if flash_attention_version < version.parse("2.1.0"): raise ImportError( f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" @@ -1370,9 +1370,6 @@ def _check_and_enable_flash_attn_2( else: raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") elif torch.version.hip: - if importlib.util.find_spec("flash_attn") is None: - raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") - flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) if flash_attention_version < version.parse("2.0.4"): raise ImportError( f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}" diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index beb161ec989e5a..f284bb73a227be 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -524,7 +524,8 @@ def forward( if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( - "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 6092f15e232088..eab28ad131e800 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -622,7 +622,8 @@ def forward( else: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( - "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None." + ' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1a7262eb4f380f..4f4d6248925bfa 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -675,7 +675,8 @@ def forward( if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d578bfdabe44db..1d4a7a7a1ac483 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -705,7 +705,8 @@ def forward( if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( - "WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards." + "WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c207b1b3428ae7..8dbac4581822f8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3146,6 +3146,34 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): elif torch_dtype == "float32": torch_dtype = torch.float32 + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 1e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 1e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3154,25 +3182,18 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model_sdpa = ( - model_class.from_pretrained( - tmpdirname, - torch_dtype=torch_dtype, - ) - .to(torch_device) - .eval() - ) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - model_eager = ( - model_class.from_pretrained( - tmpdirname, - torch_dtype=torch_dtype, - attn_implementation="eager", - ) - .to(torch_device) - .eval() + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", ) + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): @@ -3201,31 +3222,21 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): dummy_input = dummy_input[:batch_size] if dummy_input.shape[0] != batch_size: if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: - dummy_input = torch.cat( - ( - dummy_input, - torch.rand( - batch_size - dummy_input.shape[0], - *dummy_input.shape[1:], - dtype=torch_dtype, - device=torch_device, - ), - ), - dim=0, - ).to(torch_device) + extension = torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) else: - dummy_input = torch.cat( - ( - dummy_input, - torch.randint( - high=5, - size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), - dtype=dummy_input.dtype, - device=torch_device, - ), - ), - dim=0, - ).to(torch_device) + extension = torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) if not use_mask: dummy_attention_mask = None @@ -3242,18 +3253,14 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): dummy_attention_mask = dummy_attention_mask[:batch_size] if dummy_attention_mask.shape[0] != batch_size: - dummy_attention_mask = torch.cat( - ( - dummy_attention_mask, - torch.ones( - batch_size - dummy_attention_mask.shape[0], - *dummy_attention_mask.shape[1:], - dtype=dummy_attention_mask.dtype, - device=torch_device, - ), - ), - dim=0, - ).to(torch_device) + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) dummy_attention_mask[:] = 1 if padding_side == "left": @@ -3268,18 +3275,14 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): if is_encoder_decoder: decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size] if decoder_input_ids.shape[0] != batch_size: - decoder_input_ids = torch.cat( - ( - decoder_input_ids, - torch.ones( - batch_size - decoder_input_ids.shape[0], - *decoder_input_ids.shape[1:], - dtype=decoder_input_ids.dtype, - device=torch_device, - ), - ), - dim=0, - ).to(torch_device) + extension = torch.ones( + batch_size - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) # TODO: never an `attention_mask` arg here? other_inputs = { @@ -3317,41 +3320,13 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): else outputs_sdpa.decoder_hidden_states[-1] ) - if torch_device == "cpu": - if torch_dtype == torch.float32: - atol = 1e-6 - rtol = 1e-4 - elif torch_dtype == torch.bfloat16: - atol = 1e-2 - rtol = 1e-2 - elif torch_device == "cuda": - if not enable_kernels: - if torch_dtype == torch.float32: - atol = 1e-6 - rtol = 1e-4 - elif torch_dtype == torch.bfloat16: - atol = 1e-2 - rtol = 1e-2 - elif torch_dtype == torch.float16: - atol = 1e-3 - rtol = 1e-3 - else: - if torch_dtype == torch.float32: - atol = 1e-6 - rtol = 1e-4 - elif torch_dtype == torch.bfloat16: - atol = 1e-2 - rtol = 3e-2 - elif torch_dtype == torch.float16: - atol = 5e-3 - rtol = 5e-3 + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] else: atol = 1e-7 rtol = 1e-4 - def get_mean_reldiff(failcase, x, ref, atol, rtol): - return f"{failcase}: mean reldiff: {((x - ref).abs() / (ref.abs() + 1e-12)).mean()}, torch atol = {atol}, torch rtol = {rtol}" - # Masked tokens output slightly deviates - we don't mind that. if use_mask: if padding_side == "left": diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index c193b8020a6faa..6091e52e0effc2 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -50,7 +50,6 @@ require_torch, require_torch_accelerator, require_torch_multi_accelerator, - require_torch_sdpa, require_usr_bin_time, slow, torch_device, @@ -61,7 +60,13 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) -from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available +from transformers.utils.import_utils import ( + is_flash_attn_2_available, + is_flax_available, + is_tf_available, + is_torch_sdpa_available, + is_torchdynamo_available, +) sys.path.append(str(Path(__file__).parent.parent / "utils")) @@ -1729,7 +1734,6 @@ def test_unmask_unattended_random_mask(self): self.assertTrue(torch.equal(expanded_mask, result)) -@require_torch_sdpa class TestAttentionImplementation(unittest.TestCase): def test_error_no_sdpa_available(self): with self.assertRaises(ValueError) as cm: @@ -1755,3 +1759,25 @@ def test_error_wrong_attn_implementation(self): _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo") self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception)) + + def test_not_available_flash(self): + if is_flash_attn_2_available(): + self.skipTest("Please uninstall flash-attn package to run test_not_available_flash") + + with self.assertRaises(ImportError) as cm: + _ = AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" + ) + + self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + + def test_not_available_sdpa(self): + if is_torch_sdpa_available(): + self.skipTest("This test requires torch<=2.0") + + with self.assertRaises(ImportError) as cm: + _ = AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa" + ) + + self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception)) From 214685795f8e2527f4c42445e5d745f551b4a4ed Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 10:35:05 +0000 Subject: [PATCH 091/100] WIP torch.jit.trace fix. left: test both eager & sdpa --- src/transformers/modeling_attn_mask_utils.py | 11 ++++++++++- tests/test_modeling_common.py | 18 +++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index e0067ec33a62bf..fcd62facf568c7 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -334,9 +334,16 @@ def _prepare_4d_causal_attention_mask_for_sdpa( key_value_length = input_shape[-1] + past_key_values_length batch_size, query_length = input_shape + # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: Fix this as well when using torchdynamo with fullgraph=True. + is_tracing = torch.jit.is_tracing() + if attention_mask is not None: if torch.all(attention_mask == 1): - if query_length == 1: + if is_tracing: + pass + elif query_length == 1: # For query_length == 1, causal attention and bi-directional attention are the same. attention_mask = None elif key_value_length == query_length: @@ -346,6 +353,8 @@ def _prepare_4d_causal_attention_mask_for_sdpa( # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. # Reference: https://github.com/pytorch/pytorch/issues/108108 pass + elif is_tracing: + raise ValueError('Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.') if attention_mask is not None: expanded_4d_mask = attn_mask_converter.to_4d( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8dbac4581822f8..978e1969bc3f01 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -776,7 +776,6 @@ def _create_and_check_torchscript(self, config, inputs_dict): configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.torchscript = True - configs_no_init._attn_implementation = "eager" for model_class in self.all_model_classes: model = model_class(config=configs_no_init) model.to(torch_device) @@ -813,8 +812,21 @@ def _create_and_check_torchscript(self, config, inputs_dict): ) # when traced model is checked, an error is produced due to name mangling else: main_input = inputs[main_input_name] - model(main_input) - traced_model = torch.jit.trace(model, main_input) + + if model.config._attn_implementation == "sdpa": + trace_input = {main_input_name: main_input} + + if "attention_mask" in inputs: + trace_input["attention_mask"] = inputs["attention_mask"] + else: + self.skipTest("testing SDPA without attention_mask is not supported") + + model(main_input, attention_mask=inputs["attention_mask"]) + # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. + traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) + else: + model(main_input) + traced_model = torch.jit.trace(model, (main_input,)) except RuntimeError: self.fail("Couldn't trace module.") From 9b485913bc288162a7eba6e29ab85d26869d6164 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 13:40:42 +0000 Subject: [PATCH 092/100] add test for torch.jit.trace for both eager/sdpa --- docs/source/en/perf_infer_gpu_one.md | 6 - src/transformers/modeling_attn_mask_utils.py | 4 +- tests/test_modeling_common.py | 185 ++++++++++--------- tests/test_modeling_utils.py | 1 + 4 files changed, 99 insertions(+), 97 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index ef58b879ef5f71..b12670584a4ec5 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -185,12 +185,6 @@ RuntimeError: No available kernel. Aborting execution. pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 ``` - - -As of Transformers 4.36, attention modules using `torch.nn.functional.scaled_dot_product_attention` do not support tracing through [`torch.jit.trace`](https://pytorch.org/docs/stable/generated/torch.jit.trace.html). Please load your model with the argument `attn_implementation="eager"` in [`~PreTrainedModel.from_pretrained`] in order to export to TorchScript through `torch.jit.trace`. - - - ## BetterTransformer diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index fcd62facf568c7..8704a41d013e59 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -354,7 +354,9 @@ def _prepare_4d_causal_attention_mask_for_sdpa( # Reference: https://github.com/pytorch/pytorch/issues/108108 pass elif is_tracing: - raise ValueError('Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.') + raise ValueError( + 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' + ) if attention_mask is not None: expanded_4d_mask = attn_mask_converter.to_4d( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 978e1969bc3f01..f0e6c0f1fce37f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -777,115 +777,120 @@ def _create_and_check_torchscript(self, config, inputs_dict): configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.torchscript = True for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) - model.to(torch_device) - model.eval() - inputs = self._prepare_for_class(inputs_dict, model_class) + for attn_implementation in ["eager", "sdpa"]: + if attn_implementation == "sdpa" and not model_class._supports_sdpa: + continue - main_input_name = model_class.main_input_name + configs_no_init._attn_implementation = attn_implementation + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - main_input = inputs[main_input_name] - attention_mask = inputs["attention_mask"] - decoder_input_ids = inputs["decoder_input_ids"] - decoder_attention_mask = inputs["decoder_attention_mask"] - model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask) - traced_model = torch.jit.trace( - model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask) - ) - elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs - input_ids = inputs["input_ids"] - bbox = inputs["bbox"] - image = inputs["image"].tensor - model(input_ids, bbox, image) - traced_model = torch.jit.trace( - model, (input_ids, bbox, image), check_trace=False - ) # when traced model is checked, an error is produced due to name mangling - elif "bbox" in inputs: # Bros requires additional inputs (bbox) - input_ids = inputs["input_ids"] - bbox = inputs["bbox"] - model(input_ids, bbox) - traced_model = torch.jit.trace( - model, (input_ids, bbox), check_trace=False - ) # when traced model is checked, an error is produced due to name mangling - else: - main_input = inputs[main_input_name] + main_input_name = model_class.main_input_name - if model.config._attn_implementation == "sdpa": - trace_input = {main_input_name: main_input} + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + main_input = inputs[main_input_name] + attention_mask = inputs["attention_mask"] + decoder_input_ids = inputs["decoder_input_ids"] + decoder_attention_mask = inputs["decoder_attention_mask"] + model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask) + traced_model = torch.jit.trace( + model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask) + ) + elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs + input_ids = inputs["input_ids"] + bbox = inputs["bbox"] + image = inputs["image"].tensor + model(input_ids, bbox, image) + traced_model = torch.jit.trace( + model, (input_ids, bbox, image), check_trace=False + ) # when traced model is checked, an error is produced due to name mangling + elif "bbox" in inputs: # Bros requires additional inputs (bbox) + input_ids = inputs["input_ids"] + bbox = inputs["bbox"] + model(input_ids, bbox) + traced_model = torch.jit.trace( + model, (input_ids, bbox), check_trace=False + ) # when traced model is checked, an error is produced due to name mangling + else: + main_input = inputs[main_input_name] - if "attention_mask" in inputs: - trace_input["attention_mask"] = inputs["attention_mask"] - else: - self.skipTest("testing SDPA without attention_mask is not supported") + if model.config._attn_implementation == "sdpa": + trace_input = {main_input_name: main_input} - model(main_input, attention_mask=inputs["attention_mask"]) - # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. - traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) - else: - model(main_input) - traced_model = torch.jit.trace(model, (main_input,)) - except RuntimeError: - self.fail("Couldn't trace module.") + if "attention_mask" in inputs: + trace_input["attention_mask"] = inputs["attention_mask"] + else: + self.skipTest("testing SDPA without attention_mask is not supported") - with tempfile.TemporaryDirectory() as tmp_dir_name: - pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + model(main_input, attention_mask=inputs["attention_mask"]) + # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. + traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) + else: + model(main_input) + traced_model = torch.jit.trace(model, (main_input,)) + except RuntimeError: + self.fail("Couldn't trace module.") - try: - torch.jit.save(traced_model, pt_file_name) - except Exception: - self.fail("Couldn't save module.") + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") - try: - loaded_model = torch.jit.load(pt_file_name) - except Exception: - self.fail("Couldn't load module.") + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") - model.to(torch_device) - model.eval() + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") - loaded_model.to(torch_device) - loaded_model.eval() + model.to(torch_device) + model.eval() - model_state_dict = model.state_dict() - loaded_model_state_dict = loaded_model.state_dict() + loaded_model.to(torch_device) + loaded_model.eval() - non_persistent_buffers = {} - for key in loaded_model_state_dict.keys(): - if key not in model_state_dict.keys(): - non_persistent_buffers[key] = loaded_model_state_dict[key] + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() - loaded_model_state_dict = { - key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers - } + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } - self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) - model_buffers = list(model.buffers()) - for non_persistent_buffer in non_persistent_buffers.values(): - found_buffer = False - for i, model_buffer in enumerate(model_buffers): - if torch.equal(non_persistent_buffer, model_buffer): - found_buffer = True - break + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break - self.assertTrue(found_buffer) - model_buffers.pop(i) + self.assertTrue(found_buffer) + model_buffers.pop(i) - models_equal = True - for layer_name, p1 in model_state_dict.items(): - if layer_name in loaded_model_state_dict: - p2 = loaded_model_state_dict[layer_name] - if p1.data.ne(p2.data).sum() > 0: - models_equal = False + models_equal = True + for layer_name, p1 in model_state_dict.items(): + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False - self.assertTrue(models_equal) + self.assertTrue(models_equal) - # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. - # (Even with this call, there are still memory leak by ~0.04MB) - self.clear_torch_jit_class_registry() + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() def test_torch_fx(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6091e52e0effc2..bab3253d9d6af9 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1734,6 +1734,7 @@ def test_unmask_unattended_random_mask(self): self.assertTrue(torch.equal(expanded_mask, result)) +@require_torch class TestAttentionImplementation(unittest.TestCase): def test_error_no_sdpa_available(self): with self.assertRaises(ValueError) as cm: From cc7fc4ec4b0dea82fd7364c128b3e6a351a9478c Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 14:03:08 +0000 Subject: [PATCH 093/100] fix falcon with torch==2.0 that needs to use sdpa --- .../models/falcon/modeling_falcon.py | 10 ++++++- tests/models/llama/test_modeling_llama.py | 28 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 33a52a8231ff01..90a2deb6a26b6f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -43,6 +43,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal_than_2_0, logging, ) from .configuration_falcon import FalconConfig @@ -958,7 +959,14 @@ def _init_weights(self, module: nn.Module): # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa @classmethod def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig": - # We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1). + # NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0). + if hard_check_only: + if not is_torch_greater_or_equal_than_2_0: + raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.") + + if not is_torch_greater_or_equal_than_2_0: + return config + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) if _is_bettertransformer: return config diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 24a4f013a13818..427f94f873cff2 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -14,6 +14,7 @@ # limitations under the License. """ Testing suite for the PyTorch LLaMA model. """ +import tempfile import unittest import pytest @@ -420,6 +421,33 @@ def test_flash_attn_2_generate_padding_right(self): self.assertListEqual(output_native, output_fa_2) + @require_flash_attn + @require_torch_gpu + @slow + def test_use_flash_attention_2_true(self): + """ + NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended. + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_class(config) + model.save_pretrained(tmp_dir) + + new_model = LlamaForCausalLM.from_pretrained( + tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16 + ).to("cuda") + + self.assertTrue(new_model.config._attn_implementation == "flash_attention_2") + + has_flash = False + for name, submodule in new_model.named_modules(): + if "FlashAttention" in submodule.__class__.__name__: + has_flash = True + break + if not has_flash: + raise ValueError("The flash model should have flash attention layers") + @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self): From 84867708d2654541029cef26f509e9c4356cd18b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 15:13:40 +0100 Subject: [PATCH 094/100] fix doc --- docs/source/en/model_doc/gpt_neox.md | 4 ++-- docs/source/en/model_doc/phi.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/gpt_neox.md b/docs/source/en/model_doc/gpt_neox.md index 1885d44450aab9..fd105a3e82e1ee 100644 --- a/docs/source/en/model_doc/gpt_neox.md +++ b/docs/source/en/model_doc/gpt_neox.md @@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation ### Usage -To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference: +To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference: ```python >>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast -model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) +model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device) ... ``` diff --git a/docs/source/en/model_doc/phi.md b/docs/source/en/model_doc/phi.md index 03eac894162724..3076aa378cbe85 100644 --- a/docs/source/en/model_doc/phi.md +++ b/docs/source/en/model_doc/phi.md @@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: >>> from transformers import PhiForCausalLM, AutoTokenizer >>> # define the model and tokenizer and push the model and tokens to the GPU. ->>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda") +>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda") >>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev") >>> # feel free to change the prompt to your liking. @@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t - forward - \ No newline at end of file + From c6181f2e0eb8fc08c7a047c2d260d7b0f284247c Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 14:19:22 +0000 Subject: [PATCH 095/100] hopefully last fix --- src/transformers/models/falcon/modeling_falcon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 90a2deb6a26b6f..4684f46ded0242 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -37,13 +37,13 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import is_torch_greater_or_equal_than_2_0 from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, - is_torch_greater_or_equal_than_2_0, logging, ) from .configuration_falcon import FalconConfig From 7ebfd1d0dadaf797ac4ab790b62ffe889537a820 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 14:44:52 +0000 Subject: [PATCH 096/100] fix key_value_length that has no default now in mask converter --- src/transformers/modeling_attn_mask_utils.py | 5 ++++- tests/test_modeling_utils.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 602e024d7a4a0b..ac6d97a845df08 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -360,7 +360,10 @@ def _prepare_4d_causal_attention_mask_for_sdpa( if attention_mask is not None: expanded_4d_mask = attn_mask_converter.to_4d( - attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, ) # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 994a7496541f31..ddfaad5214dc50 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1728,7 +1728,7 @@ def test_unmask_unattended_left_padding(self): key_value_length = attention_mask.shape[-1] + past_key_values_length expanded_mask = attn_mask_converter.to_4d( - attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32 ) result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) @@ -1777,7 +1777,7 @@ def test_unmask_unattended_right_padding(self): key_value_length = attention_mask.shape[-1] + past_key_values_length expanded_mask = attn_mask_converter.to_4d( - attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32 ) result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) @@ -1794,7 +1794,7 @@ def test_unmask_unattended_random_mask(self): key_value_length = attention_mask.shape[-1] + past_key_values_length expanded_mask = attn_mask_converter.to_4d( - attention_mask, attention_mask.shape[-1], key_value_length, dtype=torch.float32 + attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32 ) result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0) From dacf149fcc648f8e2435b0656f33c89ed86c9c09 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 15:07:52 +0000 Subject: [PATCH 097/100] is it flacky? From 810de1a9caf26c70d3094192242514a417522758 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 19:10:49 +0000 Subject: [PATCH 098/100] fix speculative decoding bug --- src/transformers/modeling_attn_mask_utils.py | 26 +++++++++++++++---- src/transformers/models/bart/modeling_bart.py | 2 -- .../models/whisper/modeling_whisper.py | 2 -- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index ac6d97a845df08..734f443e1fc9d4 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -353,12 +353,22 @@ def _prepare_4d_causal_attention_mask_for_sdpa( # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. # Reference: https://github.com/pytorch/pytorch/issues/108108 pass + elif query_length > 1 and key_value_length != query_length: + # See the comment above (https://github.com/pytorch/pytorch/issues/108108). + # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`. + attention_mask = True elif is_tracing: raise ValueError( 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' ) - if attention_mask is not None: + if attention_mask is None: + expanded_4d_mask = None + elif attention_mask is True: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + else: expanded_4d_mask = attn_mask_converter.to_4d( attention_mask, input_shape[-1], @@ -372,8 +382,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa( expanded_4d_mask = AttentionMaskConverter._unmask_unattended( expanded_4d_mask, attention_mask, unmasked_value=0.0 ) - else: - expanded_4d_mask = None return expanded_4d_mask @@ -409,8 +417,16 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, """ batch_size, key_value_length = mask.shape tgt_len = tgt_len if tgt_len is not None else key_value_length - if batch_size == 1 and torch.all(mask == 1): - if tgt_len == 1: + + # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: Fix this as well when using torchdynamo with fullgraph=True. + is_tracing = torch.jit.is_tracing() + + if torch.all(mask == 1): + if is_tracing: + pass + elif tgt_len == 1: # For query_length == 1, causal attention and bi-directional attention are the same. return None elif key_value_length == tgt_len: diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index f284bb73a227be..16527216c7a501 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -582,8 +582,6 @@ def forward( past_key_value = (key_states, value_states) query_states = self._shape(query_states, tgt_len, bsz) - key_states = key_states - value_states = value_states attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 1d4a7a7a1ac483..21c4c82b7f40e7 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -763,8 +763,6 @@ def forward( past_key_value = (key_states, value_states) query_states = self._shape(query_states, tgt_len, bsz) - key_states = key_states - value_states = value_states attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, From f116cce7cc9e8c02a4a1cb22d543c75768a42668 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 19:38:24 +0000 Subject: [PATCH 099/100] tests do pass From 3f06a3a0aec8cc1ec3ad6bf66ebe277392c5ab37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 21:10:01 +0100 Subject: [PATCH 100/100] fix following #27907 --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3c7523684a2504..43d5c6faef86ed 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -699,7 +699,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)