Skip to content

Commit

Permalink
Update src/transformers/models/llama/modeling_llama.py
Browse files Browse the repository at this point in the history
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
fxmarty and ArthurZucker authored Apr 17, 2024
1 parent 38fb6f6 commit 70d903f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,9 +1073,10 @@ def _update_causal_mask(
if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
)
):
return None

if ignore_causal_mask:
return None
Expand Down

0 comments on commit 70d903f

Please sign in to comment.