Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi-3] Bug on stale kv cache #33129

Merged
merged 13 commits into from
Sep 13, 2024
23 changes: 21 additions & 2 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __init__(self, dim, config, device=None):

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
seq_len = seq_len or torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
Expand Down Expand Up @@ -1238,6 +1238,15 @@ def forward(
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
```"""
if (
use_cache
and self.config.rope_scaling
and cache_position is not None
and cache_position[0] == self.config.original_max_position_embeddings
):
logger.warning(
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
)

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1294,7 +1303,6 @@ def forward(
attentions=outputs.attentions,
)

# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
Expand All @@ -1307,6 +1315,17 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
garg-amit marked this conversation as resolved.
Show resolved Hide resolved
):
# When the first time input length reached long and short factor switching point, enforce re-compute cache
# It will cause downside of slower at this single token position, however, better than current failure.
if (
past_key_values
and self.config.rope_scaling
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
Expand Down
41 changes: 41 additions & 0 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,47 @@ def test_model_rope_scaling_from_config(self, scaling_type):
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))

@parameterized.expand([("longrope",)])
def test_model_rope_scaling_short_long_factor(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
n_factors = config.hidden_size // config.num_key_value_heads // 2
config.rope_scaling = {
"type": scaling_type,
"short_factor": [3.0 for _ in range(n_factors)],
"long_factor": [5.0 for _ in range(n_factors)],
}
input_tensor = ids_tensor([1, 4090], config.vocab_size)
model = Phi3ForCausalLM(config)
model.to(torch_device)
model.eval()
generation_args_short = {
"max_length": config.original_max_position_embeddings,
"temperature": 0.0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
}
output_with_short_factor = model.generate(input_tensor, **generation_args_short)
keys_with_short_factor = output_with_short_factor.past_key_values[0][0]
generation_args_long = {
"max_length": config.original_max_position_embeddings + 5,
"temperature": 0.0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_logits": True,
}
output_with_long_factor = model.generate(input_tensor, **generation_args_long)
keys_with_long_factor = output_with_long_factor.past_key_values[0][0]
last_token_logits = output_with_long_factor.logits[-1][-1]
regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1]
keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :]

# KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position
self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-2, rtol=1e-2))
# Last token generated using long factor
self.assertTrue(torch.allclose(last_token_logits, regenerated_last_token_logits, atol=1e-2, rtol=1e-2))


@slow
@require_torch
Expand Down