Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix static generation when compiling! #28937

Merged
merged 42 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2187685
wow I was scared!
ArthurZucker Feb 9, 2024
4922c92
fix everything
ArthurZucker Feb 9, 2024
56768a0
nits
ArthurZucker Feb 9, 2024
b565051
make it BC?
ArthurZucker Feb 12, 2024
99afd1a
add todo
ArthurZucker Feb 12, 2024
edc498f
nits
ArthurZucker Feb 12, 2024
651c4bd
is_tracing should still be used to pass tracing tests
ArthurZucker Feb 12, 2024
f69626e
nits
ArthurZucker Feb 12, 2024
96136ac
some nits to make sure genration works with static cache uncompiled
ArthurZucker Feb 12, 2024
d5ebd80
fix sdpa
ArthurZucker Feb 12, 2024
70adcf6
fix FA2 for both static and dynamic in a better way?
ArthurZucker Feb 14, 2024
61ed4cb
style
ArthurZucker Feb 14, 2024
fedc563
fix-copies
ArthurZucker Feb 14, 2024
0195d58
fix fix copies
ArthurZucker Feb 14, 2024
07f3adb
fix sequential beam searcg
ArthurZucker Feb 14, 2024
9402c25
style
ArthurZucker Feb 14, 2024
86303c4
use `keys_to_ignore`
ArthurZucker Feb 14, 2024
fb9e907
nit
ArthurZucker Feb 14, 2024
9aa667e
correct dtype inference when init
ArthurZucker Feb 14, 2024
68a5f29
:( the fix for FA2 is still not optimal to investigate!
ArthurZucker Feb 14, 2024
3b9969b
styling
ArthurZucker Feb 14, 2024
162ab87
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 14, 2024
914b0d7
nits
ArthurZucker Feb 14, 2024
e79f79f
nit
ArthurZucker Feb 14, 2024
ee2317d
this might work better
ArthurZucker Feb 14, 2024
93b2691
add comment
ArthurZucker Feb 14, 2024
3619ed3
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 14, 2024
c23cdc4
"position_ids" -> "cache_position"
ArthurZucker Feb 14, 2024
717a8e7
style
ArthurZucker Feb 14, 2024
7fe0964
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 14, 2024
464c463
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 15, 2024
80148ab
nit
ArthurZucker Feb 15, 2024
c9f3c82
Remove changes that should no be propagatted just yet
ArthurZucker Feb 15, 2024
5f54d84
Apply suggestions from code review
ArthurZucker Feb 15, 2024
b3fc042
Styling
ArthurZucker Feb 15, 2024
5fdb2da
make sure we raise an errir for static cache with FA2 enabled
ArthurZucker Feb 15, 2024
03edf91
move to the bottom of the signature
ArthurZucker Feb 15, 2024
b762304
style
ArthurZucker Feb 15, 2024
9fbe901
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 15, 2024
7afe7d9
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 15, 2024
3772d1c
nit in the name
ArthurZucker Feb 15, 2024
cf0bc32
Merge branches 'fix-static-kv-cache' and 'fix-static-kv-cache' of git…
ArthurZucker Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,15 @@ class StaticCache(Cache):
The default `dtype` to use when initializing the layer.
"""

def __init__(
self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32
) -> None:
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads
self.dtype = dtype if dtype is not None else torch.float32
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype

cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
Expand Down Expand Up @@ -386,20 +384,23 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("position_ids")
new_cache_positions = cache_kwargs.get("cache_position")
k_out = self.key_cache
v_out = self.value_cache

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states

self.seen_tokens += key_states.shape[-2]
self.seen_tokens += key_states.shape[2]
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
return self.seen_tokens

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return self.max_cache_len
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4776,8 +4776,9 @@ def _split_model_inputs(
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
keys_to_ignore = ["cache_position", "encoder_outputs"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
Comment on lines -4779 to +4781
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beam search will split the cache positions otherwise


# we split the tensors and tuples of tensors
data_split_list = [
Expand Down
126 changes: 71 additions & 55 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -303,6 +303,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
Expand Down Expand Up @@ -333,21 +334,13 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_seen_tokens = 0
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_seen_tokens

new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions}
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand All @@ -356,7 +349,8 @@ def forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[..., past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
Expand Down Expand Up @@ -410,6 +404,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
Expand All @@ -427,20 +422,14 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_seen_tokens = 0
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_seen_tokens
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
Expand Down Expand Up @@ -603,6 +592,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
Expand All @@ -617,6 +607,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

bsz, q_len, _ = hidden_states.size()
Expand All @@ -629,29 +620,22 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_seen_tokens = 0
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_seen_tokens
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions}
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = None
if attention_mask is not None:
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
Expand All @@ -666,7 +650,6 @@ def forward(
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -703,6 +686,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -736,6 +720,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -800,13 +785,20 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()

def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

for layer in self.model.layers:
weights = layer.self_attn.o_proj.weight
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=layer.self_attn.o_proj.weight.device
self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
)

def _reset_cache(self):
Expand Down Expand Up @@ -932,6 +924,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -951,12 +944,23 @@ def forward(
)
use_cache = False

if use_cache and not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)

# embed positions
Expand All @@ -980,6 +984,7 @@ def forward(
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -989,6 +994,7 @@ def forward(
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1021,8 +1027,9 @@ def forward(

def _update_causal_mask(self, attention_mask, input_tensor):
if self.config._attn_implementation == "flash_attention_2":
causal_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
return causal_mask
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
Expand Down Expand Up @@ -1051,14 +1058,11 @@ def _update_causal_mask(self, attention_mask, input_tensor):
)

if self.config._attn_implementation == "sdpa":
if attention_mask is None:
return None
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and (torch.all(attention_mask == 1)):
return None
if is_tracing and seq_length == 1:
return None
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
dtype
)

return causal_mask

Expand Down Expand Up @@ -1107,6 +1111,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1150,6 +1155,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -1189,6 +1195,7 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
Expand Down Expand Up @@ -1228,9 +1235,17 @@ def prepare_inputs_for_generation(

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
seen_tokens = past_key_value.get_seq_length()
input_ids = input_ids[:, seen_tokens:]
position_ids = position_ids[:, seen_tokens:]
past_length = past_key_value.get_seq_length()
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
Expand All @@ -1241,6 +1256,7 @@ def prepare_inputs_for_generation(
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
Expand Down
Loading
Loading