Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SDPA sliding window compatibility #30127

Merged
merged 5 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

key_value_length = input_shape[-1] + past_key_values_length
batch_size, query_length = input_shape
_, query_length = input_shape

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
Expand All @@ -316,7 +316,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

if attention_mask is not None:
ignore_causal_mask = False

if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
elif sliding_window is None or key_value_length < sliding_window:
Comment on lines +319 to +324
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are basically 2 cases:

  1. You ignore the causal mask
  2. You don't ignore it.
    Code is really not super clear but we will refactor this soon anyways.

# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
Expand All @@ -335,26 +340,17 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
attention_mask = None
ignore_causal_mask = True
elif key_value_length == query_length:
attention_mask = None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
pass
elif query_length > 1 and key_value_length != query_length:
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
attention_mask = True
elif is_tracing:
raise ValueError(
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
)
ignore_causal_mask = True

if attention_mask is None:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108

if ignore_causal_mask:
expanded_4d_mask = None
elif attention_mask is True:
elif attention_mask is None:
expanded_4d_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
Expand Down
51 changes: 51 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3835,6 +3835,57 @@ def test_eager_matches_sdpa_generate(self):

self.assertTrue(torch.allclose(res_eager, res_sdpa))

@require_torch_sdpa
def test_sdpa_matches_eager_sliding_window(self):
WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"]

if len(self.all_generative_model_classes) == 0:
self.skipTest(f"No generative model classes for {self.__class__.__name__}")

for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

if config.model_type not in WINDOW_ATTENTION_MODELS:
self.skipTest(f"{config.model_type} does not use window attention")

config.sliding_window = 2

dummy_input = inputs_dict[model_class.main_input_name]
attention_mask = inputs_dict["attention_mask"]

self.assertTrue(dummy_input.ndim == 2)
self.assertTrue(dummy_input.shape[1] > 6)

with tempfile.TemporaryDirectory() as tmpdir:
with torch.device(torch_device):
model_eager = AutoModelForCausalLM.from_config(
config, attn_implementation="eager", torch_dtype=torch.float32
)

model_eager.save_pretrained(tmpdir)

with torch.device(torch_device):
model_sdpa = AutoModelForCausalLM.from_pretrained(
tmpdir, attn_implementation="sdpa", torch_dtype=torch.float32
)

model_eager = model_eager.eval()
model_sdpa = model_sdpa.eval()

with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=False,
):
res_eager = model_eager(**inputs_dict, return_dict=False)[0]
res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]

# Only non-padding tokens are expected to match.
self.assertTrue(
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-3)
)

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
Expand Down
Loading