diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 1511fbac0976ac..3cf34eab584e48 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import numpy as np from tqdm import tqdm @@ -147,10 +149,11 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): if architecture == "llama" and (".attn_k." in name or ".attn_q." in name): num_heads = parsed_parameters["config"]["num_attention_heads"] - tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0]) - weights = weights.reshape(tmp_shape) - weights = weights.transpose(0, 2, 1, 3) - weights = weights.reshape(shape[::-1]) + num_kv_heads = parsed_parameters["config"]["num_key_value_heads"] + if ".attn_q." in name: + weights = reverse_permute_weights(weights, num_heads, num_heads) + elif ".attn_k." in name: + weights = reverse_permute_weights(weights, num_heads, num_kv_heads) for tensor_name in tensor_key_mapping: if tensor_name in name: @@ -163,3 +166,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") return parsed_parameters + + +def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray: + # Original permutation implementation + # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408 + if num_kv_heads is not None and n_head != num_kv_heads: + n_head = num_kv_heads + + dim = weights.shape[0] // n_head // 2 + w = weights.reshape(n_head, dim, 2, *weights.shape[1:]) + return w.swapaxes(2, 1).reshape(weights.shape) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index e5e8dbaf36cffb..db96e9052c5f36 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -188,8 +188,7 @@ def test_llama3_q4_0(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, I am new to this forum. I am" - + EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_tokenization_xnli(self):