Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine mllm API and add help info #334

Merged
merged 11 commits into from
Nov 22, 2024
22 changes: 13 additions & 9 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,30 @@ def run_fast():


def run_mllm():
from auto_round.script.mllm import setup_parser, tune, eval
args = setup_parser()
if args.eval:
if "--eval" in sys.argv:
from auto_round.script.mllm import setup_lmeval_parser, eval
sys.argv.remove("--eval")
args = setup_lmeval_parser()
eval(args)
elif "--lmms" in sys.argv:
sys.argv.remove("--lmms")
run_lmms()
else:
from auto_round.script.mllm import setup_parser, tune
args = setup_parser()
tune(args)

def run_lmms():
from transformers.utils.versions import require_version
require_version("lmms_eval", "lmms_eval need to be installed, `pip install lmms_eval`")
# from auto_round.script.lmms_eval import setup_lmms_args, eval
from auto_round.script.mllm import setup_lmms_parser, lmms_eval
args = setup_lmms_parser()
lmms_eval(args)

def switch():
if "--lmms" in sys.argv:
sys.argv.remove("--lmms")
run_lmms()
elif "--mllm" in sys.argv:
# if "--lmms" in sys.argv:
# sys.argv.remove("--lmms")
# run_lmms()
if "--mllm" in sys.argv:
sys.argv.remove("--mllm")
run_mllm()
else:
Expand Down
7 changes: 4 additions & 3 deletions auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def __init__(
self,
model,
tokenizer,
image_processor=None,
processor = None,
image_processor = None,
bits: int = 4,
group_size: int = 128,
sym: bool = False,
Expand Down Expand Up @@ -143,8 +144,8 @@ def __init__(
self.image_processor = image_processor
self.template = template if template is not None else model.config.model_type
self.template = get_template(
self.template, model=model, tokenizer=tokenizer, image_processor=image_processor)

self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)
dataset = self.template.default_dataset if dataset is None else dataset
from ..calib_dataset import CALIB_DATASETS
if truncation is None:
Expand Down
3 changes: 2 additions & 1 deletion auto_round/mllm/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def lmms_eval(
apply_chat_template=False
):
from auto_round import AutoRoundConfig

from transformers.utils.versions import require_version
require_version("lmms_eval", "lmms_eval need to be installed, `pip install lmms_eval`")
if isinstance(tasks, str):
tasks = tasks.replace(' ', '').split(',')

Expand Down
7 changes: 4 additions & 3 deletions auto_round/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def register(processor):
class BasicProcessor:
def __init__(self):
pass

def post_init(self, model, tokenizer, image_processor=None, **kwargs):
def post_init(self, model, tokenizer, processor=None, image_processor=None, **kwargs):
self.model = model
self.tokenizer = tokenizer
self.processor = processor
if image_processor is not None:
self.image_processor = image_processor
else:
Expand Down Expand Up @@ -76,7 +77,7 @@ def get_input(
if truncation is True and truncation_strategy == "text":
text = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length])

ret = self.tokenizer.processor(
ret = self.processor(
text=text,
images=images,
return_tensors=return_tensors,
Expand Down
4 changes: 2 additions & 2 deletions auto_round/mllm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _load_preset_template():
_load_preset_template()


def get_template(template_or_path: str, model=None, tokenizer=None, image_processor=None):
def get_template(template_or_path: str, model=None, tokenizer=None, processor=None, image_processor=None):
"""Get template by template name or from a json file.

Args:
Expand All @@ -166,6 +166,6 @@ def get_template(template_or_path: str, model=None, tokenizer=None, image_proces
logger.warning(f"Unable to recognize {template_or_path}, using default template instead.")
template = TEMPLATES["default"]

template.processor.post_init(model=model, tokenizer=tokenizer, image_processor=image_processor)
template.processor.post_init(model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)

return template
89 changes: 51 additions & 38 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,37 +160,7 @@ def __init__(self, *args, **kwargs):
self.add_argument("--to_quant_block_names", default=None, type=str,
help="Names of quantitative blocks, please use commas to separate them.")

## ======================= VLM eval=======================
self.add_argument("--tasks", type=str,
default="MMBench_DEV_EN_V11,ScienceQA_VAL,TextVQA_VAL,POPE",
help="eval tasks for VLMEvalKit.")
# Args that only apply to Video Dataset
self.add_argument("--nframe", type=int, default=8,
help="the number of frames to sample from a video,"
" only applicable to the evaluation of video benchmarks.")
self.add_argument("--pack", action='store_true',
help="a video may associate with multiple questions, if pack==True,"
" will ask all questions for a video in a single")
self.add_argument("--use-subtitle", action='store_true')
self.add_argument("--fps", type=float, default=-1)
# Work Dir
# Infer + Eval or Infer Only
self.add_argument("--mode", type=str, default='all', choices=['all', 'infer'],
help="when mode set to 'all', will perform both inference and evaluation;"
" when set to 'infer' will only perform the inference.")
self.add_argument('--eval_data_dir', type=str, default=None,
help='path for VLMEvalKit to store the eval data. Default will store in ~/LMUData')
# API Kwargs, Apply to API VLMs and Judge API LLMs
self.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
# Explicitly Set the Judge Model
self.add_argument('--judge', type=str, default=None)
# Logging Utils
self.add_argument('--verbose', action='store_true')
# Configuration for Resume
# Ignore: will not rerun failed VLM inference
self.add_argument('--ignore', action='store_true', help='ignore failed indices. ')
# Rerun: will remove all evaluation temp files
self.add_argument('--rerun', action='store_true')



def setup_parser():
Expand All @@ -215,6 +185,50 @@ def setup_parser():
return args


def setup_lmeval_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", "--model_name", "--model_name_or_path",
help="model name or path")
parser.add_argument("--tasks", type=str,
default="MMBench_DEV_EN_V11,ScienceQA_VAL,TextVQA_VAL,POPE",
help="eval tasks for VLMEvalKit.")
# Args that only apply to Video Dataset
parser.add_argument("--nframe", type=int, default=8,
help="the number of frames to sample from a video,"
" only applicable to the evaluation of video benchmarks.")
parser.add_argument("--pack", action='store_true',
help="a video may associate with multiple questions, if pack==True,"
" will ask all questions for a video in a single")
parser.add_argument("--fps", type=float, default=-1,
help="set the fps for a video.")
# Work Dir
# Infer + Eval or Infer Only
parser.add_argument("--mode", type=str, default='all', choices=['all', 'infer'],
help="when mode set to 'all', will perform both inference and evaluation;"
" when set to 'infer' will only perform the inference.")
parser.add_argument('--eval_data_dir', type=str, default=None,
help='path for VLMEvalKit to store the eval data. Default will store in ~/LMUData')
# API Kwargs, Apply to API VLMs and Judge API LLMs
parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
# Explicitly Set the Judge Model
parser.add_argument('--judge', type=str, default=None,
help="whether is a judge model.")
# Logging Utils
parser.add_argument('--verbose', action='store_true',
help="whether to display verbose information.")
# Configuration for Resume
# Ignore: will not rerun failed VLM inference
parser.add_argument('--ignore', action='store_true',
help='ignore failed indices. ')
# Rerun: will remove all evaluation temp files
parser.add_argument('--rerun', action='store_true',
help="if true, will remove all evaluation temp files and rerun.")
parser.add_argument("--output_dir", default="./eval_result", type=str,
help="the directory to save quantized model")
args = parser.parse_args()
return args


def tune(args):
if args.format is None:
args.format = "auto_round"
Expand Down Expand Up @@ -265,14 +279,14 @@ def tune(args):
processor, image_processor = None, None
if "llava" in model_name:
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
tokenizer, model, image_processor, _ = load_pretrained_model(model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
tokenizer, model, image_processor, _ = load_pretrained_model(
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
else:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
tokenizer.processor = processor
model_type = config.model_type
if "qwen2_vl" in model_type:
from transformers import Qwen2VLForConditionalGeneration
Expand Down Expand Up @@ -361,7 +375,7 @@ def tune(args):
if "--truncation" not in sys.argv:
args.truncation = None

autoround = round(model, tokenizer, image_processor=image_processor, dataset=args.dataset,
autoround = round(model, tokenizer, processor=processor, image_processor=image_processor, dataset=args.dataset,
extra_data_dir=args.extra_data_dir, bits=args.bits, group_size=args.group_size,
sym=not args.asym, batch_size=args.batch_size, seqlen=seqlen, nblocks=args.nblocks,
iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, amp=not args.disable_amp,
Expand Down Expand Up @@ -406,7 +420,6 @@ def eval(args):
data_store_dir=args.eval_data_dir,
dataset=args.tasks,
pack=args.pack,
use_subtitle=args.use_subtitle,
fps=args.fps,
nframe=args.nframe,
rerun=args.rerun,
Expand All @@ -426,8 +439,8 @@ def setup_lmms_parser():
default="pope,textvqa_val,scienceqa,mmbench_en",
help="To get full list of tasks, use the command lmms-eval --tasks list",
)
parser.add_argument("--output_dir", default="./tmp_autoround", type=str,
help="the directory to save quantized model")
parser.add_argument("--output_dir", default="./eval_result", type=str,
help="the directory to save quantized model")
parser.add_argument(
"--num_fewshot",
type=int,
Expand Down
13 changes: 13 additions & 0 deletions test/test_basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,24 @@ def test_auto_round_cmd(self):


# test mllm script
# test auto_round_mllm help
res = os.system(
f"cd .. && {python_path} -m auto_round --mllm -h")
if res > 0 or res == -1:
assert False, "cmd line test fail, please have a check"

# test auto_round_mllm --eval help
res = os.system(
f"cd .. && {python_path} -m auto_round --mllm --eval -h")
if res > 0 or res == -1:
assert False, "cmd line test fail, please have a check"

# test auto_round_mllm --lmms help
res = os.system(
f"cd .. && {python_path} -m auto_round --mllm --lmms -h")
if res > 0 or res == -1:
assert False, "cmd line test fail, please have a check"

res = os.system(
f"cd .. && {python_path} -m auto_round --mllm --iter 2 --nsamples 10 --format auto_round --output_dir ./saved")
if res > 0 or res == -1:
Expand Down
9 changes: 4 additions & 5 deletions test/test_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def tearDownClass(self):
def test_tune(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
tokenizer.processor = processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_name, trust_remote_code=True, device_map="auto")
bits, group_size = 4, 128
autoround = AutoRoundMLLM(
model, tokenizer, bits=bits, group_size=group_size,
model, tokenizer, processor=processor,
bits=bits, group_size=group_size,
nsamples=1,
batch_size=1, iters=2, dataset=self.dataset,seqlen=256)
autoround.quantize()
Expand All @@ -57,12 +57,12 @@ def test_tune(self):
def test_quant_vision(self): ## bug need to fix
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
tokenizer.processor = processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_name, trust_remote_code=True, device_map="auto")
bits, group_size = 4, 128
autoround = AutoRoundMLLM(
model, tokenizer, bits=bits, group_size=group_size,
model, tokenizer, processor=processor,
bits=bits, group_size=group_size,
nsamples=5,
batch_size=3, iters=2, dataset=self.dataset, quant_nontext_module=False,seqlen=256)
autoround.quantize()
Expand All @@ -72,7 +72,6 @@ def test_quant_block_names(self):
from auto_round.utils import get_multimodal_block_names,find_matching_blocks
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
tokenizer.processor = processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_name, trust_remote_code=True, device_map="auto")
to_quant_block_names = 'visual.*12,layers.0,model.layers.*9'
Expand Down