Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[Bug Fix] Fixed Qwen loading & Mistral GPTQ convert (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenzhong1 authored Feb 20, 2024
1 parent ee40f28 commit d47984c
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 10 deletions.
62 changes: 62 additions & 0 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,65 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h
weight.numpy().tofile(fout)

print(f"converting {dst_name} quantized tensor to fp32 tensor")

def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
qzeros = model[f"{src_name}.qzeros"]
zeros = qzeros_to_zeros(qzeros)
scales = model[f"{src_name}.scales"]
qweight = model[f"{src_name}.qweight"]

int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
int_weight = int_weight.view(-1,int_weight.shape[-1])

# permute_func for llama-like model
if permute_func:
int_weight = permute_func(int_weight.t(), n_head, n_head_kv).t().contiguous()
gptq_scales = permute_func(gptq_scales.t(), n_head, n_head_kv).t().contiguous()
gptq_zeros = permute_func(gptq_zeros.t(), n_head, n_head_kv).t().contiguous()

# shuffle weight in GPTQ when act order is on
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = model[f"{src_name}.g_idx"]
int_weight2 = int_weight.clone()
group_size=q_config['group_size']
group_dict = {}
for i in range(len(g_idx)):
group_idx = g_idx[i].item()
if group_idx not in group_dict:
target_idx = group_idx * group_size
group_dict[group_idx] = 0
else:
group_dict[group_idx] = group_dict[group_idx] + 1
target_idx = group_idx * group_size + group_dict[group_idx]
int_weight2[target_idx] = int_weight[i]
int_weight = int_weight2

shape = int_weight.shape
write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)

if q_config['bits'] == 4:
int_weight = (int_weight - 8) * 16
gptq_scales = gptq_scales / 16
gptq_zeros = (gptq_zeros - 8) * 16
dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
int_weight = np.ascontiguousarray(int_weight.numpy())
gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
if q_config['sym']:
gptq_zeros = np.empty(0, dtype=np.int8)
else:
gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = np.ascontiguousarray(g_idx.numpy())
else:
g_idx = np.empty(0, dtype=np.int32)

# pack int weight in bestla format
byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst,
weight_dtype="int4" if q_config['bits'] == 4 else "int8",
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")
16 changes: 8 additions & 8 deletions neural_speed/convert/convert_quantized_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def permute_func(weights, n_head: int, n_head_kv: int):
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2,
*weights.shape[1:]).swapaxes(1, 2).reshape(weights.shape))

def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
if ".weight" in src_name:
Expand Down Expand Up @@ -174,21 +174,21 @@ def main(args_in: Optional[List[str]] = None) -> None:
convert_to_fp32_tensor("lm_head.weight", "output.weight", list_vars, f)

for i in range(n_layer):
convert_q4_bestla_tensor(f"model.layers.{i}.self_attn.q_proj",
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.q_proj",
f"layers.{i}.attention.wq.weight", list_vars, f, quantize_config, n_head, n_head,
permute_func=permute_func)
convert_q4_bestla_tensor(f"model.layers.{i}.self_attn.k_proj",
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.k_proj",
f"layers.{i}.attention.wk.weight", list_vars, f, quantize_config, n_head, n_head_kv,
permute_func=permute_func)
convert_q4_bestla_tensor(f"model.layers.{i}.self_attn.v_proj",
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.v_proj",
f"layers.{i}.attention.wv.weight", list_vars, f, quantize_config, n_head)
convert_q4_bestla_tensor(f"model.layers.{i}.self_attn.o_proj",
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.o_proj",
f"layers.{i}.attention.wo.weight", list_vars, f, quantize_config, n_head)
convert_q4_bestla_tensor(f"model.layers.{i}.mlp.gate_proj",
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.gate_proj",
f"layers.{i}.feed_forward.w1.weight", list_vars, f, quantize_config, n_head)
convert_q4_bestla_tensor(f"model.layers.{i}.mlp.down_proj",
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.down_proj",
f"layers.{i}.feed_forward.w2.weight", list_vars, f, quantize_config, n_head)
convert_q4_bestla_tensor(f"model.layers.{i}.mlp.up_proj",
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.up_proj",
f"layers.{i}.feed_forward.w3.weight", list_vars, f, quantize_config, n_head)

convert_to_fp32_tensor(f"model.layers.{i}.input_layernorm.weight",
Expand Down
4 changes: 2 additions & 2 deletions neural_speed/convert/convert_quantized_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def permute_func(weights, n_head: int, n_head_kv: int):
if n_head_kv is not None and n_head != n_head_kv:
n_head //= n_head_kv
n_head = n_head_kv
return (weights.reshape(n_head_kv, 2, weights.shape[0] // n_head_kv // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
Expand Down Expand Up @@ -70,7 +70,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
n_embd // n_head, # rot (obsolete)
0, #file_type.value, # TODO
]
# import pdb; pdb.set_trace()

f.write(struct.pack("i" * len(values), *values))
f.write(struct.pack("i", 0))
f.write(struct.pack("f", 0))
Expand Down
1 change: 1 addition & 0 deletions neural_speed/models/qwen/qwen_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo
model.hparams = ml->file_loaders.at(0)->hparams;
model_file_version file_version = ml->file_loaders.at(0)->file_version;
auto& hparams = model.hparams;
n_ff = hparams.ffn_hidden_size / 2;
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
Expand Down

0 comments on commit d47984c

Please sign in to comment.