From 26e967ef1e95de10649dfae9e024fb317af2ff87 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 23 Apr 2024 18:15:21 -0500 Subject: [PATCH 1/6] Reenable SDPA's FA2 during training with torch.compile --- src/transformers/modeling_attn_mask_utils.py | 7 ++-- .../models/cohere/modeling_cohere.py | 32 ++++++++++++------- .../models/gemma/modeling_gemma.py | 32 ++++++++++++------- .../models/llama/modeling_llama.py | 32 ++++++++++++------- 4 files changed, 67 insertions(+), 36 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index c69d9555b2af..3301d9557d5e 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -240,6 +240,7 @@ def _ignore_causal_mask_sdpa( inputs_embeds: torch.Tensor, past_key_values_length: int, sliding_window: Optional[int] = None, + is_training: bool = False, ) -> bool: """ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. @@ -263,11 +264,11 @@ def _ignore_causal_mask_sdpa( if attention_mask is None: # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). - # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. + # Thus, we only set `ignore_causal_mask = True` if the model is set to training. # # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). if ( - not is_tracing + (is_training or not is_tracing) and (query_length == 1 or key_value_length == query_length) and (sliding_window is None or key_value_length < sliding_window) ): @@ -279,7 +280,7 @@ def _ignore_causal_mask_sdpa( raise ValueError( f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." ) - elif not is_tracing and torch.all(attention_mask == 1): + elif (is_training or not is_tracing) and torch.all(attention_mask == 1): if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. ignore_causal_mask = True diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bb4c051692..834c9ef83d83 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -590,16 +590,23 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, - ) + # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + if causal_mask is None and q_len > 1: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=True, + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -996,7 +1003,10 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b6b207748a..cb528fcb1b41 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -570,16 +570,23 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, - ) + # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + if causal_mask is None and q_len > 1: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=True, + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) @@ -982,7 +989,10 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a6..347a63e92582 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -666,16 +666,23 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, - ) + # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + if causal_mask is None and q_len > 1: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=True, + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -1074,7 +1081,10 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None From bc8aa59aec7e74965a0fa89c2d8a9ea6356714a1 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 23 Apr 2024 18:46:47 -0500 Subject: [PATCH 2/6] fix Olmo's SDPA FA2 dispatching too --- src/transformers/models/olmo/modeling_olmo.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index e3b0e05127c5..b53954e35c9e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -647,14 +647,23 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, - ) + # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + if causal_mask is None and q_len > 1: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=True, + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -1057,7 +1066,10 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None From 76186c7031e0aebd85fbeb165a382e25b7c8d4a5 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Wed, 24 Apr 2024 12:39:46 -0500 Subject: [PATCH 3/6] update formatting --- src/transformers/models/cohere/modeling_cohere.py | 13 +++---------- src/transformers/models/gemma/modeling_gemma.py | 13 +++---------- src/transformers/models/llama/modeling_llama.py | 13 +++---------- src/transformers/models/olmo/modeling_olmo.py | 13 +++---------- 4 files changed, 12 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 834c9ef83d83..407d5f15d31e 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -591,21 +591,14 @@ def forward( value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + dropout = self.attention_dropout if self.training else 0.0 if causal_mask is None and q_len > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=True, + query_states, key_states, value_states, None, dropout, True ) else: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + query_states, key_states, value_states, causal_mask, dropout ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index cb528fcb1b41..c8f6fc5f0efd 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -571,21 +571,14 @@ def forward( value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + dropout = self.attention_dropout if self.training else 0.0 if causal_mask is None and q_len > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=True, + query_states, key_states, value_states, None, dropout, True ) else: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + query_states, key_states, value_states, causal_mask, dropout ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 347a63e92582..e796d3565231 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -667,21 +667,14 @@ def forward( value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + dropout = self.attention_dropout if self.training else 0.0 if causal_mask is None and q_len > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=True, + query_states, key_states, value_states, None, dropout, True ) else: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + query_states, key_states, value_states, causal_mask, dropout ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index b53954e35c9e..da431d4eaabf 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -648,21 +648,14 @@ def forward( value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + dropout = self.attention_dropout if self.training else 0.0 if causal_mask is None and q_len > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=True, + query_states, key_states, value_states, None, dropout, True ) else: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + query_states, key_states, value_states, causal_mask, dropout ) attn_output = attn_output.transpose(1, 2).contiguous() From f3f1347f57b17726c45baeff7d290c821fdc87f4 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Wed, 24 Apr 2024 12:50:21 -0500 Subject: [PATCH 4/6] improved SDPA comment --- src/transformers/models/cohere/modeling_cohere.py | 3 ++- src/transformers/models/gemma/modeling_gemma.py | 3 ++- src/transformers/models/llama/modeling_llama.py | 3 ++- src/transformers/models/olmo/modeling_olmo.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 407d5f15d31e..9465e80faef2 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -590,7 +590,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + # We dispatch to SDPA's Flash Attention 2 or Efficient kernels via the if statement instead of + # setting `is_causal` inline to support both torch.compile's `dynamic=True` & `fullgraph=True` dropout = self.attention_dropout if self.training else 0.0 if causal_mask is None and q_len > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c8f6fc5f0efd..d86eef1dbfc1 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -570,7 +570,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + # We dispatch to SDPA's Flash Attention 2 or Efficient kernels via the if statement instead of + # setting `is_causal` inline to support both torch.compile's `dynamic=True` & `fullgraph=True` dropout = self.attention_dropout if self.training else 0.0 if causal_mask is None and q_len > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e796d3565231..8354125a86cd 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -666,7 +666,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + # We dispatch to SDPA's Flash Attention 2 or Efficient kernels via the if statement instead of + # setting `is_causal` inline to support both torch.compile's `dynamic=True` & `fullgraph=True` dropout = self.attention_dropout if self.training else 0.0 if causal_mask is None and q_len > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index da431d4eaabf..520d2514350a 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -647,7 +647,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention 2 backend via the if statement to support both torch.compile's dynamic=True & fullgraph=True + # We dispatch to SDPA's Flash Attention 2 or Efficient kernels via the if statement instead of + # setting `is_causal` inline to support both torch.compile's `dynamic=True` & `fullgraph=True` dropout = self.attention_dropout if self.training else 0.0 if causal_mask is None and q_len > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( From 60837f9d77bb24de5cba57ffaa8e29c1eefacf97 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 26 Apr 2024 12:14:14 -0500 Subject: [PATCH 5/6] formatting and explanatory comment --- .../models/cohere/modeling_cohere.py | 22 +++++++++++-------- .../models/gemma/modeling_gemma.py | 22 +++++++++++-------- .../models/llama/modeling_llama.py | 22 +++++++++++-------- src/transformers/models/olmo/modeling_olmo.py | 22 +++++++++++-------- 4 files changed, 52 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9465e80faef2..e22d7a846f49 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -590,17 +590,21 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention 2 or Efficient kernels via the if statement instead of - # setting `is_causal` inline to support both torch.compile's `dynamic=True` & `fullgraph=True` - dropout = self.attention_dropout if self.training else 0.0 + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` if causal_mask is None and q_len > 1: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, None, dropout, True - ) + is_causal = True else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, causal_mask, dropout - ) + is_causal = False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index d86eef1dbfc1..be45c6d600c8 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -570,17 +570,21 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention 2 or Efficient kernels via the if statement instead of - # setting `is_causal` inline to support both torch.compile's `dynamic=True` & `fullgraph=True` - dropout = self.attention_dropout if self.training else 0.0 + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` if causal_mask is None and q_len > 1: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, None, dropout, True - ) + is_causal = True else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, causal_mask, dropout - ) + is_causal = False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8354125a86cd..bf2137f3bf24 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -666,17 +666,21 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention 2 or Efficient kernels via the if statement instead of - # setting `is_causal` inline to support both torch.compile's `dynamic=True` & `fullgraph=True` - dropout = self.attention_dropout if self.training else 0.0 + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` if causal_mask is None and q_len > 1: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, None, dropout, True - ) + is_causal = True else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, causal_mask, dropout - ) + is_causal = False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 520d2514350a..8a75b2dd6fd0 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -647,17 +647,21 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention 2 or Efficient kernels via the if statement instead of - # setting `is_causal` inline to support both torch.compile's `dynamic=True` & `fullgraph=True` - dropout = self.attention_dropout if self.training else 0.0 + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` if causal_mask is None and q_len > 1: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, None, dropout, True - ) + is_causal = True else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, causal_mask, dropout - ) + is_causal = False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) From 0aab518d0e5fd8cff5517b7c3ad960853e5936e4 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Mon, 29 Apr 2024 10:30:18 -0500 Subject: [PATCH 6/6] is_causal if statement to one-liner --- src/transformers/models/cohere/modeling_cohere.py | 5 +---- src/transformers/models/gemma/modeling_gemma.py | 5 +---- src/transformers/models/llama/modeling_llama.py | 5 +---- src/transformers/models/olmo/modeling_olmo.py | 5 +---- 4 files changed, 4 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index e22d7a846f49..3d529fd1ec4f 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -592,10 +592,7 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` - if causal_mask is None and q_len > 1: - is_causal = True - else: - is_causal = False + is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index be45c6d600c8..97e4e5d49f8e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -572,10 +572,7 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` - if causal_mask is None and q_len > 1: - is_causal = True - else: - is_causal = False + is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index bf2137f3bf24..9a2566f2fdd2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -668,10 +668,7 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` - if causal_mask is None and q_len > 1: - is_causal = True - else: - is_causal = False + is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 8a75b2dd6fd0..87db966e2d8f 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -649,10 +649,7 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` - if causal_mask is None and q_len > 1: - is_causal = True - else: - is_causal = False + is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states,