-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The conversion of the llama3 model back from gguf seems weird. #31766
Comments
I tried to compare two model weights and found that only import torch
from safetensors.torch import load_file
m1 = load_file("llama3-8b-single/model.safetensors")
m2 = load_file("llama3-8b-b3286/model.safetensors")
for k in m1:
flag = torch.allclose(m1[k], m2[k])
if not flag:
print(k, flag)
"""
model.layers.0.self_attn.k_proj.weight False
model.layers.0.self_attn.q_proj.weight False
model.layers.1.self_attn.k_proj.weight False
model.layers.1.self_attn.q_proj.weight False
...
"""
k = "model.layers.0.self_attn.k_proj.weight"
for i, _ in enumerate(m1[k]):
if not torch.allclose(m1[k][i], m2[k][i]):
print(i)
print(m1[k][i])
print(m2[k][i])
"""
32
tensor([-0.0615, -0.0119, -0.0183, ..., 0.0344, -0.0513, 0.0645],
dtype=torch.float16)
tensor([-0.0447, -0.0293, 0.0396, ..., 0.0067, 0.0242, -0.0035],
dtype=torch.float16)
...
990
tensor([-0.0215, -0.0674, 0.0096, ..., 0.0200, -0.0063, -0.0147],
dtype=torch.float16)
tensor([-0.0272, -0.0378, 0.0124, ..., 0.0242, -0.0052, -0.0052],
dtype=torch.float16)
991
tensor([-0.0287, -0.0288, 0.0073, ..., 0.0089, -0.0006, -0.0042],
dtype=torch.float16)
tensor([-0.0052, -0.0284, 0.0289, ..., 0.0135, 0.0055, -0.0042],
dtype=torch.float16)
""" According to |
After comparing the metadata between the two versions, I think the shape issue is caused by the missing python llama-cpp-b3286/gguf-py/examples/reader.py llama3-8b/llama3-8b.b1742.f16.gguf > llama3-8b-b1742.txt
python llama-cpp-b3286/gguf-py/examples/reader.py llama3-8b/llama3-8b.b3286.f16.gguf > llama3-8b-b3286.txt $ diff llama3-8b-b1742.txt llama3-8b-b3286.txt
4c4
< GGUF.kv_count : [19]
---
> GGUF.kv_count : [21]
6c6
< general.name : [46]
---
> general.name : [108 108 97 109 97 51 45 56 98]
16a17
> llama.vocab_size : [128256]
17a19
> tokenizer.ggml.pre : [108 108 97 109 97 45 98 112 101]
19d20
< tokenizer.ggml.scores : [0.]
23a25
> general.quantization_version : [2] When trying to load the old version of llama3 gguf, I think it misses the vocab size and defaults to 32000 somewhere. |
The major cause is the def reverse_hf_permute(weights: np.ndarray, n_head: int, n_head_kv: int) -> np.ndarray:
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)
# ...
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
num_heads = parsed_parameters["config"]["num_attention_heads"]
n_head_kv = parsed_parameters["config"]["num_key_value_heads"]
if ".attn_q." in name:
weights = reverse_hf_permute(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = reverse_hf_permute(weights, num_heads, n_head_kv) |
After this fix, there is still a minor PPL gap between the original HF model and the GGUF model, which might be caused by the tokenizer. Using the HF tokenizer should result in the same PPL. |
Thanks for this report ! Really appreciate it ! It makes the review so much easier ! |
System Info
Who can help?
@SunMarc
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Prepare environment:
conda create -yn HF-GGUF python=3.11 conda activate HF-GGUF conda install -y nvidia/label/cuda-12.1.1::cuda conda install -y pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia pip install accelerate cmake datasets fire gguf hf-transfer "numpy<2.0" sentencepiece transformers conda activate HF-GGUF
The two versions of the llama.cpp repository:
Prepare models:
Perplexity evaluation script:
Evaluate perplexity:
I've modified some integration code to make the f16 conversion work:
Expected behavior
The PPL of llama3 converted back from gguf is as high as 1801.3766, which is significantly different from the original model. However, the same steps do not cause this issue on TinyLlama.
Therefore, I referred to the upload time of TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF and chose the version of llama.cpp released on January 1st for the gguf conversion. The following error occurred:
No matter which version of llama.cpp is used, the PPL is consistent on TinyLlama. However, I noticed that when using a newer version of llama.cpp, there is an additional step in HF: "Merges were not in the checkpoint, building merges on the fly."
The TinyLlama part can be reproduced using the following steps:
The text was updated successfully, but these errors were encountered: