diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index 559c30ad..081274d9 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -42,7 +42,8 @@ from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod -from auto_round.utils import get_module, set_module, is_hpu_supported +from auto_round.utils import (get_module, set_module, is_hpu_supported, get_block_names, + get_multimodal_block_names, find_matching_blocks) from auto_round.backend import get_layer_backend, dynamic_import_inference_linear @@ -409,7 +410,15 @@ def convert_model(self, model: nn.Module): sym = quantization_config.sym to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config, "to_quant_block_names") else None - layer_names = get_layer_names_in_block(model, to_quant_block_names=to_quant_block_names) + quant_block_list = quantization_config.quant_block_list if hasattr(quantization_config, + "quant_block_list") else None + if to_quant_block_names is None: # TODO check compatibility + all_blocks = get_block_names(model) + else: + all_blocks = get_multimodal_block_names(model, quant_vision=True) + if quant_block_list is None: + quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names) + layer_names = get_layer_names_in_block(model, quant_block_list=quant_block_list) extra_config = {} if hasattr(quantization_config, "extra_config"): @@ -734,3 +743,4 @@ def is_serializable(self): transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer + diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 4f342797..1371dc20 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -198,11 +198,12 @@ def __init__( self.device = detect_device(device) self.scale_dtype = convert_dtype_str2torch(scale_dtype) self.set_amp_dtype() - - self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device - if not hasattr(self, 'to_quant_block_names'): + self.to_quant_block_names = to_quant_block_names + if not hasattr(self, 'quant_block_list'): all_blocks = get_block_names(model) - self.to_quant_block_names = find_matching_blocks(model, all_blocks, to_quant_block_names) + self.quant_block_list = find_matching_blocks(model, all_blocks, self.to_quant_block_names) + self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device + ##activation self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size @@ -281,8 +282,8 @@ def quantize(self): The quantized model and layer configurations. """ - if bool(self.to_quant_block_names): - all_blocks = self.to_quant_block_names + if bool(self.quant_block_list): + all_blocks = self.quant_block_list else: all_blocks = get_block_names(self.model) @@ -434,7 +435,7 @@ def set_layerwise_config(self, layer_config): Returns: None """ - layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.to_quant_block_names) + layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list) keys = ["data_type", "bits", "group_size", "sym", "scale_dtype", "act_bits", "act_group_size", "act_sym", "act_dynamic", "act_data_type"] for n, m in self.model.named_modules(): @@ -1333,6 +1334,7 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k serialization_dict=serialization_dict, backend=backend, to_quant_block_names=self.to_quant_block_names, + quant_block_list=self.quant_block_list, **kwargs ) return compressed_model @@ -1347,7 +1349,7 @@ def get_quantized_layer_names_outside_blocks(self): return [] layer_names = [] - all_layers_in_block = get_layer_names_in_block(self.model, self.supported_types, self.to_quant_block_names) + all_layers_in_block = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list) for key in self.layer_config.keys(): if key in all_layers_in_block: @@ -1735,3 +1737,4 @@ def __init__( optimizer=optimizer, **kwargs, ) + diff --git a/auto_round/export/export_to_autogptq/export.py b/auto_round/export/export_to_autogptq/export.py index bc7bfd71..05034576 100644 --- a/auto_round/export/export_to_autogptq/export.py +++ b/auto_round/export/export_to_autogptq/export.py @@ -121,6 +121,7 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll supported_types = kwargs["supported_types"] safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"] to_quant_block_names = kwargs["to_quant_block_names"] + quant_block_list = kwargs.get("quant_block_list", None) logger.info("Saving quantized model to autogptq format, this may take a while...") tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) @@ -130,8 +131,8 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll processor.save_pretrained(output_dir) ##check module quantized in block, this may have bug for mixed precision quantization quantization_config = kwargs["serialization_dict"] - if bool(to_quant_block_names): - all_blocks = to_quant_block_names + if bool(quant_block_list): + all_blocks = quant_block_list flattened_list = [item for sublist in all_blocks for item in sublist] common_prefix = os.path.commonprefix(flattened_list).rstrip('.') if common_prefix not in BLOCK_PATTERNS: @@ -220,3 +221,4 @@ def save(model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", saf with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f: json.dump(model.config.quantization_config, f, indent=2) + diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 1a4e33dd..f90fb270 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -189,10 +189,11 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex model = kwargs["model"] to_quant_block_names = kwargs["to_quant_block_names"] + quant_block_list = kwargs.get("quant_block_list", None) safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"] if not inplace: model = copy.deepcopy(model.to("cpu")) - layer_names_in_block = get_layer_names_in_block(model, to_quant_block_names=to_quant_block_names) + layer_names_in_block = get_layer_names_in_block(model, quant_block_list=quant_block_list) layer_config = kwargs["layer_config"] quantization_config = kwargs["serialization_dict"] @@ -279,3 +280,4 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri json.dump(model.config.quantization_config, f, indent=2) + diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index 61f89338..fa07bd77 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -22,7 +22,8 @@ to_device, to_dtype, get_multimodal_block_names, - find_matching_blocks + find_matching_blocks, + extract_block_names_to_str ) from ..autoround import AutoRound from .template import get_template, Template @@ -143,7 +144,10 @@ def __init__( **kwargs, ): all_blocks = get_multimodal_block_names(model, quant_nontext_module) - self.to_quant_block_names = find_matching_blocks(model, all_blocks, to_quant_block_names) + self.quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names) + if to_quant_block_names is None: + to_quant_block_names = extract_block_names_to_str(self.quant_block_list) + self.to_quant_block_names = to_quant_block_names self.extra_data_dir = extra_data_dir self.quant_nontext_module = quant_nontext_module self.image_processor = image_processor @@ -368,3 +372,4 @@ def calib(self, nsamples, bs): m = m.to("meta") # torch.cuda.empty_cache() + diff --git a/auto_round/utils.py b/auto_round/utils.py index b5c8236d..8864d5ed 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -318,6 +318,22 @@ def validate_modules(module_names, quant_vision=False, vison_blocks_names=None): "or raise an issue at https://github.com/intel/auto-round/issues.") return +def get_common_prefix(paths): + # Split each path into components and find the common prefix + split_paths = [path.split('.') for path in paths] + common_prefix = split_paths[0] + for path in split_paths[1:]: + common_prefix = [comp for comp, other in zip(common_prefix, path) if comp == other] + return '.'.join(common_prefix) + +def extract_block_names_to_str(quant_block_list): + if not isinstance(quant_block_list, (list,tuple)): + return None + # Extract common prefix for each list + prefixes = [get_common_prefix(blocks) for blocks in quant_block_list] + # Join prefixes into a single string + return ','.join(prefixes) + def find_matching_blocks(model, all_blocks, to_quant_block_names): """ @@ -347,9 +363,9 @@ def find_matching_blocks(model, all_blocks, to_quant_block_names): matched_sublist.extend(matches) if matched_sublist: target_blocks.append(matched_sublist) - if not target_blocks: - raise ValueError("No block names matched. Please check the input for to_quant_block_name," \ - "or set to_quant_block_name to None to automatically match quantizable blocks.") + if not target_blocks: + raise ValueError("No block names matched. Please check the input for to_quant_block_name," \ + "or set to_quant_block_name to None to automatically match quantizable blocks.") return target_blocks @@ -776,7 +792,7 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): def get_layer_names_in_block(model, supported_types=[torch.nn.Linear, - transformers.modeling_utils.Conv1D], to_quant_block_names=None): + transformers.modeling_utils.Conv1D], quant_block_list=None): """Retrieves the names of layers within each block of the model. Returns: @@ -787,8 +803,8 @@ def get_layer_names_in_block(model, supported_types=[torch.nn.Linear, if isinstance(m, tuple(supported_types)): m.tmp_name = n layers_in_block = [] - if bool(to_quant_block_names): - all_blocks = to_quant_block_names + if bool(quant_block_list): + all_blocks = quant_block_list else: all_blocks = get_block_names(model) for block_names in all_blocks: @@ -1081,3 +1097,4 @@ def get_fp_layer_names(model, fp_layers): not_to_quantized_layers.append(name) return not_to_quantized_layers + diff --git a/test/test_cuda_before_release.py b/test/test_cuda_before_release.py index daddf720..4e3c5df4 100644 --- a/test/test_cuda_before_release.py +++ b/test/test_cuda_before_release.py @@ -119,3 +119,21 @@ def test_undivided_group_size_tuning(self): autoround = AutoRound(model, tokenizer, bits=4, group_size=128, nsamples=1, iters=1) autoround.quantize() + + @unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda") + def test_vision_generation(self): + quantized_model_path = "OPEA/Phi-3.5-vision-instruct-qvision-int4-sym-inc" + from auto_round import AutoRoundConfig + device = "auto" ##cpu, hpu, cuda + quantization_config = AutoRoundConfig( + backend=device + ) + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, trust_remote_code=True, + device_map=device, quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + res = tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]) + print(res) + assert ( + res == """ There is a girl who likes adventure, and she is looking for a partner to go on a treasure hunt. She has found a map that leads to a hidden treasure, but she needs a partner to help her decipher the clues and find the treasure. You""") diff --git a/test/test_generation.py b/test/test_generation.py index 22392367..6133bdfe 100644 --- a/test/test_generation.py +++ b/test/test_generation.py @@ -33,9 +33,8 @@ def tearDownClass(self): shutil.rmtree("./saved", ignore_errors=True) shutil.rmtree("runs", ignore_errors=True) + @unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda") def test_llm_generation_sym_gpu_gptq(self): - if not torch.cuda.is_available(): - return bits = 4 group_size = 32 autoround = AutoRound( @@ -67,11 +66,9 @@ def test_llm_generation_sym_gpu_gptq(self): assert ( res == """There is a girl who likes adventure, and I'm not sure if she's into it, but I'm sure she's into it.\nI'm not sure if she's into adventure, but I'm sure she's into it.\nI'm not sure if she's into adventure""") - # # # + # @unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda") # def test_llm_generation_sym_gpu_gptq_marlin(self): ##need auto_gptq >0.7.1 - # if not torch.cuda.is_available(): - # return # bits = 4 # group_size = 128 # autoround = AutoRound( @@ -103,10 +100,8 @@ def test_llm_generation_sym_gpu_gptq(self): # assert ( # res == """There is a girl who likes adventure, and I'm not sure if she's into it, but I'm sure she's into it.\nI'm not sure if she's into adventure, but I'm sure she's into it.\nI'm not sure if she's into adventure""") - + @unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda") def test_llm_generation_asym_gpu_awq(self): - if not torch.cuda.is_available(): - return bits = 4 group_size = 32 autoround = AutoRound( @@ -175,3 +170,4 @@ def test_llm_generation_asym_qbits(self): +