-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible Bug with KV Caching in Llama (original) model #25420
Comments
cc @ArthurZucker and @gante |
Hey! It seems like the problème is from your custom code rather than the I don't know exactly what is wrong with your custom greedy decoding, but would probably say that you are not feeding the positional ID information that is automatically create in |
Hi @maximkha 👋 Thank you for raising this issue! Sadly, our bandwidth is limited, so our capacity to dive into custom code for which a solution already exists is limited :) As @ArthurZucker wrote, you are missing the position IDs, which may have a significant impact on the output. The same is true for the attention mask. Our modeling code makes its best effort to infer these two inputs when they are missing, but it fails in some cases. My suggestion would be to introduce a |
Thanks so so much! Turns out the |
Actually, I'm currently experiencing another issue when using this for Llama for sequential classification. It seems that even when I use prepare_inputs_for_generation, I'm getting values that disagree. I'm not exactly sure what the culprit is, but I have been using the appropriate _reorder_cache function. |
Are you using padding? If so which padding side are you using? We had a few bug fixes related to padding recently see #24979, should work on main with padding left |
Hey @ArthurZucker, thanks for the response. I actually am not doing any padding. Here's a minimally reproducible example: from transformers import LlamaForSequenceClassification
import torch
# simple attention mask code
def create_attention_mask(seq_len, bsz=1):
return torch.ones((bsz, seq_len))
# from https://github.com/huggingface/transformers/blob/5e5fa0d88c293e6d5be2517b4f45680ba3bb5df2/src/transformers/models/llama/modeling_llama.py#L856
def prepare_inputs_for_generation(input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
# this is huggyllama/llama-7b
MODEL = "/nobackup-fast/khanov/llama-7b"
classification_model = LlamaForSequenceClassification.from_pretrained(MODEL, num_labels=1, torch_dtype=torch.bfloat16).cuda()
# for simplicity (and to clearly illustrate the effect), set all the weights to 1
with torch.no_grad():
classification_model.score.weight.set_(torch.ones_like(classification_model.score.weight))
# some random tokens
test_tokens = torch.tensor([1,263,29901,2599])
test_tokens = test_tokens.unsqueeze(0).cuda()
# some additional test token that we would like to run our classification model on
new_test_tokens = torch.hstack((test_tokens, torch.tensor([5]).unsqueeze(0).cuda()))
# generate the cache
cls_out = classification_model(**prepare_inputs_for_generation(test_tokens, past_key_values=None, attention_mask=create_attention_mask(test_tokens.shape[-1], test_tokens.shape[0]), use_cache=True))
# run the classification model without any special caching stuff
print("Correct output (with prepare_inputs)")
cls_out_new = classification_model(**prepare_inputs_for_generation(new_test_tokens, past_key_values=None, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0])))
print(f"{cls_out_new.logits=}")
# cls_out_new.logits = 89
# run it without the prepare input (just in case that's the issue)
print("Correct output (no prepare_inputs)")
cls_out_new = classification_model(new_test_tokens)
print(f"{cls_out_new.logits=}")
# cls_out_new.logits = 89
# with caching, and prepare input
print("With past_key_values (with prepare_inputs)")
cls_out_test = classification_model(**prepare_inputs_for_generation(new_test_tokens, past_key_values=cls_out.past_key_values, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0]), use_cache=True))
print(f"{cls_out_test.logits=}")
# cls_out_test.logits = 88.5
# with caching, without prepare input
print("With past_key_values (no prepare_inputs)")
cls_out_test = classification_model(new_test_tokens[:, -1:], past_key_values=cls_out.past_key_values, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0]), position_ids=torch.tensor([[new_test_tokens.shape[-1] -1]]), use_cache=True)
print(f"{cls_out_test.logits=}")
# cls_out_test.logits = 88.5 The Please let me know if anything seems wrong about this! I really appreciate the help! |
Hmmmm this is also happening if I replace the LlamaForSequenceClassification with LlamaForCausalLM. There are slight discrepancies in the logits: Examplefrom transformers import LlamaForSequenceClassification, LlamaForCausalLM
import torch
# this is huggyllama/llama-7b
MODEL = "/nobackup-fast/khanov/llama-7b"
llm = LlamaForCausalLM.from_pretrained(MODEL, num_labels=1, torch_dtype=torch.bfloat16).cuda()
# simple attention mask code
def create_attention_mask(seq_len, bsz=1):
return torch.ones((bsz, seq_len))
# from https://github.com/huggingface/transformers/blob/5e5fa0d88c293e6d5be2517b4f45680ba3bb5df2/src/transformers/models/llama/modeling_llama.py#L856
def prepare_inputs_for_generation(input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
# for simplicity (and to clearly illustrate the effect), set all the weights to 1
# with torch.no_grad():
# classification_model.score.weight.set_(torch.ones_like(classification_model.score.weight))
# some random tokens
test_tokens = torch.tensor([1,263,29901,2599])
test_tokens = test_tokens.unsqueeze(0).cuda()
# some additional test token that we would like to run our classification model on
new_test_tokens = torch.hstack((test_tokens, torch.tensor([5]).unsqueeze(0).cuda()))
# generate the cache
llm_out = llm(**prepare_inputs_for_generation(test_tokens, past_key_values=None, attention_mask=create_attention_mask(test_tokens.shape[-1], test_tokens.shape[0]), use_cache=True))
# run the classification model without any special caching stuff
print("Correct output (with prepare_inputs)")
llm_out_new = llm(**prepare_inputs_for_generation(new_test_tokens, past_key_values=None, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0])))
print(f"{llm_out_new.logits[0, -1, :]=}")
"""Correct output (with prepare_inputs)
llm_out_new.logits[0, -1, :]=tensor([-12.0625, -15.3125, 2.5781, ..., -6.4688, -8.1250, -6.8125],
device='cuda:0', grad_fn=<SliceBackward0>)"""
# run it without the prepare input (just in case that's the issue)
print("Correct output (no prepare_inputs)")
llm_out_new = llm(new_test_tokens)
print(f"{llm_out_new.logits[0, -1, :]=}")
"""Correct output (no prepare_inputs)
llm_out_new.logits[0, -1, :]=tensor([-12.0625, -15.3125, 2.5781, ..., -6.4688, -8.1250, -6.8125],
device='cuda:0', grad_fn=<SliceBackward0>)"""
# with caching, and prepare input
print("With past_key_values (with prepare_inputs)")
llm_out_test = llm(**prepare_inputs_for_generation(new_test_tokens, past_key_values=llm_out.past_key_values, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0]), use_cache=True))
print(f"{llm_out_test.logits[0, -1, :]=}")
"""With past_key_values (with prepare_inputs)
llm_out_test.logits[0, -1, :]=tensor([-12.0625, -15.3750, 2.5938, ..., -6.5000, -8.1250, -6.8125],
device='cuda:0', grad_fn=<SliceBackward0>)"""
# with caching, without prepare input
print("With past_key_values (no prepare_inputs)")
llm_out_test = llm(new_test_tokens[:, -1:], past_key_values=llm_out.past_key_values, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0]), position_ids=torch.tensor([[new_test_tokens.shape[-1] -1]]), use_cache=True)
print(f"{llm_out_test.logits[0, -1, :]=}")
"""With past_key_values (no prepare_inputs)
llm_out_test.logits[0, -1, :]=tensor([-12.0625, -15.3750, 2.5938, ..., -6.5000, -8.1250, -6.8125],
device='cuda:0', grad_fn=<SliceBackward0>)""" |
Ok I think I found the culprit! It seems that when using past_key_values, and bfloat16 the errors are huge. float32 (default): With bfloat16: With float16: Since the unit tests only check for f32, they aren't catching this. Here's the script to measure this: Scriptfrom transformers import LlamaForSequenceClassification, LlamaForCausalLM
import torch
# this is huggyllama/llama-7b
MODEL = "/nobackup-fast/khanov/llama-7b"
WITH_BFLOAT16 = False
if WITH_BFLOAT16:
llm = LlamaForCausalLM.from_pretrained(MODEL, num_labels=1, torch_dtype=torch.bfloat16).cuda()
else:
llm = LlamaForCausalLM.from_pretrained(MODEL, num_labels=1).cuda()
# simple attention mask code
def create_attention_mask(seq_len, bsz=1):
return torch.ones((bsz, seq_len))
# from https://github.com/huggingface/transformers/blob/5e5fa0d88c293e6d5be2517b4f45680ba3bb5df2/src/transformers/models/llama/modeling_llama.py#L856
def prepare_inputs_for_generation(input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
# some random tokens
test_tokens = torch.tensor([1,263,29901,2599])
test_tokens = test_tokens.unsqueeze(0).cuda()
# some additional test token that we would like to run our classification model on
new_test_tokens = torch.hstack((test_tokens, torch.tensor([5]).unsqueeze(0).cuda()))
# generate the cache
llm_out = llm(**prepare_inputs_for_generation(test_tokens, past_key_values=None, attention_mask=create_attention_mask(test_tokens.shape[-1], test_tokens.shape[0]), use_cache=True))
# run the classification model without any special caching stuff
print("Correct output (with prepare_inputs)")
llm_out_new = llm(**prepare_inputs_for_generation(new_test_tokens, past_key_values=None, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0])))
print(f"{llm_out_new.logits[0, -1, :]=}")
# run it without the prepare input (just in case that's the issue)
print("Correct output (no prepare_inputs)")
llm_out_new = llm(new_test_tokens)
print(f"{llm_out_new.logits[0, -1, :]=}")
# with caching, and prepare input
print("With past_key_values (with prepare_inputs)")
llm_out_test = llm(**prepare_inputs_for_generation(new_test_tokens, past_key_values=llm_out.past_key_values, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0]), use_cache=True))
print(f"{llm_out_test.logits[0, -1, :]=}")
print(f"{torch.max(torch.abs(llm_out_new.logits[0, -1, :]-llm_out_test.logits[0, -1, :]))=}")
# HERE: this is 1.0490e-05 when using f32, and 0.1250 when using bfloat16
# with caching, without prepare input
print("With past_key_values (no prepare_inputs)")
llm_out_test = llm(new_test_tokens[:, -1:], past_key_values=llm_out.past_key_values, attention_mask=create_attention_mask(new_test_tokens.shape[-1], new_test_tokens.shape[0]), position_ids=torch.tensor([[new_test_tokens.shape[-1] -1]]), use_cache=True)
print(f"{llm_out_test.logits[0, -1, :]=}")
print(f"{torch.max(torch.abs(llm_out_new.logits[0, -1, :]-llm_out_test.logits[0, -1, :]))=}")
# HERE: this is 1.0490e-05 when using f32, and 0.1250 when using bfloat16 Any ideas of how to fix this discrepancy? |
@ArthurZucker, any updates on this? |
I appreciate the update! |
Likewise, I won't have bandwidth to help unless it is a bug from a short reproducible script, based on a non-custom |
Hey @gante, this isn't an issue with generate specifically, it seems to be that when using the key_value_caching and bfloat16, the logits are significantly different from the non-cached version (some precision loss I'm assuming). There is no generation involved, just using key_values with bfloat16 skews the logits. I'm not sure if this level of precision loss is to be expected or not. TL;DR this is a problem with precision + caching, not generate. Also, sorry for all the messages, but this level of precision loss is impacting my results. |
Hey folks 👋 I’ve done a deep dive on this issue, and I will link related issues to this comment that attempts to summarize findings. cc:
TL;DRUsing KV caches, assisted generation, left-padding, and batching will change the Why does this happen?A key operation in neural networks is matrix multiplication, where values are multiplied and accumulated. Unless you have infinite precision, different implementations or different shapes (e.g. crop a few rows of the first matrix) may produce different outputs, as the intermediary calculations must remain in the specified precision and are subject to rounding. For instance, our models with TF and JAX implementations never have the exact output as the PyTorch implementation, they tend to differ by a maximum When using KV caches (and, in some models, left-padding), we are changing the input shape to some matrix multiplication operations. For instance, in Llama, when you apply the linear projection to obtain the QKV for the attention layer, the input shape will be different depending on whether you're using left-padding and/or KV caches. Therefore, the output of these operations may be different, and these tiny differences build up across layers and across generated tokens, especially at lower resolutions. If you place a breakpoint inside the model, and see what happens with and without KV caches, you'll see:
How big is this difference?Let's do a simple experiment: for the same set of inputs, let's measure the hidden states' and the logits' maximum difference for the first generated token, with and without KV caching. I created the following test script from an example given in a related issue (#26344). TL;DR it averages the maximum value for the variables described above over 1000 runs: Test scriptfrom transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import load_dataset
from tqdm import tqdm
TOTAL_NUM_SAMPLES = 1000
INPUT_LEN = 64
model_name = "codellama/CodeLlama-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto"
)
# model = AutoModelForCausalLM.from_pretrained(model_name)
ds = load_dataset("bigcode/the-stack", data_dir="data/python", split="train", streaming=True)
ds_iterator = iter(ds.take(TOTAL_NUM_SAMPLES))
max_diffs = {}
for _ in tqdm(range(TOTAL_NUM_SAMPLES)):
next_data = next(ds_iterator)["content"]
all_input_ids = tokenizer(
[next_data], return_tensors="pt", max_length=INPUT_LEN, truncation=True
).input_ids.to(model.device)
# process the whole sequence
all_outputs = model(all_input_ids, output_hidden_states=True, return_dict=True)
# get logits for the last token
last_token_logits = all_outputs.logits[0][-1:]
# process the sequence except the last token
kv = model(all_input_ids[:, :-1]).past_key_values
# input only the last token with previous kv_cache
new_output = model(all_input_ids[:, -1:], past_key_values=kv, output_hidden_states=True, return_dict=True)
# extract the last token logits
new_last_token_logits = new_output.logits[0][-1:]
for layer_idx in range(len(all_outputs.hidden_states)):
max_diff = torch.abs(
all_outputs.hidden_states[layer_idx][:, -1, :] - new_output.hidden_states[layer_idx]
).max()
max_diffs.setdefault(f"layer {layer_idx}", []).append(max_diff.cpu().item())
# theese two distributions should be equal, but they are not.
max_diffs.setdefault("logits", []).append(torch.abs(last_token_logits - new_last_token_logits).max().cpu().item())
for key, value in max_diffs.items():
print(f"{key}: {sum(value) / len(value)}") Here are the results I got for Llama, FP32
Llama, FP16 (the expected 16-bit format to use)
Llama, BF16 (the wrong 16-bit format to use with Llama)
GPT2, FP16
As we can see:
What can we do about it?First of all: the benefits of using variables with lower precision and KV caching is obvious. Are the downsides worth it? My advice is to measure the model on metrics relevant to your task (e.g. perplexity), and compare the cost-benefits on your use case. I suspect using KV caching will remain cost-effective :) Secondly: there may be ways to reduce this mismatch, but so far I haven't found any. A common trick is to upcast some sensible operations to FP32 (like the on the attention layers' softmax). For completeness, on Llama, I tried:
Most had no impact, some reduced the mismatch at a high throughput cost. Finally, regarding left-padding: We might be able to mitigate problems here when we migrate batched generation to nested tensors, which don't need padding. I hope this comprehensive analysis helps you understand what's going on 🤗 And, who knows, be the spark that ignites a solution to this issue 🪄 |
Thanks for the detailed explanation @gante ! makes a lot of sense! |
@gante |
System Info
transformers==4.31.0
Who can help?
@ArthurZucker, @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I was working on a custom decoding method, however, I found a deviation from greedy search when using KV caching.
Expected behavior
I was expecting the results to not change when using the past_key_values kwarg, however, when passing past_key_values, the model assigned different logits to the tokens. This deviates from the model.generate behavior too. This is possibly related to #18809, and #21080.
The text was updated successfully, but these errors were encountered: