Skip to content

Commit

Permalink
Fix layer order.
Browse files Browse the repository at this point in the history
  • Loading branch information
heiner committed May 21, 2024
1 parent 080b549 commit ee5921e
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions convert_grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_weights(fn):
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q8_0 in ggml.c
assert tensor.shape[1] % GGML_QK8_0 == 0
tensor = tensor.view(-1, GGML_QK8_0)
tensor = tensor.reshape(-1, GGML_QK8_0)
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
# add scale into each block
Expand All @@ -152,7 +152,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_1 in ggml.c
assert tensor.shape[1] % GGML_QK4_1 == 0
tensor = tensor.view(-1, GGML_QK4_1)
tensor = tensor.reshape(-1, GGML_QK4_1)
abs_max_indices = tensor.max(dim=-1, keepdim=True).indices
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
abs_min_indices = tensor.min(dim=-1, keepdim=True).indices
Expand Down Expand Up @@ -185,15 +185,13 @@ def maybe_quantize_tensor(tensor, ggml_type):
raise NotImplementedError(f"Cannot quantize tensor of dtype {tensor.dtype} ({ggml_type})")


def get_dtype_and_ggml_type(tensor, ggml_type):
if tensor.ndim in (2, 3):
def get_dtype_and_ggml_type(name, tensor, ggml_type):
if tensor.ndim in (2, 3) and "ffn_gate_inp" not in name:
if tensor.shape[1] % GGML_QK8_0 == 0:
return np.int8, ggml_type
else:
return np.float16, gguf.GGMLQuantizationType.F16
else:
# 1d weight: convert it to float32
assert tensor.ndim == 1, tensor
return np.float32, gguf.GGMLQuantizationType.F32


Expand All @@ -205,7 +203,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
for idx, name in enumerate(weight_names):
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
f.add_tensor_info(
f"{name}.weight",
Expand All @@ -227,7 +225,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
for name in weight_names:
weight, scales = weights.pop(name)
tensor = convert_weight(name, weight, scales, config)
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
_, tensor_ggml_type = get_dtype_and_ggml_type(name, tensor, ggml_type)
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()

logging.info(
Expand Down Expand Up @@ -317,7 +315,10 @@ def get_weight_names(num_hidden_layers=64):
gguf.MODEL_TENSOR.FFN_GATE_INP,
)

for bid in range(num_hidden_layers):
layers = [str(bid) for bid in range(64)]
layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...

for bid in layers[:num_hidden_layers]:
for key in layer:
weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid))

Expand All @@ -333,7 +334,6 @@ def ffn_size(emb_size, widening_factor):
return _ffn_size

config = {
"vocab_size": 128 * 1024,
"hidden_act": "gelu",
"pad_token_id": 0,
"eos_token_id": 2,
Expand Down Expand Up @@ -366,8 +366,7 @@ def ffn_size(emb_size, widening_factor):

f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)

f.add_name("grok")
f.add_vocab_size(config.vocab_size)
f.add_name("grok-1")
f.add_context_length(config.max_position_embeddings)
f.add_embedding_length(config.hidden_size)
f.add_block_count(config.num_hidden_layers)
Expand All @@ -389,6 +388,8 @@ def ffn_size(emb_size, widening_factor):
f.add_token_scores(scores)
f.add_token_types(toktypes)

f.add_quantization_version(ggml_type)

dump_state_dict(f, ggml_type, args.input_dir, config)
f.close()

Expand Down

0 comments on commit ee5921e

Please sign in to comment.