Skip to content

Commit

Permalink
Fix the incorrect permutation of gguf (huggingface#31788)
Browse files Browse the repository at this point in the history
* Fix the incorrect permutation of gguf

* rename num_kv_heads

Co-authored-by: Marc Sun <[email protected]>

* add typing to num_kv_heads

Co-authored-by: Marc Sun <[email protected]>

* rename variables

* refactor permute function name

* update the expected text of the llama3 q4 test

---------

Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
2 people authored and MHRDYN7 committed Jul 23, 2024
1 parent bae3fae commit 2b8e8ea
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
22 changes: 18 additions & 4 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import numpy as np
from tqdm import tqdm

Expand Down Expand Up @@ -147,10 +149,11 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):

if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
num_heads = parsed_parameters["config"]["num_attention_heads"]
tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0])
weights = weights.reshape(tmp_shape)
weights = weights.transpose(0, 2, 1, 3)
weights = weights.reshape(shape[::-1])
num_kv_heads = parsed_parameters["config"]["num_key_value_heads"]
if ".attn_q." in name:
weights = reverse_permute_weights(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)

for tensor_name in tensor_key_mapping:
if tensor_name in name:
Expand All @@ -163,3 +166,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")

return parsed_parameters


def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
# Original permutation implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
if num_kv_heads is not None and n_head != num_kv_heads:
n_head = num_kv_heads

dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)
3 changes: 1 addition & 2 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ def test_llama3_q4_0(self):
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I am new to this forum. I am"

EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
Expand Down

0 comments on commit 2b8e8ea

Please sign in to comment.