diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index 3f6b627e..dabd42e0 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -50,6 +50,9 @@ class AutoRoundMLLM(AutoRound): Args: model: The PyTorch model to be quantized. tokenizer: An optional tokenizer for processing input data. + processor: Any multi-modal model will require an object to encode or + decode the data that groups several modalities (among text, vision and audio). + image_processor: Image processor for special model like llava. bits (int): Number of bits for quantization (default is 4). group_size (int): Size of the quantization group (default is 128). sym (bool): Whether sym to be used (default is True). @@ -143,31 +146,40 @@ def __init__( self.quant_nontext_module = quant_nontext_module self.image_processor = image_processor self.template = template if template is not None else model.config.model_type - self.template = get_template( - self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) + if not isinstance(dataset, torch.utils.data.DataLoader): + self.template = get_template( + self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) dataset = self.template.default_dataset if dataset is None else dataset - from ..calib_dataset import CALIB_DATASETS - if truncation is None: - truncation = True if dataset in CALIB_DATASETS.keys() else False - self.truncation = truncation - + if nsamples % batch_size != 0: nsamples = (nsamples // batch_size + 1) * batch_size logger.warning(f"'nsamples' is not divisible by 'batch_size', will adjusted to {nsamples}") + + from ..calib_dataset import CALIB_DATASETS + from .mllm_dataset import MLLM_DATASET + if isinstance(dataset, str): + if quant_nontext_module or (dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer)): + if quant_nontext_module: + logger.warning(f"Text only dataset cannot be used for calibrating non-text modules," + "switching to liuhaotian/llava_conv_58k") + else: + logger.warning(f"{model.config.model_type} not support for {dataset}," + " will use liuhaotian/llava_conv_58k with default config as an alternative.") + dataset = "liuhaotian/llava_conv_58k" - if quant_nontext_module or (dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer)): - if quant_nontext_module: - logger.warning(f"Quantitative nontext module is not supported for plain text datasets," - "will use liuhaotian/llava_conv_58k with default config as an alternative.") - else: - logger.warning(f"{model.config.model_type} not support for {dataset}," - " will use liuhaotian/llava_conv_58k with default config as an alternative.") - dataset = "liuhaotian/llava_conv_58k" - self.truncation = False + if dataset in MLLM_DATASET.keys(): + truncation = False + batch_size = 1 + seqlen = 512 if seqlen is None else seqlen + if quant_nontext_module and batch_size != 1: + logger.warning(f"batch_size({batch_size}) cannot be used for calibrating non-text modules," + "reset to 1") + gradient_accumulate_steps = batch_size * gradient_accumulate_steps batch_size = 1 - gradient_accumulate_steps = 4 - seqlen = 512 + seqlen = 2048 if seqlen is None else seqlen + truncation = True if truncation is None else truncation + self.truncation = truncation super(AutoRoundMLLM, self).__init__( model=model, @@ -250,6 +262,7 @@ def calib(self, nsamples, bs): with tqdm(range(1, total + 1), desc="calib") as pbar: for data in self.dataloader: if data is None: + pbar.update(1) continue if isinstance(data, torch.Tensor): input_ids = data.to(self.device) @@ -297,6 +310,9 @@ def calib(self, nsamples, bs): data_new[key] = to_dtype(data_new[key], self.model.dtype) input_ids = data_new["input_ids"] + if input_ids.shape[-1] < self.seqlen: + pbar.update(1) + continue try: if isinstance(data_new, torch.Tensor): self.model(data_new) @@ -318,7 +334,7 @@ def calib(self, nsamples, bs): f"no data has been cached, please provide more data with sequence length >={self.seqlen} in the " f"dataset or decease the sequence length" ) - exit() + exit(-1) elif total_cnt < nsamples: logger.warning( f"Insufficient number of samples collected may affect the quantification. " @@ -338,3 +354,4 @@ def calib(self, nsamples, bs): for n, m in embed_layers: m = m.to("meta") # torch.cuda.empty_cache() + diff --git a/auto_round/mllm/mllm_dataset.py b/auto_round/mllm/mllm_dataset.py index 89055934..400f5773 100644 --- a/auto_round/mllm/mllm_dataset.py +++ b/auto_round/mllm/mllm_dataset.py @@ -59,7 +59,6 @@ class LlavaDataset(Dataset): } _COCO_DATA_URL = "http://images.cocodataset.org/train2017/" IMAGE_TOKEN = "" - MAX_SEQLEN = 512 def __init__( self, @@ -78,23 +77,24 @@ def __init__( self.model_type = template.model_type self.template = template self.tokenizer = tokenzier - if dataset_path == "liuhaotian/llava": - dataset_path = "llava_conv_58k" - else: - dataset_path = dataset_path.split("/")[-1] if os.path.exists(dataset_path): logger.info(f'use dataset {dataset_path}, loading from disk...') self.questions = json.load(open(dataset_path, "r")) else: import requests + if dataset_path == "liuhaotian/llava": + dataset_path = "llava_conv_58k" + else: + dataset_path = dataset_path.split("/")[-1] dataset_name = dataset_path.split('/')[-1] if dataset_name in self.LLAVA_DATASET: logger.info(f'use dataset {dataset_name}, downloading ...') self.questions = requests.get(self.LLAVA_DATASET[dataset_name], stream=True).json() else: raise KeyError(f"{dataset_path} is not support, we support {self.LLAVA_DATASET.keys()}.") - self.seqlen = min(seqlen, self.MAX_SEQLEN) - self.questions = self.check(self.questions, seqlen, nsamples) + + self.seqlen = seqlen + self.questions = self.check(self.questions, self.seqlen, nsamples) self.padding = padding self.truncation = truncation self.extra_data_dir = extra_data_dir @@ -108,8 +108,8 @@ def __init__( image_fold = image_fold['image'] self.image_fold = image_fold - def check(self, questions, seqlen, nsamples): - def _check(questions, min_seqlen, max_seqlen, nsamples): + def check(self, questions, word_len, nsamples): + def _check(questions, min_word_len, max_word_len, nsamples): new_questions = [] max_len = 0 for source in questions: @@ -120,22 +120,21 @@ def _check(questions, min_seqlen, max_seqlen, nsamples): str_len += len(text['value'].split(' ')) if str_len > max_len: max_len = str_len - if min_seqlen <= str_len < max_seqlen: + if min_word_len <= str_len < max_word_len: new_questions.append(source) - if len(new_questions) >= nsamples: - return new_questions + if len(new_questions) >= nsamples: + return new_questions + if min_word_len > max_len: + logger.debug(f"seqlen={min_word_len} is greater than the max length of dataset {max_len}," + f" will change seqlen to {max_len - 128}") + new_min_word_len = max_len - 128 else: - if seqlen > max_len: - logger.warning(f"seqlen={seqlen} is greater than the max length of dataset {max_len}," - f" will change seqlen to {max_len - 128}") - new_min_seqlen = max_len - 128 - else: - logger.warning(f"no enough sample for seqlen greater than {min_seqlen}," - f" will decrease to {min_seqlen - 128}") - new_min_seqlen = min_seqlen - 128 - return new_questions + _check(questions, new_min_seqlen, min_seqlen, nsamples - len(new_questions)) - - return _check(questions, seqlen, float("inf"), nsamples) + logger.debug(f"no enough sample for seqlen greater than {min_word_len}," + f" will decrease to {min_word_len - 128}") + new_min_word_len = min_word_len - 128 + return new_questions + _check(questions, new_min_word_len, min_word_len, nsamples - len(new_questions)) + + return _check(questions, word_len, float("inf"), nsamples) def __len__(self): return len(self.questions) @@ -248,7 +247,8 @@ def get_mllm_dataloader( tokenizer, seqlen, dataset, seed, bs, nsamples) if quant_nontext_module: logger.error( - f"Quantitative nontext module is not supported for plain text datasets," \ - " please disable arg '--quant_nontext_module'") + f"Text only dataset cannot be used for calibrating non-text modules," + " switching to liuhaotian/llava_conv_58k") exit(-1) return dataloader, bs, gradient_accumulate_steps + diff --git a/auto_round/mllm/processor.py b/auto_round/mllm/processor.py index c18db601..302d0120 100644 --- a/auto_round/mllm/processor.py +++ b/auto_round/mllm/processor.py @@ -126,13 +126,13 @@ def squeeze_result(ret): class CogVLM2Processor(BasicProcessor): def get_input( self, text, images, truncation=False, - squeeze=True, **kwargs): + squeeze=True, max_length=None, **kwargs): if images is not None: images = self.image_processor(images) - + padding_len = 2303 - max_length = 0 + max_length = 0 if max_length is None else max_length max_length += padding_len padding = False input_data = self.model.build_conversation_input_ids( diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 84febd5d..7d74494b 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -54,7 +54,8 @@ def __init__(self, *args, **kwargs): self.add_argument("--dataset", type=str, default=None, help="the dataset for quantization training." - " current support NeelNanda/pile-10k,llava_conv_58k,llava_instruct_80k " + " current support NeelNanda/pile-10k,liuhaotian/llava_conv_58k," + "liuhaotian/llava_instruct_80k,liuhaotian/llava_instruct_150k" "It can be a custom one. Default is NeelNanda/pile-10k") self.add_argument("--lr", default=None, type=float, @@ -175,8 +176,8 @@ def setup_parser(): parser.add_argument("--iters", "--iter", default=200, type=int, help=" iters") - parser.add_argument("--seqlen", "--seq_len", default=2048, type=int, - help="sequence length") + parser.add_argument("--seqlen", "--seq_len", default=None, type=int, + help="sequence length, default 2048 for text-only, 512 for liuhaotian/llava") parser.add_argument("--nsamples", default=128, type=int, help="number of samples") @@ -306,7 +307,6 @@ def tune(args): from auto_round import AutoRoundMLLM model = model.eval() - seqlen = args.seqlen if args.model_dtype != None: try: @@ -374,10 +374,11 @@ def tune(args): if "--truncation" not in sys.argv: args.truncation = None + autoround = round(model, tokenizer, processor=processor, image_processor=image_processor, dataset=args.dataset, extra_data_dir=args.extra_data_dir, bits=args.bits, group_size=args.group_size, - sym=not args.asym, batch_size=args.batch_size, seqlen=seqlen, nblocks=args.nblocks, + sym=not args.asym, batch_size=args.batch_size, seqlen=args.seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, amp=not args.disable_amp, enable_quanted_input=not args.disable_quanted_input, truncation=args.truncation, nsamples=args.nsamples, low_gpu_mem_usage=args.low_gpu_mem_usage, diff --git a/test/test_mllm.py b/test/test_mllm.py index 1d09ca92..2845c139 100644 --- a/test/test_mllm.py +++ b/test/test_mllm.py @@ -64,7 +64,7 @@ def test_quant_vision(self): ## bug need to fix model, tokenizer, processor=processor, bits=bits, group_size=group_size, nsamples=5, - batch_size=3, iters=2, dataset=self.dataset, quant_nontext_module=False,seqlen=256) + batch_size=3, iters=2, dataset=self.dataset, quant_nontext_module=True, seqlen=256) autoround.quantize() autoround.save_quantized("./saved/", format="auto_round", inplace=True) @@ -80,8 +80,31 @@ def test_quant_block_names(self): blocks = find_matching_blocks(model, all_blocks, to_quant_block_names) assert target_blocks == blocks - + def test_dataset_check(self): + from auto_round.mllm.mllm_dataset import MLLM_DATASET + class Myclass: + model_type=None + dataset = MLLM_DATASET['liuhaotian/llava'](template=Myclass(), model=None, tokenzier=None, dataset_path="liuhaotian/llava", seqlen=32, nsamples=32) + self.assertEqual(len(dataset.questions), 32) + dataset = MLLM_DATASET['liuhaotian/llava'](template=Myclass(), model=None, tokenzier=None, dataset_path="liuhaotian/llava", seqlen=2048, nsamples=512) + self.assertEqual(len(dataset.questions), 512) + + def test_diff_dataset(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True) + model = Qwen2VLForConditionalGeneration.from_pretrained( + self.model_name, trust_remote_code=True, device_map="auto") + bits, group_size = 4, 128 + dataset = ["dataset test", "list test"] + autoround = AutoRoundMLLM( + model, tokenizer, processor=processor, + bits=bits, group_size=group_size, + nsamples=2, + batch_size=1, iters=2, dataset=dataset, seqlen=1) + autoround.quantize() if __name__ == "__main__": unittest.main() + +