diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 340d457e..461811d2 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -141,7 +141,7 @@ def __init__(self, *args, **kwargs): self.add_argument("--mllm", action='store_true', help="To determine whether use multimodel-llm mode.") - self.add_argument("--quant_vision", action='store_true', + self.add_argument("--quant_nontext_module", action='store_true', help="To determine whether the quantization should handle vision component.") self.add_argument("--extra_data_dir", default="", type=str, @@ -521,7 +521,7 @@ def tune_mllm(args): device=device_str, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps, scale_dtype=args.scale_dtype, layer_config=layer_config, enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits, - quant_vision=args.quant_vision) + quant_nontext_module=args.quant_nontext_module) model, _ = autoround.quantize() model.eval() @@ -564,3 +564,4 @@ def run_mllm(): if __name__ == '__main__': run() + diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 731974ec..2d499717 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -18,7 +18,6 @@ import copy import time from typing import Optional, Union - from transformers import set_seed from torch import autocast from tqdm import tqdm @@ -27,9 +26,11 @@ from .quantizer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer, \ WrapperTransformerConv1d from .special_model_handler import (check_hidden_state_dim, - check_share_attention_mask, - check_not_share_position_ids, - check_not_share_rotary_pos_emb) + shareable_keywords, + special_model_init, + reset_params, + skip_keywards_hint + ) from .utils import ( CpuInfo, block_forward, @@ -174,7 +175,7 @@ def __init__( self.low_cpu_mem_usage = low_cpu_mem_usage self.layer_config = {} if layer_config is None else layer_config self.seqlen = seqlen - self.train_bs = batch_size + self.batch_size, self.gradient_accumulate_steps = batch_size, gradient_accumulate_steps self.nblocks = nblocks self.dataset = dataset self.iters = iters @@ -201,12 +202,11 @@ def __init__( self.quant_block_list = quant_block_list self.sampler = sampler - self.gradient_accumulate_steps = gradient_accumulate_steps self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap self.lr_scheduler = lr_scheduler self.optimizer = self.get_optimizer(None) - self.share_attention_mask_flag = None ##TODO remove it later + self.input_dim = None self.infer_bs_coeff = 1 self.set_layerwise_config(self.layer_config) ##better place in the end @@ -230,7 +230,7 @@ def check_configs(self): assert self.group_size == -1 or self.group_size >= 1, "only supports positive group_size or -1(per channel)" assert self.act_group_size == -1 or self.act_group_size >= 1, \ "only supports positive group_size or -1(per channel)" - assert self.train_bs > 0, "batch size must be positive" + assert self.batch_size > 0, "batch size must be positive" assert self.iters > 0, "iters must be positive" assert self.seqlen > 0, "seqlen must be positive" assert self.nblocks > 0, "nblocks must be positive" @@ -280,14 +280,19 @@ def quantize(self): for block_names in all_blocks: inputs = all_inputs[block_names[0]] all_inputs.pop(block_names[0]) - + keys = inputs.keys() + input_id_str = [key for key in keys if key.startswith('hidden_state')] + if len(input_id_str) != 1: + raise RuntimeError("hidden_states arg mismatch error," \ + " please check the input kwargs of block forward for more details.") + inputs["input_ids"] = inputs.pop(input_id_str[0], None) clear_memory(self.inputs) if "input_ids" in inputs.keys(): total_samples = len(inputs["input_ids"]) self.n_samples = total_samples - if total_samples < self.train_bs: - self.train_bs = total_samples + if total_samples < self.batch_size: + self.batch_size = total_samples logger.warning(f"force the train batch size to {total_samples}") @@ -453,18 +458,12 @@ def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_de input_others, indices, self.seqlen, - self.share_attention_mask_flag, - self.not_share_position_ids_flag, - self.not_share_rotary_pos_emb_flag, self.input_dim ) tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to( cache_device ) - if self.train_bs == 1 and self.not_share_rotary_pos_emb_flag: - output.append(tmp_output) - else: - output.extend(list(torch.split(tmp_output, 1, dim=self.input_dim))) + output.extend(list(torch.split(tmp_output, 1, dim=self.input_dim))) if self.low_gpu_mem_usage: clear_memory() @@ -542,6 +541,11 @@ def calib(self, nsamples, bs): self.model(**data_new) except NotImplementedError: pass + except RuntimeError as error: + logger.warning("When quantization encounters tensor" \ + " shape mismatch error, you can try to avoid it with batch_size=1") + logger.error(error) + pass except Exception as error: raise error total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1 @@ -643,7 +647,7 @@ def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_n if last_cache_name is None and len(block_names) + len(layer_names) == 1: self.last_cache_name = block_names[0] if len(block_names) == 1 else layer_names[0] # do not set last_cache_name for multimodal models - calib_bs = self.train_bs + calib_bs = self.batch_size self.hook_handles = [] self._replace_forward() self.calib(nsamples, calib_bs) @@ -665,95 +669,92 @@ def get_block_forward_func(self, name): Returns: function: The forward function. """ + + def post_process_cache_data(batch_size, data, data_name): + """ + Processes store data for batch handling, reshaping if necessary. - def forward(m, hidden_states, *positional_args, **kwargs): + Args: + batch_size (int): The size of the batch. + data: The data value to store, potentially for caching. + data_name (str): Name of the data. + + Returns: + Processed data or None + """ + new_data = data + if batch_size <= 1: + return new_data + if data_name in shareable_keywords: + return None + if "alibi" in data_name: + if isinstance(data, torch.Tensor): + alibi = data + alibi = alibi.reshape(batch_size, -1, alibi.shape[1], alibi.shape[2]) + new_data = alibi + return new_data + + def forward(m, hidden_states=None, *positional_inputs, **kwargs): """Rewrite forward function, process and collect input data. Args: hidden_states (torch.Tensor): The hidden states tensor. - *positional_args: Variable number of positional arguments. + *positional_inputs: Variable number of positional arguments. **kwargs: Variable number of keyword arguments. Returns: NotImplementedError: Getting the first layer inputs and then raise the error to save runtime. """ - if self.share_attention_mask_flag is None: - self.input_dim = check_hidden_state_dim(self.model, positional_args) - self.share_attention_mask_flag = check_share_attention_mask(self.model, hidden_states, **kwargs) - self.not_share_position_ids_flag = check_not_share_position_ids(self.model, **kwargs) - self.not_share_rotary_pos_emb_flag = check_not_share_rotary_pos_emb(self.model, **kwargs) - if name in self.inputs: - if self.train_bs == 1 and self.not_share_rotary_pos_emb_flag: - self.inputs[name]["input_ids"].append(hidden_states.to("cpu")) - else: - self.inputs[name]["input_ids"].extend( - list(torch.split(hidden_states.to("cpu"), 1, dim=self.input_dim))) - else: + if name not in self.inputs: self.inputs[name] = {} - if self.train_bs == 1 and self.not_share_rotary_pos_emb_flag: - self.inputs[name]["input_ids"] = [hidden_states.to("cpu")] - else: - self.inputs[name]["input_ids"] = list(torch.split(hidden_states.to("cpu"), 1, dim=self.input_dim)) - - if "positional_inputs" not in self.inputs[name]: - self.inputs[name]["positional_inputs"] = [] - for idx, item in enumerate(positional_args): - self.inputs[name]["positional_inputs"] = to_device(positional_args) + special_model_init(self.model, positional_inputs, self.inputs[name]) + + if self.input_dim is None: + self.input_dim = check_hidden_state_dim(self.model, positional_inputs) + + if hidden_states is not None: + kwargs['hidden_states'] = hidden_states for key in kwargs.keys(): if isinstance(kwargs[key], torch.Tensor) or isinstance(kwargs[key], list) \ - or isinstance(kwargs[key], tuple) \ - or (key == "alibi") or (key == "attention_mask"): - if "attention_mask" in key: - if key not in self.inputs[name].keys(): - self.inputs[name][key] = None - if kwargs[key] is not None: - if (not self.share_attention_mask_flag) and self.inputs[name][key] is not None: - self.inputs[name][key].extend(list(torch.split(kwargs[key].to("cpu"), 1, dim=0))) - else: - self.inputs[name][key] = list(torch.split(kwargs[key].to("cpu"), 1, dim=0)) - elif "alibi" in key: - if key not in self.inputs[name].keys(): - self.inputs[name][key] = None - if isinstance(kwargs[key], torch.Tensor): - alibi = kwargs[key] - batch = kwargs["attention_mask"].shape[0] - alibi = alibi.reshape(batch, -1, alibi.shape[1], alibi.shape[2]) - if (not self.share_attention_mask_flag) and self.inputs[name][key] is not None: - self.inputs[name][key].extend(list(torch.split(alibi.to("cpu"), 1, dim=0))) - else: - self.inputs[name][key] = list(torch.split(alibi.to("cpu"), 1, dim=0)) - elif "position_ids" in key or 'cache_position' in key or 'position_embeddings' in key: - if self.train_bs == 1 and self.not_share_position_ids_flag: - if key not in self.inputs[name].keys(): - self.inputs[name][key] = [to_device(kwargs[key], device=torch.device("cpu"))] - else: - self.inputs[name][key].append(to_device(kwargs[key], device=torch.device("cpu"))) - elif key not in self.inputs[name].keys(): - self.inputs[name][key] = list(torch.split(kwargs[key].to("cpu"), 1, dim=0)) \ - if self.not_share_position_ids_flag \ - else to_device(kwargs[key], device=torch.device("cpu")) - elif kwargs[key] is not None and self.not_share_position_ids_flag: - self.inputs[name][key].extend(list(torch.split(kwargs[key].to("cpu"), 1, dim=0))) - elif 'rotary_pos_emb' in key or 'cu_seqlens' in key: - if key not in self.inputs[name].keys(): - self.inputs[name][key] = [to_device(kwargs[key], device=torch.device("cpu"))] \ - if self.not_share_rotary_pos_emb_flag \ - else to_device(kwargs[key], device=torch.device("cpu")) - elif kwargs[key] is not None and self.not_share_rotary_pos_emb_flag: - self.inputs[name][key].append(to_device(kwargs[key], device=torch.device("cpu"))) - elif "cross_attention_states" in key: - if key not in self.inputs[name].keys(): - self.inputs[name][key] = [to_device(kwargs[key], device=torch.device("cpu"))] + or isinstance(kwargs[key], tuple): + if key not in self.inputs[name].keys(): # initialization + data = to_device(kwargs[key], device=torch.device("cpu")) + if data is None or (self.batch_size > 1 and key in shareable_keywords): + self.inputs[name][key] = data + continue + if self.batch_size <= 1: + self.inputs[name][key] = [data] else: - self.inputs[name][key].extend(list(torch.split(kwargs[key].to("cpu"), 1, dim=0))) - elif key not in self.inputs[name].keys(): - self.inputs[name][key] = to_device(kwargs[key], device=torch.device("cpu")) - + data = post_process_cache_data(self.batch_size, data, key) + self.inputs[name][key] = list(torch.split(data, 1, dim=self.input_dim)) + else: # append cache inputs + new_data = post_process_cache_data(self.batch_size, kwargs[key], key) + if new_data is None: # shareable args or NoneType + continue + new_data = to_device(new_data, device=torch.device("cpu")) + if self.batch_size <= 1: + self.inputs[name][key].append(new_data) + else: + self.inputs[name][key].extend(list(torch.split(new_data, 1, dim=self.input_dim))) + elif isinstance(kwargs[key], (str, bool, type(None))): + if key not in self.inputs[name].keys(): + self.inputs[name][key] = kwargs[key] + else: + # Parameters not to be cached + if skip_keywards_hint(key): + logger.warning_once(f"Please note that this '{key}' key" \ + " is not currently used in quantization fine-tuning.") + reset_params(self.inputs[name]) if name == self.last_cache_name: raise NotImplementedError else: - return m.orig_forward(hidden_states, *positional_args, **kwargs) + if hidden_states is not None: + kwargs.pop('hidden_states') + return m.orig_forward(hidden_states, *positional_inputs, **kwargs) + else: + #Currently only for Llama-3.2-Vision-Instruct Series + return m.orig_forward(*positional_inputs, **kwargs) return forward @@ -848,9 +849,10 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp scaler = self.get_scaler() # pylint: disable=assignment-from-none init_loss = None # best_v, best_min_scale, best_max_scale = torch.tensor(0), torch.tensor(1.0), torch.tensor(1.0) - gradient_accumulate_steps = self.train_bs ##Force to low gpu - train_bs = 1 ##Force to low gpu - pick_samples = train_bs * gradient_accumulate_steps + gradient_accumulate_steps = self.batch_size ##Force to low gpu + batch_size = 1 ##Force to low gpu + pick_samples = batch_size * gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) if self.sampler != "rand": whole_indices = torch.randperm(nsamples)[:pick_samples] total_loss = 0 @@ -859,7 +861,7 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp if self.sampler == "rand": whole_indices = torch.randperm(nsamples)[:pick_samples] for tmp_step in range(gradient_accumulate_steps): - indices = whole_indices[tmp_step * train_bs: (tmp_step + 1) * train_bs] + indices = whole_indices[tmp_step * batch_size: (tmp_step + 1) * batch_size] if q_inputs is not None: current_input = [q_inputs[i] for i in indices] current_input = torch.cat(current_input, dim=0).to(device) @@ -925,7 +927,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch Tuple: (q_outputs, output) if self.enable_quanted_input is True, else (None, output) """ - output = self.get_block_outputs(block, input_ids, input_others, self.train_bs * self.infer_bs_coeff, device, + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device) if q_input is not None: @@ -968,8 +970,9 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch else: lr_schedule = copy.deepcopy(self.lr_scheduler) - pick_samples = self.train_bs * self.gradient_accumulate_steps nsamples = len(input_ids) + pick_samples = self.batch_size * self.gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) if self.sampler != "rand": whole_indices = torch.randperm(nsamples)[:pick_samples] last_best_iter = 0 @@ -984,15 +987,12 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch if self.sampler == "rand": whole_indices = torch.randperm(nsamples)[:pick_samples] for tmp_step in range(self.gradient_accumulate_steps): - indices = whole_indices[tmp_step * self.train_bs: (tmp_step + 1) * self.train_bs] + indices = whole_indices[tmp_step * self.batch_size: (tmp_step + 1) * self.batch_size] current_input_ids, current_input_others = sampling_inputs( input_ids, input_others, indices, seqlen=self.seqlen, - share_attention_mask_flag=self.share_attention_mask_flag, - not_share_position_ids_flag=self.not_share_position_ids_flag, - not_share_rotary_pos_emb_flag=self.not_share_rotary_pos_emb_flag, input_dim=self.input_dim, ) @@ -1050,7 +1050,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch if self.low_cpu_mem_usage: block = block.to(device) q_outputs = self.get_block_outputs( - block, input_ids, input_others, self.train_bs * self.infer_bs_coeff, device, + block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, cache_device=self.cache_device ) mv_module_from_gpu(block, self.low_cpu_mem_usage) @@ -1189,7 +1189,7 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k "enable_minmax_tuning", "data_type", "seqlen", - "train_bs", + "batch_size", "scale_dtype", "lr", "minmax_lr", @@ -1617,3 +1617,4 @@ def __init__( optimizer=optimizer, **kwargs, ) + diff --git a/auto_round/low_cpu_mem/utils.py b/auto_round/low_cpu_mem/utils.py index 6b02c2a5..6f579118 100644 --- a/auto_round/low_cpu_mem/utils.py +++ b/auto_round/low_cpu_mem/utils.py @@ -188,8 +188,7 @@ def _forward(module, name, *args, **kwargs): try: if model.device.type == 'meta': - target_device = 'meta' - + target_device = 'cpu' else: target_device = model.device input = { @@ -469,4 +468,4 @@ def layer_wise_load(path): if len(d) > 0: d = pickle.loads(d) state_dict.update(d) - return state_dict \ No newline at end of file + return state_dict diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index 45e1b57e..2aebd756 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -26,7 +26,7 @@ from .template import get_template, Template from .mllm_dataset import get_mllm_dataloader from ..low_cpu_mem.utils import get_layers_before_block - +from ..special_model_handler import check_mllm_model_batch class AutoRoundMLLM(AutoRound): """Class for automatic rounding-based quantization with MLLMs. @@ -119,6 +119,7 @@ def __init__( if self.template is None: self.template = get_template(model.config.model_type) assert dataset is not None, "dataset should not be None" + batch_size, gradient_accumulate_steps = check_mllm_model_batch(model, batch_size, gradient_accumulate_steps) if isinstance(dataset, str): dataset = get_mllm_dataloader(self.template, model, tokenizer, dataset, extra_data_dir, seqlen, batch_size) @@ -275,4 +276,4 @@ def calib(self, nsamples, bs): if self.low_cpu_mem_usage: for n, m in embed_layers: m = m.to("meta") - # torch.cuda.empty_cache() \ No newline at end of file + # torch.cuda.empty_cache() diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 722d51a9..238cf952 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -13,37 +13,48 @@ # limitations under the License. import torch +from collections import UserDict +special_states_dim_tuple = ("chatglm",) # input_dim is not the default dimension 0 +shareable_keywords = ("position_ids", "cache_position", "position_embeddings") +mllms_with_limited_bs = ("llava", "qwen2-vl", "phi3_v", "mllama") # Limitations on batch_size +skippable_cache_keys = ("past_key_value",) -share_attention_mask_tuple = ("baichuan",) -special_states_dim_tuple = ("chatglm",) -not_share_position_ids_tuple = ("llava", "phi3_v", "qwen2_vl",) -not_share_rotary_pos_emb_tuple = ("qwen2_vl",) -def check_share_attention_mask(model, hidden_states, attention_mask=None, **kwargs): - """Checks if the attention mask states of the hidden states are shared in the model. +def to_device(input, device=torch.device("cpu")): + """Moves input data to the specified device. Args: - hidden_states (torch.Tensor): The hidden states of the model. - attention_mask (torch.Tensor, optional): The attention mask tensor. Defaults to None. - **kwargs: Additional keyword arguments. + input: The input data to be moved. + device: The target device. Returns: - bool: True if attention mask is shared in the model, False otherwise. + The input data on the specified device. """ - if attention_mask is None or not isinstance(hidden_states, torch.Tensor): - return False - is_special = False - for key in share_attention_mask_tuple: - if hasattr(model, "config") and key in model.config.model_type: - is_special = True - break - return bool(is_special and attention_mask.shape[0] != hidden_states.shape[0]) + if input is None: + return None + if isinstance(input, torch.Tensor): + return input.to(device) + if isinstance(input, dict) or isinstance(input, UserDict): + for inp in input.keys(): + input[inp] = to_device(input[inp], device) + elif isinstance(input, list) or isinstance(input, tuple): + if len(input) == 0: + return input + input_res = [] + for inp in input: + input_res.append(to_device(inp, device)) + if isinstance(input, tuple): + input_res = tuple(input_res) + input = input_res -def check_hidden_state_dim(model, positional_args): - """Checks the dimensionality of the hidden states. + return input + + +def check_hidden_state_dim(model, positional_inputs): + """Check the concatenable dimension of hidden states. Args: - positional_args: The positional arguments. + positional_inputs: The positional arguments. Returns: int: 1 if the model type is 'chatglm' and positional arguments are not None, 0 otherwise. @@ -53,23 +64,59 @@ def check_hidden_state_dim(model, positional_args): if hasattr(model, "config") and key in model.config.model_type: is_special = True break - return int(is_special and positional_args is not None) + return int(is_special and positional_inputs is not None) -def check_not_share_position_ids(model, **kwargs): - is_special = False - for key in not_share_position_ids_tuple: - if hasattr(model, "config") and key in model.config.model_type: - is_special = True - break - return bool(is_special) +def special_model_init(model, positional_inputs, inputs): + """ + Initializes special model inputs by adding positional inputs if missing. + + Args: + model: The model instance being initialized. + positional_inputs (list): List of positional inputs to add to inputs. + inputs (dict): Dictionary of model inputs. + + Modifies: + inputs (dict): Adds "positional_inputs" key if not present. + """ + if "positional_inputs" not in inputs: # for chatglm Series + inputs["positional_inputs"] = [] + for idx, item in enumerate(positional_inputs): + inputs["positional_inputs"] = to_device(positional_inputs) -def check_not_share_rotary_pos_emb(model, **kwargs): - is_special = False - for key in not_share_rotary_pos_emb_tuple: - if hasattr(model, "config") and key in model.config.model_type: - is_special = True - break - return bool(is_special) +def reset_params(inputs): + """ + Resets specific input parameters to avoid saving the key-value cache during fine-tuning. + + Args: + inputs (dict): Dictionary of model inputs. + + Modifies: + inputs (dict): Sets "use_cache" to False if the key is present. + """ + if "use_cache" in inputs.keys(): # Not storing kv cache + inputs['use_cache'] = False + + +def skip_keywards_hint(key): + """ + Prints a reminder if a key is not stored during quantization fine-tuning. + """ + for cache_key in skippable_cache_keys: + if cache_key not in key: + return True + return False + +def check_mllm_model_batch(model, batch_size, gradient_accumulate_steps): + """ + Checks model configuration to determine if it's necessary to limit bs to avoid potential input shape mismatches. + """ + for key in mllms_with_limited_bs: + if hasattr(model, "config") and key in model.config.model_type and batch_size != 1: + accumulate_steps = batch_size * gradient_accumulate_steps + print("To avoid the tensor concat mismatch problem, modified parameters to " \ + f"batch_size=1. As an alternative, set the gradient_accumulate_steps={accumulate_steps}") + return 1, accumulate_steps + return batch_size, gradient_accumulate_steps diff --git a/auto_round/utils.py b/auto_round/utils.py index ad8cdd52..9b21a15c 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -29,6 +29,7 @@ from functools import lru_cache from packaging import version import gc +from .special_model_handler import shareable_keywords @lru_cache(None) def warning_once(self, msg: str): @@ -347,9 +348,7 @@ def collect_best_params(block): @torch.no_grad() def sampling_inputs(input_ids, input_others, indices, seqlen, - share_attention_mask_flag=False, - not_share_position_ids_flag=False, - not_share_rotary_pos_emb_flag=False, input_dim=0): + input_dim=0): """Samples inputs based on the given indices and sequence length. Args: @@ -366,19 +365,20 @@ def sampling_inputs(input_ids, input_others, indices, seqlen, current_input_ids = torch.cat(current_input_ids, dim=input_dim) current_input_others = {"positional_inputs": input_others["positional_inputs"]} for key in input_others.keys(): - if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key) \ - or (not_share_position_ids_flag and ("position_ids" in key or \ - "cache_position" in key or "position_embeddings" in key)) \ - or (not_share_rotary_pos_emb_flag and ("rotary_pos_emb" in key or 'cu_seqlens' in key)) \ - or "cross_attention_states" in key: + if "positional_inputs" in key: + continue + if (key not in shareable_keywords or len(indices) == 1) \ + and not isinstance(input_others[key], (str, bool, type(None))): current_input_others[key] = None if input_others[key] is not None: current_input_others[key] = [input_others[key][i] for i in indices] - if not isinstance(current_input_others[key], torch.Tensor): - if len(current_input_others[key]) == 1: - current_input_others[key] = current_input_others[key][0] - else: + if len(indices) == 1: + current_input_others[key] = current_input_others[key][0] + else: + try: current_input_others[key] = torch.cat(current_input_others[key], dim=0) + except TypeError as err: + logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") else: current_input_others[key] = input_others[key] @@ -864,3 +864,4 @@ def clear_memory(tensor=None): del tensor gc.collect() torch.cuda.empty_cache() + diff --git a/examples/language-modeling/main.py b/examples/language-modeling/main.py index 1fb88eec..d209f04e 100644 --- a/examples/language-modeling/main.py +++ b/examples/language-modeling/main.py @@ -40,7 +40,7 @@ parser.add_argument("--group_size", default=128, type=int, help="group size") - parser.add_argument("--train_bs", default=8, type=int, + parser.add_argument("--batch_size", default=8, type=int, help="train batch size") parser.add_argument("--eval_bs", default=None, type=int, @@ -323,7 +323,7 @@ error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization." raise EnvironmentError(error_message) - autoround = round(model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.train_bs, + autoround = round(model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.batch_size, dataset=args.dataset, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, device=device_str, amp=not args.disable_amp, nsamples=args.nsamples, @@ -466,3 +466,4 @@ from lm_eval.utils import make_table print(make_table(res)) + diff --git a/examples/language-modeling/requirements.txt b/examples/language-modeling/requirements.txt index d7e84841..13edc765 100644 --- a/examples/language-modeling/requirements.txt +++ b/examples/language-modeling/requirements.txt @@ -19,3 +19,6 @@ wandb py-cpuinfo numpy < 2.0 threadpoolctl +numexpr +bitsandbytes # for baichuan Series + diff --git a/examples/multimodal-modeling/Common_model/main.py b/examples/multimodal-modeling/Common_model/main.py index bb35a8bd..92883580 100644 --- a/examples/multimodal-modeling/Common_model/main.py +++ b/examples/multimodal-modeling/Common_model/main.py @@ -277,7 +277,7 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat parser.add_argument("--group_size", default=128, type=int, help="group size") - parser.add_argument("--train_bs", default=1, type=int, + parser.add_argument("--batch_size", default=1, type=int, help="train batch size") parser.add_argument("--eval_bs", default=4, type=int, @@ -288,7 +288,7 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat "allowing for automatic detection. Currently, device settings support CPU, GPU, and HPU.") parser.add_argument("--sym", action='store_true', - help=" sym quantization") + help="sym quantization") parser.add_argument("--iters", default=200, type=int, help=" iters") @@ -339,6 +339,9 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat parser.add_argument("--disable_trust_remote_code", action='store_true', help="Whether to disable trust_remote_code") + + parser.add_argument("--not_use_best_mse", action='store_true', + help="To determine whether the quantization should handle vision component.") parser.add_argument("--disable_quanted_input", action='store_true', help="whether to disuse the output of quantized block to tune the next block") @@ -381,8 +384,8 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat if args.act_bits <= 8 and args.deployment_device != "fake": assert False, "only support fake mode for activation quantization currently" - if "marlin" in args.deployment_device and args.sym == False: - assert False, "marlin backend only supports sym quantization, please set --sym" + if "marlin" in args.deployment_device and args.asym == True: + assert False, "marlin backend only supports sym quantization, please enable --sym" model_name = args.model_name if model_name[-1] == "/": @@ -420,9 +423,10 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat questions = json.load(open(args.question_file, "r")) config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) model_type = config.model_type + processor = None if "mllama" in model_type: from transformers import MllamaForConditionalGeneration - model = MllamaForConditionalGeneration.from_pretrained(args.model_name, attn_implementation="eager", + model = MllamaForConditionalGeneration.from_pretrained(args.model_name, trust_remote_code=not args.disable_trust_remote_code) # torch_dtype=torch.bfloat16 processor = AutoProcessor.from_pretrained(args.model_name) tokenizer.processor = processor @@ -448,7 +452,7 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat raw_data = DataFormating(questions, args.image_folder, model_type=model_type) dataset = LazySupervisedDataset(raw_data, tokenizer, max_len=min(args.seqlen, tokenizer.model_max_length), image_folder=args.image_folder) - dataloader = get_train_dataloader(dataset, model, data_collator=default_collator, train_batch_size=args.train_bs) + dataloader = get_train_dataloader(dataset, model, data_collator=default_collator, train_batch_size=args.batch_size) from auto_round import (AutoRound, AutoRoundAdam) @@ -497,10 +501,10 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat quant_block_list = get_multimodal_block_names(model, args.quant_vision) - autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs, + autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.batch_size, dataset=dataloader, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, - amp=not args.disable_amp, nsamples=args.nsamples, + amp=not args.disable_amp, nsamples=args.nsamples, not_use_best_mse=args.not_use_best_mse, low_gpu_mem_usage=args.low_gpu_mem_usage, device=device_str, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps, scale_dtype=args.scale_dtype, layer_config=layer_config, @@ -579,5 +583,3 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat ) - - diff --git a/examples/multimodal-modeling/Llava/main.py b/examples/multimodal-modeling/Llava/main.py index bf6f5793..8408fdee 100644 --- a/examples/multimodal-modeling/Llava/main.py +++ b/examples/multimodal-modeling/Llava/main.py @@ -68,8 +68,8 @@ def create_data_loader(dataset, batch_size=1, data_collator=None): data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator) return data_loader -def save_tower(model, save_path, quant_vision: bool = False, max_shard_size: str = "5GB", safe_serialization: bool = True): - if not quant_vision: +def save_tower(model, save_path, quant_nontext_module: bool = False, max_shard_size: str = "5GB", safe_serialization: bool = True): + if not quant_nontext_module: print("Won't save vision_tower since this part was not quantized.") return ori_path = save_path @@ -113,7 +113,7 @@ def save_tower(model, save_path, quant_vision: bool = False, max_shard_size: str parser.add_argument("--group_size", default=128, type=int, help="group size") - parser.add_argument("--train_bs", default=1, type=int, + parser.add_argument("--batch_size", default=1, type=int, help="train batch size") parser.add_argument("--eval_bs", default=4, type=int, @@ -194,7 +194,7 @@ def save_tower(model, save_path, quant_vision: bool = False, max_shard_size: str parser.add_argument("--is_multimodal", type=bool, default=True, help="To determine whether the preprocessing should handle multimodal infomations.") - parser.add_argument("--quant_vision", action='store_true', + parser.add_argument("--quant_nontext_module", action='store_true', help="To determine whether the quantization should handle vision component.") # ========== Calibration Datasets ============= @@ -300,7 +300,7 @@ def save_tower(model, save_path, quant_vision: bool = False, max_shard_size: str questions = json.load(open(args.question_file, "r")) dataset = CustomDataset(questions, args.image_folder, tokenizer, image_processor, args=args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - dataloader = create_data_loader(dataset, args.train_bs, data_collator) + dataloader = create_data_loader(dataset, args.batch_size, data_collator) round = AutoRound if args.adam: @@ -340,9 +340,9 @@ def save_tower(model, save_path, quant_vision: bool = False, max_shard_size: str print(f"warning, low_gpu_mem_usage=False is strongly recommended if the whole model could be loaded to " f"gpu") - quant_block_list = get_multimodal_block_names(model, args.quant_vision) + quant_block_list = get_multimodal_block_names(model, args.quant_nontext_module) - autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs, + autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.batch_size, dataset=dataloader, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, device=device_str, amp=not args.disable_amp, nsamples=args.nsamples, layer_config=layer_config, @@ -379,23 +379,23 @@ def save_tower(model, save_path, quant_vision: bool = False, max_shard_size: str if "round" in gpu_format: eval_folder = f'{export_dir}-round' compressed_model = autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace) - save_tower(compressed_model, eval_folder, quant_vision=args.quant_vision) + save_tower(compressed_model, eval_folder, quant_nontext_module=args.quant_nontext_module) elif "gptq" in gpu_format: eval_folder = f'{export_dir}-gpu' compressed_model = autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace) - save_tower(compressed_model, eval_folder, quant_vision=args.quant_vision) + save_tower(compressed_model, eval_folder, quant_nontext_module=args.quant_nontext_module) if 'xpu' in deployment_device: compressed_model = autoround.save_quantized(f'{export_dir}-xpu', format="itrex_xpu", use_triton=True, inplace=inplace, compression_dtype=torch.int8, compression_dim=0, use_optimum_format=False, device="xpu") - save_tower(compressed_model, eval_folder, quant_vision=args.quant_vision) + save_tower(compressed_model, eval_folder, quant_nontext_module=args.quant_nontext_module) if "cpu" in deployment_device: compressed_model = autoround.save_quantized(output_dir=f'{export_dir}-cpu', format='itrex', inplace=inplace) - save_tower(compressed_model, eval_folder, quant_vision=args.quant_vision) + save_tower(compressed_model, eval_folder, quant_nontext_module=args.quant_nontext_module) if "fake" in deployment_device: model = model.to("cpu") model.save_pretrained(output_dir) - save_tower(model, output_dir, quant_vision=args.quant_vision) + save_tower(model, output_dir, quant_nontext_module=args.quant_nontext_module) tokenizer.save_pretrained(output_dir) if eval_folder is None: eval_folder = output_dir @@ -417,3 +417,4 @@ def save_tower(model, save_path, quant_vision: bool = False, max_shard_size: str evaluator.calculate_accuracy(result_file = args.eval_result_file) + diff --git a/examples/multimodal-modeling/Phi-3-vision/main.py b/examples/multimodal-modeling/Phi-3-vision/main.py index 99f457f8..682e1737 100644 --- a/examples/multimodal-modeling/Phi-3-vision/main.py +++ b/examples/multimodal-modeling/Phi-3-vision/main.py @@ -167,7 +167,7 @@ def create_data_loader(dataset, batch_size=1, data_collator=None): parser.add_argument("--group_size", default=128, type=int, help="group size") - parser.add_argument("--train_bs", default=1, type=int, + parser.add_argument("--batch_size", default=1, type=int, help="train batch size") parser.add_argument("--eval_bs", default=4, type=int, @@ -248,7 +248,7 @@ def create_data_loader(dataset, batch_size=1, data_collator=None): parser.add_argument("--act_bits", default=32, type=int, help="activation bits") - parser.add_argument("--quant_vision", action='store_true', + parser.add_argument("--quant_nontext_module", action='store_true', help="To determine whether the quantization should handle vision component.") parser.add_argument("--enable_safe_serialization", action='store_true', @@ -334,7 +334,7 @@ def create_data_loader(dataset, batch_size=1, data_collator=None): data_path=args.question_file, processor=processor, data_args=data_args ) data_collator = DataCollatorForSupervisedDataset(tokenizer=processor.tokenizer) - dataloader = create_data_loader(dataset, batch_size=args.train_bs, data_collator=data_collator) + dataloader = create_data_loader(dataset, batch_size=args.batch_size, data_collator=data_collator) from auto_round import (AutoRound, AutoRoundAdam) @@ -378,9 +378,9 @@ def create_data_loader(dataset, batch_size=1, data_collator=None): print(f"warning, low_gpu_mem_usage=False is strongly recommended if the whole model could be loaded to " f"gpu") - quant_block_list = get_multimodal_block_names(model, args.quant_vision) + quant_block_list = get_multimodal_block_names(model, args.quant_nontext_module) - autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs, + autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.batch_size, dataset=dataloader, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, device=device_str, amp=not args.disable_amp, nsamples=args.nsamples, layer_config=layer_config, @@ -465,3 +465,4 @@ def create_data_loader(dataset, batch_size=1, data_collator=None): print(make_table(res)) + diff --git a/requirements.txt b/requirements.txt index abd50f03..455c6052 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ datasets py-cpuinfo sentencepiece torch -transformers<=4.45.2 +transformers triton numpy < 2.0 threadpoolctl @@ -11,4 +11,4 @@ lm-eval>=0.4.2,<=0.4.5 tqdm packaging auto-gptq>=0.7.1 -pillow \ No newline at end of file +pillow