From a129213d87684131a614feb0e16c7e4161c3d0f7 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 <109137058+Zhenzhong1@users.noreply.github.com> Date: Wed, 6 Mar 2024 20:57:30 +0800 Subject: [PATCH] [GPTQ Enhence] Support GPTQ & AWQ inference for QWENv1, v1.5 and Mixtral. (#134) --- docs/gptq_and_awq.md | 9 +- docs/supported_models.md | 94 +++--- neural_speed/__init__.py | 2 +- neural_speed/convert/__init__.py | 2 +- neural_speed/convert/common.py | 110 +++++-- .../convert/convert_quantized_mistral.py | 3 - .../convert/convert_quantized_mixtral.py | 281 ++++++++++++++++ .../convert/convert_quantized_qwen.py | 306 ++++++++++++++++++ neural_speed/convert/convert_qwen.py | 26 +- neural_speed/models/llama/llama_utils.cpp | 4 +- neural_speed/models/model_utils/model_files.h | 43 ++- neural_speed/models/qwen/qwen_utils.cpp | 6 +- 12 files changed, 778 insertions(+), 108 deletions(-) create mode 100644 neural_speed/convert/convert_quantized_mixtral.py create mode 100644 neural_speed/convert/convert_quantized_qwen.py diff --git a/docs/gptq_and_awq.md b/docs/gptq_and_awq.md index f3eeecf2c..ec66f5e43 100644 --- a/docs/gptq_and_awq.md +++ b/docs/gptq_and_awq.md @@ -6,11 +6,12 @@ Neural Speed supports multiple weight-only quantization algorithms, such as GPTQ More algorithm details please check [GPTQ](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978). Validated GPTQ & AWQ models directly from the HuggingFace: -* [Llama-2-7B-Chat-GPT](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) & [Llama-2-13B-Chat-GPT](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) -* [CodeLlama-7B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ) & [CodeLlama-13B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-13B-Instruct-GPTQ) +* [Llama-2-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) & [Llama-2-13B-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-13B-Chat-GPTQ) & [Llama-2-7B-AWQ](https://huggingface.co/TheBloke/Llama-2-7B-AWQ) & [Llama-2-13B-chat-AWQ](https://huggingface.co/TheBloke/Llama-2-13B-chat-AWQ) +* [CodeLlama-7B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ) & [CodeLlama-13B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-13B-Instruct-GPTQ) & [CodeLlama-7B-AWQ](https://huggingface.co/TheBloke/CodeLlama-7B-AWQ) & [CodeLlama-13B-AWQ](https://huggingface.co/TheBloke/CodeLlama-13B-AWQ) +* [Mistral-7B-Instruct-v0.1-GPTQ](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GPTQ) & [Mistral-7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GPTQ) +* [Mixtral-8x7B-Instruct-v0.1-GPTQ](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ) & [Mixtral-8x7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ) +* [Qwen-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-GPTQ) & [Qwen-7B-Chat-AWQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-AWQ) & * [Qwen1.5-7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GPTQ-Int4) * [SOLAR-10.7B-v1.0-GPTQ](https://huggingface.co/TheBloke/SOLAR-10.7B-v1.0-GPTQ) -* [Llama-2-7B-AWQ](https://huggingface.co/TheBloke/Llama-2-7B-AWQ) & [Llama-2-13B-chat-AWQ](https://huggingface.co/TheBloke/Llama-2-13B-chat-AWQ) -* [CodeLlama-7B-AWQ](https://huggingface.co/TheBloke/CodeLlama-7B-AWQ) & [CodeLlama-13B-AWQ](https://huggingface.co/TheBloke/CodeLlama-13B-AWQ) Please check more validated GPTQ & AWQ models in the list of [supported_models](./docs/supported_models.md). diff --git a/docs/supported_models.md b/docs/supported_models.md index c4a658256..ef3f7d362 100644 --- a/docs/supported_models.md +++ b/docs/supported_models.md @@ -72,17 +72,58 @@ Neural Speed supports the following models: ✅ ✅ Latest + + + Neural-Chat-7B-v3-1, + Neural-Chat-7B-v3-2 + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + Latest + + + Mistral-7B, + Mixtral-8x7B + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + 4.36.0 or newer + + + Qwen-7B, + Qwen-14B, + Qwen1.5-7B, + Qwen1.5-0.5B + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ + Latest GPT-J-6B ✅ - - - ✅ - - - + ✅ + ✅ + ✅ + ✅ + ✅ + ✅ Latest @@ -160,19 +201,6 @@ Neural Speed supports the following models: Latest - - - Neural-Chat-7B-v3-1, - Neural-Chat-7B-v3-2 - ✅ - ✅ - ✅ - ✅ - ✅ - ✅ - ✅ - ✅ - Latest ChatGLM-6B, @@ -200,34 +228,6 @@ Neural Speed supports the following models: 4.33.1 - - Mistral-7B, - Mixtral-8x7B - ✅ - - - - ✅ - - - - 4.36.0 or newer - - - Qwen-7B, - Qwen-14B, - Qwen1.5-7B, - Qwen1.5-0.5B - ✅ - - - - ✅ - - - - Latest - phi-2, phi-1_5 diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index 7bb39ce16..d0ab95c21 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -66,7 +66,7 @@ def __import_package(self, model_type): import neural_speed.qwen_cpp as cpp_model elif model_type == "mistral": import neural_speed.mistral_cpp as cpp_model - elif model_type == "qwen": + elif model_type == "qwen2": import neural_speed.qwen_cpp as cpp_model elif model_type == "phi": import neural_speed.phi_cpp as cpp_model diff --git a/neural_speed/convert/__init__.py b/neural_speed/convert/__init__.py index 9f063a5ec..4e2a6796d 100644 --- a/neural_speed/convert/__init__.py +++ b/neural_speed/convert/__init__.py @@ -19,7 +19,7 @@ from transformers import AutoConfig import subprocess -model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper"} +model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper", "qwen2": "qwen"} def convert_model(model, outfile, outtype="f32", whisper_repo_path=None, use_quantized_model=False): diff --git a/neural_speed/convert/common.py b/neural_speed/convert/common.py index b8891f2e8..fdefc07a4 100644 --- a/neural_speed/convert/common.py +++ b/neural_speed/convert/common.py @@ -16,13 +16,14 @@ # limitations under the License. import torch +import os from pathlib import Path import numpy as np import struct import json import warnings -from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, - Literal, Optional, Sequence, Tuple, TypeVar, Union) +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, + Union) from sentencepiece import SentencePieceProcessor # type: ignore GGML_QK8_0 = 32 @@ -35,6 +36,7 @@ GGML_QK4_1_TYPE = 3 GGML_QJBLAS_TYPE = 19 + # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py def bytes_to_unicode(): """ @@ -57,6 +59,7 @@ def bytes_to_unicode(): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) + def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: # equivalent to ggml_quantize_q4_0 in ggml.c assert tensor.shape[1] % GGML_QK4_0 == 0 @@ -71,6 +74,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1) return tensor + def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: # equivalent to ggml_quantize_q4_1 in ggml.c assert tensor.shape[1] % GGML_QK4_1 == 0 @@ -85,6 +89,7 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: tensor = torch.cat((scale.half().view(torch.int8), min_vals.half().view(torch.int8), tensor), dim=-1) return tensor + class SentencePieceVocab: def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) @@ -154,8 +159,7 @@ def load_vocab(path: Path) -> SentencePieceVocab: else: raise FileNotFoundError( f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \ - pass the directory as --vocab-dir" - ) + pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" print(f"Loading vocab file {path}") return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) @@ -192,6 +196,7 @@ def qzeros_to_zeros(qzeros, bits=4): col += 1 return zeros + def unpack_weight(qweight, scales, qzeros, q_config): if "quant_method" not in q_config: raise ValueError(f"Unsupported q_config without quant_method: {q_config}") @@ -220,17 +225,18 @@ def unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config): wf = torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32).unsqueeze(0) zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) - torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) zeros = zeros + 1 zeros = zeros.reshape(scales.shape) weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8) - torch.bitwise_and(weight,(2 ** bits) - 1, out=weight) + torch.bitwise_and(weight, (2**bits) - 1, out=weight) return weight, scales, zeros + def unpack_gptq_weight_3bits(qweight, scales, qzeros, q_config): print("unpack_gptq_weight_3bits... ", end='') group_size = q_config['group_size'] @@ -239,23 +245,23 @@ def unpack_gptq_weight_3bits(qweight, scales, qzeros, q_config): assert bits == 3 # Int32 can only store 10 * 3bits data. This is the offset for each data. - wf = torch.tensor([[ i for i in range(0, s32_bits - bits, bits)]], dtype=torch.int32) + wf = torch.tensor([[i for i in range(0, s32_bits - bits, bits)]], dtype=torch.int32) zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) - torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) zeros = zeros + 1 zeros = zeros.reshape(zeros.shape[0], -1) - zeros = zeros[:,:scales.shape[1]] + zeros = zeros[:, :scales.shape[1]] weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8) weight = weight.reshape(-1, weight.shape[-1]) input_feature = group_size * scales.shape[0] - weight = weight[:input_feature,:] + weight = weight[:input_feature, :] - torch.bitwise_and(weight,(2 ** bits) - 1, out=weight) + torch.bitwise_and(weight, (2**bits) - 1, out=weight) return weight, scales, zeros @@ -271,12 +277,13 @@ def unpack_awq_weight(qweight, scales, qzeros, q_config): for col in range(qweight.shape[1]): for i in range(pack_num): w_col = torch.bitwise_right_shift(qweight[:, col], 4 * order_map[i]) - weight[:, col * pack_num + i] = torch.bitwise_and(w_col, (2 ** bits) - 1) + weight[:, col * pack_num + i] = torch.bitwise_and(w_col, (2**bits) - 1) z_col = torch.bitwise_right_shift(qzeros[:, col], 4 * order_map[i]) - zeros[:, col * pack_num + i] = torch.bitwise_and(z_col, (2 ** bits) - 1) + zeros[:, col * pack_num + i] = torch.bitwise_and(z_col, (2**bits) - 1) return weight, scales, zeros + def write_header(fout, shape, dst_name, ftype_cur): sname = dst_name.encode('utf-8') fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur)) @@ -295,6 +302,31 @@ def find_quantized_model_file(model_path): print(f"Detected model file {found[0]}") return str(found[0]) + +def load_quantized_safetensors(model_path): + # load GPTQ & AWQ models, only for safetensors + from safetensors.torch import load_file + safetensors = [] + for file in os.listdir(model_path): + if file.endswith(".safetensors"): + safetensors.append(file) + + print(f"safetensors list = {safetensors}") + model = {} + for file in safetensors: + tmp = load_file(model_path + "/" + file) + if isinstance(tmp, dict): + model.update(tmp) + + with open(model_path + '/config.json', "r", encoding="utf-8") as f: + config = json.load(f) + + quantize_config = config["quantization_config"] + if "zero_point" in quantize_config: + quantize_config["sym"] = not quantize_config["zero_point"] + return model, config, config["quantization_config"] + + def load_quantized_model(model_path): input_path = find_quantized_model_file(model_path) model = None @@ -318,8 +350,9 @@ def load_quantized_model(model_path): def convert_to_fp32_tensor(src_name, dst_name, model, fout): v = model[src_name] shape = v.shape - # print("Processing non-Q4 variable: " + src_name + - # " with shape: ", shape, " and type: ", v.dtype) + n_dims = len(shape) + print("Processing non-Q4 variable: " + src_name + " -> " + dst_name + " with shape: ", shape, " and type: ", + v.dtype, "data: ", v[:2, :2].tolist() if n_dims > 1 else v[:2].tolist()) v = v.to(torch.float32) ftype_cur = {torch.float16: 1, torch.float32: 0}[v.dtype] @@ -329,7 +362,8 @@ def convert_to_fp32_tensor(src_name, dst_name, model, fout): # data v.numpy().tofile(fout) - print(f"converting {dst_name} float tensor") + #print(f"converting {src_name} -> {dst_name} float tensor") + def convert_q4_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2=0, permute_func=None): qzeros = model[f"{src_name}.qzeros"] @@ -338,9 +372,9 @@ def convert_q4_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2 qweight = model[f"{src_name}.qweight"] int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config) - int_weight = int_weight.view(-1,int_weight.shape[-1]).t() - gptq_scales = gptq_scales.view(-1,gptq_scales.shape[-1]).t() - gptq_zeros = gptq_zeros.view(-1,gptq_zeros.shape[-1]).t() + int_weight = int_weight.view(-1, int_weight.shape[-1]).t() + gptq_scales = gptq_scales.view(-1, gptq_scales.shape[-1]).t() + gptq_zeros = gptq_zeros.view(-1, gptq_zeros.shape[-1]).t() write_header(fout, int_weight.shape, dst_name, 2) if permute_func: @@ -350,12 +384,13 @@ def convert_q4_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2 tensor = int_weight.reshape(-1, 32) - 8 tensor = tensor[:, :16] | (tensor[:, 16:] << 4) - gptq_scale = gptq_scales.reshape(-1,1) + gptq_scale = gptq_scales.reshape(-1, 1) # gptq_scale = torch.cat([gptq_scale,gptq_scale,gptq_scale,gptq_scale], dim=1).view(-1,1) pack_tensor = torch.cat((gptq_scale.half().view(torch.int8), tensor), dim=-1) pack_tensor.numpy().tofile(fout) print(f"converting {dst_name} quantized tensor to ggml q4 block") + def convert_q4_1_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2=0, permute_func=None): qzeros = model[f"{src_name}.qzeros"] zeros = qzeros_to_zeros(qzeros) @@ -364,9 +399,9 @@ def convert_q4_1_tensor(src_name, dst_name, model, fout, q_config, n_head, n_hea qweight = model[f"{src_name}.qweight"] int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config) - int_weight = int_weight.view(-1,int_weight.shape[-1]).t() - gptq_scales = gptq_scales.view(-1,gptq_scales.shape[-1]).t() - gptq_zeros = gptq_zeros.view(-1,gptq_zeros.shape[-1]).t() + int_weight = int_weight.view(-1, int_weight.shape[-1]).t() + gptq_scales = gptq_scales.view(-1, gptq_scales.shape[-1]).t() + gptq_zeros = gptq_zeros.view(-1, gptq_zeros.shape[-1]).t() write_header(fout, int_weight.shape, dst_name, 3) if permute_func: @@ -376,9 +411,9 @@ def convert_q4_1_tensor(src_name, dst_name, model, fout, q_config, n_head, n_hea tensor = int_weight.reshape(-1, 32) tensor = tensor[:, :16] | (tensor[:, 16:] << 4) - gptq_scale = gptq_scales.reshape(-1,1) - gptq_zeros = gptq_zeros.reshape(-1,1) - gptq_zeros = -gptq_scale*gptq_zeros + gptq_scale = gptq_scales.reshape(-1, 1) + gptq_zeros = gptq_zeros.reshape(-1, 1) + gptq_zeros = -gptq_scale * gptq_zeros pack_tensor = torch.cat((gptq_scale.half().view(torch.int8), gptq_zeros.half().view(torch.int8), tensor), dim=-1) pack_tensor.numpy().tofile(fout) print(f"converting {dst_name} quantized tensor to ggml q4 1 block") @@ -413,6 +448,7 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h print(f"converting {dst_name} quantized tensor to fp32 tensor") + def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None): # unpack weight and repack into jblas format import neural_speed.llama_cpp as cpp_model @@ -422,7 +458,7 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, qweight = model[f"{src_name}.qweight"] int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config) - int_weight = int_weight.view(-1,int_weight.shape[-1]) + int_weight = int_weight.view(-1, int_weight.shape[-1]) # permute_func for llama-like model if permute_func: @@ -431,10 +467,10 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, gptq_zeros = permute_func(gptq_zeros.t(), n_head, n_head_kv).t().contiguous() # shuffle weight in GPTQ when act order is on - if 'desc_act'in q_config and q_config['desc_act']: + if 'desc_act' in q_config and q_config['desc_act']: g_idx = model[f"{src_name}.g_idx"] int_weight2 = int_weight.clone() - group_size=q_config['group_size'] + group_size = q_config['group_size'] group_dict = {} for i in range(len(g_idx)): group_idx = g_idx[i].item() @@ -461,16 +497,20 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, gptq_zeros = np.empty(0, dtype=np.int8) else: gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy()) - if 'desc_act'in q_config and q_config['desc_act']: + if 'desc_act' in q_config and q_config['desc_act']: g_idx = np.ascontiguousarray(g_idx.numpy()) else: g_idx = np.empty(0, dtype=np.int32) # pack int weight in bestla format - byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst, - weight_dtype="int4" if q_config['bits'] == 4 else "int8", - group_size=q_config['group_size'], - alg="sym" if q_config['sym'] else "asym", - compute_dtype="int8") + byte_size = cpp_model.Model.np_bestla_qpack(int_weight, + gptq_scales, + gptq_zeros, + g_idx, + dst, + weight_dtype="int4" if q_config['bits'] == 4 else "int8", + group_size=q_config['group_size'], + alg="sym" if q_config['sym'] else "asym", + compute_dtype="int8") dst.flatten()[:byte_size].tofile(fout) print(f"converting {dst_name} qauntized tensor to bestla q4 block") diff --git a/neural_speed/convert/convert_quantized_mistral.py b/neural_speed/convert/convert_quantized_mistral.py index 7e25bd22b..a0af0e13a 100644 --- a/neural_speed/convert/convert_quantized_mistral.py +++ b/neural_speed/convert/convert_quantized_mistral.py @@ -161,9 +161,6 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("f", config["rms_norm_eps"])) f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000)) f.write(struct.pack("f", rope_scale)) - f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled - f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings - f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings diff --git a/neural_speed/convert/convert_quantized_mixtral.py b/neural_speed/convert/convert_quantized_mixtral.py new file mode 100644 index 000000000..abb209f9a --- /dev/null +++ b/neural_speed/convert/convert_quantized_mixtral.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import sys +import re +import argparse +from common import * + + +def permute_func(weights, n_head: int, n_head_kv: int): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head_kv, 2, weights.shape[0] // n_head_kv // 2, + *weights.shape[1:]).swapaxes(1, 2).reshape(weights.shape)) + + +def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None): + # unpack weight and repack into jblas format + import neural_speed.llama_cpp as cpp_model + if ".weight" in src_name: + src_name = src_name.replace(".weight", "") + qzeros = model[f"{src_name}.qzeros"] + zeros = qzeros_to_zeros(qzeros) + scales = model[f"{src_name}.scales"] + qweight = model[f"{src_name}.qweight"] + + int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config) + int_weight = int_weight.view(-1, int_weight.shape[-1]) + + # shuffle weight in GPTQ when act order is on + if 'desc_act' in q_config and q_config['desc_act']: + g_idx = model[f"{src_name}.g_idx"] + int_weight2 = int_weight.clone() + group_size = q_config['group_size'] + group_dict = {} + for i in range(len(g_idx)): + group_idx = g_idx[i].item() + if group_idx not in group_dict: + target_idx = group_idx * group_size + group_dict[group_idx] = 0 + else: + group_dict[group_idx] = group_dict[group_idx] + 1 + target_idx = group_idx * group_size + group_dict[group_idx] + int_weight2[target_idx] = int_weight[i] + int_weight = int_weight2 + + # permute_func for llama-like model + if permute_func: + int_weight = permute_func(int_weight.t(), n_head, n_head_kv).t().contiguous() + gptq_scales = permute_func(gptq_scales.t(), n_head, n_head_kv).t().contiguous() + gptq_zeros = permute_func(gptq_zeros.t(), n_head, n_head_kv).t().contiguous() + + shape = int_weight.shape[::-1] + # write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE) + n_dims = len(shape) + str = dst_name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), GGML_QJBLAS_TYPE)) + for i in range(n_dims): + fout.write(struct.pack("i", shape[n_dims - 1 - i])) + fout.write(str) + + weight_dtype = "int8" + if q_config['bits'] == 4: + int_weight = (int_weight - 8) * 16 + gptq_scales = gptq_scales / 16 + gptq_zeros = (gptq_zeros - 8) * 16 + weight_dtype == "int4" + + dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8) + int_weight = np.ascontiguousarray(int_weight.numpy()) + gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy()) + if q_config['sym']: + gptq_zeros = np.empty(0, dtype=np.int8) + else: + gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy()) + if 'desc_act' in q_config and q_config['desc_act']: + g_idx = np.ascontiguousarray(g_idx.numpy()) + else: + g_idx = np.empty(0, dtype=np.int32) + + # pack int weight in bestla format + byte_size = cpp_model.Model.np_bestla_qpack(int_weight, + gptq_scales, + gptq_zeros, + g_idx, + dst, + weight_dtype=weight_dtype, + group_size=q_config['group_size'], + alg="sym" if q_config['sym'] else "asym", + compute_dtype="int8") + dst.flatten()[:byte_size].tofile(fout) + print(f"convert_to_q4_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}") + + +def main(args_in: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file") + parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file") + args = parser.parse_args(args_in) + + out_path = args.outfile.as_posix() + model_path = args.model.as_posix() + + model, config, quantize_config = load_quantized_safetensors(model_path) + f = open(out_path, "wb") + + # possible data types + # ftype == 0 -> float32 + # ftype == 1 -> float16 + ftype = 0 + if args.outtype == "f16": + ftype = 1 + + # 1. write hparams + n_vocab = config["vocab_size"] + n_embd = config["hidden_size"] + n_layer = config["num_hidden_layers"] + n_head = config["num_attention_heads"] + ffn_hidden_size = config["intermediate_size"] + + # hardcoded: + n_mult = 256 + # 1. write head and params + ne_file_magic = 0x67676d66 + f.write(struct.pack("i", ne_file_magic)) # magic: ne in hex + #f.write(b"ggjt"[::-1]) # magic + rope_scale = 1 + if "rope_scaling" in config and config["rope_scaling"] is not None: + rope_scale = config["rope_scaling"]["factor"] if "factor" in config["rope_scaling"] else 1 + + n_head = n_head + n_head_kv = 8 + values = [ + 1, # file version + n_vocab, + n_embd, + 256, #hparams.n_mult, + n_head, + n_head_kv, # n_head_kv (multi_query attention) + n_layer, + n_embd // n_head, # rot (obsolete) + 0, #file_type.value, # TODO + ] + + f.write(struct.pack("i" * len(values), *values)) + f.write(struct.pack("i", 0)) + f.write(struct.pack("f", 0)) + f.write(struct.pack("f", 0)) + f.write(struct.pack("i", 0)) + f.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt) + f.write(struct.pack("i", 0)) # do_layer_norm_before (for opt) + + f.write(struct.pack("i", 0)) + f.write(struct.pack("i", ffn_hidden_size)) + f.write(struct.pack("i", 0)) + f.write(struct.pack("i", 8)) # n_experts + f.write(struct.pack("i", 2)) # n_expert_used + + f.write(struct.pack("f", config["rms_norm_eps"])) + f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000)) + f.write(struct.pack("f", rope_scale)) + + f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + + # TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json + # but bos_token_id = 1 in llama.cpp + f.write(struct.pack("i", 1)) + f.write(struct.pack("i", 2)) + + f.write(struct.pack("i", 0)) + f.write(struct.pack("i", 0)) + + # 2. vocab + tokenizer_path = os.path.join(model_path, "tokenizer.model") + vocab = load_vocab(Path(tokenizer_path)) + for text, score in vocab.all_tokens(): + f.write(struct.pack("i", len(text))) + f.write(text) + f.write(struct.pack("f", score)) + + def convert_mixtral_to_fp32_tensor(src_name, dst_name, model, fout): + # qwen-gptq is torch.bfloat16 mostly. + if model[src_name].dtype == torch.float32: + data = model[src_name].squeeze().numpy() + else: + data = model[src_name].squeeze().to(torch.float32).numpy() + data = data.astype(np.float32) + shape = data.shape + n_dims = len(shape) + print("convert_mixtral_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ", + data.dtype) + + #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype] + # default type is fp32 + ftype_cur = 0 + if ftype == 1 and n_dims > 1: + data = data.astype(np.float16) + ftype_cur = 1 + else: + data = data.astype(np.float32) + + # header + # write_header(fout, shape, dst_name, ftype_cur) + str = dst_name.encode('utf-8') + f.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + f.write(struct.pack("i", data.shape[n_dims - 1 - i])) + f.write(str) + + # data + data.tofile(f) + + # 3. write tensors + list_vars = model + convert_mixtral_to_fp32_tensor("model.embed_tokens.weight", "tok_embeddings.weight", list_vars, f) + convert_mixtral_to_fp32_tensor("model.norm.weight", "norm.weight", list_vars, f) + convert_mixtral_to_fp32_tensor("lm_head.weight", "output.weight", list_vars, f) + + for i in range(n_layer): + convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.q_proj", + f"layers.{i}.attention.wq.weight", + list_vars, + f, + quantize_config, + n_head, + n_head, + permute_func=permute_func) + convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.k_proj", + f"layers.{i}.attention.wk.weight", + list_vars, + f, + quantize_config, + n_head, + n_head_kv, + permute_func=permute_func) + convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight", list_vars, + f, quantize_config, n_head) + convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight", list_vars, + f, quantize_config, n_head) + + convert_mixtral_to_fp32_tensor(f"model.layers.{i}.block_sparse_moe.gate.weight", + f"layers.{i}.ffn_gate_inp.weight", list_vars, f) + + for j in range(8): + convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w1", + f"layers.{i}.ffn_gate.{j}.weight", list_vars, f, quantize_config, n_head) + convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w2", + f"layers.{i}.ffn_down.{j}.weight", list_vars, f, quantize_config, n_head) + convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w3", + f"layers.{i}.ffn_up.{j}.weight", list_vars, f, quantize_config, n_head) + + convert_mixtral_to_fp32_tensor(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight", + list_vars, f) + convert_mixtral_to_fp32_tensor(f"model.layers.{i}.post_attention_layernorm.weight", + f"layers.{i}.ffn_norm.weight", list_vars, f) + + f.close() + print(f"Success! saved as {out_path}") + + +if __name__ == '__main__': + main() diff --git a/neural_speed/convert/convert_quantized_qwen.py b/neural_speed/convert/convert_quantized_qwen.py new file mode 100644 index 000000000..fc0b87ed2 --- /dev/null +++ b/neural_speed/convert/convert_quantized_qwen.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import sys +import re +import argparse +from common import * + + +def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config): + # unpack weight and repack into 3bits / 4bits BestLA format + import neural_speed.llama_cpp as cpp_model + if ".weight" in src_name: + src_name = src_name.replace(".weight", "") + qzeros = model[f"{src_name}.qzeros"] + zeros = qzeros_to_zeros(qzeros) + scales = model[f"{src_name}.scales"] + qweight = model[f"{src_name}.qweight"] + + int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config) + int_weight = int_weight.view(-1, int_weight.shape[-1]) + + # shuffle weight in GPTQ when act order is on + if 'desc_act' in q_config and q_config['desc_act']: + g_idx = model[f"{src_name}.g_idx"] + int_weight2 = int_weight.clone() + group_size = q_config['group_size'] + group_dict = {} + for i in range(len(g_idx)): + group_idx = g_idx[i].item() + if group_idx not in group_dict: + target_idx = group_idx * group_size + group_dict[group_idx] = 0 + else: + group_dict[group_idx] = group_dict[group_idx] + 1 + target_idx = group_idx * group_size + group_dict[group_idx] + int_weight2[target_idx] = int_weight[i] + int_weight = int_weight2 + + # shape = int_weight.shape[::-1] + shape = int_weight.shape[::-1] + # write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE) + n_dims = len(shape) + str = dst_name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), GGML_QJBLAS_TYPE)) + for i in range(n_dims): + fout.write(struct.pack("i", shape[n_dims - 1 - i])) + fout.write(str) + + # INC stores sig-int4 value as u4(range 0~15, they add a offset), + # BesTLA requires s4_clip((-8,7)*16), so we sub the offset and then mul 16. + # Int3 is the same as int4, but offset=4, mul scale==32. + weight_dtype = "int8" + if q_config['bits'] == 4: + int_weight = (int_weight - 8) * 16 + gptq_scales = gptq_scales / 16 + gptq_zeros = (gptq_zeros - 8) * 16 + weight_dtype = "int4" + elif q_config['bits'] == 3: + int_weight = (int_weight - 4) * 32 + gptq_scales = gptq_scales / 32 + gptq_zeros = (gptq_zeros - 4) * 32 + weight_dtype = "int3" + else: + ValueError(f"Unsupported q_config[bits]: {q_config['bits']}") + + dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8) + int_weight = np.ascontiguousarray(int_weight.numpy()) + gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy()) + if q_config['sym']: + gptq_zeros = np.empty(0, dtype=np.int8) + else: + gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy()) + if 'desc_act' in q_config and q_config['desc_act']: + g_idx = np.ascontiguousarray(g_idx.numpy()) + else: + g_idx = np.empty(0, dtype=np.int32) + + # repack int weight in BesTLA format + byte_size = cpp_model.Model.np_bestla_qpack(int_weight, + gptq_scales, + gptq_zeros, + g_idx, + dst, + weight_dtype=weight_dtype, + group_size=q_config['group_size'], + alg="sym" if q_config['sym'] else "asym", + compute_dtype="int8") + dst.flatten()[:byte_size].tofile(fout) + print(f"convert_to_qx_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}") + + +def main(args_in: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file") + parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file") + args = parser.parse_args(args_in) + + out_path = args.outfile.as_posix() + model_path = args.model.as_posix() + + from transformers import AutoModelForCausalLM, AutoTokenizer + # QWEN-GPTQ & AWQ + model, hparams, quantize_config = load_quantized_safetensors(model_path) + list_vars = model + + print(hparams) + + # orinal QWEN + # model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) + # hparams = model.config.to_dict() + # list_vars = model.state_dict() + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + f = open(out_path, "wb") + + # possible data types + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype = 0 + if args.outtype == "f16": + ftype = 1 + + # 1. write hparams + # 0x67676d6c is unversioned ne + # 0x67676d66 is versioned ggmf (requires token scores) + ne_file_magic = 0x67676d66 + #ne_file_version = 0x00000001 # v1 + + f.write(struct.pack("i", ne_file_magic)) # magic: ne in hex + f.write(struct.pack("i", 1)) + + f.write(struct.pack("i", hparams["vocab_size"])) + f.write(struct.pack("i", hparams["hidden_size"])) + f.write(struct.pack("i", hparams["intermediate_size"])) # dummy data + f.write(struct.pack("i", hparams["num_attention_heads"])) + f.write(struct.pack("i", 0)) # multi-query attention + f.write(struct.pack("i", hparams["num_hidden_layers"])) + f.write( + struct.pack( + "i", hparams["kv_channels"] if "kv_channels" in hparams else int(hparams["hidden_size"] / + hparams["num_attention_heads"]))) + f.write(struct.pack("i", ftype)) + f.write(struct.pack("i", hparams["seq_length"] if "seq_length" in hparams else hparams["max_position_embeddings"])) + f.write(struct.pack("f", 0.0)) + f.write(struct.pack("f", 0.0)) + f.write(struct.pack("i", 0)) + f.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt) + f.write(struct.pack("i", 0)) # do_layer_norm_before (for opt) + + f.write(struct.pack("i", 0)) + f.write(struct.pack("i", hparams["intermediate_size"])) + f.write(struct.pack("i", 0)) + f.write(struct.pack("i", 0)) # n_experts + f.write(struct.pack("i", 0)) # n_expert_used + f.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps + f.write(struct.pack("f", 10000.0)) # freq_base + f.write(struct.pack("f", 1.0)) # rope_factor + + f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + + f.write( + struct.pack( + "i", hparams["bos_token_id"] if "bos_token_id" in hparams else tokenizer.special_tokens['<|endoftext|>'])) + f.write( + struct.pack( + "i", hparams["eos_token_id"] if "eos_token_id" in hparams else tokenizer.special_tokens['<|endoftext|>'])) + f.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) + f.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1)) + + # 2. vocab + for i in range(hparams["vocab_size"]): + if i < tokenizer.vocab_size: + text = tokenizer.decode([i]).encode('utf-8') + f.write(struct.pack("i", len(text))) + f.write(text) + f.write(struct.pack("f", 0.0 - i)) + else: + text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8') + f.write(struct.pack("i", len(text))) + f.write(text) + f.write(struct.pack("f", -10000)) + + def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout): + # qwen-gptq is torch.bfloat16 mostly. + if model[src_name].dtype == torch.float32: + data = model[src_name].squeeze().numpy() + else: + data = model[src_name].squeeze().to(torch.float32).numpy() + data = data.astype(np.float32) + shape = data.shape + n_dims = len(shape) + print("convert_qwen_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ", + data.dtype) + + #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype] + # default type is fp32 + ftype_cur = 0 + if ftype == 1 and n_dims > 1: + data = data.astype(np.float16) + ftype_cur = 1 + else: + data = data.astype(np.float32) + + # header + # write_header(fout, shape, dst_name, ftype_cur) + str = src_name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str) + + # data + data.tofile(f) + + #3. write tensors + if hparams['model_type'] == 'qwen': + convert_qwen_to_fp32_tensor("transformer.wte.weight", "transformer.wte.weight", list_vars, f) + convert_qwen_to_fp32_tensor("transformer.ln_f.weight", "transformer.ln_f.weight", list_vars, f) + convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, f) + + for i in range(hparams["num_hidden_layers"]): + convert_qwen_to_fp32_tensor(f"transformer.h.{i}.ln_1.weight", f"transformer.h.{i}.ln_1.weight", list_vars, + f) + convert_qwen_to_fp32_tensor(f"transformer.h.{i}.ln_2.weight", f"transformer.h.{i}.ln_2.weight", list_vars, + f) + + # qkv GEMM + convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.c_attn.weight", + f"transformer.h.{i}.attn.c_attn.weight", list_vars, f, quantize_config) + convert_qwen_to_fp32_tensor(f"transformer.h.{i}.attn.c_attn.bias", f"transformer.h.{i}.attn.c_attn.bias", + list_vars, f) + convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.c_proj.weight", + f"transformer.h.{i}.attn.c_proj.weight", list_vars, f, quantize_config) + + # ffn GEMM + convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.w1.weight", f"transformer.h.{i}.mlp.w1.weight", + list_vars, f, quantize_config) + convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.w2.weight", f"transformer.h.{i}.mlp.w2.weight", + list_vars, f, quantize_config) + convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.c_proj.weight", f"transformer.h.{i}.mlp.c_proj.weight", + list_vars, f, quantize_config) + + f.close() + print(f"Success! saved as {out_path}") + elif hparams['model_type'] == 'qwen2': + # 3. write tensors + convert_qwen_to_fp32_tensor("model.embed_tokens.weight", "model.embed_tokens.weight", list_vars, f) + convert_qwen_to_fp32_tensor("model.norm.weight", "model.norm.weight", list_vars, f) + convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, f) + + for i in range(hparams["num_hidden_layers"]): + convert_qwen_to_fp32_tensor(f"model.layers.{i}.input_layernorm.weight", + f"model.layers.{i}.input_layernorm.weight", list_vars, f) + convert_qwen_to_fp32_tensor(f"model.layers.{i}.post_attention_layernorm.weight", + f"model.layers.{i}.post_attention_layernorm.weight", list_vars, f) + + # qkv GEMM + convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.q_proj.weight", + f"model.layers.{i}.self_attn.q_proj.weight", list_vars, f, quantize_config) + convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.k_proj.weight", + f"model.layers.{i}.self_attn.k_proj.weight", list_vars, f, quantize_config) + convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.v_proj.weight", + f"model.layers.{i}.self_attn.v_proj.weight", list_vars, f, quantize_config) + convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.o_proj.weight", + f"model.layers.{i}.self_attn.o_proj.weight", list_vars, f, quantize_config) + + convert_qwen_to_fp32_tensor(f"model.layers.{i}.self_attn.q_proj.bias", + f"model.layers.{i}.self_attn.q_proj.bias", list_vars, f) + convert_qwen_to_fp32_tensor(f"model.layers.{i}.self_attn.k_proj.bias", + f"model.layers.{i}.self_attn.k_proj.bias", list_vars, f) + convert_qwen_to_fp32_tensor(f"model.layers.{i}.self_attn.v_proj.bias", + f"model.layers.{i}.self_attn.v_proj.bias", list_vars, f) + + # ffn GEMM + convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.down_proj.weight", + f"model.layers.{i}.mlp.down_proj.weight", list_vars, f, quantize_config) + convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.gate_proj.weight", + f"model.layers.{i}.mlp.gate_proj.weight", list_vars, f, quantize_config) + convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.up_proj.weight", f"model.layers.{i}.mlp.up_proj.weight", + list_vars, f, quantize_config) + + f.close() + print(f"Success! saved as {out_path}") + + +if __name__ == '__main__': + main() diff --git a/neural_speed/convert/convert_qwen.py b/neural_speed/convert/convert_qwen.py index 7966717b2..bcdacbdc2 100644 --- a/neural_speed/convert/convert_qwen.py +++ b/neural_speed/convert/convert_qwen.py @@ -100,11 +100,13 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", hparams["num_attention_heads"])) fout.write(struct.pack("i", 0)) # multi-query attention fout.write(struct.pack("i", hparams["num_hidden_layers"])) - fout.write(struct.pack("i", hparams["kv_channels"] if "kv_channels" in hparams - else int(hparams["hidden_size"]/hparams["num_attention_heads"]))) + fout.write( + struct.pack( + "i", hparams["kv_channels"] if "kv_channels" in hparams else int(hparams["hidden_size"] / + hparams["num_attention_heads"]))) fout.write(struct.pack("i", ftype)) - fout.write(struct.pack("i", hparams["seq_length"] if "seq_length" in hparams - else hparams["max_position_embeddings"])) + fout.write( + struct.pack("i", hparams["seq_length"] if "seq_length" in hparams else hparams["max_position_embeddings"])) fout.write(struct.pack("f", 0.0)) fout.write(struct.pack("f", 0.0)) fout.write(struct.pack("i", 0)) @@ -120,13 +122,15 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor - fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled - fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings - fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) - fout.write(struct.pack("i", hparams["bos_token_id"] if hparams["bos_token_id"] - else tokenizer.special_tokens['<|endoftext|>'])) - fout.write(struct.pack("i", hparams["eos_token_id"] if hparams["eos_token_id"] - else tokenizer.special_tokens['<|endoftext|>'])) + fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + fout.write( + struct.pack( + "i", hparams["bos_token_id"] if "bos_token_id" in hparams else tokenizer.special_tokens['<|endoftext|>'])) + fout.write( + struct.pack( + "i", hparams["eos_token_id"] if "eos_token_id" in hparams else tokenizer.special_tokens['<|endoftext|>'])) fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1)) diff --git a/neural_speed/models/llama/llama_utils.cpp b/neural_speed/models/llama/llama_utils.cpp index 2bca0673c..9757156c0 100644 --- a/neural_speed/models/llama/llama_utils.cpp +++ b/neural_speed/models/llama/llama_utils.cpp @@ -118,7 +118,7 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, const int i_gpu_start = n_layer - n_gpu_layer; model.layers.resize(n_layer); size_t vram_total = 0; - if (ml->verify_tensor("token_embd.weight")) { + if (ml->verify_tensor("token_embd.weight")) { // GGUF model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU); model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab}, @@ -168,7 +168,7 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); } } - } else { + } else { // NE Fortmat model.others[0] = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); model.others[1] = ml->get_tensor("norm.weight", {n_embd}, NE_BACKEND_CPU); model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab}, diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h index 3813a0a09..e71ee94f0 100644 --- a/neural_speed/models/model_utils/model_files.h +++ b/neural_speed/models/model_utils/model_files.h @@ -1076,6 +1076,7 @@ struct model_file_loader { } void read_ne_hparams() { + unsigned int count = 0; hparams.n_vocab = file.read_u32(); hparams.n_embd = file.read_u32(); hparams.n_mult = file.read_u32(); @@ -1083,43 +1084,79 @@ struct model_file_loader { hparams.n_head_kv = file.read_u32(); hparams.n_layer = file.read_u32(); hparams.n_rot = file.read_u32(); + printf("%-16s %d.hparams.n_vocab = %-30d\n", __func__, count++, hparams.n_vocab); + printf("%-16s %d.hparams.n_embd = %-30d\n", __func__, count++, hparams.n_embd); + printf("%-16s %d.hparams.n_mult = %-30d\n", __func__, count++, hparams.n_mult); + printf("%-16s %d.hparams.n_head = %-30d\n", __func__, count++, hparams.n_head); + printf("%-16s %d.hparams.n_head_kv = %-30d\n", __func__, count++, hparams.n_head_kv); + printf("%-16s %d.hparams.n_layer = %-30d\n", __func__, count++, hparams.n_layer); + printf("%-16s %d.hparams.n_rot = %-30d\n", __func__, count++, hparams.n_vocab); + hparams.ftype = (enum ne_ftype)file.read_u32(); hparams.max_seq_len = file.read_u32(); file.read_raw(&hparams.alibi_bias_max, sizeof(float)); file.read_raw(&hparams.clip_qkv, sizeof(float)); hparams.par_res = file.read_u32(); - hparams.word_embed_proj_dim = file.read_u32(); hparams.do_layer_norm_before = bool(file.read_u32()); + printf("%-16s %d.hparams.ftype = %-30d\n", __func__, count++, hparams.ftype); + printf("%-16s %d.hparams.max_seq_len = %-30d\n", __func__, count++, hparams.max_seq_len); + printf("%-16s %d.hparams.alibi_bias_max = %-30f\n", __func__, count++, hparams.alibi_bias_max); + printf("%-16s %d.hparams.clip_qkv = %-30f\n", __func__, count++, hparams.clip_qkv); + printf("%-16s %d.hparams.par_res = %-30d\n", __func__, count++, hparams.par_res); + printf("%-16s %d.hparams.word_embed_proj_dim = %-30d\n", __func__, count++, hparams.word_embed_proj_dim); + printf("%-16s %d.hparams.do_layer_norm_before = %-30d\n", __func__, count++, hparams.do_layer_norm_before); - // For ChatGLM-2 hparams.multi_query_group_num = file.read_u32(); hparams.ffn_hidden_size = file.read_u32(); + printf("%-16s %d.hparams.multi_query_group_num = %-30d\n", __func__, count++, hparams.multi_query_group_num); + printf("%-16s %d.hparams.ffn_hidden_size = %-30d\n", __func__, count++, hparams.ffn_hidden_size); - // For ChatGLM-2 hparams.inner_hidden_size = file.read_u32(); hparams.n_experts = file.read_u32(); hparams.n_experts_used = file.read_u32(); + printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size); + printf("%-16s %d.hparams.n_experts = %-30d\n", __func__, count++, hparams.n_experts); + printf("%-16s %d.hparams.n_experts_used = %-30d\n", __func__, count++, hparams.n_experts_used); + // rms related file.read_raw(&hparams.rms_norm_eps, sizeof(float)); file.read_raw(&hparams.freq_base, sizeof(float)); file.read_raw(&hparams.freq_scale, sizeof(float)); + printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size); + printf("%-16s %d.hparams.freq_base = %-30f\n", __func__, count++, hparams.freq_base); + printf("%-16s %d.hparams.freq_scale = %-30f\n", __func__, count++, hparams.freq_scale); file.read_raw(&hparams.rope_scaling_factor, sizeof(float)); hparams.original_max_position_embeddings = file.read_u32(); hparams.use_yarn = file.read_u32(); + printf("%-16s %d.hparams.rope_scaling_factor = %-30f\n", __func__, count++, hparams.rope_scaling_factor); + printf("%-16s %d.hparams.original_max_position_embeddings = %-30d\n", __func__, count++, + hparams.original_max_position_embeddings); + printf("%-16s %d.hparams.use_yarn = %-30d\n", __func__, count++, hparams.use_yarn); + unsigned int total = 25; + if (count != total) { + fprintf(stderr, "The number of ne_parameters is wrong.\n"); + } } void read_ne_vocab() { + unsigned int count = 0; + unsigned int ne_hparams_total = 25; file.read_raw(&vocab.bos_token_id, sizeof(model_vocab::id)); file.read_raw(&vocab.eos_token_id, sizeof(model_vocab::id)); file.read_raw(&vocab.pad_token_id, sizeof(model_vocab::id)); file.read_raw(&vocab.sep_token_id, sizeof(model_vocab::id)); + printf("%-16s %d.vocab.bos_token_id = %-30d\n", __func__, ne_hparams_total + count++, vocab.bos_token_id); + printf("%-16s %d.vocab.eos_token_id = %-30d\n", __func__, ne_hparams_total + count++, vocab.eos_token_id); + printf("%-16s %d.vocab.pad_token_id = %-30d\n", __func__, ne_hparams_total + count++, vocab.pad_token_id); + printf("%-16s %d.vocab.sep_token_id = %-30d\n", __func__, ne_hparams_total + count++, vocab.sep_token_id); vocab.id_to_token.resize(hparams.n_vocab); for (uint32_t i = 0; i < hparams.n_vocab; i++) { uint32_t len = file.read_u32(); std::string word = file.read_string(len); + // std::cout << "word = " << word << std::endl; float score = 0.0f; if (file_version >= MODEL_FILE_VERSION_GGMF_V1) { diff --git a/neural_speed/models/qwen/qwen_utils.cpp b/neural_speed/models/qwen/qwen_utils.cpp index 2ea38492e..77bba671b 100644 --- a/neural_speed/models/qwen/qwen_utils.cpp +++ b/neural_speed/models/qwen/qwen_utils.cpp @@ -58,10 +58,14 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo } fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); - fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size); + fprintf(stderr, "%s: ftype = %u\n", __func__, hparams.ftype); + fprintf(stderr, "%s: max_seq_len= %u\n", __func__, hparams.max_seq_len); + fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); n_embd = hparams.n_embd; n_vocab = hparams.n_vocab;