Skip to content

Commit

Permalink
Revert qwen2 breaking changes related to attention refactor (#36162)
Browse files Browse the repository at this point in the history
* dito

* add a test

* upsate

* test needs fa2

* update test and configuration

* test requires fa2

* style
  • Loading branch information
ArthurZucker authored Feb 14, 2025
1 parent cb586a3 commit 96f01a3
Show file tree
Hide file tree
Showing 4 changed files with 468 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2/configuration_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code
self.max_window_layers = max_window_layers

# for backward compatibility
Expand Down
57 changes: 43 additions & 14 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
Expand Down Expand Up @@ -616,7 +616,15 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

Expand All @@ -625,21 +633,30 @@ def _update_causal_mask(
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
Expand All @@ -656,6 +673,8 @@ def _update_causal_mask(
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)

if (
Expand All @@ -667,7 +686,6 @@ def _update_causal_mask(
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask
Expand All @@ -681,21 +699,20 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
config: Qwen2Config,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
Expand All @@ -704,6 +721,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen2Config`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
Expand All @@ -713,12 +734,21 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
Expand All @@ -727,7 +757,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

return causal_mask


Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/qwen2/modular_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaMLP,
LlamaModel,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..mistral.modeling_mistral import MistralModel
from .configuration_qwen2 import Qwen2Config


Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
)


class Qwen2Model(LlamaModel):
class Qwen2Model(MistralModel):
pass


Expand Down
Loading

0 comments on commit 96f01a3

Please sign in to comment.