Skip to content

Commit

Permalink
support for more vlms
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Dec 19, 2024
1 parent ecc17be commit 919995d
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 10 deletions.
10 changes: 7 additions & 3 deletions auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def _only_text_test(model, tokenizer, device):
text = ["only text", "test"]
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if device.split(':')[0] != model.device.type:
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.unk_token
if device.split(':')[0] != model.device.type: # TODO: OOM
model = model.to(device)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
model(**inputs)
Expand Down Expand Up @@ -158,6 +158,9 @@ def __init__(
self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)
dataset = self.template.default_dataset if dataset is None else dataset

if model.config.model_type == "deepseek_vl_v2":
model.forward = model.language.forward

from ..calib_dataset import CALIB_DATASETS
from .mllm_dataset import MLLM_DATASET
if isinstance(dataset, str):
Expand Down Expand Up @@ -256,6 +259,7 @@ def calib(self, nsamples, bs):
template=self.template,
model=self.model,
tokenizer=self.tokenizer,
processor=self.processor,
image_processor=self.image_processor,
dataset=dataset,
extra_data_dir=self.extra_data_dir,
Expand Down Expand Up @@ -324,7 +328,7 @@ def calib(self, nsamples, bs):
data_new = {}
for key in data.keys():
data_new[key] = to_device(data[key], self.model.device)
if key == 'images':
if key in ['images', 'pixel_values']:
data_new[key] = to_dtype(data_new[key], self.model.dtype)
input_ids = data_new["input_ids"]

Expand Down
3 changes: 2 additions & 1 deletion auto_round/mllm/mllm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def get_mllm_dataloader(
template,
model,
tokenizer,
processor,
image_processor=None,
dataset="liuhaotian/llava_conv_58k",
extra_data_dir=None,
Expand Down Expand Up @@ -222,7 +223,7 @@ def get_mllm_dataloader(
"""
if isinstance(template, str):
from .template import get_template
template = get_template(template, model=model, tokenizer=tokenizer, image_processor=image_processor)
template = get_template(template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)

if os.path.isfile(dataset) or dataset in MLLM_DATASET.keys():
dataset = MLLM_DATASET['liuhaotian/llava'](
Expand Down
45 changes: 45 additions & 0 deletions auto_round/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,51 @@ def squeeze_result(ret):
return ret


@regist_processor("hf")
class HFProcessor(BasicProcessor):
IMAGE_TOKEN = '<image>'
def __init__(self):
pass

def post_init(self, model, tokenizer, processor=None, image_processor=None, **kwargs):
self.model = model
self.tokenizer = tokenizer
self.processor = processor
if image_processor is not None:
self.image_processor = image_processor
else:
self.image_processor = self.default_image_processor

def get_input(
self,
text,
images,
return_tensors="pt",
squeeze=True,
max_length=None,
truncation=False,
truncation_strategy="text",
**kwargs):

messages = []
for content in text:
messages.append({
"role": content['role'],
"content": [
{"text": content["content"].replace(self.IMAGE_TOKEN, ""), "type": "text"}
]
})
if self.IMAGE_TOKEN in content['content']:
messages[-1]["content"].append({"text": None, "type": "image"})
text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
if images is not None:
images = self.image_processor(images)
ret = self.processor(text=text, images=images, return_tensors="pt")
if squeeze:
ret = self.squeeze_result(ret)
return ret


@regist_processor("qwen2_vl")
class Qwen2VLProcessor(BasicProcessor):
@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion auto_round/mllm/templates/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
"replace_tokens": null,
"extra_encode" : false,
"default_dataset": "NeelNanda/pile-10k",
"processor": "basic"
"processor": "hf"
}
13 changes: 12 additions & 1 deletion auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ def tune(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.device = ",".join(map(str, range(len(devices))))
devices = args.device.replace(" ", "").split(',')
use_auto_mapping = True
if len(devices) > 1: ##for 70B model on single card, use auto will cause some layer offload to cpu
use_auto_mapping = True
elif args.device == "auto":
use_auto_mapping == True

Expand All @@ -288,6 +289,13 @@ def tune(args):
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
elif "deepseek" in model_name.lower():
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
processor = DeepseekVLV2Processor.from_pretrained(model_name)
tokenizer = processor.tokenizer
model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
else:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand All @@ -299,6 +307,9 @@ def tune(args):
elif "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
cls = MllamaForConditionalGeneration
elif "idefics3" in model_type:
from transformers import AutoModelForVision2Seq
cls = AutoModelForVision2Seq
else:
cls = AutoModelForCausalLM

Expand Down
13 changes: 9 additions & 4 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,16 @@ def get_multimodal_block_names(model, quant_vision=False):
"""
block_names = []
target_modules = []
vison_blocks_tuple = ("vision", "visual",)
vison_blocks_tuple = ("vision", "visual", "projector")
module_list_type = ("ModuleList", "Sequential")
last_module_list = None
for n, m in model.named_modules():
if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
if quant_vision or all(key not in n.lower() for key in (vison_blocks_tuple)):
target_modules.append((n, m))
# if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
if hasattr(type(m), "__name__") and any(key in type(m).__name__ for key in module_list_type):
if quant_vision or all(key not in n.lower() for key in vison_blocks_tuple):
if last_module_list is None or last_module_list not in n:
last_module_list = n
target_modules.append((n, m))
validate_modules(target_modules, quant_vision, vison_blocks_tuple)
for i, target_m in enumerate(target_modules):
block_names.append([])
Expand Down

0 comments on commit 919995d

Please sign in to comment.