Skip to content

Commit

Permalink
mllm eval bug fix (#297)
Browse files Browse the repository at this point in the history
* fix mllm eval bugs

Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo authored Nov 4, 2024
1 parent 4384914 commit 0bb70a6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 4 additions & 2 deletions auto_round/mllm/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#model_name
"Qwen-VL": dict(cls="QwenVL"),
"Qwen-VL-Chat": dict(cls="QwenVLChat"),
"Qwen2-VL": dict(cls="Qwen2VLChat", min_pixels=1280*28*28, max_pixels=16384*28*28),
"Qwen2-VL": dict(cls="Qwen2VLChat", min_pixels=1280*28*28, max_pixels=16384*28*28, verbose=False),
"Llama-3.2": dict(cls="llama_vision"),
"Phi-3-vision": dict(cls="Phi3Vision"),
"Phi-3.5-vision": dict(cls="Phi3_5Vision"),
Expand Down Expand Up @@ -112,7 +112,9 @@ def mllm_eval(
kwargs["model_path"] = pretrained_model_name_or_path
model_cls = kwargs.pop("cls")
model_cls = getattr(vlmeval.vlm, model_cls)
vlmeval.config.supported_VLM[model_name] = partial(model_cls, verbose=verbose, **kwargs)
if "verbose" in kwargs:
kwargs["verbose"] = verbose
vlmeval.config.supported_VLM[model_name] = partial(model_cls, **kwargs)

pred_root = os.path.join(work_dir, model_name)
os.makedirs(pred_root, exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, *args, **kwargs):
self.add_argument("--asym", action='store_true',
help=" asym quantization")

self.add_argument("--dataset", required=True, type=str,
self.add_argument("--dataset", type=str, default=None,
help="The dataset for quantization training. It can be a custom one.")

self.add_argument("--lr", default=None, type=float,
Expand Down Expand Up @@ -201,6 +201,7 @@ def tune(args):
model_name = model_name[:-1]
logger.info(f"start to quantize {model_name}")

assert args.dataset is not None, "dataset should not be None."

device_str = detect_device(args.device)
torch_dtype = "auto"
Expand Down

0 comments on commit 0bb70a6

Please sign in to comment.