Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix static generation when compiling! #28937

Merged
merged 42 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2187685
wow I was scared!
ArthurZucker Feb 9, 2024
4922c92
fix everything
ArthurZucker Feb 9, 2024
56768a0
nits
ArthurZucker Feb 9, 2024
b565051
make it BC?
ArthurZucker Feb 12, 2024
99afd1a
add todo
ArthurZucker Feb 12, 2024
edc498f
nits
ArthurZucker Feb 12, 2024
651c4bd
is_tracing should still be used to pass tracing tests
ArthurZucker Feb 12, 2024
f69626e
nits
ArthurZucker Feb 12, 2024
96136ac
some nits to make sure genration works with static cache uncompiled
ArthurZucker Feb 12, 2024
d5ebd80
fix sdpa
ArthurZucker Feb 12, 2024
70adcf6
fix FA2 for both static and dynamic in a better way?
ArthurZucker Feb 14, 2024
61ed4cb
style
ArthurZucker Feb 14, 2024
fedc563
fix-copies
ArthurZucker Feb 14, 2024
0195d58
fix fix copies
ArthurZucker Feb 14, 2024
07f3adb
fix sequential beam searcg
ArthurZucker Feb 14, 2024
9402c25
style
ArthurZucker Feb 14, 2024
86303c4
use `keys_to_ignore`
ArthurZucker Feb 14, 2024
fb9e907
nit
ArthurZucker Feb 14, 2024
9aa667e
correct dtype inference when init
ArthurZucker Feb 14, 2024
68a5f29
:( the fix for FA2 is still not optimal to investigate!
ArthurZucker Feb 14, 2024
3b9969b
styling
ArthurZucker Feb 14, 2024
162ab87
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 14, 2024
914b0d7
nits
ArthurZucker Feb 14, 2024
e79f79f
nit
ArthurZucker Feb 14, 2024
ee2317d
this might work better
ArthurZucker Feb 14, 2024
93b2691
add comment
ArthurZucker Feb 14, 2024
3619ed3
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 14, 2024
c23cdc4
"position_ids" -> "cache_position"
ArthurZucker Feb 14, 2024
717a8e7
style
ArthurZucker Feb 14, 2024
7fe0964
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 14, 2024
464c463
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 15, 2024
80148ab
nit
ArthurZucker Feb 15, 2024
c9f3c82
Remove changes that should no be propagatted just yet
ArthurZucker Feb 15, 2024
5f54d84
Apply suggestions from code review
ArthurZucker Feb 15, 2024
b3fc042
Styling
ArthurZucker Feb 15, 2024
5fdb2da
make sure we raise an errir for static cache with FA2 enabled
ArthurZucker Feb 15, 2024
03edf91
move to the bottom of the signature
ArthurZucker Feb 15, 2024
b762304
style
ArthurZucker Feb 15, 2024
9fbe901
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 15, 2024
7afe7d9
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 15, 2024
3772d1c
nit in the name
ArthurZucker Feb 15, 2024
cf0bc32
Merge branches 'fix-static-kv-cache' and 'fix-static-kv-cache' of git…
ArthurZucker Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,15 @@ class StaticCache(Cache):
The default `dtype` to use when initializing the layer.
"""

def __init__(
self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32
) -> None:
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads
self.dtype = dtype if dtype is not None else torch.float32
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype

cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
Expand Down Expand Up @@ -386,20 +384,23 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("position_ids")
new_cache_positions = cache_kwargs.get("cache_position")
k_out = self.key_cache
v_out = self.value_cache

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states

self.seen_tokens += key_states.shape[-2]
self.seen_tokens += key_states.shape[2]
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
return self.seen_tokens

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return self.max_cache_len
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4776,8 +4776,9 @@ def _split_model_inputs(
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
keys_to_ignore = ["cache_position", "encoder_outputs"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
Comment on lines -4779 to +4781
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beam search will split the cache positions otherwise


# we split the tensors and tuples of tensors
data_split_list = [
Expand Down
131 changes: 76 additions & 55 deletions src/transformers/models/llama/modeling_llama.py

Large diffs are not rendered by default.

7 changes: 0 additions & 7 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,6 @@ def forward(
attentions=outputs.attentions,
)

# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
Expand Down Expand Up @@ -864,12 +863,6 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
seen_tokens = past_key_value.get_seq_length()
input_ids = input_ids[:, seen_tokens:]
position_ids = position_ids[:, seen_tokens:]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ def forward(
attentions=outputs.attentions,
)

# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
Expand Down Expand Up @@ -1125,12 +1125,6 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
seen_tokens = past_key_value.get_seq_length()
input_ids = input_ids[:, seen_tokens:]
position_ids = position_ids[:, seen_tokens:]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand Down
7 changes: 0 additions & 7 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,6 @@ def forward(
attentions=outputs.attentions,
)

# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
Expand Down Expand Up @@ -1089,12 +1088,6 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
seen_tokens = past_key_value.get_seq_length()
input_ids = input_ids[:, seen_tokens:]
position_ids = position_ids[:, seen_tokens:]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,23 @@ def _random_kvs(config):
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
self.assertTrue(cached_values.shape == (1, 32, 10, 128))

gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
self.assertTrue(cached_values.shape == (1, 4, 10, 128))

mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
Expand Down
Loading