Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The conversion of the llama3 model back from gguf seems weird. #31766

Closed
2 of 4 tasks
PenutChen opened this issue Jul 3, 2024 · 5 comments · Fixed by #31788
Closed
2 of 4 tasks

The conversion of the llama3 model back from gguf seems weird. #31766

PenutChen opened this issue Jul 3, 2024 · 5 comments · Fixed by #31788

Comments

@PenutChen
Copy link
Contributor

System Info

torch==2.3.1
accelerate==0.31.0
gguf==0.6.0
numpy==1.26.4
transformers==4.42.3

Who can help?

@SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Prepare environment:

conda create -yn HF-GGUF python=3.11
conda activate HF-GGUF
conda install -y nvidia/label/cuda-12.1.1::cuda
conda install -y pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install accelerate cmake datasets fire gguf hf-transfer "numpy<2.0" sentencepiece transformers
conda activate HF-GGUF

The two versions of the llama.cpp repository:

# tag: b1742, sha: edd1ab7, date: january
git clone --branch b1742 --single-branch https://github.com/ggerganov/llama.cpp llama-cpp-b1742
cmake -B build . -DLLAMA_CUBLAS=ON --fresh
cmake --build build --config Release -j

# tag: b3286, sha: fadde67, date: july
git clone --branch b3286 --single-branch https://github.com/ggerganov/llama.cpp llama-cpp-b3286
cmake -B build . -DGGML_CUDA=ON --fresh
cmake --build build --config Release -j

Prepare models:

# prepare hf models
HF_HUB_ENABLE_HF_TRANSFER=True huggingface-cli download meta-llama/Meta-Llama-3-8B --local-dir llama3-8b

# convert to gguf
python llama-cpp-b1742/convert.py --outtype f16 --outfile llama3-8b/llama3-8b.b1742.f16.gguf llama3-8b
python llama-cpp-b3286/convert-hf-to-gguf.py --outtype f16 --outfile llama3-8b/llama3-8b.b3286.f16.gguf llama3-8b

Perplexity evaluation script:

import torch
from datasets import load_dataset
from fire import Fire
from tqdm import trange
from transformers import AutoModelForCausalLM as ModelImp
from transformers import AutoTokenizer as TkCls
from transformers import PreTrainedModel as ModelCls


@torch.inference_mode()
def evaluate_perplexity(model_id, gguf_file=None, seqlen=2048):
    model: ModelCls = ModelImp.from_pretrained(
        model_id,
        device_map="auto",
        gguf_file=gguf_file,
        torch_dtype=torch.float16,
    )

    tk = TkCls.from_pretrained(model_id, gguf_file=gguf_file)
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    input_ids = list()
    for item in dataset:
        text = item["text"] + "\n"
        tokens = tk.encode(text, add_special_tokens=False)
        input_ids.extend(tokens)

    data_size = len(input_ids) // seqlen
    input_ids = input_ids[: data_size * seqlen]

    input_ids = torch.LongTensor(input_ids).view(data_size, seqlen)
    bos_token = torch.full((data_size, 1), tk.bos_token_id, dtype=torch.int64)
    input_ids = torch.concat((bos_token, input_ids), dim=1)
    input_ids = input_ids.to(model.device)

    batch_size = 1
    nlls = list()
    with trange(0, data_size, batch_size) as prog:
        for i in prog:
            batch = input_ids[i : i + batch_size]
            outputs = model(batch, labels=batch)
            nlls.append(outputs.loss)
            ppl = torch.exp(torch.stack(nlls).mean())
            prog.desc = f"ppl: {ppl:.4f}"


if __name__ == "__main__":
    Fire(evaluate_perplexity)

Evaluate perplexity:

python eval-ppl.py llama3-8b  # ppl: 6.2941
python eval-ppl.py llama3-8b llama3-8b.b3286.f16.gguf  # ppl: 1801.3766
python eval-ppl.py llama3-8b llama3-8b.b1742.f16.gguf  # ValueError: Trying to set a tensor of shape torch.Size([128256, 4096]) in "weight" (which has shape torch.Size([32000, 4096])), this look incorrect.

I've modified some integration code to make the f16 conversion work:

# transformers/integrations/ggml.py
GGML_TYPES = {
    "F32": 0,
    "F16": 1,
    # ...
}

def load_dequant_gguf_tensor(shape, ggml_type, data):
    if ggml_type == GGML_TYPES["F32"]:
        values = data
    elif ggml_type == GGML_TYPES["F16"]:
        values = data
    # ...

Expected behavior

The PPL of llama3 converted back from gguf is as high as 1801.3766, which is significantly different from the original model. However, the same steps do not cause this issue on TinyLlama.

Therefore, I referred to the upload time of TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF and chose the version of llama.cpp released on January 1st for the gguf conversion. The following error occurred:

Converting and de-quantizing GGUF tensors...: 100%|███████████████| 291/291 [00:48<00:00,  6.01it/s]
Traceback (most recent call last):
  File "/data2/Penut/Experiments/240703-HF-GGUF/eval-ppl.py", line 48, in <module>
    Fire(evaluate_perplexity)
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data2/Penut/Experiments/240703-HF-GGUF/eval-ppl.py", line 12, in evaluate_perplexity
    model: ModelCls = ModelImp.from_pretrained(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/transformers/modeling_utils.py", line 3838, in from_pretrained
    ) = cls._load_pretrained_model(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4227, in _load_pretrained_model
    error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/transformers/modeling_utils.py", line 895, in _load_state_dict_into_meta_model
    set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  File "/data2/Penut/.miniconda/envs/HF-GGUF/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 358, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([128256, 4096]) in "weight" (which has shape torch.Size([32000, 4096])), this look incorrect.

No matter which version of llama.cpp is used, the PPL is consistent on TinyLlama. However, I noticed that when using a newer version of llama.cpp, there is an additional step in HF: "Merges were not in the checkpoint, building merges on the fly."

The TinyLlama part can be reproduced using the following steps:

# tinyllama
HF_HUB_ENABLE_HF_TRANSFER=True huggingface-cli download TinyLlama/TinyLlama-1.1B-Chat-v1.0 --local-dir tinyllama-1b
python llama-cpp-b1742/convert.py --outtype f16 --outfile tinyllama-1b/tinyllama-1b.b1742.fp16.gguf tinyllama-1b
python llama-cpp-b3286/convert-hf-to-gguf.py --outtype f16 --outfile tinyllama-1b/tinyllama-1b.b3286.fp16.gguf tinyllama-1b

# ppl: 8.0233
python eval-ppl.py tinyllama-1b
python eval-ppl.py tinyllama-1b tinyllama-1b.b3286.fp16.gguf
python eval-ppl.py tinyllama-1b tinyllama-1b.b1742.fp16.gguf
@PenutChen
Copy link
Contributor Author

PenutChen commented Jul 4, 2024

I tried to compare two model weights and found that only q_proj and k_proj are different:

import torch
from safetensors.torch import load_file

m1 = load_file("llama3-8b-single/model.safetensors")
m2 = load_file("llama3-8b-b3286/model.safetensors")

for k in m1:
    flag = torch.allclose(m1[k], m2[k])
    if not flag:
        print(k, flag)

"""
model.layers.0.self_attn.k_proj.weight False
model.layers.0.self_attn.q_proj.weight False
model.layers.1.self_attn.k_proj.weight False
model.layers.1.self_attn.q_proj.weight False
...
"""

k = "model.layers.0.self_attn.k_proj.weight"
for i, _ in enumerate(m1[k]):
    if not torch.allclose(m1[k][i], m2[k][i]):
        print(i)
        print(m1[k][i])
        print(m2[k][i])

"""
32
tensor([-0.0615, -0.0119, -0.0183,  ...,  0.0344, -0.0513,  0.0645],
       dtype=torch.float16)
tensor([-0.0447, -0.0293,  0.0396,  ...,  0.0067,  0.0242, -0.0035],
       dtype=torch.float16)

...

990
tensor([-0.0215, -0.0674,  0.0096,  ...,  0.0200, -0.0063, -0.0147],
       dtype=torch.float16)
tensor([-0.0272, -0.0378,  0.0124,  ...,  0.0242, -0.0052, -0.0052],
       dtype=torch.float16)
991
tensor([-0.0287, -0.0288,  0.0073,  ...,  0.0089, -0.0006, -0.0042],
       dtype=torch.float16)
tensor([-0.0052, -0.0284,  0.0289,  ...,  0.0135,  0.0055, -0.0042],
       dtype=torch.float16)
"""

According to convert-hf-to-gguf.py, there is a permute operation on these modules. This might be the root cause of the issue.

@PenutChen
Copy link
Contributor Author

PenutChen commented Jul 4, 2024

After comparing the metadata between the two versions, I think the shape issue is caused by the missing vocab_size information:

python llama-cpp-b3286/gguf-py/examples/reader.py llama3-8b/llama3-8b.b1742.f16.gguf > llama3-8b-b1742.txt
python llama-cpp-b3286/gguf-py/examples/reader.py llama3-8b/llama3-8b.b3286.f16.gguf > llama3-8b-b3286.txt 
$ diff llama3-8b-b1742.txt llama3-8b-b3286.txt 
4c4
< GGUF.kv_count                          : [19]
---
> GGUF.kv_count                          : [21]
6c6
< general.name                           : [46]
---
> general.name                           : [108 108  97 109  97  51  45  56  98]
16a17
> llama.vocab_size                       : [128256]
17a19
> tokenizer.ggml.pre                     : [108 108  97 109  97  45  98 112 101]
19d20
< tokenizer.ggml.scores                  : [0.]
23a25
> general.quantization_version           : [2]

When trying to load the old version of llama3 gguf, I think it misses the vocab size and defaults to 32000 somewhere.

@PenutChen
Copy link
Contributor Author

PenutChen commented Jul 4, 2024

The major cause is the q_proj and k_proj permutation. According to convert-hf-to-gguf.py, the permutation considers not only num_attention_heads but also num_key_value_heads, especially for a GQA model like Llama3. I tried to fix the reversed permutation like this: main...PenutChen:transformers:gguf-permute-fix

def reverse_hf_permute(weights: np.ndarray, n_head: int, n_head_kv: int) -> np.ndarray:
    if n_head_kv is not None and n_head != n_head_kv:
        n_head = n_head_kv

    dim = weights.shape[0] // n_head // 2
    w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
    return w.swapaxes(2, 1).reshape(weights.shape)

# ...

if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
    num_heads = parsed_parameters["config"]["num_attention_heads"]
    n_head_kv = parsed_parameters["config"]["num_key_value_heads"]
    if ".attn_q." in name:
        weights = reverse_hf_permute(weights, num_heads, num_heads)
    elif ".attn_k." in name:
        weights = reverse_hf_permute(weights, num_heads, n_head_kv)

@PenutChen
Copy link
Contributor Author

After this fix, there is still a minor PPL gap between the original HF model and the GGUF model, which might be caused by the tokenizer. Using the HF tokenizer should result in the same PPL.

@SunMarc
Copy link
Member

SunMarc commented Jul 4, 2024

Thanks for this report ! Really appreciate it ! It makes the review so much easier !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants