From e6f89d196b5f648e7bba430f6a9ffb61ca33f23e Mon Sep 17 00:00:00 2001 From: WeiweiZhang1 Date: Wed, 11 Dec 2024 17:41:12 +0800 Subject: [PATCH 1/3] support_llava_hf_vlm_example (#381) * support_llava_hf_vlm_example Signed-off-by: Zhang, Weiwei1 * skip import check Signed-off-by: Zhang, Weiwei1 --------- Signed-off-by: Zhang, Weiwei1 --- auto_round/script/mllm.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 89e5b23a..634518be 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -282,18 +282,21 @@ def tune(args): # load_model processor, image_processor = None, None - if "llava" in model_name: - from llava.model.builder import load_pretrained_model # pylint: disable=E0401 + config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) + if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration": + from llava.model.builder import load_pretrained_model # pylint: disable=E0401 tokenizer, model, image_processor, _ = load_pretrained_model( model_name, model_base=None, model_name=model_name, torch_dtype=torch_dtype) model_type = "llava" else: - config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) - tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) model_type = config.model_type - if "qwen2_vl" in model_type: + if "llava" in model_type: + from transformers import LlavaForConditionalGeneration + cls = LlavaForConditionalGeneration + elif "qwen2_vl" in model_type: from transformers import Qwen2VLForConditionalGeneration cls = Qwen2VLForConditionalGeneration elif "mllama" in model_type: @@ -511,3 +514,4 @@ def lmms_eval(args): apply_chat_template=False, ) return results + From 8d8c70d51945328ec9d38e86c604db80eab890f0 Mon Sep 17 00:00:00 2001 From: WeiweiZhang1 Date: Thu, 12 Dec 2024 13:07:22 +0800 Subject: [PATCH 2/3] fix duplicated block_name_to_quantize exporting in gptq format (#382) Signed-off-by: Zhang, Weiwei1 --- README.md | 4 ++-- .../export/export_to_autogptq/export.py | 24 ++++++++----------- auto_round/mllm/eval.py | 2 ++ auto_round/quantizer.py | 3 +-- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 039e4ef6..9ca783d9 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,7 @@ steps, which competes impressively against recent methods without introducing any additional inference overhead and keeping low tuning cost. The below image presents an overview of AutoRound. Check out our paper on [arxiv](https://arxiv.org/pdf/2309.05516) for more -details and visit [low_bit_open_llm_leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard) for -more accuracy data and recipes across various models. +details and quantized huggingface space models in [OPEA](https://huggingface.co/OPEA), [Kaitchup](https://huggingface.co/kaitchup) and [fbaldassarri](https://huggingface.co/fbaldassarri).
@@ -414,3 +413,4 @@ If you find AutoRound useful for your research, please cite our paper: ``` + diff --git a/auto_round/export/export_to_autogptq/export.py b/auto_round/export/export_to_autogptq/export.py index ad2cf5dc..b0c116d9 100644 --- a/auto_round/export/export_to_autogptq/export.py +++ b/auto_round/export/export_to_autogptq/export.py @@ -121,7 +121,7 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll supported_types = kwargs["supported_types"] safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"] to_quant_block_names = kwargs["to_quant_block_names"] - quant_block_list = kwargs.get("quant_block_list", None) + quant_block_list = kwargs.get("quant_block_list", get_block_names(model)) logger.info("Saving quantized model to autogptq format, this may take a while...") tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) @@ -131,19 +131,14 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll processor.save_pretrained(output_dir) ##check module quantized in block, this may have bug for mixed precision quantization quantization_config = kwargs["serialization_dict"] - if bool(quant_block_list): - all_blocks = quant_block_list - flattened_list = [item for sublist in all_blocks for item in sublist] - common_prefix = os.path.commonprefix(flattened_list).rstrip('.') - if common_prefix not in BLOCK_PATTERNS: - logger.error(f"auto-gptq format may not support loading this quantized model") - quantization_config['block_name_to_quantize'] = common_prefix - else: - all_blocks = get_block_names(model) - flattened_list = [item for sublist in all_blocks for item in sublist] - common_prefix = os.path.commonprefix(flattened_list).rstrip('.') - if common_prefix not in BLOCK_PATTERNS: - quantization_config['block_name_to_quantize'] = common_prefix + all_blocks = quant_block_list + flattened_list = [item for sublist in all_blocks for item in sublist] + common_prefix = os.path.commonprefix(flattened_list).rstrip('.') + if common_prefix not in BLOCK_PATTERNS: + logger.error(f"auto-gptq format may not support loading this quantized model") + quantization_config['block_name_to_quantize'] = common_prefix + quantization_config.pop("to_quant_block_names", None) + all_to_quantized = True modules_in_block_to_quantize = [] @@ -222,3 +217,4 @@ def save(model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", saf json.dump(model.config.quantization_config, f, indent=2) + diff --git a/auto_round/mllm/eval.py b/auto_round/mllm/eval.py index a2f9d81c..568bbd2b 100644 --- a/auto_round/mllm/eval.py +++ b/auto_round/mllm/eval.py @@ -81,6 +81,7 @@ "llava_next": dict(cls="LLaVA_Next"), "phi3_v": dict(cls="Phi3Vision"), "mllama": dict(cls="llama_vision"), + "glm-4v-9b": dict(cls="GLM4v"), } @@ -409,3 +410,4 @@ class CliArgs: json.dump(results, open(output_file, 'w'), indent=4, default=_handle_non_serializable) return results + diff --git a/auto_round/quantizer.py b/auto_round/quantizer.py index 90d00493..296fa853 100644 --- a/auto_round/quantizer.py +++ b/auto_round/quantizer.py @@ -94,7 +94,6 @@ def _init_tuning_params_and_quant_func(self): self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0) self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0) self._init_params("value", p_dtype, weight_reshape.shape, 0, True) - # Min-max scale initialization shape = get_scale_shape(orig_weight, orig_layer.group_size) self._init_params("min_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) @@ -304,7 +303,6 @@ def forward(self, x): bias = self.orig_layer.bias if bias is not None and bias.device.type == 'meta': bias = self.orig_layer.get_bias().to(self.device) - if self.enable_norm_bias_tuning: bias, _, _ = self._qdq_bias(bias, self.bias_v) @@ -520,3 +518,4 @@ def unwrapper_block(block, best_params): best_param = None orig_layer = m.unwrapper(best_param) set_module(block, n, orig_layer) + From e88882ef78df7f6e9f6cc068ebb485794fb48cdd Mon Sep 17 00:00:00 2001 From: WeiweiZhang1 Date: Thu, 12 Dec 2024 13:45:54 +0800 Subject: [PATCH 3/3] ix incorrect device setting in autoround format inference (#383) Signed-off-by: Zhang, Weiwei1 --- auto_round/auto_quantizer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index 081274d9..e3a193e2 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -363,7 +363,14 @@ def detect_device(self, target_backend, orig_backend): if backend is None: raise ValueError("Backend not found, please set it to 'auto' to have a try ") - return BackendInfos[backend].device[0] + device = BackendInfos[backend].device[0] + if "cuda" in device and torch.cuda.is_available(): + return device + elif "hpu" in device and is_hpu_supported(): + return device + else: + return "cpu" + def convert_model(self, model: nn.Module): """Converts the given model to an AutoRound model by replacing its layers with quantized layers. @@ -392,6 +399,7 @@ def convert_model(self, model: nn.Module): quantization_config.target_backend = quantization_config.backend target_device = self.detect_device(quantization_config.target_backend, quantization_config.backend) + self.target_device = target_device if hasattr(quantization_config, "backend"): # pragma: no cover @@ -744,3 +752,4 @@ def is_serializable(self): transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer +