From e5f88ae0765ca371d29fb49402bfc96d6a9d7298 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Thu, 30 Jan 2025 09:22:33 +0100 Subject: [PATCH] Fix is_causal being a tensor (#35791) * fix is_causal being a tensor * convert in sdpa attention only when jit tracing --- src/transformers/integrations/sdpa_attention.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 38701690bf7c..df6e96131a91 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -45,6 +45,11 @@ def sdpa_attention_forward( if is_causal is None: is_causal = causal_mask is None and query.shape[2] > 1 + # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. + # We convert it to a bool for the SDPA kernel that only accepts bools. + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + attn_output = torch.nn.functional.scaled_dot_product_attention( query, key,