Skip to content

Commit

Permalink
Fix static generation when compiling! (huggingface#28937)
Browse files Browse the repository at this point in the history
* wow I was scared!

* fix everything

* nits

* make it BC?

* add todo

* nits

* is_tracing should still be used to pass tracing tests

* nits

* some nits to make sure genration works with static cache uncompiled

* fix sdpa

* fix FA2 for both static and dynamic in a better way?

* style

* fix-copies

* fix fix copies

* fix sequential beam searcg

* style

* use `keys_to_ignore`

* nit

* correct dtype inference when init

* :( the fix for FA2 is still not optimal to investigate!

* styling

* nits

* nit

* this might work better

* add comment

* Update src/transformers/models/llama/modeling_llama.py

* "position_ids" -> "cache_position"

* style

* nit

* Remove changes that should no be propagatted just yet

* Apply suggestions from code review

* Styling

* make sure we raise an errir for static cache with FA2 enabled

* move  to the bottom of the signature

* style

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/models/llama/modeling_llama.py

* nit in the name

---------

Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
2 people authored and hackyon committed Feb 15, 2024
1 parent d72cb5a commit 7410112
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 87 deletions.
13 changes: 7 additions & 6 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,15 @@ class StaticCache(Cache):
The default `dtype` to use when initializing the layer.
"""

def __init__(
self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32
) -> None:
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype

cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
Expand Down Expand Up @@ -386,20 +384,23 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("position_ids")
new_cache_positions = cache_kwargs.get("cache_position")
k_out = self.key_cache
v_out = self.value_cache

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states

self.seen_tokens += key_states.shape[-2]
self.seen_tokens += key_states.shape[2]
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return self.max_cache_len
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4776,8 +4776,9 @@ def _split_model_inputs(
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
keys_to_ignore = ["cache_position", "encoder_outputs"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]

# we split the tensors and tuples of tensors
data_split_list = [
Expand Down
126 changes: 71 additions & 55 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -303,6 +303,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
Expand Down Expand Up @@ -333,21 +334,13 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_seen_tokens = 0
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_seen_tokens

new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions}
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand All @@ -356,7 +349,8 @@ def forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[..., past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
Expand Down Expand Up @@ -410,6 +404,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
Expand All @@ -427,20 +422,14 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_seen_tokens = 0
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_seen_tokens
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
Expand Down Expand Up @@ -603,6 +592,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
Expand All @@ -617,6 +607,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

bsz, q_len, _ = hidden_states.size()
Expand All @@ -629,29 +620,22 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
past_seen_tokens = 0
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
kv_seq_len += past_seen_tokens
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions}
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = None
if attention_mask is not None:
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
Expand All @@ -666,7 +650,6 @@ def forward(
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -703,6 +686,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -736,6 +720,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -800,13 +785,20 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()

def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

for layer in self.model.layers:
weights = layer.self_attn.o_proj.weight
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=layer.self_attn.o_proj.weight.device
self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
)

def _reset_cache(self):
Expand Down Expand Up @@ -932,6 +924,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -951,12 +944,23 @@ def forward(
)
use_cache = False

if use_cache and not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)

# embed positions
Expand All @@ -980,6 +984,7 @@ def forward(
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -989,6 +994,7 @@ def forward(
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1021,8 +1027,9 @@ def forward(

def _update_causal_mask(self, attention_mask, input_tensor):
if self.config._attn_implementation == "flash_attention_2":
causal_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
return causal_mask
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
Expand Down Expand Up @@ -1051,14 +1058,11 @@ def _update_causal_mask(self, attention_mask, input_tensor):
)

if self.config._attn_implementation == "sdpa":
if attention_mask is None:
return None
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and (torch.all(attention_mask == 1)):
return None
if is_tracing and seq_length == 1:
return None
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
dtype
)

return causal_mask

Expand Down Expand Up @@ -1107,6 +1111,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1150,6 +1155,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -1189,6 +1195,7 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
Expand Down Expand Up @@ -1228,9 +1235,17 @@ def prepare_inputs_for_generation(

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
seen_tokens = past_key_value.get_seq_length()
input_ids = input_ids[:, seen_tokens:]
position_ids = position_ids[:, seen_tokens:]
past_length = past_key_value.get_seq_length()
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
Expand All @@ -1241,6 +1256,7 @@ def prepare_inputs_for_generation(
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
Expand Down
Loading

0 comments on commit 7410112

Please sign in to comment.