From f16bc7d6146204dea51a49a071f33530f6892aa1 Mon Sep 17 00:00:00 2001 From: changwangss Date: Mon, 9 Sep 2024 21:52:05 -0700 Subject: [PATCH 1/5] add convert_format_awq_to_gptq Signed-off-by: changwangss --- .../torch/algorithms/weight_only/utility.py | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index eced733ca8d..23328131d82 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Weight-Only utility.""" +import numpy as np import torch from neural_compressor.torch.utils import accelerator, device_synchronize, logger @@ -1228,3 +1229,185 @@ def convert_dtype_str2torch(str_dtype): return torch.bfloat16 else: assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype) + + +# ref reverse reorder from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L491 +def awq_reverse_reorder_int_tensor(int_tensor, bits: int): + assert bits == 4 + + int_tensor = int_tensor.T.contiguous() + compress_ratio = 32 // bits + assert int_tensor.shape[-1] % compress_ratio == 0 + + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + order_tensor = torch.tensor(order_map, dtype=torch.int32, device=int_tensor.device).reshape(1, -1) + order_tensor = order_tensor.repeat(int_tensor.shape[1] // compress_ratio, 1) + order_tensor = order_tensor + torch.arange( + 0, + int_tensor.shape[1], + compress_ratio, + dtype=torch.int32, + device=int_tensor.device, + ).reshape(-1, 1) + order_tensor = order_tensor.reshape(-1) + + reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor] + reverse_order_tensor = reverse_order_tensor[order_tensor] + int_tensor = int_tensor[:, reverse_order_tensor] + return int_tensor + + +# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516 +def unpack_awq( + awq_qweight: torch.Tensor, + awq_qzeros: torch.Tensor, + awq_scales: torch.Tensor, + bits: int, + group_size: int, +): + """ + Args: + awq_qweight (`torch.LongTensor`): + Expected shape: (in_features, out_features // (32 // bits)) + awq_qzeros (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features // (32 // bits)) + awq_scales (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features) + + Returns: + fp16_weight (`torch.LongTensor`): + With shape (in_features, out_features). + zeros (`torch.LongTensor`): + With shape (in_features // group_size, out_features). + """ + assert bits == 4 + + qzeros = awq_qzeros + qweight = awq_qweight + qweight = qweight.T.contiguous() + + infeatures = awq_qweight.shape[0] + + wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qzeros.device).unsqueeze(0) + zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0)).to( + torch.int16 if bits == 8 else torch.int8 + ) + + # zeros = zeros + 1 + + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + + zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) + + weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1), wf.unsqueeze(-1)).to( + torch.int16 if bits == 8 else torch.int8 + ) + torch.bitwise_and(weight, (2**bits) - 1, out=weight) + weight = weight.reshape(-1, group_size, weight.shape[2]) + + weight = weight.view(-1, weight.shape[-1]) + zeros = zeros.view(-1, zeros.shape[-1]) + + zeros = zeros.T.contiguous() + zeros = awq_reverse_reorder_int_tensor(zeros, bits) + weight = awq_reverse_reorder_int_tensor(weight, bits) + + # Dequantize weights. + scales = awq_scales + zeros = zeros.contiguous() + scale_zeros = zeros * scales + + g_idx = torch.tensor([i // group_size for i in range(infeatures)], dtype=torch.int32) + scale_mat = scales[g_idx] + scale_zeros_mat = scale_zeros[g_idx].half() + + qdq_weight_T = weight * scale_mat - scale_zeros_mat.half() + + fp16_weight = qdq_weight_T.T + + return fp16_weight, zeros + + +# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516 +def pack_from_tensors( + unpacked_qweight: torch.Tensor, + unpacked_qzeros: torch.Tensor, + awq_scales: torch.Tensor, + bits: int, + group_size: int, +): + """ + Args: + unpacked_qweight (`torch.LongTensor`): + Expected shape: (in_features, out_features) + unpacked_qzeros (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features) + awq_scales (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features) + + Returns: + qweight (`torch.LongTensor`): + With shape (in_features // (32 // bits), out_features) + qzeros (`torch.LongTensor`): + With shape (in_features // group_size, out_features // (32 // bits)) + """ + assert bits == 4 + W = unpacked_qweight.clone().cpu() + + # TODO: This should be checked somehow. + # if isinstance(linear, nn.Conv2d): + # W = W.flatten(1) + # if isinstance(linear, transformers.pytorch_utils.Conv1D): + # W = W.t() + + awq_scales = awq_scales.t().contiguous() + unpacked_qzeros = unpacked_qzeros.contiguous() + unpacked_qzeros = unpacked_qzeros.cpu() + + awq_scales = awq_scales.cpu() + scale_zeros = unpacked_qzeros.t() * awq_scales + scales = awq_scales.clone() + + infeatures = unpacked_qweight.shape[1] + + intweight = [] + for idx in range(infeatures): + g_idx = idx // group_size + + intweight.append(torch.round((W[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + + i = 0 + row = 0 + qweight = np.zeros((intweight.shape[0] // 32 * bits, intweight.shape[1]), dtype=np.uint32) + while row < qweight.shape[0]: + for j in range(i, i + (32 // bits)): + qweight[row] |= intweight[j] << (bits * (j - i)) + i += 32 // bits + row += 1 + + qweight = qweight.astype(np.int32) + qweight = torch.from_numpy(qweight) + + unpacked_qzeros = unpacked_qzeros - 1 + torch.bitwise_and(unpacked_qzeros, (2**bits) - 1, out=unpacked_qzeros) + + unpacked_qzeros = unpacked_qzeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (unpacked_qzeros.shape[0], unpacked_qzeros.shape[1] // 32 * bits), + dtype=np.uint32, + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + for j in range(i, i + (32 // bits)): + qzeros[:, col] |= unpacked_qzeros[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + + qzeros = qzeros.astype(np.int32) + qzeros = torch.from_numpy(qzeros) + + return qweight, qzeros From 606dc72de3870c9f550e1207736bd6f8540c2965 Mon Sep 17 00:00:00 2001 From: changwangss Date: Mon, 9 Sep 2024 22:10:24 -0700 Subject: [PATCH 2/5] add repack func Signed-off-by: changwangss --- .../torch/algorithms/weight_only/utility.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 23328131d82..7405aca960d 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -1411,3 +1411,32 @@ def pack_from_tensors( qzeros = torch.from_numpy(qzeros) return qweight, qzeros + + +def repack_awq_to_optimum_format( + awq_qweight: torch.Tensor, + awq_qzeros: torch.Tensor, + awq_scales: torch.Tensor, + bits: int, + group_size: int, +): + """ + Args: + awq_qweight (`torch.LongTensor`): + Expected shape: (in_features, out_features // (32 // bits)) + awq_qzeros (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features // (32 // bits)) + awq_scales (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features) + + Returns: + qweight (`torch.LongTensor`): + With shape (in_features // (32 // bits), out_features) + qzeros (`torch.LongTensor`): + With shape (in_features // group_size, out_features // (32 // bits)) + scales (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features) + """ + unpack_qweight, unpack_qzeros = unpack_awq(awq_qweight, awq_qzeros, awq_scales, bits, group_size) + qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales) + return qweight, qzeros, awq_scales From 8ecb856df23cd7750598d8988d8107a4ff97952b Mon Sep 17 00:00:00 2001 From: changwangss Date: Wed, 18 Sep 2024 01:29:07 -0700 Subject: [PATCH 3/5] add ut and add backend Signed-off-by: changwangss --- .../torch/algorithms/weight_only/utility.py | 2 +- .../transformers/models/modeling_auto.py | 82 ++++++++++++------- .../transformers/quantization/utils.py | 38 +++++++++ .../transformers/utils/quantization_config.py | 2 + .../weight_only/test_transfomers.py | 13 +++ 5 files changed, 106 insertions(+), 31 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 7405aca960d..6f4256534bd 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -1438,5 +1438,5 @@ def repack_awq_to_optimum_format( Expected shape: (in_features // group_size, out_features) """ unpack_qweight, unpack_qzeros = unpack_awq(awq_qweight, awq_qzeros, awq_scales, bits, group_size) - qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales) + qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales, bits, group_size) return qweight, qzeros, awq_scales diff --git a/neural_compressor/transformers/models/modeling_auto.py b/neural_compressor/transformers/models/modeling_auto.py index a4a91e27f03..657d2a9bd49 100644 --- a/neural_compressor/transformers/models/modeling_auto.py +++ b/neural_compressor/transformers/models/modeling_auto.py @@ -47,7 +47,13 @@ from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear from neural_compressor.torch.utils import set_module -from ..quantization.utils import convert_dtype_torch2str, convert_to_quantized_model, replace_linear, save_low_bit +from ..quantization.utils import ( + convert_dtype_torch2str, + convert_to_quantized_model, + repack_awq_and_load_state_dict, + replace_linear, + save_low_bit, +) from ..utils import AutoRoundConfig, AwqConfig, GPTQConfig, RtnConfig, TeqConfig @@ -179,6 +185,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) and model.config.model_type == "chatglm": model = model.float() model = convert_to_quantized_model(model, quantization_config, device=device_map) + if isinstance(quantization_config, AwqConfig): + quantization_config.backend = "inc" quantization_config.remove_redundant_parameters() model.config.quantization_config = quantization_config else: @@ -295,6 +303,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): quantization_config = GPTQConfig.from_dict(quantization_config) elif quantization_config["quant_method"] == "autoround": quantization_config = AutoRoundConfig.from_dict(quantization_config) + assert quantization_config is not None, "Detect this model is not a low-bit model." if commit_hash is None: @@ -613,41 +622,54 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): with ContextManagers(init_contexts): model = model_class(config, *model_args, **kwargs) - + if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc": + if quantization_config.modules_to_not_convert is None: + quantization_config.modules_to_not_convert = ["lm_head", "transformer.output_layer", "embed_out"] + else: + quantization_config.modules_to_not_convert += ["lm_head", "transformer.output_layer", "embed_out"] model = build_woq_model(model, quantization_config) - if is_sharded: - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - else: - # Time to load the checkpoint - state_dict = load_state_dict(resolved_archive_file) - loaded_state_dict_keys = list(state_dict.keys()) - # restore default dtype if dtype_orig is not None: torch.set_default_dtype(dtype_orig) - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = model_class._load_pretrained_model( - model, - None, - loaded_state_dict_keys, # XXX: rename? - resolved_archive_file, - pretrained_model_name_or_path, - sharded_metadata=sharded_metadata, - _fast_init=_fast_init, - low_cpu_mem_usage=True, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - keep_in_fp32_modules=[], - ) + if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc": + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) + model = repack_awq_and_load_state_dict( + model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded + ) + else: + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=True, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) # make sure token embedding weights are still tied if needed model.tie_weights() diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index 6f209344348..e66e573e3b2 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -23,6 +23,7 @@ from neural_compressor.common.utils import LazyImport, logger from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear +from neural_compressor.torch.algorithms.weight_only.utility import repack_awq_to_optimum_format from neural_compressor.torch.quantization import ( AutoRoundConfig, AWQConfig, @@ -654,3 +655,40 @@ def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: boo token=kwargs.get("token"), ) self.quantization_config.save_pretrained(save_directory, **kwargs) + + +def repack_awq_and_load_state_dict( + model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded +): + from transformers.modeling_utils import load_state_dict + + bits = quantization_config.bits + group_size = quantization_config.group_size + + state_dict = {} + if isinstance(resolved_archive_file, str): + resolved_archive_file = [resolved_archive_file] + assert isinstance(resolved_archive_file, list), "Please check if the loading weight is shared." + for shard_file in resolved_archive_file: + assert shard_file.endswith("safetensors"), "Please check the loading weight saved format." + state_dict.update(load_state_dict(shard_file)) + assert len(state_dict.keys()) > 0, "Please check the state_dict loading." + for name, module in model.named_modules(): + if isinstance(module, INCWeightOnlyLinear): + assert name + ".qweight" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.qweight'}" + assert name + ".qzeros" in loaded_state_dict_keys, f"Please check the state_dict key {name + '.qzeros'}" + assert name + ".scales" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.scales'}" + if name + ".scales" in loaded_state_dict_keys: + awq_qweight = state_dict[name + ".qweight"] + awq_qzeros = state_dict[name + ".qzeros"] + awq_scales = state_dict[name + ".scales"] + qweight, qzeros, awq_scales = repack_awq_to_optimum_format( + awq_qweight, awq_qzeros, awq_scales, bits, group_size + ) + state_dict[name + ".qweight"] = qweight + state_dict[name + ".qzeros"] = qzeros + state_dict[name + ".scales"] = awq_scales + + model.load_state_dict(state_dict, strict=False, assign=True) + + return model diff --git a/neural_compressor/transformers/utils/quantization_config.py b/neural_compressor/transformers/utils/quantization_config.py index 925cc3ccc7a..13dff04dc4f 100644 --- a/neural_compressor/transformers/utils/quantization_config.py +++ b/neural_compressor/transformers/utils/quantization_config.py @@ -409,6 +409,7 @@ def __init__( zero_point: bool = True, absorb_layer_dict: dict = {}, quant_lm_head: bool = False, + backend: str = None, **kwargs, ): self.quant_method = QuantizationMethod.AWQ @@ -427,6 +428,7 @@ def __init__( self.seq_len = seq_len self.absorb_layer_dict = absorb_layer_dict self.quant_lm_head = quant_lm_head + self.backend = backend self.modules_to_not_convert = kwargs.get( "modules_to_not_convert", ["lm_head", "transformer.output_layer", "embed_out"] ) diff --git a/test/3x/torch/quantization/weight_only/test_transfomers.py b/test/3x/torch/quantization/weight_only/test_transfomers.py index 95a89f86f68..e9194d9a371 100644 --- a/test/3x/torch/quantization/weight_only/test_transfomers.py +++ b/test/3x/torch/quantization/weight_only/test_transfomers.py @@ -18,6 +18,9 @@ class TestTansformersLikeAPI: def setup_class(self): self.model_name_or_path = "hf-internal-testing/tiny-random-gptj" + self.autoawq_model = "casperhansen/opt-125m-awq" + self.prompt = "One day, the little girl" + self.generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4) def teardown_class(self): shutil.rmtree("nc_workspace", ignore_errors=True) @@ -111,3 +114,13 @@ def test_save_load(self): loaded_model = AutoModelForCausalLM.from_pretrained(output_dir) loaded_output = loaded_model(dummy_input)[0] assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check." + + def test_loading_autoawq_model(self): + user_model = AutoModelForCausalLM.from_pretrained(self.autoawq_model) + tokenizer = AutoTokenizer.from_pretrained(self.autoawq_model) + input_ids = tokenizer(self.prompt, return_tensors="pt")["input_ids"] + self.generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4) + gen_ids = user_model.generate(input_ids, **self.generate_kwargs) + gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True) + target_text = ["One day, the little girl in the back of my mind will ask me if I'm a"] + assert gen_text == target_text, "loading autoawq quantized model failed." From 96fc62b5d2a641e3732204d30a8d9715f2876a3f Mon Sep 17 00:00:00 2001 From: changwangss Date: Wed, 18 Sep 2024 01:39:09 -0700 Subject: [PATCH 4/5] reuse code Signed-off-by: changwangss --- .../transformers/models/modeling_auto.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/neural_compressor/transformers/models/modeling_auto.py b/neural_compressor/transformers/models/modeling_auto.py index 657d2a9bd49..55ef52a8e01 100644 --- a/neural_compressor/transformers/models/modeling_auto.py +++ b/neural_compressor/transformers/models/modeling_auto.py @@ -629,26 +629,20 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): quantization_config.modules_to_not_convert += ["lm_head", "transformer.output_layer", "embed_out"] model = build_woq_model(model, quantization_config) + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) # restore default dtype if dtype_orig is not None: torch.set_default_dtype(dtype_orig) if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc": - if is_sharded: - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - else: - state_dict = load_state_dict(resolved_archive_file) - loaded_state_dict_keys = list(state_dict.keys()) model = repack_awq_and_load_state_dict( model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded ) else: - if is_sharded: - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - else: - # Time to load the checkpoint - state_dict = load_state_dict(resolved_archive_file) - loaded_state_dict_keys = list(state_dict.keys()) ( model, missing_keys, From f9153297af478cd56c9f7658bd0d10fda759792e Mon Sep 17 00:00:00 2001 From: changwangss Date: Wed, 18 Sep 2024 02:03:20 -0700 Subject: [PATCH 5/5] fix description Signed-off-by: changwangss --- .../torch/algorithms/weight_only/utility.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 6f4256534bd..e2a0463f95a 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -1233,6 +1233,10 @@ def convert_dtype_str2torch(str_dtype): # ref reverse reorder from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L491 def awq_reverse_reorder_int_tensor(int_tensor, bits: int): + """Awq tensor convert tool. + + Reverse_reorder_int_tensor + """ assert bits == 4 int_tensor = int_tensor.T.contiguous() @@ -1265,7 +1269,8 @@ def unpack_awq( bits: int, group_size: int, ): - """ + """Unpack awq format to actual values. + Args: awq_qweight (`torch.LongTensor`): Expected shape: (in_features, out_features // (32 // bits)) @@ -1336,7 +1341,8 @@ def pack_from_tensors( bits: int, group_size: int, ): - """ + """Pack the tensor to optimum format. + Args: unpacked_qweight (`torch.LongTensor`): Expected shape: (in_features, out_features) @@ -1420,7 +1426,8 @@ def repack_awq_to_optimum_format( bits: int, group_size: int, ): - """ + """The function to repack_awq_to_optimum_format. + Args: awq_qweight (`torch.LongTensor`): Expected shape: (in_features, out_features // (32 // bits))