Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix merge error #316

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
act_dynamic: bool = True,
quant_block_list: list = None,
enable_norm_bias_tuning: bool = False,
truncation: bool = False,
**kwargs,
):
if quant_block_list is None:
Expand All @@ -123,6 +124,7 @@ def __init__(
self.template = template if template is not None else model.config.model_type
self.template = get_template(
self.template, model=model, tokenizer=tokenizer, image_processor=image_processor)
self.truncation = truncation
assert dataset is not None, "dataset should not be None"
batch_size, gradient_accumulate_steps = check_mllm_model_batch(model, batch_size, gradient_accumulate_steps)

Expand Down Expand Up @@ -188,7 +190,8 @@ def calib(self, nsamples, bs):
extra_data_dir=self.extra_data_dir,
seqlen=self.seqlen,
bs=bs,
seed=self.seed
seed=self.seed,
truncation=self.truncation,
)
else:
self.dataloader = self.dataset
Expand Down
8 changes: 5 additions & 3 deletions auto_round/mllm/mllm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def check(self, questions, seqlen):
text_lenght += len(text['value'].split(' '))
if text_lenght >= seqlen:
new_questions.append(source)
assert len(new_questions) > 0, f"no data with length greater than {seqlen}, please check"
assert len(new_questions) > 0, \
f"no data with length greater than {seqlen}, please reduce the seqlen or change another dataset."
return new_questions


Expand Down Expand Up @@ -177,6 +178,7 @@ def get_mllm_dataloader(
bs=1,
split=None,
apply_template=None,
truncation=False,
seed=42,
):
"""Generate a DataLoader for calibration using specified parameters.
Expand Down Expand Up @@ -204,11 +206,11 @@ def get_mllm_dataloader(
if os.path.isfile(dataset):
dataset = MLLM_DATASET['liuhaotian/llava'](
template, model, tokenizer, dataset, extra_data_dir,
seqlen=min(seqlen, tokenizer.model_max_length))
seqlen=min(seqlen, tokenizer.model_max_length), truncation=truncation)
elif "liuhaotian/llava" in dataset:
dataset = MLLM_DATASET["liuhaotian/llava"](
template, model, tokenizer, dataset, extra_data_dir,
seqlen=min(seqlen, tokenizer.model_max_length))
seqlen=min(seqlen, tokenizer.model_max_length), truncation=truncation)
else:
from datasets import load_dataset
from ..calib_dataset import get_tokenizer_function
Expand Down
14 changes: 7 additions & 7 deletions auto_round/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def get_input(
return_tensors="pt",
squeeze=True,
max_length=None,
truncation=False,
truncation_strategy="text",
**kwargs):

Expand All @@ -69,7 +70,7 @@ def get_input(
if images is not None:
images = self.image_processor(images)

if truncation_strategy == "text" and max_length is not None:
if truncation is True and truncation_strategy == "text":
text = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length])

ret = self.tokenizer.processor(
Expand All @@ -78,7 +79,7 @@ def get_input(
return_tensors=return_tensors,
# videos = None
)
if truncation_strategy == "token" and max_length:
if truncation is True and truncation_strategy == "token":
seqlen = ret['input_ids'].shape[-1]
for key in ret:
shape_ = ret[key].shape
Expand Down Expand Up @@ -118,7 +119,7 @@ def squeeze_result(ret):
@regist_processor("cogvlm2")
class CogVLM2Processor(BasicProcessor):
def get_input(
self, text, images,
self, text, images, truncation=False,
squeeze=True, **kwargs):

if images is not None:
Expand All @@ -127,7 +128,6 @@ def get_input(
padding_len = 2303
max_length = 0
max_length += padding_len
truncation = True
padding = False
input_data = self.model.build_conversation_input_ids(
self.tokenizer,
Expand Down Expand Up @@ -207,7 +207,7 @@ def post_init(self, model, tokenizer, image_processor=None, **kwargs):

def get_input(
self, text, images,max_length=None,
squeeze=True, truncation_strategy="text", **kwargs):
squeeze=True, truncation=False, truncation_strategy="text", **kwargs):

if images is not None:
images = fetch_image(images).convert('RGB')
Expand All @@ -217,13 +217,13 @@ class DataArgs:
is_multimodal = True
mm_use_im_start_end = False

if truncation_strategy == "text" and max_length is not None:
if truncation is True and truncation_strategy == "text":
text = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length])

input_data = llava_train.preprocess_multimodal([text], DataArgs())
ret = llava_train.preprocess(input_data, self.tokenizer, has_image=(images is not None))

if truncation_strategy == "token" and max_length:
if truncation is True and truncation_strategy == "token":
seqlen = ret['input_ids'].shape[-1]
for key in ret:
if ret[key].shape[-1] == seqlen:
Expand Down
4 changes: 3 additions & 1 deletion auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __init__(self, *args, **kwargs):

self.add_argument("--template", default=None, type=str,
help="the template for building training dataset. It can be a custom one.")

self.add_argument("--truncation", action="store_true")
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved

## ======================= VLM eval=======================
self.add_argument("--tasks", type=str,
Expand Down Expand Up @@ -331,7 +333,7 @@ def tune(args):
extra_data_dir=args.extra_data_dir, bits=args.bits, group_size=args.group_size,
sym=not args.asym, batch_size=args.batch_size, seqlen=seqlen, nblocks=args.nblocks,
iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, amp=not args.disable_amp,
enable_quanted_input=not args.disable_quanted_input,
enable_quanted_input=not args.disable_quanted_input, truncation=args.truncation,
nsamples=args.nsamples, low_gpu_mem_usage=args.low_gpu_mem_usage,
device=device_str, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps,
scale_dtype=args.scale_dtype, layer_config=layer_config, template=args.template,
Expand Down