Skip to content

Commit

Permalink
Add support to convert from EfficientQAT/GPTQv2/exllamav2 weights to …
Browse files Browse the repository at this point in the history
…gguf
  • Loading branch information
kaleid-liner committed Jul 22, 2024
1 parent c70219e commit d230bc6
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion convert-hf-to-gguf-t-mac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,33 @@ def real_quantize_tensor(w, n_bit=8, zero_point=True, q_group_size=-1):
return w, scales, zeros


def unpack_gptqv2(qweight, scales, qzeros):
"""
Unpack GPTQv2
Return T-MAC biased uint8 weight [0, 2 ** bits), fp16 scales, biased fp16 zeros, bits, group_size
"""
assert qweight.dtype == "int32"
assert qzeros.dtype == "int32"

bits = 32 // (scales.shape[1] // qzeros.shape[1])
K = qweight.shape[0] * (32 // bits)
M = qweight.shape[1]
group_size = K // scales.shape[0]

# Unpack qweight
qweights = [(qweight >> bit_offset) & ((1 << bits) - 1) for bit_offset in range(0, 32, bits)]
w = np.stack(qweights, axis=1).reshape(K, M).T.astype("uint8")

scales = scales.T

# Unpack qzeros
zeros = [(qzeros >> bit_offset) & ((1 << bits) - 1) for bit_offset in range(0, 32, bits)]
zeros = np.stack(zeros, axis=-1).reshape(K // group_size, M).T.astype(scales.dtype)
zeros = (zeros - (2 ** (bits - 1))) * scales

return w, scales, zeros, bits, group_size


@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
Expand Down Expand Up @@ -1528,11 +1555,58 @@ def write_tensors(self):
n_kv_head = self.hparams.get("num_key_value_heads")
n_experts = self.hparams.get("num_local_experts")
experts = dict()

quant_dict = {}
# Store scales and qzeros to dict to be later preprocessed
# Save memory by not storing qweight
for name, data_torch in self.get_tensors():
if name.endswith(".scales") or name.endswith(".qzeros"):
data = data_torch.numpy()
quant_dict[name] = data

for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue

# should be converted with qweight together
if name.endswith(".scales") or name.endswith(".qzeros") or name.endswith(".g_idx"):
continue

if name.endswith(".qweight"):
qweight = data_torch.numpy()
scales = quant_dict[name.replace(".qweight", ".scales")]
qzeros = quant_dict[name.replace(".qweight", ".qzeros")]
w, scales, zeros, bits, group_size = unpack_gptqv2(qweight, scales, qzeros)
if name.endswith("q_proj.qweight"):
w = permute(w, n_head, n_head)
scales = permute(scales, n_head, n_head)
zeros = permute(zeros, n_head, n_head)
if name.endswith("k_proj.qweight"):
w = permute(w, n_head, n_kv_head)
scales = permute(scales, n_head, n_kv_head)
zeros = permute(zeros, n_head, n_kv_head)
data_shape = w.shape
new_name = tensor_map.get_name(name.replace(".qweight", ".weight"), try_suffixes=(".weight", ".bias"))

if self.ftype == LlamaFType.MOSTLY_I2:
to_dtype = gguf.GGMLQuantizationType.I2
data = preprocess_for_t_mac(w, scales, zeros, bits=bits)
assert bits == 2, "Currently we only support 2-bit quantized model. 4-bit will soon be added."
else:
to_dtype = gguf.GGMLQuantizationType.F32
w = w.astype("float32").reshape(-1, group_size)
scales = scales.astype("float32").reshape(-1, 1)
zeros = zeros.astype("float32").reshape(-1, 1)
data = (w - (zeros / scales + (2 ** (bits - 1)))) * scales
if self.ftype == LlamaFType.MOSTLY_F16:
to_dtype = gguf.GGMLQuantizationType.F16
data = data.astype("float16")

logger.info(f"{new_name}, n_dims = {data_torch.ndim}, {data_torch.dtype} --> {to_dtype.name}")
self.gguf_writer.add_tensor(new_name, data, raw_shape=data_shape, raw_dtype=to_dtype)
continue

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
Expand Down Expand Up @@ -3213,7 +3287,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--kcfg", type=str, default="", help="Path to T-MAC kcfg.ini")
parser.add_argument("--quant-type", type=str, default="bitnet", choices=["bitnet", "bitdistiller"])
parser.add_argument("--quant-type", type=str, default="bitnet", choices=["bitnet", "bitdistiller", "gptqv2"])
parser.add_argument("--group-size", type=int, default=128)

return parser.parse_args()
Expand Down

0 comments on commit d230bc6

Please sign in to comment.