Skip to content

Commit

Permalink
convert-llama-h5-to-gguf.py : clarify the reverse permute
Browse files Browse the repository at this point in the history
  • Loading branch information
klosax authored Aug 16, 2023
1 parent 4a1741a commit ea5615a
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions convert-llama-h5-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
# compatible with python < 3.9
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'

def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
# reverse HF permute back to original pth layout
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
Expand Down Expand Up @@ -219,9 +221,9 @@ def count_model_parts(dir_model: str) -> int:

data = data.squeeze().numpy()

# permute these
# reverse permute these
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
data = permute(data, head_count, head_count_kv)
data = reverse_hf_permute(data, head_count, head_count_kv)

# map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map:
Expand Down Expand Up @@ -288,9 +290,9 @@ def count_model_parts(dir_model: str) -> int:

data = data.squeeze().numpy()

# permute these
# reverse permute these
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
data = permute(data, head_count, head_count_kv)
data = reverse_hf_permute(data, head_count, head_count_kv)

# map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map:
Expand Down

0 comments on commit ea5615a

Please sign in to comment.