diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index ffa68d39..a615a99b 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -117,6 +117,7 @@ def __init__( act_dynamic: bool = True, quant_block_list: list = None, enable_norm_bias_tuning: bool = False, + truncation: bool = False, enable_torch_compile: bool = None, **kwargs, ): @@ -128,6 +129,7 @@ def __init__( self.template = template if template is not None else model.config.model_type self.template = get_template( self.template, model=model, tokenizer=tokenizer, image_processor=image_processor) + self.truncation = truncation assert dataset is not None, "dataset should not be None" batch_size, gradient_accumulate_steps = check_mllm_model_batch(model, batch_size, gradient_accumulate_steps) @@ -193,7 +195,8 @@ def calib(self, nsamples, bs): extra_data_dir=self.extra_data_dir, seqlen=self.seqlen, bs=bs, - seed=self.seed + seed=self.seed, + truncation=self.truncation, ) else: self.dataloader = self.dataset diff --git a/auto_round/mllm/mllm_dataset.py b/auto_round/mllm/mllm_dataset.py index cf319fbf..dae71fca 100644 --- a/auto_round/mllm/mllm_dataset.py +++ b/auto_round/mllm/mllm_dataset.py @@ -57,6 +57,7 @@ class LlavaDataset(Dataset): "llava_instruct_150k": BASE_LLAVA_URL + "llava_instruct_150k.json?download=true", } _COCO_DATA_URL = "http://images.cocodataset.org/train2017/" + IMAGE_TOKEN = "" def __init__( self, @@ -101,16 +102,18 @@ def __init__( self.image_fold = image_fold - @staticmethod - def check(questions, seqlen): + def check(self, questions, seqlen): new_questions = [] for source in questions: text_lenght = 0 for text in source['conversations']: + if self.IMAGE_TOKEN in text['value']: + text['value'] = self.IMAGE_TOKEN + text['value'].replace(self.IMAGE_TOKEN, '') text_lenght += len(text['value'].split(' ')) if text_lenght >= seqlen: new_questions.append(source) - assert len(new_questions) > 0, f"no data with length greater than {seqlen}, please check" + assert len(new_questions) > 0, \ + f"no data with length greater than {seqlen}, please reduce the seqlen or change another dataset." return new_questions @@ -175,6 +178,7 @@ def get_mllm_dataloader( bs=1, split=None, apply_template=None, + truncation=False, seed=42, ): """Generate a DataLoader for calibration using specified parameters. @@ -202,11 +206,11 @@ def get_mllm_dataloader( if os.path.isfile(dataset): dataset = MLLM_DATASET['liuhaotian/llava']( template, model, tokenizer, dataset, extra_data_dir, - seqlen=min(seqlen, tokenizer.model_max_length)) + seqlen=min(seqlen, tokenizer.model_max_length), truncation=truncation) elif "liuhaotian/llava" in dataset: dataset = MLLM_DATASET["liuhaotian/llava"]( template, model, tokenizer, dataset, extra_data_dir, - seqlen=min(seqlen, tokenizer.model_max_length)) + seqlen=min(seqlen, tokenizer.model_max_length), truncation=truncation) else: from datasets import load_dataset from ..calib_dataset import get_tokenizer_function diff --git a/auto_round/mllm/processor.py b/auto_round/mllm/processor.py index 3b285e01..7de74d7f 100644 --- a/auto_round/mllm/processor.py +++ b/auto_round/mllm/processor.py @@ -45,6 +45,7 @@ def get_input( return_tensors="pt", squeeze=True, max_length=None, + truncation=False, truncation_strategy="text", **kwargs): @@ -69,8 +70,8 @@ def get_input( if images is not None: images = self.image_processor(images) - if truncation_strategy == "text" and max_length is not None: - text = text[:max_length] + if truncation is True and truncation_strategy == "text": + text = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length]) ret = self.tokenizer.processor( text=text, @@ -78,7 +79,7 @@ def get_input( return_tensors=return_tensors, # videos = None ) - if truncation_strategy == "token" and max_length: + if truncation is True and truncation_strategy == "token": seqlen = ret['input_ids'].shape[-1] for key in ret: shape_ = ret[key].shape @@ -118,7 +119,7 @@ def squeeze_result(ret): @regist_processor("cogvlm2") class CogVLM2Processor(BasicProcessor): def get_input( - self, text, images, + self, text, images, truncation=False, squeeze=True, **kwargs): if images is not None: @@ -127,7 +128,6 @@ def get_input( padding_len = 2303 max_length = 0 max_length += padding_len - truncation = True padding = False input_data = self.model.build_conversation_input_ids( self.tokenizer, @@ -207,7 +207,7 @@ def post_init(self, model, tokenizer, image_processor=None, **kwargs): def get_input( self, text, images,max_length=None, - squeeze=True, truncation_strategy="text", **kwargs): + squeeze=True, truncation=False, truncation_strategy="text", **kwargs): if images is not None: images = fetch_image(images).convert('RGB') @@ -217,13 +217,13 @@ class DataArgs: is_multimodal = True mm_use_im_start_end = False - if truncation_strategy == "text" and max_length is not None: - text = text[:max_length] + if truncation is True and truncation_strategy == "text": + text = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length]) input_data = llava_train.preprocess_multimodal([text], DataArgs()) ret = llava_train.preprocess(input_data, self.tokenizer, has_image=(images is not None)) - if truncation_strategy == "token" and max_length: + if truncation is True and truncation_strategy == "token": seqlen = ret['input_ids'].shape[-1] for key in ret: if ret[key].shape[-1] == seqlen: diff --git a/auto_round/mllm/template.py b/auto_round/mllm/template.py index 6d477f04..9154187d 100644 --- a/auto_round/mllm/template.py +++ b/auto_round/mllm/template.py @@ -48,12 +48,12 @@ class Template: format_separator: str default_system: str replace_tokens: List[tuple] - add_special_token: bool + extra_encode: bool processor: "BasicProcessor" def _encode(self, sources): """Encodes formatted inputs to pairs of token ids.""" - if self.add_special_token: + if self.extra_encode: element = "" for i, source in enumerate(sources): if i == 0: @@ -84,7 +84,7 @@ def _register_template( format_separator: Optional[str] = None, default_system: str = "", replace_tokens: List[tuple] = None, - add_special_token: Optional[bool] = True, + extra_encode: Optional[bool] = True, processor: "BasicProcessor" = PROCESSORS["basic"], ): """Registers a chat template.""" @@ -105,7 +105,7 @@ def _register_template( format_separator = format_separator or default_format_separator, default_system = default_system, replace_tokens = replace_tokens, - add_special_token = add_special_token, + extra_encode = extra_encode, processor = processor() ) return TEMPLATES[model_type] diff --git a/auto_round/mllm/templates/default.json b/auto_round/mllm/templates/default.json index 51ec37f3..2a14283d 100644 --- a/auto_round/mllm/templates/default.json +++ b/auto_round/mllm/templates/default.json @@ -9,5 +9,5 @@ "default_system": "You are a helpful assistant.", "replace_tokens": null, "processor": "basic", - "add_special_token" : false + "extra_encode" : false } \ No newline at end of file diff --git a/auto_round/mllm/templates/llava.json b/auto_round/mllm/templates/llava.json index 13d5a6b4..87b4be02 100644 --- a/auto_round/mllm/templates/llava.json +++ b/auto_round/mllm/templates/llava.json @@ -2,5 +2,5 @@ "model_type": "llava", "replace_tokens": null, "processor": "llava", - "add_special_token" : false + "extra_encode" : false } \ No newline at end of file diff --git a/auto_round/mllm/templates/phi3_v.json b/auto_round/mllm/templates/phi3_v.json index 09949ff4..955c15f7 100644 --- a/auto_round/mllm/templates/phi3_v.json +++ b/auto_round/mllm/templates/phi3_v.json @@ -2,5 +2,5 @@ "model_type": "phi3_v", "replace_tokens": ["", "<|image_1|>"], "processor": "basic", - "add_special_token" : false + "extra_encode" : false } \ No newline at end of file diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 54a25f3a..ca5d0351 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -147,6 +147,9 @@ def __init__(self, *args, **kwargs): self.add_argument("--template", default=None, type=str, help="the template for building training dataset. It can be a custom one.") + + self.add_argument("--truncation", action="store_true", + help="whether to truncate sequences at the maximum length.") ## ======================= VLM eval======================= self.add_argument("--tasks", type=str, @@ -331,7 +334,7 @@ def tune(args): extra_data_dir=args.extra_data_dir, bits=args.bits, group_size=args.group_size, sym=not args.asym, batch_size=args.batch_size, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, amp=not args.disable_amp, - enable_quanted_input=not args.disable_quanted_input, + enable_quanted_input=not args.disable_quanted_input, truncation=args.truncation, nsamples=args.nsamples, low_gpu_mem_usage=args.low_gpu_mem_usage, device=device_str, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps, scale_dtype=args.scale_dtype, layer_config=layer_config, template=args.template,