Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unblock Llama2 ONNX export w/ sdpa by falling back to manual impl #28823

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,12 +673,22 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
_jit_tracing = torch.jit.is_tracing()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that we call torch.jit.is_tracing as many times as there are layers.

_fallback_to_manual = _jit_tracing or output_attentions
# TODO: Improve these warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
if _jit_tracing:
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but SDPA can not be traced with torch.jit.trace when no attention_mask is provided. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)

if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)

if _fallback_to_manual:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1029,9 +1039,11 @@ def forward(
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
elif self._use_sdpa and not output_attentions and not torch.jit.is_tracing():
# NOTE: Fallback to eager attention in any of the following cases:
# - output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# - Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
Expand Down