Skip to content

Commit

Permalink
Fix SDPA sliding window compatibility (#30127)
Browse files Browse the repository at this point in the history
* fix sdpa + sliding window

* give credit

Co-authored-by: ehuaa <[email protected]>

* remove unnecessary warning

* fix typog

* add test

---------

Co-authored-by: ehuaa <[email protected]>
  • Loading branch information
fxmarty and ehuaa authored Apr 17, 2024
1 parent 5fabebd commit 40eb6d6
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 19 deletions.
34 changes: 15 additions & 19 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

key_value_length = input_shape[-1] + past_key_values_length
batch_size, query_length = input_shape
_, query_length = input_shape

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
Expand All @@ -316,7 +316,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

if attention_mask is not None:
ignore_causal_mask = False

if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
elif sliding_window is None or key_value_length < sliding_window:
# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
Expand All @@ -335,26 +340,17 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
attention_mask = None
ignore_causal_mask = True
elif key_value_length == query_length:
attention_mask = None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
pass
elif query_length > 1 and key_value_length != query_length:
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
attention_mask = True
elif is_tracing:
raise ValueError(
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
)
ignore_causal_mask = True

if attention_mask is None:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108

if ignore_causal_mask:
expanded_4d_mask = None
elif attention_mask is True:
elif attention_mask is None:
expanded_4d_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
51 changes: 51 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3841,6 +3841,57 @@ def test_eager_matches_sdpa_generate(self):

self.assertTrue(torch.allclose(res_eager, res_sdpa))

@require_torch_sdpa
def test_sdpa_matches_eager_sliding_window(self):
WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"]

if len(self.all_generative_model_classes) == 0:
self.skipTest(f"No generative model classes for {self.__class__.__name__}")

for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

if config.model_type not in WINDOW_ATTENTION_MODELS:
self.skipTest(f"{config.model_type} does not use window attention")

config.sliding_window = 2

dummy_input = inputs_dict[model_class.main_input_name]
attention_mask = inputs_dict["attention_mask"]

self.assertTrue(dummy_input.ndim == 2)
self.assertTrue(dummy_input.shape[1] > 6)

with tempfile.TemporaryDirectory() as tmpdir:
with torch.device(torch_device):
model_eager = AutoModelForCausalLM.from_config(
config, attn_implementation="eager", torch_dtype=torch.float32
)

model_eager.save_pretrained(tmpdir)

with torch.device(torch_device):
model_sdpa = AutoModelForCausalLM.from_pretrained(
tmpdir, attn_implementation="sdpa", torch_dtype=torch.float32
)

model_eager = model_eager.eval()
model_sdpa = model_sdpa.eval()

with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=False,
):
res_eager = model_eager(**inputs_dict, return_dict=False)[0]
res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]

# Only non-padding tokens are expected to match.
self.assertTrue(
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-3)
)

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
Expand Down

0 comments on commit 40eb6d6

Please sign in to comment.