diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 53490a01..7785b812 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -33,26 +33,30 @@ def run_fast(): def run_mllm(): - from auto_round.script.mllm import setup_parser, tune, eval - args = setup_parser() - if args.eval: + if "--eval" in sys.argv: + from auto_round.script.mllm import setup_lmeval_parser, eval + sys.argv.remove("--eval") + args = setup_lmeval_parser() eval(args) + elif "--lmms" in sys.argv: + sys.argv.remove("--lmms") + run_lmms() else: + from auto_round.script.mllm import setup_parser, tune + args = setup_parser() tune(args) def run_lmms(): - from transformers.utils.versions import require_version - require_version("lmms_eval", "lmms_eval need to be installed, `pip install lmms_eval`") # from auto_round.script.lmms_eval import setup_lmms_args, eval from auto_round.script.mllm import setup_lmms_parser, lmms_eval args = setup_lmms_parser() lmms_eval(args) def switch(): - if "--lmms" in sys.argv: - sys.argv.remove("--lmms") - run_lmms() - elif "--mllm" in sys.argv: + # if "--lmms" in sys.argv: + # sys.argv.remove("--lmms") + # run_lmms() + if "--mllm" in sys.argv: sys.argv.remove("--mllm") run_mllm() else: diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index 3c9d54b8..3f6b627e 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -96,7 +96,8 @@ def __init__( self, model, tokenizer, - image_processor=None, + processor = None, + image_processor = None, bits: int = 4, group_size: int = 128, sym: bool = False, @@ -143,8 +144,8 @@ def __init__( self.image_processor = image_processor self.template = template if template is not None else model.config.model_type self.template = get_template( - self.template, model=model, tokenizer=tokenizer, image_processor=image_processor) - + self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) + dataset = self.template.default_dataset if dataset is None else dataset from ..calib_dataset import CALIB_DATASETS if truncation is None: diff --git a/auto_round/mllm/eval.py b/auto_round/mllm/eval.py index 2cb27d1d..a2f9d81c 100644 --- a/auto_round/mllm/eval.py +++ b/auto_round/mllm/eval.py @@ -350,7 +350,8 @@ def lmms_eval( apply_chat_template=False ): from auto_round import AutoRoundConfig - + from transformers.utils.versions import require_version + require_version("lmms_eval", "lmms_eval need to be installed, `pip install lmms_eval`") if isinstance(tasks, str): tasks = tasks.replace(' ', '').split(',') diff --git a/auto_round/mllm/processor.py b/auto_round/mllm/processor.py index cd6518f3..c18db601 100644 --- a/auto_round/mllm/processor.py +++ b/auto_round/mllm/processor.py @@ -32,10 +32,11 @@ def register(processor): class BasicProcessor: def __init__(self): pass - - def post_init(self, model, tokenizer, image_processor=None, **kwargs): + + def post_init(self, model, tokenizer, processor=None, image_processor=None, **kwargs): self.model = model self.tokenizer = tokenizer + self.processor = processor if image_processor is not None: self.image_processor = image_processor else: @@ -76,7 +77,7 @@ def get_input( if truncation is True and truncation_strategy == "text": text = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length]) - ret = self.tokenizer.processor( + ret = self.processor( text=text, images=images, return_tensors=return_tensors, diff --git a/auto_round/mllm/template.py b/auto_round/mllm/template.py index 99ad1fc4..970fd1f8 100644 --- a/auto_round/mllm/template.py +++ b/auto_round/mllm/template.py @@ -147,7 +147,7 @@ def _load_preset_template(): _load_preset_template() -def get_template(template_or_path: str, model=None, tokenizer=None, image_processor=None): +def get_template(template_or_path: str, model=None, tokenizer=None, processor=None, image_processor=None): """Get template by template name or from a json file. Args: @@ -166,6 +166,6 @@ def get_template(template_or_path: str, model=None, tokenizer=None, image_proces logger.warning(f"Unable to recognize {template_or_path}, using default template instead.") template = TEMPLATES["default"] - template.processor.post_init(model=model, tokenizer=tokenizer, image_processor=image_processor) + template.processor.post_init(model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) return template diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index a47c7605..84febd5d 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -160,37 +160,7 @@ def __init__(self, *args, **kwargs): self.add_argument("--to_quant_block_names", default=None, type=str, help="Names of quantitative blocks, please use commas to separate them.") - ## ======================= VLM eval======================= - self.add_argument("--tasks", type=str, - default="MMBench_DEV_EN_V11,ScienceQA_VAL,TextVQA_VAL,POPE", - help="eval tasks for VLMEvalKit.") - # Args that only apply to Video Dataset - self.add_argument("--nframe", type=int, default=8, - help="the number of frames to sample from a video," - " only applicable to the evaluation of video benchmarks.") - self.add_argument("--pack", action='store_true', - help="a video may associate with multiple questions, if pack==True," - " will ask all questions for a video in a single") - self.add_argument("--use-subtitle", action='store_true') - self.add_argument("--fps", type=float, default=-1) - # Work Dir - # Infer + Eval or Infer Only - self.add_argument("--mode", type=str, default='all', choices=['all', 'infer'], - help="when mode set to 'all', will perform both inference and evaluation;" - " when set to 'infer' will only perform the inference.") - self.add_argument('--eval_data_dir', type=str, default=None, - help='path for VLMEvalKit to store the eval data. Default will store in ~/LMUData') - # API Kwargs, Apply to API VLMs and Judge API LLMs - self.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs') - # Explicitly Set the Judge Model - self.add_argument('--judge', type=str, default=None) - # Logging Utils - self.add_argument('--verbose', action='store_true') - # Configuration for Resume - # Ignore: will not rerun failed VLM inference - self.add_argument('--ignore', action='store_true', help='ignore failed indices. ') - # Rerun: will remove all evaluation temp files - self.add_argument('--rerun', action='store_true') + def setup_parser(): @@ -215,6 +185,50 @@ def setup_parser(): return args +def setup_lmeval_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", "--model_name", "--model_name_or_path", + help="model name or path") + parser.add_argument("--tasks", type=str, + default="MMBench_DEV_EN_V11,ScienceQA_VAL,TextVQA_VAL,POPE", + help="eval tasks for VLMEvalKit.") + # Args that only apply to Video Dataset + parser.add_argument("--nframe", type=int, default=8, + help="the number of frames to sample from a video," + " only applicable to the evaluation of video benchmarks.") + parser.add_argument("--pack", action='store_true', + help="a video may associate with multiple questions, if pack==True," + " will ask all questions for a video in a single") + parser.add_argument("--fps", type=float, default=-1, + help="set the fps for a video.") + # Work Dir + # Infer + Eval or Infer Only + parser.add_argument("--mode", type=str, default='all', choices=['all', 'infer'], + help="when mode set to 'all', will perform both inference and evaluation;" + " when set to 'infer' will only perform the inference.") + parser.add_argument('--eval_data_dir', type=str, default=None, + help='path for VLMEvalKit to store the eval data. Default will store in ~/LMUData') + # API Kwargs, Apply to API VLMs and Judge API LLMs + parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs') + # Explicitly Set the Judge Model + parser.add_argument('--judge', type=str, default=None, + help="whether is a judge model.") + # Logging Utils + parser.add_argument('--verbose', action='store_true', + help="whether to display verbose information.") + # Configuration for Resume + # Ignore: will not rerun failed VLM inference + parser.add_argument('--ignore', action='store_true', + help='ignore failed indices. ') + # Rerun: will remove all evaluation temp files + parser.add_argument('--rerun', action='store_true', + help="if true, will remove all evaluation temp files and rerun.") + parser.add_argument("--output_dir", default="./eval_result", type=str, + help="the directory to save quantized model") + args = parser.parse_args() + return args + + def tune(args): if args.format is None: args.format = "auto_round" @@ -265,14 +279,14 @@ def tune(args): processor, image_processor = None, None if "llava" in model_name: from llava.model.builder import load_pretrained_model # pylint: disable=E0401 - tokenizer, model, image_processor, _ = load_pretrained_model(model_name, model_base=None, model_name=model_name, - torch_dtype=torch_dtype) + tokenizer, model, image_processor, _ = load_pretrained_model( + model_name, model_base=None, model_name=model_name, + torch_dtype=torch_dtype) model_type = "llava" else: config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) - tokenizer.processor = processor model_type = config.model_type if "qwen2_vl" in model_type: from transformers import Qwen2VLForConditionalGeneration @@ -361,7 +375,7 @@ def tune(args): if "--truncation" not in sys.argv: args.truncation = None - autoround = round(model, tokenizer, image_processor=image_processor, dataset=args.dataset, + autoround = round(model, tokenizer, processor=processor, image_processor=image_processor, dataset=args.dataset, extra_data_dir=args.extra_data_dir, bits=args.bits, group_size=args.group_size, sym=not args.asym, batch_size=args.batch_size, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, amp=not args.disable_amp, @@ -406,7 +420,6 @@ def eval(args): data_store_dir=args.eval_data_dir, dataset=args.tasks, pack=args.pack, - use_subtitle=args.use_subtitle, fps=args.fps, nframe=args.nframe, rerun=args.rerun, @@ -426,8 +439,8 @@ def setup_lmms_parser(): default="pope,textvqa_val,scienceqa,mmbench_en", help="To get full list of tasks, use the command lmms-eval --tasks list", ) - parser.add_argument("--output_dir", default="./tmp_autoround", type=str, - help="the directory to save quantized model") + parser.add_argument("--output_dir", default="./eval_result", type=str, + help="the directory to save quantized model") parser.add_argument( "--num_fewshot", type=int, diff --git a/test/test_basic_usage.py b/test/test_basic_usage.py index 6d28801d..78b0a4f4 100644 --- a/test/test_basic_usage.py +++ b/test/test_basic_usage.py @@ -32,11 +32,24 @@ def test_auto_round_cmd(self): # test mllm script + # test auto_round_mllm help res = os.system( f"cd .. && {python_path} -m auto_round --mllm -h") if res > 0 or res == -1: assert False, "cmd line test fail, please have a check" + # test auto_round_mllm --eval help + res = os.system( + f"cd .. && {python_path} -m auto_round --mllm --eval -h") + if res > 0 or res == -1: + assert False, "cmd line test fail, please have a check" + + # test auto_round_mllm --lmms help + res = os.system( + f"cd .. && {python_path} -m auto_round --mllm --lmms -h") + if res > 0 or res == -1: + assert False, "cmd line test fail, please have a check" + res = os.system( f"cd .. && {python_path} -m auto_round --mllm --iter 2 --nsamples 10 --format auto_round --output_dir ./saved") if res > 0 or res == -1: diff --git a/test/test_mllm.py b/test/test_mllm.py index 3441962d..1d09ca92 100644 --- a/test/test_mllm.py +++ b/test/test_mllm.py @@ -42,12 +42,12 @@ def tearDownClass(self): def test_tune(self): tokenizer = AutoTokenizer.from_pretrained(self.model_name) processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True) - tokenizer.processor = processor model = Qwen2VLForConditionalGeneration.from_pretrained( self.model_name, trust_remote_code=True, device_map="auto") bits, group_size = 4, 128 autoround = AutoRoundMLLM( - model, tokenizer, bits=bits, group_size=group_size, + model, tokenizer, processor=processor, + bits=bits, group_size=group_size, nsamples=1, batch_size=1, iters=2, dataset=self.dataset,seqlen=256) autoround.quantize() @@ -57,12 +57,12 @@ def test_tune(self): def test_quant_vision(self): ## bug need to fix tokenizer = AutoTokenizer.from_pretrained(self.model_name) processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True) - tokenizer.processor = processor model = Qwen2VLForConditionalGeneration.from_pretrained( self.model_name, trust_remote_code=True, device_map="auto") bits, group_size = 4, 128 autoround = AutoRoundMLLM( - model, tokenizer, bits=bits, group_size=group_size, + model, tokenizer, processor=processor, + bits=bits, group_size=group_size, nsamples=5, batch_size=3, iters=2, dataset=self.dataset, quant_nontext_module=False,seqlen=256) autoround.quantize() @@ -72,7 +72,6 @@ def test_quant_block_names(self): from auto_round.utils import get_multimodal_block_names,find_matching_blocks tokenizer = AutoTokenizer.from_pretrained(self.model_name) processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True) - tokenizer.processor = processor model = Qwen2VLForConditionalGeneration.from_pretrained( self.model_name, trust_remote_code=True, device_map="auto") to_quant_block_names = 'visual.*12,layers.0,model.layers.*9'