From eed9b3072f4db0feeb14788c254a4a56478e0da9 Mon Sep 17 00:00:00 2001
From: Zhenzhong1 <109137058+Zhenzhong1@users.noreply.github.com>
Date: Fri, 15 Mar 2024 10:45:20 +0800
Subject: [PATCH] [GPTQ Enhence] Support GPTQ for Baichuan2-13B & Falcon 7B &
Phi-1.5 (#169)
---
docs/gptq_and_awq.md | 5 +-
docs/supported_models.md | 12 +-
neural_speed/__init__.py | 10 +-
neural_speed/convert/__init__.py | 11 +-
neural_speed/convert/common.py | 83 +++++
.../convert/convert_quantized_baichuan.py | 196 +++++++++++
.../convert/convert_quantized_bloom.py | 243 -------------
.../convert/convert_quantized_falcon.py | 174 +++++++++
neural_speed/convert/convert_quantized_phi.py | 329 ++++++++++++++++++
.../convert/convert_quantized_qwen.py | 107 +-----
10 files changed, 819 insertions(+), 351 deletions(-)
create mode 100644 neural_speed/convert/convert_quantized_baichuan.py
delete mode 100644 neural_speed/convert/convert_quantized_bloom.py
create mode 100644 neural_speed/convert/convert_quantized_falcon.py
create mode 100644 neural_speed/convert/convert_quantized_phi.py
diff --git a/docs/gptq_and_awq.md b/docs/gptq_and_awq.md
index d8dfbc43f..ef887c1c1 100644
--- a/docs/gptq_and_awq.md
+++ b/docs/gptq_and_awq.md
@@ -12,8 +12,11 @@ Validated GPTQ & AWQ models directly from the HuggingFace:
* [Mixtral-8x7B-Instruct-v0.1-GPTQ](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ) & [Mixtral-8x7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ)
* [Qwen-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-GPTQ) & [Qwen-7B-Chat-AWQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-AWQ) & * [Qwen1.5-7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GPTQ-Int4)
* [SOLAR-10.7B-v1.0-GPTQ](https://huggingface.co/TheBloke/SOLAR-10.7B-v1.0-GPTQ)
+* [Baichuan2-13B-Chat-GPTQ](https://hf-mirror.com/TheBloke/Baichuan2-13B-Chat-GPTQ)
+* [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b/tree/main)
+* [onlinex/phi-1_5-gptq-4bit](https://hf-mirror.com/onlinex/phi-1_5-gptq-4bit)
-Please check more validated GPTQ & AWQ models in the list of [supported_models](./supported_models.md).
+For more details, please check the list of [supported_models](./supported_models.md).
## Examples
diff --git a/docs/supported_models.md b/docs/supported_models.md
index 4aad26d29..7db9c6b77 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -235,13 +235,13 @@ Neural Speed supports the following models:
Baichuan-13B-Chat,
Baichuan2-13B-Chat |
✅ |
- |
- |
- |
✅ |
- |
- |
- |
+ ✅ |
+ ✅ |
+ ✅ |
+ ✅ |
+ ✅ |
+ ✅ |
4.33.1 |
4096 |
diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py
index dda41c270..8116b2086 100644
--- a/neural_speed/__init__.py
+++ b/neural_speed/__init__.py
@@ -24,7 +24,6 @@
class Model:
-
def __init__(self):
self.module = None
self.model = None
@@ -83,6 +82,15 @@ def get_model_type(model_config):
model_type = model_maps.get(model_config.model_type, model_config.model_type)
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
model_type = "chatglm2"
+
+ # for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
+ if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
+ model_type = "falcon"
+
+ # for TheBloke/phi-2-GPTQ
+ if model_type == "phi-msft":
+ model_type = "phi"
+
return model_type
def init(self,
diff --git a/neural_speed/convert/__init__.py b/neural_speed/convert/__init__.py
index 18ce11490..3cc4f2301 100644
--- a/neural_speed/convert/__init__.py
+++ b/neural_speed/convert/__init__.py
@@ -18,7 +18,15 @@
from pathlib import Path
import subprocess
-model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper", "qwen2": "qwen"}
+model_maps = {
+ "gpt_neox": "gptneox",
+ "gpt_bigcode": "starcoder",
+ "whisper": "whisper",
+ "qwen2": "qwen",
+ "RefinedWebModel": "falcon",
+ "RefinedWeb": "falcon",
+ "phi-msft": "phi"
+}
def convert_model(model, outfile, outtype="f32", model_hub="huggingface", use_quantized_model=False):
@@ -28,6 +36,7 @@ def convert_model(model, outfile, outtype="f32", model_hub="huggingface", use_qu
else:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
+
model_type = model_maps.get(config.model_type, config.model_type)
if use_quantized_model:
diff --git a/neural_speed/convert/common.py b/neural_speed/convert/common.py
index fc0c7d1fc..d4e5f49cc 100644
--- a/neural_speed/convert/common.py
+++ b/neural_speed/convert/common.py
@@ -516,3 +516,86 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")
+
+
+def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
+ # unpack weight and repack into 3bits / 4bits BestLA format
+ import neural_speed.llama_cpp as cpp_model
+ if ".weight" in src_name:
+ src_name = src_name.replace(".weight", "")
+ qzeros = model[f"{src_name}.qzeros"]
+ zeros = qzeros_to_zeros(qzeros)
+ scales = model[f"{src_name}.scales"]
+ qweight = model[f"{src_name}.qweight"]
+
+ int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
+ int_weight = int_weight.view(-1, int_weight.shape[-1])
+
+ # shuffle weight in GPTQ when act order is on
+ if 'desc_act' in q_config and q_config['desc_act']:
+ g_idx = model[f"{src_name}.g_idx"]
+ int_weight2 = int_weight.clone()
+ group_size = q_config['group_size']
+ group_dict = {}
+ for i in range(len(g_idx)):
+ group_idx = g_idx[i].item()
+ if group_idx not in group_dict:
+ target_idx = group_idx * group_size
+ group_dict[group_idx] = 0
+ else:
+ group_dict[group_idx] = group_dict[group_idx] + 1
+ target_idx = group_idx * group_size + group_dict[group_idx]
+ int_weight2[target_idx] = int_weight[i]
+ int_weight = int_weight2
+
+ # shape = int_weight.shape[::-1]
+ shape = int_weight.shape[::-1]
+ # write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)
+ n_dims = len(shape)
+ str = dst_name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), GGML_QJBLAS_TYPE))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", shape[n_dims - 1 - i]))
+ fout.write(str)
+
+ # INC stores sig-int4 value as u4(range 0~15, they add a offset),
+ # BesTLA requires s4_clip((-8,7)*16), so we sub the offset and then mul 16.
+ # Int3 is the same as int4, but offset=4, mul scale==32.
+ weight_dtype = "int8"
+ if q_config['bits'] == 4:
+ int_weight = (int_weight - 8) * 16
+ gptq_scales = gptq_scales / 16
+ gptq_zeros = (gptq_zeros - 8) * 16
+ weight_dtype = "int4"
+ elif q_config['bits'] == 3:
+ int_weight = (int_weight - 4) * 32
+ gptq_scales = gptq_scales / 32
+ gptq_zeros = (gptq_zeros - 4) * 32
+ weight_dtype = "int3"
+ else:
+ ValueError(f"Unsupported q_config[bits]: {q_config['bits']}")
+
+ dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
+ int_weight = np.ascontiguousarray(int_weight.numpy())
+ gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
+ if q_config['sym']:
+ gptq_zeros = np.empty(0, dtype=np.int8)
+ else:
+ gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
+ if 'desc_act' in q_config and q_config['desc_act']:
+ g_idx = np.ascontiguousarray(g_idx.numpy())
+ else:
+ g_idx = np.empty(0, dtype=np.int32)
+
+ # repack int weight in BesTLA format
+ byte_size = cpp_model.Model.np_bestla_qpack(int_weight,
+ gptq_scales,
+ gptq_zeros,
+ g_idx,
+ dst,
+ weight_dtype=weight_dtype,
+ group_size=q_config['group_size'],
+ alg="sym" if q_config['sym'] else "asym",
+ compute_dtype="int8")
+ dst.flatten()[:byte_size].tofile(fout)
+ print(f"convert_to_qx_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}")
diff --git a/neural_speed/convert/convert_quantized_baichuan.py b/neural_speed/convert/convert_quantized_baichuan.py
new file mode 100644
index 000000000..22928a6bc
--- /dev/null
+++ b/neural_speed/convert/convert_quantized_baichuan.py
@@ -0,0 +1,196 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import sys
+import re
+import argparse
+from common import *
+from sentencepiece import SentencePieceProcessor
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+def load_vocab_for_baichuan(path: Path) -> SentencePieceVocab:
+ # Be extra-friendly and accept either a file or a directory. Also, if it's
+ # a directory, it might be the model directory, and tokenizer.model might
+ # be in the parent of that.
+ if path.is_dir():
+ path2 = path / "tokenizer.model"
+ # Use `.parent` instead of /.. to handle the symlink case better.
+ path3 = path.parent / "tokenizer.model"
+ if path2.exists():
+ path = path2
+ elif path3.exists():
+ path = path3
+ else:
+ raise FileNotFoundError(
+ f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \
+ pass the directory as --vocab-dir")
+ added_tokens_path = path.parent / "added_tokens.json"
+ print(f"Loading vocab file {path}")
+ return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
+
+
+def main(args_in: Optional[List[str]] = None) -> None:
+ parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
+ parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
+ parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
+ parser.add_argument("--model_hub",
+ choices=["huggingface", "modelscope"],
+ default="huggingface",
+ help="hub to load model")
+ parser.add_argument("model", type=Path, help="directory containing model file")
+ args = parser.parse_args(args_in)
+
+ out_path = args.outfile.as_posix()
+ model_path = args.model.as_posix()
+
+ model, hparams, quantize_config = load_quantized_safetensors(model_path)
+ list_vars = model
+
+ print(hparams)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ fout = open(out_path, "wb")
+
+ # possible data types
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype = 0
+ if args.outtype == "f16":
+ ftype = 1
+
+ # 1. write hparams
+ print(hparams)
+ ne_file_magic = 0x67676d66
+ fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
+ fout.write(struct.pack("i", 1))
+
+ fout.write(struct.pack("i", hparams["vocab_size"]))
+ fout.write(struct.pack("i", hparams["hidden_size"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", hparams["num_attention_heads"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", hparams["num_hidden_layers"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", ftype))
+ fout.write(struct.pack("i", hparams["model_max_length"]))
+ fout.write(struct.pack("f", 0))
+ fout.write(struct.pack("f", 0))
+ fout.write(struct.pack("i", 0))
+
+ fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
+ fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
+
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", hparams["intermediate_size"]))
+ fout.write(struct.pack("i", 0)) # n_experts
+ fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
+ fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
+
+ fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+
+ fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
+ fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
+ fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
+
+ # 2. vocab
+ tokenizer_path = Path(tokenizer.vocab_file).parent
+ vocab = load_vocab_for_baichuan(Path(tokenizer_path))
+ counter = 0
+ for text, score in vocab.all_tokens():
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", score))
+ counter += 1
+
+ while counter < hparams["vocab_size"]:
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", 0))
+ counter += 1
+
+ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
+ # qwen-gptq is torch.bfloat16 mostly.
+ if model[src_name].dtype == torch.float32:
+ data = model[src_name].squeeze().numpy()
+ else:
+ data = model[src_name].squeeze().to(torch.float32).numpy()
+ data = data.astype(np.float32)
+ shape = data.shape
+ n_dims = len(shape)
+ print("convert_qwen_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ",
+ data.dtype)
+
+ #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype]
+ # default type is fp32
+ ftype_cur = 0
+ if ftype == 1 and n_dims > 1:
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ data = data.astype(np.float32)
+
+ # header
+ # write_header(fout, shape, dst_name, ftype_cur)
+ str = src_name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str)
+
+ # data
+ data.tofile(fout)
+
+ #3. write tensors
+ convert_qwen_to_fp32_tensor("model.embed_tokens.weight", "model.embed_tokens.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor("model.norm.weight", "model.norm.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)
+
+ for i in range(hparams["num_hidden_layers"]):
+ prefix = "model.layers." + str(i)
+
+ convert_qwen_to_fp32_tensor(f"{prefix}.input_layernorm.weight", f"{prefix}.input_layernorm.weight", list_vars,
+ fout)
+ convert_qwen_to_fp32_tensor(f"{prefix}.post_attention_layernorm.weight",
+ f"{prefix}.post_attention_layernorm.weight", list_vars, fout)
+ # qkv GEMM
+ convert_to_qx_bestla_tensor(f"{prefix}.self_attn.W_pack.weight", f"{prefix}.self_attn.W_pack.weight", list_vars,
+ fout, quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.self_attn.o_proj.weight", f"{prefix}.self_attn.o_proj.weight", list_vars,
+ fout, quantize_config)
+
+ # ffn GEMM
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.gate_proj", f"{prefix}.mlp.gate_proj.weight", list_vars, fout,
+ quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.down_proj", f"{prefix}.mlp.down_proj.weight", list_vars, fout,
+ quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.up_proj", f"{prefix}.mlp.up_proj.weight", list_vars, fout,
+ quantize_config)
+
+ fout.close()
+ print(f"Success! saved as {out_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/neural_speed/convert/convert_quantized_bloom.py b/neural_speed/convert/convert_quantized_bloom.py
deleted file mode 100644
index a323019e8..000000000
--- a/neural_speed/convert/convert_quantized_bloom.py
+++ /dev/null
@@ -1,243 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
-# Copyright (c) 2023 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-import os
-import numpy as np
-import struct
-from transformers import AutoTokenizer, TextStreamer, AutoConfig
-from transformers import AutoModelForCausalLM
-import json
-import copy
-from neural_compressor.adaptor.torch_utils.weight_only import quant_weight, quant_weight_w_scale
-import intel_extension_for_transformers.llm.runtime.graph.chatglm2_cpp as cpp_model
-
-GGML_QK8_0 = 32
-GGML_QK4_0 = 32
-GGML_QK4_1 = 32
-GGML_QK5_0 = 32
-GGML_QK5_1 = 32
-
-
-def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
- # equivalent to ggml_quantize_q4_0 in ggml.c
- # import pudb; pudb.set_trace()
- assert tensor.shape[1] % GGML_QK4_0 == 0
- tensor = tensor.view(-1, GGML_QK4_0)
- abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
- max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
- scale = max_values / -8
- tensor = (tensor / scale + 8).round().clamp(min=0, max=15).char()
- # compress two int4 weights into an int8
- tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
- # add scale into each block
- tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
- return tensor
-
-
-def fetch_module(model, op_name):
- """Get module with a given op name.
-
- Args:
- model (object): the input model.
- op_name (str): name of op.
-
- Returns:
- module (object).
- """
- module = model
- name_list = op_name.split(".")
- for name in name_list:
- if hasattr(module, name):
- module = getattr(module, name)
- else:
- module = module
- return module
-
-
-def extract_gptq(model, k, v):
- print(f"Compressing {k}")
- if v["dtype"] == "fp32":
- return
- else:
- dtype = v["dtype"]
- num_bits = v["bits"]
- group_size = v["group_size"]
- scheme = v["scheme"]
- m = fetch_module(model, k)
- m_weight = m.recover()
- # import pdb; pdb.set_trace()
- gptq_conf = gptq_config[k]
- if "perm" in gptq_conf:
- gptq_perm = torch.tensor(gptq_conf["perm"])
- fp32_weight = m_weight[:, gptq_perm]
- else:
- fp32_weight = m_weight
- gptq_perm = None
- gptq_scale = torch.tensor(gptq_conf["scale"])
- gptq_zp = None if scheme == "sym" else torch.tensor(gptq_conf["zero"])
- int_weight = quant_weight_w_scale(fp32_weight, gptq_scale, gptq_zp, group_size)
- return int_weight.to(torch.int8), gptq_scale, gptq_zp
-
-
-# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
-def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a corresponding list of unicode strings.
- The reversible bpe codes work on unicode strings.
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- This is a significant percentage of your normal, say, 32K bpe vocab.
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- And avoids mapping to whitespace/control characters the bpe code barfs on.
- """
- bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
- cs = bs[:]
- n = 0
- for b in range(2**8):
- if b not in bs:
- bs.append(b)
- cs.append(2**8 + n)
- n += 1
-
- cs = [chr(n) for n in cs]
-
- return dict(zip(bs, cs))
-
-
-model_name = "/mnt/disk1/data2/zhenweil/models/bloom/bloom-7b1"
-prompt = "Once upon a time, a little girl"
-tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
-config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
-inputs = tokenizer(prompt, return_tensors="pt").input_ids
-streamer = TextStreamer(tokenizer)
-model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
-
-gptq_model = "/mnt/disk1/data2/zhenweil/models/bloom/bloom-gptq/"
-from neural_compressor.utils.pytorch import load
-
-new_model = load(gptq_model, copy.deepcopy(model), weight_only=True)
-new_model_bk = copy.deepcopy(new_model)
-from neural_compressor.model import Model as INCModel
-
-inc_model = INCModel(new_model)
-qweight_config_path = gptq_model + "qconfig.json"
-gptq_config_path = gptq_model + "gptq_config.json"
-inc_model.export_compressed_model(qweight_config_path=qweight_config_path, gptq_config_path=gptq_config_path)
-
-with open(qweight_config_path, "r") as f:
- weight_config = json.load(f)
-with open(gptq_config_path, "r") as f:
- gptq_config = json.load(f)
-
-list_vars = new_model_bk.state_dict()
-f = open("bloom_gptq_q4.bin", "wb")
-
-# 1. write head and params
-hparams = config.to_dict()
-ftype = 0
-f.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
-
-f.write(struct.pack("i", hparams["vocab_size"]))
-f.write(struct.pack("i", hparams["hidden_size"]))
-f.write(struct.pack("i", 1))
-f.write(struct.pack("i", hparams["n_head"]))
-f.write(struct.pack("i", hparams.get("n_head_kv", 0))) # multi-query attention
-f.write(struct.pack("i", hparams["n_layer"]))
-f.write(struct.pack("i", 0))
-f.write(struct.pack("i", ftype))
-f.write(struct.pack("i", 0))
-f.write(struct.pack("f", 0))
-f.write(struct.pack("f", 0))
-f.write(struct.pack("i", 0))
-f.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
-f.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
-
-f.write(struct.pack("i", 0))
-f.write(struct.pack("i", 0))
-f.write(struct.pack("i", 0))
-f.write(struct.pack("i", 0)) # n_experts
-f.write(struct.pack("i", 0)) # n_expert_used
-f.write(struct.pack("f", 1e-6)) # rms norm eps
-f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
-f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
-f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
-
-f.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
-f.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
-f.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
-f.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
-
-# 2. vocab
-reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
-byte_encoder = bytes_to_unicode()
-byte_decoder = {v: k for k, v in byte_encoder.items()}
-
-for i in range(hparams["vocab_size"]):
- text = tokenizer.decode([i]).encode('utf-8')
- f.write(struct.pack("i", len(text)))
- f.write(text)
-
-# 3. write tensors
-for name in list_vars.keys():
- src = name
- if "query_key_value" in src:
- q_d, k_d, v_d = list_vars[src].reshape(config.n_head, 3, -1).unbind(1)
- list_vars[src] = torch.cat([q_d, k_d, v_d], dim=0).reshape_as(list_vars[src])
-
- ftype_cur = 0
- if ".weight" in name and list_vars[name].dim() == 2:
- ftype_cur = 2 # TODO(Zhenwei) support bestla
-
- data = list_vars[src].squeeze().numpy()
- data = data.astype(np.float32)
-
- n_dims = len(data.shape)
- print(name, n_dims, data.shape)
- str = name.encode('utf-8')
- f.write(struct.pack("iii", n_dims, len(str), ftype_cur))
- for i in range(n_dims):
- f.write(struct.pack("i", data.shape[n_dims - 1 - i]))
- f.write(str)
-
- if ".weight" in name and list_vars[name].dim() == 2:
- # to quantize
- k = name.replace(".weight", "")
- if k in weight_config and weight_config[k]["dtype"] != "fp32":
- print(f"bestla {k}")
- int_weight, gptq_scale, gptq_zp = extract_gptq(new_model, k, weight_config[k])
-
- tensor = int_weight.view(-1, 32) + 8
- tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
- gptq_scale = gptq_scale.view(-1, 1)
- gptq_scale = torch.cat([gptq_scale, gptq_scale, gptq_scale, gptq_scale], dim=1).view(-1, 1)
- tensor = torch.cat((gptq_scale.half().view(torch.int8), tensor), dim=-1)
- if "query_key_value" in src:
- q_d, k_d, v_d = tensor.reshape(config.n_head, 3, -1).unbind(1)
- tensor = torch.cat([q_d, k_d, v_d], dim=0).reshape_as(tensor)
- tensor.numpy().tofile(f)
-
- else:
- print(f"q4_0 {k}")
- tensor = quantize_q4_0(list_vars[name])
- tensor.numpy().tofile(f)
- else:
- # keep float32
- print(f"float {name}")
- data.tofile(f)
- # break
-f.close()
diff --git a/neural_speed/convert/convert_quantized_falcon.py b/neural_speed/convert/convert_quantized_falcon.py
new file mode 100644
index 000000000..956b0a92b
--- /dev/null
+++ b/neural_speed/convert/convert_quantized_falcon.py
@@ -0,0 +1,174 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import sys
+import re
+import argparse
+from common import *
+from sentencepiece import SentencePieceProcessor
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+def main(args_in: Optional[List[str]] = None) -> None:
+ parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
+ parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
+ parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
+ parser.add_argument("--model_hub",
+ choices=["huggingface", "modelscope"],
+ default="huggingface",
+ help="hub to load model")
+ parser.add_argument("model", type=Path, help="directory containing model file")
+ args = parser.parse_args(args_in)
+
+ out_path = args.outfile.as_posix()
+ model_path = args.model.as_posix()
+
+ model, hparams, quantize_config = load_quantized_safetensors(model_path)
+ list_vars = model
+
+ print(hparams)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ fout = open(out_path, "wb")
+
+ # possible data types
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype = 0
+ if args.outtype == "f16":
+ ftype = 1
+
+ # 1. write hparams
+ n_head_kv = hparams.get("n_head_kv", 1)
+ n_head = hparams["n_head"]
+ head_dim = hparams["hidden_size"] // n_head
+
+ fout.write(struct.pack("i", 0x67676d6c)) # magic: falcon in hex
+
+ fout.write(struct.pack("i", hparams["vocab_size"]))
+ fout.write(struct.pack("i", hparams["hidden_size"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", n_head))
+ fout.write(struct.pack("i", n_head_kv)) # multi-query attention
+ fout.write(struct.pack("i", hparams["n_layer"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", ftype))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("f", 0))
+ fout.write(struct.pack("f", 0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
+ fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
+
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # n_experts
+ fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
+ fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
+
+ fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+
+ fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
+ fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
+ fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
+
+ # 2. vocab
+ reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
+ byte_encoder = bytes_to_unicode()
+ byte_decoder = {v: k for k, v in byte_encoder.items()}
+
+ for i in range(hparams["vocab_size"]):
+ text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+
+ def convert_to_fp32_tensor(src_name, dst_name, model, fout):
+ # qwen-gptq is torch.bfloat16 mostly.
+ if model[src_name].dtype == torch.float32:
+ data = model[src_name].squeeze().numpy()
+ else:
+ data = model[src_name].squeeze().to(torch.float32).numpy()
+ data = data.astype(np.float32)
+ shape = data.shape
+ n_dims = len(shape)
+ print("convert_to_fp32_tensor: %45s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ",
+ data.dtype)
+
+ #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype]
+ # default type is fp32
+ ftype_cur = 0
+ if ftype == 1 and n_dims > 1:
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ data = data.astype(np.float32)
+
+ # header
+ # write_header(fout, shape, dst_name, ftype_cur)
+ str = src_name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str)
+
+ # data
+ data.tofile(fout)
+
+ #3. write tensors
+ convert_to_fp32_tensor("transformer.word_embeddings.weight", "transformer.word_embeddings.weight", list_vars, fout)
+ convert_to_fp32_tensor("transformer.ln_f.weight", "transformer.ln_f.weight", list_vars, fout)
+ convert_to_fp32_tensor("transformer.ln_f.bias", "transformer.ln_f.bias", list_vars, fout)
+ convert_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)
+
+ for i in range(hparams["n_layer"]):
+ prefix = "transformer.h." + str(i)
+
+ if n_head_kv == 1:
+ convert_to_fp32_tensor(f"{prefix}.input_layernorm.weight", f"{prefix}.input_layernorm.weight", list_vars,
+ fout)
+ convert_to_fp32_tensor(f"{prefix}.input_layernorm.bias", f"{prefix}.input_layernorm.bias", list_vars, fout)
+ elif n_head_kv == 8:
+ convert_to_fp32_tensor(f"{prefix}.ln_mlp.weight", f"{prefix}.ln_mlp.weight", list_vars, fout)
+ convert_to_fp32_tensor(f"{prefix}.ln_mlp.bias", f"{prefix}.ln_mlp.bias", list_vars, fout)
+ convert_to_fp32_tensor(f"{prefix}.ln_attn.weight", f"{prefix}.ln_attn.weight", list_vars, fout)
+ convert_to_fp32_tensor(f"{prefix}.ln_attn.bias", f"{prefix}.ln_attn.bias", list_vars, fout)
+
+ # qkv GEMM
+ convert_to_qx_bestla_tensor(f"{prefix}.self_attention.query_key_value.weight",
+ f"{prefix}.self_attention.query_key_value.weight", list_vars, fout, quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.self_attention.dense.weight", f"{prefix}.self_attention.dense.weight",
+ list_vars, fout, quantize_config)
+
+ # ffn GEMM
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.dense_h_to_4h", f"{prefix}.mlp.dense_h_to_4h.weight", list_vars,
+ fout, quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.dense_4h_to_h", f"{prefix}.mlp.dense_4h_to_h.weight", list_vars,
+ fout, quantize_config)
+
+ fout.close()
+ print(f"Success! saved as {out_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/neural_speed/convert/convert_quantized_phi.py b/neural_speed/convert/convert_quantized_phi.py
new file mode 100644
index 000000000..3c085cda2
--- /dev/null
+++ b/neural_speed/convert/convert_quantized_phi.py
@@ -0,0 +1,329 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import sys
+import re
+import argparse
+from common import *
+from sentencepiece import SentencePieceProcessor
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+def convert_phi1_5_gptq_to_bestTLA(model_path, out_path, outtype, model, hparams, quantize_config):
+ list_vars = model
+ for name in list_vars.keys():
+ print(name)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ fout = open(out_path, "wb")
+
+ # possible data types
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype = 0
+ if outtype == "f16":
+ ftype = 1
+
+ # 1. write hparams
+ print(hparams)
+ ne_file_magic = 0x67676d66
+ n_rot = int(hparams["partial_rotary_factor"] * hparams["hidden_size"] / hparams["num_attention_heads"])
+ # n_rot = hparams['rotary_dim']
+ fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
+ fout.write(struct.pack("i", 1))
+ fout.write(struct.pack("i", hparams["vocab_size"]))
+ fout.write(struct.pack("i", hparams["hidden_size"]))
+ fout.write(struct.pack("i", hparams["intermediate_size"])) # dummy data
+ fout.write(struct.pack("i", hparams["num_attention_heads"]))
+ fout.write(struct.pack("i", hparams["num_key_value_heads"])) # multi-query attention
+ fout.write(struct.pack("i", hparams["num_hidden_layers"]))
+ fout.write(struct.pack("i", n_rot))
+ fout.write(struct.pack("i", ftype))
+ fout.write(struct.pack("i", hparams["max_position_embeddings"]))
+ fout.write(struct.pack("f", 0.0))
+ fout.write(struct.pack("f", 0.0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
+ fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
+
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # n_experts
+ fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
+ fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
+ fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+ fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
+
+ # 2. vocab
+ for i in range(hparams["vocab_size"]):
+ if i < tokenizer.vocab_size:
+ text = tokenizer.decode([i]).encode('utf-8')
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", 0.0 - i))
+ else:
+ text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8')
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", -10000))
+
+ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
+ # qwen-gptq is torch.bfloat16 mostly.
+ if model[src_name].dtype == torch.float32:
+ data = model[src_name].squeeze().numpy()
+ else:
+ data = model[src_name].squeeze().to(torch.float32).numpy()
+ data = data.astype(np.float32)
+ shape = data.shape
+ n_dims = len(shape)
+ print("convert_qwen_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ",
+ data.dtype)
+
+ #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype]
+ # default type is fp32
+ ftype_cur = 0
+ if ftype == 1 and n_dims > 1:
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ data = data.astype(np.float32)
+
+ # header
+ # write_header(fout, shape, dst_name, ftype_cur)
+ str = dst_name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str)
+
+ # data
+ data.tofile(fout)
+
+ #3. write tensors
+ convert_qwen_to_fp32_tensor("model.embed_tokens.weight", "model.embed_tokens.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor("model.final_layernorm.weight", "model.final_layernorm.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor("model.final_layernorm.bias", "model.final_layernorm.bias", list_vars, fout)
+ convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor("lm_head.bias", "lm_head.bias", list_vars, fout)
+
+ for i in range(hparams["num_hidden_layers"]):
+ prefix = "model.layers." + str(i)
+ renamed_prefix = "model.layers." + str(i)
+
+ convert_qwen_to_fp32_tensor(f"{prefix}.input_layernorm.weight", f"{renamed_prefix}.input_layernorm.weight",
+ list_vars, fout)
+ convert_qwen_to_fp32_tensor(f"{prefix}.input_layernorm.bias", f"{renamed_prefix}.input_layernorm.bias",
+ list_vars, fout)
+
+ # qkv GEMM
+ convert_to_qx_bestla_tensor(f"{prefix}.self_attn.q_proj.weight", f"{prefix}.self_attn.q_proj.weight", list_vars,
+ fout, quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.self_attn.k_proj.weight", f"{prefix}.self_attn.k_proj.weight", list_vars,
+ fout, quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.self_attn.v_proj.weight", f"{prefix}.self_attn.v_proj.weight", list_vars,
+ fout, quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.self_attn.dense.weight", f"{prefix}.self_attn.dense.weight", list_vars,
+ fout, quantize_config)
+
+ convert_qwen_to_fp32_tensor(f"{prefix}.self_attn.q_proj.bias", f"{prefix}.self_attn.q_proj.bias", list_vars,
+ fout)
+ convert_qwen_to_fp32_tensor(f"{prefix}.self_attn.k_proj.bias", f"{prefix}.self_attn.k_proj.bias", list_vars,
+ fout)
+ convert_qwen_to_fp32_tensor(f"{prefix}.self_attn.v_proj.bias", f"{prefix}.self_attn.v_proj.bias", list_vars,
+ fout)
+ convert_qwen_to_fp32_tensor(f"{prefix}.self_attn.dense.bias", f"{prefix}.self_attn.dense.bias", list_vars, fout)
+
+ # ffn GEMM
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.fc1.weight", f"{renamed_prefix}.mlp.fc1.weight", list_vars, fout,
+ quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.fc2.weight", f"{renamed_prefix}.mlp.fc2.weight", list_vars, fout,
+ quantize_config)
+
+ convert_qwen_to_fp32_tensor(f"{prefix}.mlp.fc1.bias", f"{renamed_prefix}.mlp.fc1.bias", list_vars, fout)
+ convert_qwen_to_fp32_tensor(f"{prefix}.mlp.fc2.bias", f"{renamed_prefix}.mlp.fc2.bias", list_vars, fout)
+
+ fout.close()
+ print(f"Success! saved as {out_path}")
+
+
+def convert_phi2_gptq_to_bestTLA(model_path, model, out_path, hparams, quantize_config):
+ list_vars = model
+ for name in list_vars.keys():
+ print(name)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ fout = open(out_path, "wb")
+
+ # possible data types
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype = 0
+ if outtype == "f16":
+ ftype = 1
+
+ # 1. write hparams
+ print(hparams)
+ ne_file_magic = 0x67676d66
+ #n_rot = int(hparams["partial_rotary_factor"]*hparams["hidden_size"]/hparams["num_attention_heads"])
+ n_rot = hparams['rotary_dim']
+ fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
+ fout.write(struct.pack("i", 1))
+ fout.write(struct.pack("i", hparams["vocab_size"]))
+ fout.write(struct.pack("i", hparams["n_embd"]))
+ fout.write(struct.pack("i", hparams["n_embd"] * 4)) # dummy data
+ fout.write(struct.pack("i", hparams["n_head"]))
+ fout.write(struct.pack("i", hparams["n_head"])) # multi-query attention
+ fout.write(struct.pack("i", hparams["n_layer"]))
+ fout.write(struct.pack("i", n_rot))
+ fout.write(struct.pack("i", ftype))
+ fout.write(struct.pack("i", hparams["n_positions"]))
+ fout.write(struct.pack("f", 0.0))
+ fout.write(struct.pack("f", 0.0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
+ fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
+
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # n_experts
+ fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
+ fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
+ fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+ fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
+
+ # 2. vocab
+ for i in range(hparams["vocab_size"]):
+ if i < tokenizer.vocab_size:
+ text = tokenizer.decode([i]).encode('utf-8')
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", 0.0 - i))
+ else:
+ text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8')
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", -10000))
+
+ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
+ # qwen-gptq is torch.bfloat16 mostly.
+ if model[src_name].dtype == torch.float32:
+ data = model[src_name].squeeze().numpy()
+ else:
+ data = model[src_name].squeeze().to(torch.float32).numpy()
+ data = data.astype(np.float32)
+ shape = data.shape
+ n_dims = len(shape)
+ print("convert_qwen_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ",
+ data.dtype)
+
+ #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype]
+ # default type is fp32
+ ftype_cur = 0
+ if ftype == 1 and n_dims > 1:
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ data = data.astype(np.float32)
+
+ # header
+ # write_header(fout, shape, dst_name, ftype_cur)
+ str = dst_name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str)
+
+ # data
+ data.tofile(fout)
+
+ #3. write tensors
+ convert_qwen_to_fp32_tensor("transformer.embd.wte.weight", "model.embed_tokens.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor("lm_head.ln.weight", "model.final_layernorm.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor("lm_head.ln.bias", "model.final_layernorm.bias", list_vars, fout)
+ convert_qwen_to_fp32_tensor("lm_head.linear.weight", "lm_head.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor("lm_head.linear.bias", "lm_head.bias", list_vars, fout)
+
+ for i in range(hparams["n_layer"]):
+ prefix = "transformer.h." + str(i)
+ renamed_prefix = "model.layers." + str(i)
+
+ convert_qwen_to_fp32_tensor(f"{prefix}.ln.weight", f"{renamed_prefix}.input_layernorm.weight", list_vars, fout)
+ convert_qwen_to_fp32_tensor(f"{prefix}.ln.bias", f"{renamed_prefix}.input_layernorm.bias", list_vars, fout)
+
+ # qkv GEMM
+ convert_to_qx_bestla_tensor(f"{prefix}.mixer.Wqkv.weight", f"{renamed_prefix}.mixer.Wqkv.weight", list_vars,
+ fout, quantize_config)
+ convert_qwen_to_fp32_tensor(f"{prefix}.mixer.Wqkv.bias", f"{renamed_prefix}.mixer.Wqkv.bias", list_vars, fout)
+
+ convert_to_qx_bestla_tensor(f"{prefix}.mixer.out_proj.weight", f"{renamed_prefix}.mixer.out_proj.weight",
+ list_vars, fout, quantize_config)
+ convert_qwen_to_fp32_tensor(f"{prefix}.mixer.out_proj.bias", f"{renamed_prefix}.mixer.out_proj.bias", list_vars,
+ fout)
+
+ # ffn GEMM
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.fc1.weight", f"{renamed_prefix}.mlp.fc1.weight", list_vars, fout,
+ quantize_config)
+ convert_to_qx_bestla_tensor(f"{prefix}.mlp.fc2.weight", f"{renamed_prefix}.mlp.fc2.weight", list_vars, fout,
+ quantize_config)
+
+ convert_qwen_to_fp32_tensor(f"{prefix}.mlp.fc1.bias", f"{renamed_prefix}.mlp.fc1.bias", list_vars, fout)
+ convert_qwen_to_fp32_tensor(f"{prefix}.mlp.fc2.bias", f"{renamed_prefix}.mlp.fc2.bias", list_vars, fout)
+
+ fout.close()
+ print(f"Success! saved as {out_path}")
+
+
+def main(args_in: Optional[List[str]] = None) -> None:
+ parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
+ parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
+ parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
+ parser.add_argument("--model_hub",
+ choices=["huggingface", "modelscope"],
+ default="huggingface",
+ help="hub to load model")
+ parser.add_argument("model", type=Path, help="directory containing model file")
+ args = parser.parse_args(args_in)
+
+ out_path = args.outfile.as_posix()
+ model_path = args.model.as_posix()
+
+ model, hparams, quantize_config = load_quantized_safetensors(model_path)
+
+ if hparams['model_type'] == "phi":
+ convert_phi1_5_gptq_to_bestTLA(model_path, out_path, args.outtype, model, hparams, quantize_config)
+ elif hparams['model_type'] == "phi-msft":
+ convert_phi2_gptq_to_bestTLA(model_path, out_path, args.outtype, model, hparams, quantize_config)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/neural_speed/convert/convert_quantized_qwen.py b/neural_speed/convert/convert_quantized_qwen.py
index b57238862..02ded7622 100644
--- a/neural_speed/convert/convert_quantized_qwen.py
+++ b/neural_speed/convert/convert_quantized_qwen.py
@@ -21,115 +21,28 @@
import re
import argparse
from common import *
-
-
-def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
- # unpack weight and repack into 3bits / 4bits BestLA format
- import neural_speed.llama_cpp as cpp_model
- if ".weight" in src_name:
- src_name = src_name.replace(".weight", "")
- qzeros = model[f"{src_name}.qzeros"]
- zeros = qzeros_to_zeros(qzeros)
- scales = model[f"{src_name}.scales"]
- qweight = model[f"{src_name}.qweight"]
-
- int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
- int_weight = int_weight.view(-1, int_weight.shape[-1])
-
- # shuffle weight in GPTQ when act order is on
- if 'desc_act' in q_config and q_config['desc_act']:
- g_idx = model[f"{src_name}.g_idx"]
- int_weight2 = int_weight.clone()
- group_size = q_config['group_size']
- group_dict = {}
- for i in range(len(g_idx)):
- group_idx = g_idx[i].item()
- if group_idx not in group_dict:
- target_idx = group_idx * group_size
- group_dict[group_idx] = 0
- else:
- group_dict[group_idx] = group_dict[group_idx] + 1
- target_idx = group_idx * group_size + group_dict[group_idx]
- int_weight2[target_idx] = int_weight[i]
- int_weight = int_weight2
-
- # shape = int_weight.shape[::-1]
- shape = int_weight.shape[::-1]
- # write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)
- n_dims = len(shape)
- str = dst_name.encode('utf-8')
- fout.write(struct.pack("iii", n_dims, len(str), GGML_QJBLAS_TYPE))
- for i in range(n_dims):
- fout.write(struct.pack("i", shape[n_dims - 1 - i]))
- fout.write(str)
-
- # INC stores sig-int4 value as u4(range 0~15, they add a offset),
- # BesTLA requires s4_clip((-8,7)*16), so we sub the offset and then mul 16.
- # Int3 is the same as int4, but offset=4, mul scale==32.
- weight_dtype = "int8"
- if q_config['bits'] == 4:
- int_weight = (int_weight - 8) * 16
- gptq_scales = gptq_scales / 16
- gptq_zeros = (gptq_zeros - 8) * 16
- weight_dtype = "int4"
- elif q_config['bits'] == 3:
- int_weight = (int_weight - 4) * 32
- gptq_scales = gptq_scales / 32
- gptq_zeros = (gptq_zeros - 4) * 32
- weight_dtype = "int3"
- else:
- ValueError(f"Unsupported q_config[bits]: {q_config['bits']}")
-
- dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
- int_weight = np.ascontiguousarray(int_weight.numpy())
- gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
- if q_config['sym']:
- gptq_zeros = np.empty(0, dtype=np.int8)
- else:
- gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
- if 'desc_act' in q_config and q_config['desc_act']:
- g_idx = np.ascontiguousarray(g_idx.numpy())
- else:
- g_idx = np.empty(0, dtype=np.int32)
-
- # repack int weight in BesTLA format
- byte_size = cpp_model.Model.np_bestla_qpack(int_weight,
- gptq_scales,
- gptq_zeros,
- g_idx,
- dst,
- weight_dtype=weight_dtype,
- group_size=q_config['group_size'],
- alg="sym" if q_config['sym'] else "asym",
- compute_dtype="int8")
- dst.flatten()[:byte_size].tofile(fout)
- print(f"convert_to_qx_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}")
+from transformers import AutoModelForCausalLM, AutoTokenizer
def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
- parser.add_argument("--model_hub", choices=["huggingface","modelscope"],
- default="huggingface", help="hub to load model")
+ parser.add_argument("--model_hub",
+ choices=["huggingface", "modelscope"],
+ default="huggingface",
+ help="hub to load model")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)
out_path = args.outfile.as_posix()
model_path = args.model.as_posix()
- from transformers import AutoModelForCausalLM, AutoTokenizer
- # QWEN-GPTQ & AWQ
model, hparams, quantize_config = load_quantized_safetensors(model_path)
list_vars = model
print(hparams)
- # orinal QWEN
- # model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
- # hparams = model.config.to_dict()
- # list_vars = model.state_dict()
-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
f = open(out_path, "wb")
@@ -140,11 +53,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
ftype = 1
# 1. write hparams
- # 0x67676d6c is unversioned ne
- # 0x67676d66 is versioned ggmf (requires token scores)
ne_file_magic = 0x67676d66
- #ne_file_version = 0x00000001 # v1
-
f.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
f.write(struct.pack("i", 1))
@@ -168,9 +77,9 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("i", 0))
if hparams['model_type']=='qwen2':
- fout.write(struct.pack("i", hparams["intermediate_size"]))
+ f.write(struct.pack("i", hparams["intermediate_size"]))
else:
- fout.write(struct.pack("i", int(hparams["intermediate_size"]/2)))
+ f.write(struct.pack("i", int(hparams["intermediate_size"]/2)))
f.write(struct.pack("i", 0))
f.write(struct.pack("i", 0)) # n_experts
f.write(struct.pack("i", 0)) # n_expert_used
@@ -182,7 +91,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
- if hparams['model_type']=='qwen2':
+ if hparams['model_type'] == 'qwen2':
f.write(struct.pack("i", hparams["bos_token_id"]))
f.write(struct.pack("i", hparams["eos_token_id"]))
else: