diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 8ad68f39db9134..43da8917b23075 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -305,7 +305,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length - batch_size, query_length = input_shape + _, query_length = input_shape # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. @@ -316,7 +316,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa( or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) - if attention_mask is not None: + ignore_causal_mask = False + + if attention_mask is None: + if sliding_window is None or key_value_length < sliding_window: + ignore_causal_mask = not is_tracing + elif sliding_window is None or key_value_length < sliding_window: # 4d mask is passed through if len(attention_mask.shape) == 4: expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) @@ -335,26 +340,17 @@ def _prepare_4d_causal_attention_mask_for_sdpa( elif not is_tracing and torch.all(attention_mask == 1): if query_length == 1: # For query_length == 1, causal attention and bi-directional attention are the same. - attention_mask = None + ignore_causal_mask = True elif key_value_length == query_length: - attention_mask = None - else: - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 - pass - elif query_length > 1 and key_value_length != query_length: - # See the comment above (https://github.com/pytorch/pytorch/issues/108108). - # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`. - attention_mask = True - elif is_tracing: - raise ValueError( - 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' - ) + ignore_causal_mask = True - if attention_mask is None: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + + if ignore_causal_mask: expanded_4d_mask = None - elif attention_mask is True: + elif attention_mask is None: expanded_4d_mask = attn_mask_converter.to_causal_4d( input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e219271e8ee5c3..c013967c78f116 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1006,6 +1006,7 @@ def forward( (batch_size, seq_length), inputs_embeds, past_key_values_length, + sliding_window=self.config.sliding_window, ) else: # 4d mask is passed through the layers diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index baa33421d9533e..eb5794f640a080 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1191,6 +1191,7 @@ def forward( (batch_size, seq_length), inputs_embeds, past_key_values_length, + sliding_window=self.config.sliding_window, ) else: # 4d mask is passed through the layers diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 7ca32c37685c3c..b5a1370ae1fc8f 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1017,6 +1017,7 @@ def forward( (batch_size, seq_length), inputs_embeds, past_key_values_length, + sliding_window=self.config.sliding_window, ) else: # 4d mask is passed through the layers diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index cab2ef5ff7e578..7b9e2f6fc0ab9f 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1183,6 +1183,7 @@ def forward( (batch_size, seq_length), inputs_embeds, past_key_values_length, + sliding_window=self.config.sliding_window, ) else: # 4d mask is passed through the layers diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 85a76f87b8d6e5..ca4c8af23304f9 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -995,6 +995,7 @@ def forward( (batch_size, seq_length), inputs_embeds, past_key_values_length, + sliding_window=self.config.sliding_window, ) else: # 4d mask is passed through the layers diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a396ac752d324a..fd23a3f5ee9ffa 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3841,6 +3841,57 @@ def test_eager_matches_sdpa_generate(self): self.assertTrue(torch.allclose(res_eager, res_sdpa)) + @require_torch_sdpa + def test_sdpa_matches_eager_sliding_window(self): + WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"] + + if len(self.all_generative_model_classes) == 0: + self.skipTest(f"No generative model classes for {self.__class__.__name__}") + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if config.model_type not in WINDOW_ATTENTION_MODELS: + self.skipTest(f"{config.model_type} does not use window attention") + + config.sliding_window = 2 + + dummy_input = inputs_dict[model_class.main_input_name] + attention_mask = inputs_dict["attention_mask"] + + self.assertTrue(dummy_input.ndim == 2) + self.assertTrue(dummy_input.shape[1] > 6) + + with tempfile.TemporaryDirectory() as tmpdir: + with torch.device(torch_device): + model_eager = AutoModelForCausalLM.from_config( + config, attn_implementation="eager", torch_dtype=torch.float32 + ) + + model_eager.save_pretrained(tmpdir) + + with torch.device(torch_device): + model_sdpa = AutoModelForCausalLM.from_pretrained( + tmpdir, attn_implementation="sdpa", torch_dtype=torch.float32 + ) + + model_eager = model_eager.eval() + model_sdpa = model_sdpa.eval() + + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + ): + res_eager = model_eager(**inputs_dict, return_dict=False)[0] + res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0] + + # Only non-padding tokens are expected to match. + self.assertTrue( + torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-3) + ) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test