Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch 1 for mllm #298

Merged
merged 31 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Optional, Union
from tqdm import tqdm

import torch

Expand All @@ -27,6 +28,8 @@
from .mllm_dataset import get_mllm_dataloader
from ..low_cpu_mem.utils import get_layers_before_block
from ..special_model_handler import check_mllm_model_batch


class AutoRoundMLLM(AutoRound):
"""Class for automatic rounding-based quantization with MLLMs.

Expand Down Expand Up @@ -74,6 +77,7 @@ def __init__(
self,
model,
tokenizer,
image_processor = None,
bits: int = 4,
group_size: int = 128,
sym: bool = False,
Expand All @@ -83,7 +87,7 @@ def __init__(
device: str = None,
lr_scheduler=None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = None,
extra_data_dir: Union[str, torch.utils.data.DataLoader] = None,
extra_data_dir: str = None,
template: Union[str, Template] = None,
quant_nontext_module: bool = False,
enable_quanted_input: bool = True,
Expand Down Expand Up @@ -115,13 +119,10 @@ def __init__(
quant_block_list = get_multimodal_block_names(model, quant_nontext_module)
self.extra_data_dir = extra_data_dir
self.quant_nontext_module = quant_nontext_module
self.template = template
if self.template is None:
self.template = get_template(model.config.model_type)
self.template = template if template is not None else model.config.model_type
self.template = get_template(self.template, tokenizer, image_processor)
assert dataset is not None, "dataset should not be None"
batch_size, gradient_accumulate_steps = check_mllm_model_batch(model, batch_size, gradient_accumulate_steps)
if isinstance(dataset, str):
dataset = get_mllm_dataloader(self.template, model, tokenizer, dataset, extra_data_dir, seqlen, batch_size)

super(AutoRoundMLLM, self).__init__(
model=model,
Expand Down Expand Up @@ -187,7 +188,7 @@ def calib(self, nsamples, bs):
for n, m in embed_layers:
m = m.to(self.device)

for data in self.dataloader:
for data in tqdm(self.dataloader, desc="calib", total=nsamples):
if data is None:
continue
if isinstance(data, torch.Tensor):
Expand Down Expand Up @@ -243,9 +244,6 @@ def calib(self, nsamples, bs):
data_new[key] = to_dtype(data_new[key], self.model.dtype)
input_ids = data_new["input_ids"]

if input_ids.shape[-1] < self.seqlen:
continue

try:
if isinstance(data_new, torch.Tensor):
self.model(data_new)
Expand All @@ -258,7 +256,7 @@ def calib(self, nsamples, bs):
except Exception as error:
raise error
total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1
if total_cnt >= nsamples:
if total_cnt > nsamples:
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
break
if total_cnt == 0:
logger.error(
Expand Down
12 changes: 11 additions & 1 deletion auto_round/mllm/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# limitations under the License.

import os
import time
import json
from functools import partial

Expand Down Expand Up @@ -84,6 +85,10 @@ def mllm_eval(

model = None
if data_store_dir is not None:
if not os.path.exists(data_store_dir):
oldmask = os.umask(000)
os.makedirs(data_store_dir, mode=0o777)
os.umask(oldmask)
os.environ['LMUData'] = data_store_dir

model_name = pretrained_model_name_or_path
Expand Down Expand Up @@ -119,7 +124,10 @@ def mllm_eval(
pred_root = os.path.join(work_dir, model_name)
os.makedirs(pred_root, exist_ok=True)

st = time.time()
rt_file = open(f'{pred_root}/{model_name}_eval_cost.txt', 'w')
for dataset_name in dataset:
task_st = time.time()
try:
dataset_kwargs = {}
if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']:
Expand Down Expand Up @@ -269,8 +277,10 @@ def mllm_eval(
logger.info('\n' + tabulate.tabulate(eval_results))
except:
logger.info(eval_results.to_string())

rt_file.write('%s cost: %.4fs\n' % (dataset_name, time.time() - task_st))
except Exception as e:
logger.exception(f'Model {model_name} x Dataset {dataset_name} combination failed: {e}, '
'skipping this combination.')
continue
rt_file.write('%d tasks cost: %.4fs\n' % (len(dataset), time.time() - st))
rt_file.close()
45 changes: 34 additions & 11 deletions auto_round/mllm/mllm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .utils import _extract_data_dir
from .template import Template
from ..utils import logger


MLLM_DATASET : Dict[str, Dataset] = {}
Expand All @@ -43,6 +44,9 @@ def register(dataset):
return dataset
return register

_LLAVA_V1_5_MIX665K_URL = ("https://huggingface.co/datasets/liuhaotian/"
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
"LLaVA-Instruct-150K/resolve/main/conversation_58k.json?download=true")
_COCO_DATA_URL = "http://images.cocodataset.org/train2017/"

@register_dataset("llava")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better add more information like liuhaotian/llava and 58k or 150k

class LlavaDataset(Dataset):
Expand All @@ -54,8 +58,8 @@ def __init__(
model,
tokenzier,
dataset_path,
extra_data_dir,
max_length,
extra_data_dir=None,
max_length=None,
padding=True,
truncation=True,
) -> None:
Expand All @@ -64,7 +68,15 @@ def __init__(
self.model_type = template.model_type
self.template = template
self.tokenizer = tokenzier
self.questions = json.load(open(dataset_path, "r"))
if os.path.exists(dataset_path):
self.questions = json.load(open(dataset_path, "r"))
else:
import requests
logger.info('the path of llava dataset is not provide, download from url...')
if dataset_path == 'llava_v1_5_mix665k':
self.questions = requests.get(_LLAVA_V1_5_MIX665K_URL, stream=True).json()
else:
raise KeyError(f"{dataset_path} is not support, please check.")
self.padding = padding
self.truncation = truncation
self.extra_data_dir = extra_data_dir
Expand All @@ -81,9 +93,9 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return self.cached_data_dict[i]

text = self.questions[i]["conversations"]
text = self.covert_conversations(text)
if self.template.model_type != "llava":
text = self.covert_conversations(text)

text = self.template._encode(text)
if self.extra_data_dir is not None:
image_fold = _extract_data_dir(self.extra_data_dir)
if isinstance(image_fold, dict):
Expand All @@ -92,13 +104,16 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image_fold, os.path.basename(self.questions[i]["image"]))
else:
image_path = self.questions[i]["image"]
image = self.template.processor.image_processor(image_path)
if not os.path.exists(image_path):
image_path = _COCO_DATA_URL + '/' + self.questions[i]["image"].split('/')[-1]
# image = self.template.processor.image_processor(image_path)

text = self.template._encode(text)

ret = self.template.processor.get_input(
self.model,
self.tokenizer,
text=text,
images=image,
images=image_path,
padding=self.padding,
truncation=self.truncation,
return_tensors="pt",
Expand All @@ -111,8 +126,9 @@ def covert_conversations(self, data):
new_data = []
for d in data:
content = d["value"]
for old, new in self.template.replace_tokens:
content = content.replace(old, new)
if self.template.replace_tokens is not None:
for old, new in self.template.replace_tokens:
content = content.replace(old, new)
new_data.append({
"role": self.role_mapping.get(d["from"], d["from"]),
"content": content
Expand Down Expand Up @@ -148,13 +164,19 @@ def get_mllm_dataloader(
Returns:
DataLoader: The DataLoader for the calibrated datasets.
"""
assert isinstance(template, Template)
if isinstance(template, str):
from .template import get_template
template = get_template(template)

if isinstance(dataset, str):
if os.path.isfile(dataset):
dataset = MLLM_DATASET['llava'](
template, model, tokenizer, dataset, extra_data_dir,
max_length=min(seqlen, tokenizer.model_max_length))
elif "llava" in dataset:
dataset = MLLM_DATASET["llava"](
template, model, tokenizer, "llava_v1_5_mix665k", extra_data_dir,
max_length=min(seqlen, tokenizer.model_max_length))
else:
from datasets import load_dataset
from ..calib_dataset import get_tokenizer_function
Expand All @@ -165,6 +187,7 @@ def get_mllm_dataloader(

dataloader_params = {
"batch_size": bs,
"shuffle": True,
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
"collate_fn": dataset.template.processor.data_collator
}

Expand Down
Loading
Loading