From a94a441de8426a28acd97ccf0ca2e4c6e05a0841 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 17 Apr 2024 11:40:31 +0200 Subject: [PATCH] simplify --- src/transformers/modeling_attn_mask_utils.py | 48 +++++--------------- 1 file changed, 11 insertions(+), 37 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index ecd149237a38d1..8ae9b57b6c43be 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -239,6 +239,7 @@ def _ignore_causal_mask_sdpa( attention_mask: Optional[torch.Tensor], inputs_embeds: torch.Tensor, past_key_values_length: int, + sliding_window: Optional[int] = None, ) -> bool: """ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. @@ -265,8 +266,9 @@ def _ignore_causal_mask_sdpa( # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. # # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). - ignore_causal_mask = not is_tracing - else: + if sliding_window is None or key_value_length < sliding_window: + ignore_causal_mask = not is_tracing + elif sliding_window is None or key_value_length < sliding_window: if len(attention_mask.shape) == 4: expected_shape = (batch_size, 1, query_length, key_value_length) if tuple(attention_mask.shape) != expected_shape: @@ -274,11 +276,9 @@ def _ignore_causal_mask_sdpa( f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." ) elif not is_tracing and torch.all(attention_mask == 1): - if query_length == 1: + if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. ignore_causal_mask = True - elif key_value_length == query_length: - ignore_causal_mask = True # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. @@ -358,7 +358,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa( attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length - _, query_length = input_shape # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. @@ -369,37 +368,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa( or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) - ignore_causal_mask = False - - if attention_mask is None: - if sliding_window is None or key_value_length < sliding_window: - ignore_causal_mask = not is_tracing - elif sliding_window is None or key_value_length < sliding_window: - # 4d mask is passed through - if len(attention_mask.shape) == 4: - expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) - else: - # if the 4D mask has correct shape - invert it and fill with negative infinity - inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) - attention_mask = inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min - ) - return attention_mask - - elif not is_tracing and torch.all(attention_mask == 1): - if query_length == 1: - # For query_length == 1, causal attention and bi-directional attention are the same. - ignore_causal_mask = True - elif key_value_length == query_length: - ignore_causal_mask = True - - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 + ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) if ignore_causal_mask: expanded_4d_mask = None