From a129213d87684131a614feb0e16c7e4161c3d0f7 Mon Sep 17 00:00:00 2001
From: Zhenzhong1 <109137058+Zhenzhong1@users.noreply.github.com>
Date: Wed, 6 Mar 2024 20:57:30 +0800
Subject: [PATCH] [GPTQ Enhence] Support GPTQ & AWQ inference for QWENv1, v1.5
and Mixtral. (#134)
---
docs/gptq_and_awq.md | 9 +-
docs/supported_models.md | 94 +++---
neural_speed/__init__.py | 2 +-
neural_speed/convert/__init__.py | 2 +-
neural_speed/convert/common.py | 110 +++++--
.../convert/convert_quantized_mistral.py | 3 -
.../convert/convert_quantized_mixtral.py | 281 ++++++++++++++++
.../convert/convert_quantized_qwen.py | 306 ++++++++++++++++++
neural_speed/convert/convert_qwen.py | 26 +-
neural_speed/models/llama/llama_utils.cpp | 4 +-
neural_speed/models/model_utils/model_files.h | 43 ++-
neural_speed/models/qwen/qwen_utils.cpp | 6 +-
12 files changed, 778 insertions(+), 108 deletions(-)
create mode 100644 neural_speed/convert/convert_quantized_mixtral.py
create mode 100644 neural_speed/convert/convert_quantized_qwen.py
diff --git a/docs/gptq_and_awq.md b/docs/gptq_and_awq.md
index f3eeecf2c..ec66f5e43 100644
--- a/docs/gptq_and_awq.md
+++ b/docs/gptq_and_awq.md
@@ -6,11 +6,12 @@ Neural Speed supports multiple weight-only quantization algorithms, such as GPTQ
More algorithm details please check [GPTQ](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978).
Validated GPTQ & AWQ models directly from the HuggingFace:
-* [Llama-2-7B-Chat-GPT](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) & [Llama-2-13B-Chat-GPT](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ)
-* [CodeLlama-7B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ) & [CodeLlama-13B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-13B-Instruct-GPTQ)
+* [Llama-2-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) & [Llama-2-13B-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-13B-Chat-GPTQ) & [Llama-2-7B-AWQ](https://huggingface.co/TheBloke/Llama-2-7B-AWQ) & [Llama-2-13B-chat-AWQ](https://huggingface.co/TheBloke/Llama-2-13B-chat-AWQ)
+* [CodeLlama-7B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ) & [CodeLlama-13B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-13B-Instruct-GPTQ) & [CodeLlama-7B-AWQ](https://huggingface.co/TheBloke/CodeLlama-7B-AWQ) & [CodeLlama-13B-AWQ](https://huggingface.co/TheBloke/CodeLlama-13B-AWQ)
+* [Mistral-7B-Instruct-v0.1-GPTQ](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GPTQ) & [Mistral-7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GPTQ)
+* [Mixtral-8x7B-Instruct-v0.1-GPTQ](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ) & [Mixtral-8x7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ)
+* [Qwen-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-GPTQ) & [Qwen-7B-Chat-AWQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-AWQ) & * [Qwen1.5-7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GPTQ-Int4)
* [SOLAR-10.7B-v1.0-GPTQ](https://huggingface.co/TheBloke/SOLAR-10.7B-v1.0-GPTQ)
-* [Llama-2-7B-AWQ](https://huggingface.co/TheBloke/Llama-2-7B-AWQ) & [Llama-2-13B-chat-AWQ](https://huggingface.co/TheBloke/Llama-2-13B-chat-AWQ)
-* [CodeLlama-7B-AWQ](https://huggingface.co/TheBloke/CodeLlama-7B-AWQ) & [CodeLlama-13B-AWQ](https://huggingface.co/TheBloke/CodeLlama-13B-AWQ)
Please check more validated GPTQ & AWQ models in the list of [supported_models](./docs/supported_models.md).
diff --git a/docs/supported_models.md b/docs/supported_models.md
index c4a658256..ef3f7d362 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -72,17 +72,58 @@ Neural Speed supports the following models:
phi-2,
phi-1_5
diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py
index 7bb39ce16..d0ab95c21 100644
--- a/neural_speed/__init__.py
+++ b/neural_speed/__init__.py
@@ -66,7 +66,7 @@ def __import_package(self, model_type):
import neural_speed.qwen_cpp as cpp_model
elif model_type == "mistral":
import neural_speed.mistral_cpp as cpp_model
- elif model_type == "qwen":
+ elif model_type == "qwen2":
import neural_speed.qwen_cpp as cpp_model
elif model_type == "phi":
import neural_speed.phi_cpp as cpp_model
diff --git a/neural_speed/convert/__init__.py b/neural_speed/convert/__init__.py
index 9f063a5ec..4e2a6796d 100644
--- a/neural_speed/convert/__init__.py
+++ b/neural_speed/convert/__init__.py
@@ -19,7 +19,7 @@
from transformers import AutoConfig
import subprocess
-model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper"}
+model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper", "qwen2": "qwen"}
def convert_model(model, outfile, outtype="f32", whisper_repo_path=None, use_quantized_model=False):
diff --git a/neural_speed/convert/common.py b/neural_speed/convert/common.py
index b8891f2e8..fdefc07a4 100644
--- a/neural_speed/convert/common.py
+++ b/neural_speed/convert/common.py
@@ -16,13 +16,14 @@
# limitations under the License.
import torch
+import os
from pathlib import Path
import numpy as np
import struct
import json
import warnings
-from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
- Literal, Optional, Sequence, Tuple, TypeVar, Union)
+from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar,
+ Union)
from sentencepiece import SentencePieceProcessor # type: ignore
GGML_QK8_0 = 32
@@ -35,6 +36,7 @@
GGML_QK4_1_TYPE = 3
GGML_QJBLAS_TYPE = 19
+
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
def bytes_to_unicode():
"""
@@ -57,6 +59,7 @@ def bytes_to_unicode():
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
+
def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_0 in ggml.c
assert tensor.shape[1] % GGML_QK4_0 == 0
@@ -71,6 +74,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor
+
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_1 in ggml.c
assert tensor.shape[1] % GGML_QK4_1 == 0
@@ -85,6 +89,7 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
tensor = torch.cat((scale.half().view(torch.int8), min_vals.half().view(torch.int8), tensor), dim=-1)
return tensor
+
class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
@@ -154,8 +159,7 @@ def load_vocab(path: Path) -> SentencePieceVocab:
else:
raise FileNotFoundError(
f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \
- pass the directory as --vocab-dir"
- )
+ pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file {path}")
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
@@ -192,6 +196,7 @@ def qzeros_to_zeros(qzeros, bits=4):
col += 1
return zeros
+
def unpack_weight(qweight, scales, qzeros, q_config):
if "quant_method" not in q_config:
raise ValueError(f"Unsupported q_config without quant_method: {q_config}")
@@ -220,17 +225,18 @@ def unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config):
wf = torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32).unsqueeze(0)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
- torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)
+ torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)
zeros = zeros + 1
zeros = zeros.reshape(scales.shape)
weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
- torch.bitwise_and(weight,(2 ** bits) - 1, out=weight)
+ torch.bitwise_and(weight, (2**bits) - 1, out=weight)
return weight, scales, zeros
+
def unpack_gptq_weight_3bits(qweight, scales, qzeros, q_config):
print("unpack_gptq_weight_3bits... ", end='')
group_size = q_config['group_size']
@@ -239,23 +245,23 @@ def unpack_gptq_weight_3bits(qweight, scales, qzeros, q_config):
assert bits == 3
# Int32 can only store 10 * 3bits data. This is the offset for each data.
- wf = torch.tensor([[ i for i in range(0, s32_bits - bits, bits)]], dtype=torch.int32)
+ wf = torch.tensor([[i for i in range(0, s32_bits - bits, bits)]], dtype=torch.int32)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
- torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)
+ torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)
zeros = zeros + 1
zeros = zeros.reshape(zeros.shape[0], -1)
- zeros = zeros[:,:scales.shape[1]]
+ zeros = zeros[:, :scales.shape[1]]
weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
weight = weight.reshape(-1, weight.shape[-1])
input_feature = group_size * scales.shape[0]
- weight = weight[:input_feature,:]
+ weight = weight[:input_feature, :]
- torch.bitwise_and(weight,(2 ** bits) - 1, out=weight)
+ torch.bitwise_and(weight, (2**bits) - 1, out=weight)
return weight, scales, zeros
@@ -271,12 +277,13 @@ def unpack_awq_weight(qweight, scales, qzeros, q_config):
for col in range(qweight.shape[1]):
for i in range(pack_num):
w_col = torch.bitwise_right_shift(qweight[:, col], 4 * order_map[i])
- weight[:, col * pack_num + i] = torch.bitwise_and(w_col, (2 ** bits) - 1)
+ weight[:, col * pack_num + i] = torch.bitwise_and(w_col, (2**bits) - 1)
z_col = torch.bitwise_right_shift(qzeros[:, col], 4 * order_map[i])
- zeros[:, col * pack_num + i] = torch.bitwise_and(z_col, (2 ** bits) - 1)
+ zeros[:, col * pack_num + i] = torch.bitwise_and(z_col, (2**bits) - 1)
return weight, scales, zeros
+
def write_header(fout, shape, dst_name, ftype_cur):
sname = dst_name.encode('utf-8')
fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
@@ -295,6 +302,31 @@ def find_quantized_model_file(model_path):
print(f"Detected model file {found[0]}")
return str(found[0])
+
+def load_quantized_safetensors(model_path):
+ # load GPTQ & AWQ models, only for safetensors
+ from safetensors.torch import load_file
+ safetensors = []
+ for file in os.listdir(model_path):
+ if file.endswith(".safetensors"):
+ safetensors.append(file)
+
+ print(f"safetensors list = {safetensors}")
+ model = {}
+ for file in safetensors:
+ tmp = load_file(model_path + "/" + file)
+ if isinstance(tmp, dict):
+ model.update(tmp)
+
+ with open(model_path + '/config.json', "r", encoding="utf-8") as f:
+ config = json.load(f)
+
+ quantize_config = config["quantization_config"]
+ if "zero_point" in quantize_config:
+ quantize_config["sym"] = not quantize_config["zero_point"]
+ return model, config, config["quantization_config"]
+
+
def load_quantized_model(model_path):
input_path = find_quantized_model_file(model_path)
model = None
@@ -318,8 +350,9 @@ def load_quantized_model(model_path):
def convert_to_fp32_tensor(src_name, dst_name, model, fout):
v = model[src_name]
shape = v.shape
- # print("Processing non-Q4 variable: " + src_name +
- # " with shape: ", shape, " and type: ", v.dtype)
+ n_dims = len(shape)
+ print("Processing non-Q4 variable: " + src_name + " -> " + dst_name + " with shape: ", shape, " and type: ",
+ v.dtype, "data: ", v[:2, :2].tolist() if n_dims > 1 else v[:2].tolist())
v = v.to(torch.float32)
ftype_cur = {torch.float16: 1, torch.float32: 0}[v.dtype]
@@ -329,7 +362,8 @@ def convert_to_fp32_tensor(src_name, dst_name, model, fout):
# data
v.numpy().tofile(fout)
- print(f"converting {dst_name} float tensor")
+ #print(f"converting {src_name} -> {dst_name} float tensor")
+
def convert_q4_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2=0, permute_func=None):
qzeros = model[f"{src_name}.qzeros"]
@@ -338,9 +372,9 @@ def convert_q4_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2
qweight = model[f"{src_name}.qweight"]
int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
- int_weight = int_weight.view(-1,int_weight.shape[-1]).t()
- gptq_scales = gptq_scales.view(-1,gptq_scales.shape[-1]).t()
- gptq_zeros = gptq_zeros.view(-1,gptq_zeros.shape[-1]).t()
+ int_weight = int_weight.view(-1, int_weight.shape[-1]).t()
+ gptq_scales = gptq_scales.view(-1, gptq_scales.shape[-1]).t()
+ gptq_zeros = gptq_zeros.view(-1, gptq_zeros.shape[-1]).t()
write_header(fout, int_weight.shape, dst_name, 2)
if permute_func:
@@ -350,12 +384,13 @@ def convert_q4_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2
tensor = int_weight.reshape(-1, 32) - 8
tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
- gptq_scale = gptq_scales.reshape(-1,1)
+ gptq_scale = gptq_scales.reshape(-1, 1)
# gptq_scale = torch.cat([gptq_scale,gptq_scale,gptq_scale,gptq_scale], dim=1).view(-1,1)
pack_tensor = torch.cat((gptq_scale.half().view(torch.int8), tensor), dim=-1)
pack_tensor.numpy().tofile(fout)
print(f"converting {dst_name} quantized tensor to ggml q4 block")
+
def convert_q4_1_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2=0, permute_func=None):
qzeros = model[f"{src_name}.qzeros"]
zeros = qzeros_to_zeros(qzeros)
@@ -364,9 +399,9 @@ def convert_q4_1_tensor(src_name, dst_name, model, fout, q_config, n_head, n_hea
qweight = model[f"{src_name}.qweight"]
int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
- int_weight = int_weight.view(-1,int_weight.shape[-1]).t()
- gptq_scales = gptq_scales.view(-1,gptq_scales.shape[-1]).t()
- gptq_zeros = gptq_zeros.view(-1,gptq_zeros.shape[-1]).t()
+ int_weight = int_weight.view(-1, int_weight.shape[-1]).t()
+ gptq_scales = gptq_scales.view(-1, gptq_scales.shape[-1]).t()
+ gptq_zeros = gptq_zeros.view(-1, gptq_zeros.shape[-1]).t()
write_header(fout, int_weight.shape, dst_name, 3)
if permute_func:
@@ -376,9 +411,9 @@ def convert_q4_1_tensor(src_name, dst_name, model, fout, q_config, n_head, n_hea
tensor = int_weight.reshape(-1, 32)
tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
- gptq_scale = gptq_scales.reshape(-1,1)
- gptq_zeros = gptq_zeros.reshape(-1,1)
- gptq_zeros = -gptq_scale*gptq_zeros
+ gptq_scale = gptq_scales.reshape(-1, 1)
+ gptq_zeros = gptq_zeros.reshape(-1, 1)
+ gptq_zeros = -gptq_scale * gptq_zeros
pack_tensor = torch.cat((gptq_scale.half().view(torch.int8), gptq_zeros.half().view(torch.int8), tensor), dim=-1)
pack_tensor.numpy().tofile(fout)
print(f"converting {dst_name} quantized tensor to ggml q4 1 block")
@@ -413,6 +448,7 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h
print(f"converting {dst_name} quantized tensor to fp32 tensor")
+
def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
@@ -422,7 +458,7 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
qweight = model[f"{src_name}.qweight"]
int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
- int_weight = int_weight.view(-1,int_weight.shape[-1])
+ int_weight = int_weight.view(-1, int_weight.shape[-1])
# permute_func for llama-like model
if permute_func:
@@ -431,10 +467,10 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
gptq_zeros = permute_func(gptq_zeros.t(), n_head, n_head_kv).t().contiguous()
# shuffle weight in GPTQ when act order is on
- if 'desc_act'in q_config and q_config['desc_act']:
+ if 'desc_act' in q_config and q_config['desc_act']:
g_idx = model[f"{src_name}.g_idx"]
int_weight2 = int_weight.clone()
- group_size=q_config['group_size']
+ group_size = q_config['group_size']
group_dict = {}
for i in range(len(g_idx)):
group_idx = g_idx[i].item()
@@ -461,16 +497,20 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
gptq_zeros = np.empty(0, dtype=np.int8)
else:
gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
- if 'desc_act'in q_config and q_config['desc_act']:
+ if 'desc_act' in q_config and q_config['desc_act']:
g_idx = np.ascontiguousarray(g_idx.numpy())
else:
g_idx = np.empty(0, dtype=np.int32)
# pack int weight in bestla format
- byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst,
- weight_dtype="int4" if q_config['bits'] == 4 else "int8",
- group_size=q_config['group_size'],
- alg="sym" if q_config['sym'] else "asym",
- compute_dtype="int8")
+ byte_size = cpp_model.Model.np_bestla_qpack(int_weight,
+ gptq_scales,
+ gptq_zeros,
+ g_idx,
+ dst,
+ weight_dtype="int4" if q_config['bits'] == 4 else "int8",
+ group_size=q_config['group_size'],
+ alg="sym" if q_config['sym'] else "asym",
+ compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")
diff --git a/neural_speed/convert/convert_quantized_mistral.py b/neural_speed/convert/convert_quantized_mistral.py
index 7e25bd22b..a0af0e13a 100644
--- a/neural_speed/convert/convert_quantized_mistral.py
+++ b/neural_speed/convert/convert_quantized_mistral.py
@@ -161,9 +161,6 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("f", config["rms_norm_eps"]))
f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000))
f.write(struct.pack("f", rope_scale))
- f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
- f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
- f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
diff --git a/neural_speed/convert/convert_quantized_mixtral.py b/neural_speed/convert/convert_quantized_mixtral.py
new file mode 100644
index 000000000..abb209f9a
--- /dev/null
+++ b/neural_speed/convert/convert_quantized_mixtral.py
@@ -0,0 +1,281 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import sys
+import re
+import argparse
+from common import *
+
+
+def permute_func(weights, n_head: int, n_head_kv: int):
+ if n_head_kv is not None and n_head != n_head_kv:
+ n_head = n_head_kv
+ return (weights.reshape(n_head_kv, 2, weights.shape[0] // n_head_kv // 2,
+ *weights.shape[1:]).swapaxes(1, 2).reshape(weights.shape))
+
+
+def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
+ # unpack weight and repack into jblas format
+ import neural_speed.llama_cpp as cpp_model
+ if ".weight" in src_name:
+ src_name = src_name.replace(".weight", "")
+ qzeros = model[f"{src_name}.qzeros"]
+ zeros = qzeros_to_zeros(qzeros)
+ scales = model[f"{src_name}.scales"]
+ qweight = model[f"{src_name}.qweight"]
+
+ int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
+ int_weight = int_weight.view(-1, int_weight.shape[-1])
+
+ # shuffle weight in GPTQ when act order is on
+ if 'desc_act' in q_config and q_config['desc_act']:
+ g_idx = model[f"{src_name}.g_idx"]
+ int_weight2 = int_weight.clone()
+ group_size = q_config['group_size']
+ group_dict = {}
+ for i in range(len(g_idx)):
+ group_idx = g_idx[i].item()
+ if group_idx not in group_dict:
+ target_idx = group_idx * group_size
+ group_dict[group_idx] = 0
+ else:
+ group_dict[group_idx] = group_dict[group_idx] + 1
+ target_idx = group_idx * group_size + group_dict[group_idx]
+ int_weight2[target_idx] = int_weight[i]
+ int_weight = int_weight2
+
+ # permute_func for llama-like model
+ if permute_func:
+ int_weight = permute_func(int_weight.t(), n_head, n_head_kv).t().contiguous()
+ gptq_scales = permute_func(gptq_scales.t(), n_head, n_head_kv).t().contiguous()
+ gptq_zeros = permute_func(gptq_zeros.t(), n_head, n_head_kv).t().contiguous()
+
+ shape = int_weight.shape[::-1]
+ # write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)
+ n_dims = len(shape)
+ str = dst_name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), GGML_QJBLAS_TYPE))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", shape[n_dims - 1 - i]))
+ fout.write(str)
+
+ weight_dtype = "int8"
+ if q_config['bits'] == 4:
+ int_weight = (int_weight - 8) * 16
+ gptq_scales = gptq_scales / 16
+ gptq_zeros = (gptq_zeros - 8) * 16
+ weight_dtype == "int4"
+
+ dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
+ int_weight = np.ascontiguousarray(int_weight.numpy())
+ gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
+ if q_config['sym']:
+ gptq_zeros = np.empty(0, dtype=np.int8)
+ else:
+ gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
+ if 'desc_act' in q_config and q_config['desc_act']:
+ g_idx = np.ascontiguousarray(g_idx.numpy())
+ else:
+ g_idx = np.empty(0, dtype=np.int32)
+
+ # pack int weight in bestla format
+ byte_size = cpp_model.Model.np_bestla_qpack(int_weight,
+ gptq_scales,
+ gptq_zeros,
+ g_idx,
+ dst,
+ weight_dtype=weight_dtype,
+ group_size=q_config['group_size'],
+ alg="sym" if q_config['sym'] else "asym",
+ compute_dtype="int8")
+ dst.flatten()[:byte_size].tofile(fout)
+ print(f"convert_to_q4_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}")
+
+
+def main(args_in: Optional[List[str]] = None) -> None:
+ parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
+ parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
+ parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
+ parser.add_argument("model", type=Path, help="directory containing model file")
+ args = parser.parse_args(args_in)
+
+ out_path = args.outfile.as_posix()
+ model_path = args.model.as_posix()
+
+ model, config, quantize_config = load_quantized_safetensors(model_path)
+ f = open(out_path, "wb")
+
+ # possible data types
+ # ftype == 0 -> float32
+ # ftype == 1 -> float16
+ ftype = 0
+ if args.outtype == "f16":
+ ftype = 1
+
+ # 1. write hparams
+ n_vocab = config["vocab_size"]
+ n_embd = config["hidden_size"]
+ n_layer = config["num_hidden_layers"]
+ n_head = config["num_attention_heads"]
+ ffn_hidden_size = config["intermediate_size"]
+
+ # hardcoded:
+ n_mult = 256
+ # 1. write head and params
+ ne_file_magic = 0x67676d66
+ f.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
+ #f.write(b"ggjt"[::-1]) # magic
+ rope_scale = 1
+ if "rope_scaling" in config and config["rope_scaling"] is not None:
+ rope_scale = config["rope_scaling"]["factor"] if "factor" in config["rope_scaling"] else 1
+
+ n_head = n_head
+ n_head_kv = 8
+ values = [
+ 1, # file version
+ n_vocab,
+ n_embd,
+ 256, #hparams.n_mult,
+ n_head,
+ n_head_kv, # n_head_kv (multi_query attention)
+ n_layer,
+ n_embd // n_head, # rot (obsolete)
+ 0, #file_type.value, # TODO
+ ]
+
+ f.write(struct.pack("i" * len(values), *values))
+ f.write(struct.pack("i", 0))
+ f.write(struct.pack("f", 0))
+ f.write(struct.pack("f", 0))
+ f.write(struct.pack("i", 0))
+ f.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
+ f.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
+
+ f.write(struct.pack("i", 0))
+ f.write(struct.pack("i", ffn_hidden_size))
+ f.write(struct.pack("i", 0))
+ f.write(struct.pack("i", 8)) # n_experts
+ f.write(struct.pack("i", 2)) # n_expert_used
+
+ f.write(struct.pack("f", config["rms_norm_eps"]))
+ f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000))
+ f.write(struct.pack("f", rope_scale))
+
+ f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+
+ # TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json
+ # but bos_token_id = 1 in llama.cpp
+ f.write(struct.pack("i", 1))
+ f.write(struct.pack("i", 2))
+
+ f.write(struct.pack("i", 0))
+ f.write(struct.pack("i", 0))
+
+ # 2. vocab
+ tokenizer_path = os.path.join(model_path, "tokenizer.model")
+ vocab = load_vocab(Path(tokenizer_path))
+ for text, score in vocab.all_tokens():
+ f.write(struct.pack("i", len(text)))
+ f.write(text)
+ f.write(struct.pack("f", score))
+
+ def convert_mixtral_to_fp32_tensor(src_name, dst_name, model, fout):
+ # qwen-gptq is torch.bfloat16 mostly.
+ if model[src_name].dtype == torch.float32:
+ data = model[src_name].squeeze().numpy()
+ else:
+ data = model[src_name].squeeze().to(torch.float32).numpy()
+ data = data.astype(np.float32)
+ shape = data.shape
+ n_dims = len(shape)
+ print("convert_mixtral_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ",
+ data.dtype)
+
+ #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype]
+ # default type is fp32
+ ftype_cur = 0
+ if ftype == 1 and n_dims > 1:
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ data = data.astype(np.float32)
+
+ # header
+ # write_header(fout, shape, dst_name, ftype_cur)
+ str = dst_name.encode('utf-8')
+ f.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ f.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ f.write(str)
+
+ # data
+ data.tofile(f)
+
+ # 3. write tensors
+ list_vars = model
+ convert_mixtral_to_fp32_tensor("model.embed_tokens.weight", "tok_embeddings.weight", list_vars, f)
+ convert_mixtral_to_fp32_tensor("model.norm.weight", "norm.weight", list_vars, f)
+ convert_mixtral_to_fp32_tensor("lm_head.weight", "output.weight", list_vars, f)
+
+ for i in range(n_layer):
+ convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.q_proj",
+ f"layers.{i}.attention.wq.weight",
+ list_vars,
+ f,
+ quantize_config,
+ n_head,
+ n_head,
+ permute_func=permute_func)
+ convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.k_proj",
+ f"layers.{i}.attention.wk.weight",
+ list_vars,
+ f,
+ quantize_config,
+ n_head,
+ n_head_kv,
+ permute_func=permute_func)
+ convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight", list_vars,
+ f, quantize_config, n_head)
+ convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight", list_vars,
+ f, quantize_config, n_head)
+
+ convert_mixtral_to_fp32_tensor(f"model.layers.{i}.block_sparse_moe.gate.weight",
+ f"layers.{i}.ffn_gate_inp.weight", list_vars, f)
+
+ for j in range(8):
+ convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w1",
+ f"layers.{i}.ffn_gate.{j}.weight", list_vars, f, quantize_config, n_head)
+ convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w2",
+ f"layers.{i}.ffn_down.{j}.weight", list_vars, f, quantize_config, n_head)
+ convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w3",
+ f"layers.{i}.ffn_up.{j}.weight", list_vars, f, quantize_config, n_head)
+
+ convert_mixtral_to_fp32_tensor(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight",
+ list_vars, f)
+ convert_mixtral_to_fp32_tensor(f"model.layers.{i}.post_attention_layernorm.weight",
+ f"layers.{i}.ffn_norm.weight", list_vars, f)
+
+ f.close()
+ print(f"Success! saved as {out_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/neural_speed/convert/convert_quantized_qwen.py b/neural_speed/convert/convert_quantized_qwen.py
new file mode 100644
index 000000000..fc0b87ed2
--- /dev/null
+++ b/neural_speed/convert/convert_quantized_qwen.py
@@ -0,0 +1,306 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import sys
+import re
+import argparse
+from common import *
+
+
+def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
+ # unpack weight and repack into 3bits / 4bits BestLA format
+ import neural_speed.llama_cpp as cpp_model
+ if ".weight" in src_name:
+ src_name = src_name.replace(".weight", "")
+ qzeros = model[f"{src_name}.qzeros"]
+ zeros = qzeros_to_zeros(qzeros)
+ scales = model[f"{src_name}.scales"]
+ qweight = model[f"{src_name}.qweight"]
+
+ int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
+ int_weight = int_weight.view(-1, int_weight.shape[-1])
+
+ # shuffle weight in GPTQ when act order is on
+ if 'desc_act' in q_config and q_config['desc_act']:
+ g_idx = model[f"{src_name}.g_idx"]
+ int_weight2 = int_weight.clone()
+ group_size = q_config['group_size']
+ group_dict = {}
+ for i in range(len(g_idx)):
+ group_idx = g_idx[i].item()
+ if group_idx not in group_dict:
+ target_idx = group_idx * group_size
+ group_dict[group_idx] = 0
+ else:
+ group_dict[group_idx] = group_dict[group_idx] + 1
+ target_idx = group_idx * group_size + group_dict[group_idx]
+ int_weight2[target_idx] = int_weight[i]
+ int_weight = int_weight2
+
+ # shape = int_weight.shape[::-1]
+ shape = int_weight.shape[::-1]
+ # write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)
+ n_dims = len(shape)
+ str = dst_name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), GGML_QJBLAS_TYPE))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", shape[n_dims - 1 - i]))
+ fout.write(str)
+
+ # INC stores sig-int4 value as u4(range 0~15, they add a offset),
+ # BesTLA requires s4_clip((-8,7)*16), so we sub the offset and then mul 16.
+ # Int3 is the same as int4, but offset=4, mul scale==32.
+ weight_dtype = "int8"
+ if q_config['bits'] == 4:
+ int_weight = (int_weight - 8) * 16
+ gptq_scales = gptq_scales / 16
+ gptq_zeros = (gptq_zeros - 8) * 16
+ weight_dtype = "int4"
+ elif q_config['bits'] == 3:
+ int_weight = (int_weight - 4) * 32
+ gptq_scales = gptq_scales / 32
+ gptq_zeros = (gptq_zeros - 4) * 32
+ weight_dtype = "int3"
+ else:
+ ValueError(f"Unsupported q_config[bits]: {q_config['bits']}")
+
+ dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
+ int_weight = np.ascontiguousarray(int_weight.numpy())
+ gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
+ if q_config['sym']:
+ gptq_zeros = np.empty(0, dtype=np.int8)
+ else:
+ gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
+ if 'desc_act' in q_config and q_config['desc_act']:
+ g_idx = np.ascontiguousarray(g_idx.numpy())
+ else:
+ g_idx = np.empty(0, dtype=np.int32)
+
+ # repack int weight in BesTLA format
+ byte_size = cpp_model.Model.np_bestla_qpack(int_weight,
+ gptq_scales,
+ gptq_zeros,
+ g_idx,
+ dst,
+ weight_dtype=weight_dtype,
+ group_size=q_config['group_size'],
+ alg="sym" if q_config['sym'] else "asym",
+ compute_dtype="int8")
+ dst.flatten()[:byte_size].tofile(fout)
+ print(f"convert_to_qx_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}")
+
+
+def main(args_in: Optional[List[str]] = None) -> None:
+ parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
+ parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
+ parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
+ parser.add_argument("model", type=Path, help="directory containing model file")
+ args = parser.parse_args(args_in)
+
+ out_path = args.outfile.as_posix()
+ model_path = args.model.as_posix()
+
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ # QWEN-GPTQ & AWQ
+ model, hparams, quantize_config = load_quantized_safetensors(model_path)
+ list_vars = model
+
+ print(hparams)
+
+ # orinal QWEN
+ # model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
+ # hparams = model.config.to_dict()
+ # list_vars = model.state_dict()
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ f = open(out_path, "wb")
+
+ # possible data types
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype = 0
+ if args.outtype == "f16":
+ ftype = 1
+
+ # 1. write hparams
+ # 0x67676d6c is unversioned ne
+ # 0x67676d66 is versioned ggmf (requires token scores)
+ ne_file_magic = 0x67676d66
+ #ne_file_version = 0x00000001 # v1
+
+ f.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
+ f.write(struct.pack("i", 1))
+
+ f.write(struct.pack("i", hparams["vocab_size"]))
+ f.write(struct.pack("i", hparams["hidden_size"]))
+ f.write(struct.pack("i", hparams["intermediate_size"])) # dummy data
+ f.write(struct.pack("i", hparams["num_attention_heads"]))
+ f.write(struct.pack("i", 0)) # multi-query attention
+ f.write(struct.pack("i", hparams["num_hidden_layers"]))
+ f.write(
+ struct.pack(
+ "i", hparams["kv_channels"] if "kv_channels" in hparams else int(hparams["hidden_size"] /
+ hparams["num_attention_heads"])))
+ f.write(struct.pack("i", ftype))
+ f.write(struct.pack("i", hparams["seq_length"] if "seq_length" in hparams else hparams["max_position_embeddings"]))
+ f.write(struct.pack("f", 0.0))
+ f.write(struct.pack("f", 0.0))
+ f.write(struct.pack("i", 0))
+ f.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
+ f.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
+
+ f.write(struct.pack("i", 0))
+ f.write(struct.pack("i", hparams["intermediate_size"]))
+ f.write(struct.pack("i", 0))
+ f.write(struct.pack("i", 0)) # n_experts
+ f.write(struct.pack("i", 0)) # n_expert_used
+ f.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
+ f.write(struct.pack("f", 10000.0)) # freq_base
+ f.write(struct.pack("f", 1.0)) # rope_factor
+
+ f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+
+ f.write(
+ struct.pack(
+ "i", hparams["bos_token_id"] if "bos_token_id" in hparams else tokenizer.special_tokens['<|endoftext|>']))
+ f.write(
+ struct.pack(
+ "i", hparams["eos_token_id"] if "eos_token_id" in hparams else tokenizer.special_tokens['<|endoftext|>']))
+ f.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
+ f.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
+
+ # 2. vocab
+ for i in range(hparams["vocab_size"]):
+ if i < tokenizer.vocab_size:
+ text = tokenizer.decode([i]).encode('utf-8')
+ f.write(struct.pack("i", len(text)))
+ f.write(text)
+ f.write(struct.pack("f", 0.0 - i))
+ else:
+ text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8')
+ f.write(struct.pack("i", len(text)))
+ f.write(text)
+ f.write(struct.pack("f", -10000))
+
+ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
+ # qwen-gptq is torch.bfloat16 mostly.
+ if model[src_name].dtype == torch.float32:
+ data = model[src_name].squeeze().numpy()
+ else:
+ data = model[src_name].squeeze().to(torch.float32).numpy()
+ data = data.astype(np.float32)
+ shape = data.shape
+ n_dims = len(shape)
+ print("convert_qwen_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ",
+ data.dtype)
+
+ #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype]
+ # default type is fp32
+ ftype_cur = 0
+ if ftype == 1 and n_dims > 1:
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ data = data.astype(np.float32)
+
+ # header
+ # write_header(fout, shape, dst_name, ftype_cur)
+ str = src_name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str)
+
+ # data
+ data.tofile(f)
+
+ #3. write tensors
+ if hparams['model_type'] == 'qwen':
+ convert_qwen_to_fp32_tensor("transformer.wte.weight", "transformer.wte.weight", list_vars, f)
+ convert_qwen_to_fp32_tensor("transformer.ln_f.weight", "transformer.ln_f.weight", list_vars, f)
+ convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, f)
+
+ for i in range(hparams["num_hidden_layers"]):
+ convert_qwen_to_fp32_tensor(f"transformer.h.{i}.ln_1.weight", f"transformer.h.{i}.ln_1.weight", list_vars,
+ f)
+ convert_qwen_to_fp32_tensor(f"transformer.h.{i}.ln_2.weight", f"transformer.h.{i}.ln_2.weight", list_vars,
+ f)
+
+ # qkv GEMM
+ convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.c_attn.weight",
+ f"transformer.h.{i}.attn.c_attn.weight", list_vars, f, quantize_config)
+ convert_qwen_to_fp32_tensor(f"transformer.h.{i}.attn.c_attn.bias", f"transformer.h.{i}.attn.c_attn.bias",
+ list_vars, f)
+ convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.c_proj.weight",
+ f"transformer.h.{i}.attn.c_proj.weight", list_vars, f, quantize_config)
+
+ # ffn GEMM
+ convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.w1.weight", f"transformer.h.{i}.mlp.w1.weight",
+ list_vars, f, quantize_config)
+ convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.w2.weight", f"transformer.h.{i}.mlp.w2.weight",
+ list_vars, f, quantize_config)
+ convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.c_proj.weight", f"transformer.h.{i}.mlp.c_proj.weight",
+ list_vars, f, quantize_config)
+
+ f.close()
+ print(f"Success! saved as {out_path}")
+ elif hparams['model_type'] == 'qwen2':
+ # 3. write tensors
+ convert_qwen_to_fp32_tensor("model.embed_tokens.weight", "model.embed_tokens.weight", list_vars, f)
+ convert_qwen_to_fp32_tensor("model.norm.weight", "model.norm.weight", list_vars, f)
+ convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, f)
+
+ for i in range(hparams["num_hidden_layers"]):
+ convert_qwen_to_fp32_tensor(f"model.layers.{i}.input_layernorm.weight",
+ f"model.layers.{i}.input_layernorm.weight", list_vars, f)
+ convert_qwen_to_fp32_tensor(f"model.layers.{i}.post_attention_layernorm.weight",
+ f"model.layers.{i}.post_attention_layernorm.weight", list_vars, f)
+
+ # qkv GEMM
+ convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.q_proj.weight",
+ f"model.layers.{i}.self_attn.q_proj.weight", list_vars, f, quantize_config)
+ convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.k_proj.weight",
+ f"model.layers.{i}.self_attn.k_proj.weight", list_vars, f, quantize_config)
+ convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.v_proj.weight",
+ f"model.layers.{i}.self_attn.v_proj.weight", list_vars, f, quantize_config)
+ convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.o_proj.weight",
+ f"model.layers.{i}.self_attn.o_proj.weight", list_vars, f, quantize_config)
+
+ convert_qwen_to_fp32_tensor(f"model.layers.{i}.self_attn.q_proj.bias",
+ f"model.layers.{i}.self_attn.q_proj.bias", list_vars, f)
+ convert_qwen_to_fp32_tensor(f"model.layers.{i}.self_attn.k_proj.bias",
+ f"model.layers.{i}.self_attn.k_proj.bias", list_vars, f)
+ convert_qwen_to_fp32_tensor(f"model.layers.{i}.self_attn.v_proj.bias",
+ f"model.layers.{i}.self_attn.v_proj.bias", list_vars, f)
+
+ # ffn GEMM
+ convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.down_proj.weight",
+ f"model.layers.{i}.mlp.down_proj.weight", list_vars, f, quantize_config)
+ convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.gate_proj.weight",
+ f"model.layers.{i}.mlp.gate_proj.weight", list_vars, f, quantize_config)
+ convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.up_proj.weight", f"model.layers.{i}.mlp.up_proj.weight",
+ list_vars, f, quantize_config)
+
+ f.close()
+ print(f"Success! saved as {out_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/neural_speed/convert/convert_qwen.py b/neural_speed/convert/convert_qwen.py
index 7966717b2..bcdacbdc2 100644
--- a/neural_speed/convert/convert_qwen.py
+++ b/neural_speed/convert/convert_qwen.py
@@ -100,11 +100,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", hparams["num_attention_heads"]))
fout.write(struct.pack("i", 0)) # multi-query attention
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
- fout.write(struct.pack("i", hparams["kv_channels"] if "kv_channels" in hparams
- else int(hparams["hidden_size"]/hparams["num_attention_heads"])))
+ fout.write(
+ struct.pack(
+ "i", hparams["kv_channels"] if "kv_channels" in hparams else int(hparams["hidden_size"] /
+ hparams["num_attention_heads"])))
fout.write(struct.pack("i", ftype))
- fout.write(struct.pack("i", hparams["seq_length"] if "seq_length" in hparams
- else hparams["max_position_embeddings"]))
+ fout.write(
+ struct.pack("i", hparams["seq_length"] if "seq_length" in hparams else hparams["max_position_embeddings"]))
fout.write(struct.pack("f", 0.0))
fout.write(struct.pack("f", 0.0))
fout.write(struct.pack("i", 0))
@@ -120,13 +122,15 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
- fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
- fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
- fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
- fout.write(struct.pack("i", hparams["bos_token_id"] if hparams["bos_token_id"]
- else tokenizer.special_tokens['<|endoftext|>']))
- fout.write(struct.pack("i", hparams["eos_token_id"] if hparams["eos_token_id"]
- else tokenizer.special_tokens['<|endoftext|>']))
+ fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+ fout.write(
+ struct.pack(
+ "i", hparams["bos_token_id"] if "bos_token_id" in hparams else tokenizer.special_tokens['<|endoftext|>']))
+ fout.write(
+ struct.pack(
+ "i", hparams["eos_token_id"] if "eos_token_id" in hparams else tokenizer.special_tokens['<|endoftext|>']))
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
diff --git a/neural_speed/models/llama/llama_utils.cpp b/neural_speed/models/llama/llama_utils.cpp
index 2bca0673c..9757156c0 100644
--- a/neural_speed/models/llama/llama_utils.cpp
+++ b/neural_speed/models/llama/llama_utils.cpp
@@ -118,7 +118,7 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback,
const int i_gpu_start = n_layer - n_gpu_layer;
model.layers.resize(n_layer);
size_t vram_total = 0;
- if (ml->verify_tensor("token_embd.weight")) {
+ if (ml->verify_tensor("token_embd.weight")) { // GGUF
model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU);
model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab},
@@ -168,7 +168,7 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback,
ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]);
}
}
- } else {
+ } else { // NE Fortmat
model.others[0] = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
model.others[1] = ml->get_tensor("norm.weight", {n_embd}, NE_BACKEND_CPU);
model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab},
diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h
index 3813a0a09..e71ee94f0 100644
--- a/neural_speed/models/model_utils/model_files.h
+++ b/neural_speed/models/model_utils/model_files.h
@@ -1076,6 +1076,7 @@ struct model_file_loader {
}
void read_ne_hparams() {
+ unsigned int count = 0;
hparams.n_vocab = file.read_u32();
hparams.n_embd = file.read_u32();
hparams.n_mult = file.read_u32();
@@ -1083,43 +1084,79 @@ struct model_file_loader {
hparams.n_head_kv = file.read_u32();
hparams.n_layer = file.read_u32();
hparams.n_rot = file.read_u32();
+ printf("%-16s %d.hparams.n_vocab = %-30d\n", __func__, count++, hparams.n_vocab);
+ printf("%-16s %d.hparams.n_embd = %-30d\n", __func__, count++, hparams.n_embd);
+ printf("%-16s %d.hparams.n_mult = %-30d\n", __func__, count++, hparams.n_mult);
+ printf("%-16s %d.hparams.n_head = %-30d\n", __func__, count++, hparams.n_head);
+ printf("%-16s %d.hparams.n_head_kv = %-30d\n", __func__, count++, hparams.n_head_kv);
+ printf("%-16s %d.hparams.n_layer = %-30d\n", __func__, count++, hparams.n_layer);
+ printf("%-16s %d.hparams.n_rot = %-30d\n", __func__, count++, hparams.n_vocab);
+
hparams.ftype = (enum ne_ftype)file.read_u32();
hparams.max_seq_len = file.read_u32();
file.read_raw(&hparams.alibi_bias_max, sizeof(float));
file.read_raw(&hparams.clip_qkv, sizeof(float));
hparams.par_res = file.read_u32();
-
hparams.word_embed_proj_dim = file.read_u32();
hparams.do_layer_norm_before = bool(file.read_u32());
+ printf("%-16s %d.hparams.ftype = %-30d\n", __func__, count++, hparams.ftype);
+ printf("%-16s %d.hparams.max_seq_len = %-30d\n", __func__, count++, hparams.max_seq_len);
+ printf("%-16s %d.hparams.alibi_bias_max = %-30f\n", __func__, count++, hparams.alibi_bias_max);
+ printf("%-16s %d.hparams.clip_qkv = %-30f\n", __func__, count++, hparams.clip_qkv);
+ printf("%-16s %d.hparams.par_res = %-30d\n", __func__, count++, hparams.par_res);
+ printf("%-16s %d.hparams.word_embed_proj_dim = %-30d\n", __func__, count++, hparams.word_embed_proj_dim);
+ printf("%-16s %d.hparams.do_layer_norm_before = %-30d\n", __func__, count++, hparams.do_layer_norm_before);
- // For ChatGLM-2
hparams.multi_query_group_num = file.read_u32();
hparams.ffn_hidden_size = file.read_u32();
+ printf("%-16s %d.hparams.multi_query_group_num = %-30d\n", __func__, count++, hparams.multi_query_group_num);
+ printf("%-16s %d.hparams.ffn_hidden_size = %-30d\n", __func__, count++, hparams.ffn_hidden_size);
- // For ChatGLM-2
hparams.inner_hidden_size = file.read_u32();
hparams.n_experts = file.read_u32();
hparams.n_experts_used = file.read_u32();
+ printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size);
+ printf("%-16s %d.hparams.n_experts = %-30d\n", __func__, count++, hparams.n_experts);
+ printf("%-16s %d.hparams.n_experts_used = %-30d\n", __func__, count++, hparams.n_experts_used);
+ // rms related
file.read_raw(&hparams.rms_norm_eps, sizeof(float));
file.read_raw(&hparams.freq_base, sizeof(float));
file.read_raw(&hparams.freq_scale, sizeof(float));
+ printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size);
+ printf("%-16s %d.hparams.freq_base = %-30f\n", __func__, count++, hparams.freq_base);
+ printf("%-16s %d.hparams.freq_scale = %-30f\n", __func__, count++, hparams.freq_scale);
file.read_raw(&hparams.rope_scaling_factor, sizeof(float));
hparams.original_max_position_embeddings = file.read_u32();
hparams.use_yarn = file.read_u32();
+ printf("%-16s %d.hparams.rope_scaling_factor = %-30f\n", __func__, count++, hparams.rope_scaling_factor);
+ printf("%-16s %d.hparams.original_max_position_embeddings = %-30d\n", __func__, count++,
+ hparams.original_max_position_embeddings);
+ printf("%-16s %d.hparams.use_yarn = %-30d\n", __func__, count++, hparams.use_yarn);
+ unsigned int total = 25;
+ if (count != total) {
+ fprintf(stderr, "The number of ne_parameters is wrong.\n");
+ }
}
void read_ne_vocab() {
+ unsigned int count = 0;
+ unsigned int ne_hparams_total = 25;
file.read_raw(&vocab.bos_token_id, sizeof(model_vocab::id));
file.read_raw(&vocab.eos_token_id, sizeof(model_vocab::id));
file.read_raw(&vocab.pad_token_id, sizeof(model_vocab::id));
file.read_raw(&vocab.sep_token_id, sizeof(model_vocab::id));
+ printf("%-16s %d.vocab.bos_token_id = %-30d\n", __func__, ne_hparams_total + count++, vocab.bos_token_id);
+ printf("%-16s %d.vocab.eos_token_id = %-30d\n", __func__, ne_hparams_total + count++, vocab.eos_token_id);
+ printf("%-16s %d.vocab.pad_token_id = %-30d\n", __func__, ne_hparams_total + count++, vocab.pad_token_id);
+ printf("%-16s %d.vocab.sep_token_id = %-30d\n", __func__, ne_hparams_total + count++, vocab.sep_token_id);
vocab.id_to_token.resize(hparams.n_vocab);
for (uint32_t i = 0; i < hparams.n_vocab; i++) {
uint32_t len = file.read_u32();
std::string word = file.read_string(len);
+ // std::cout << "word = " << word << std::endl;
float score = 0.0f;
if (file_version >= MODEL_FILE_VERSION_GGMF_V1) {
diff --git a/neural_speed/models/qwen/qwen_utils.cpp b/neural_speed/models/qwen/qwen_utils.cpp
index 2ea38492e..77bba671b 100644
--- a/neural_speed/models/qwen/qwen_utils.cpp
+++ b/neural_speed/models/qwen/qwen_utils.cpp
@@ -58,10 +58,14 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo
}
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
+ fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
+ fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
- fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size);
+ fprintf(stderr, "%s: ftype = %u\n", __func__, hparams.ftype);
+ fprintf(stderr, "%s: max_seq_len= %u\n", __func__, hparams.max_seq_len);
+ fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
|