Skip to content

Commit

Permalink
enable llava int4 inference with autoround format (#237)
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <[email protected]>
Co-authored-by: wenhuach21 <[email protected]>
  • Loading branch information
WeiweiZhang1 and wenhuach21 authored Sep 10, 2024
1 parent 0d60ad2 commit b698213
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 49 deletions.
4 changes: 3 additions & 1 deletion auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ class StoreAttr(object):

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
if model.__class__.main_input_name != "input_ids":
raise RuntimeError("We can only quantize pure text model.")
logger.warning("We can only quantize pure text models and " \
"certain types(Llava/Qwen-VL/Phi-3-vision) of multimodal models.")

if self.pre_quantized:
model = self.convert_model(model)
Expand Down Expand Up @@ -485,3 +486,4 @@ def is_serializable(self):
transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer


6 changes: 5 additions & 1 deletion auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def pack_layer(name, model, layer_config, backend, pbar):
qlayer.pack(layer, scale, zero, None)
qlayer.to(device)
pbar.update(1)


@register_format("auto_gptq")
def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exllamav2",
**kwargs):
Expand Down Expand Up @@ -215,9 +217,10 @@ def wrapper(name):
if hasattr(model, "config"):
model.config.quantization_config = quantization_config
save(model, output_dir, safe_serialization=safe_serialization)
return model



##
def save(model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_serialization: bool = True):
"""Save model state dict and configs.
Expand Down Expand Up @@ -248,3 +251,4 @@ def save(model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", saf




2 changes: 2 additions & 0 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def wrapper(name):
if config["bits"] > 8:
modules_to_not_convert.append(name)
save_awq(model, output_dir, modules_to_not_convert=modules_to_not_convert)
return model


def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_serialization: bool = True):
Expand Down Expand Up @@ -360,3 +361,4 @@ def save_awq(
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(quantization_config, f, indent=2)

2 changes: 2 additions & 0 deletions auto_round/export/export_to_awq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def save_quantized_as_autoawq(output_dir, model_path, inplace=True, **kwargs):
quant_config["zero_point"] = not sym

save_quantized(compressed_model, save_dir=output_dir, quant_config=quant_config)
return compressed_model


from safetensors.torch import save_file
Expand Down Expand Up @@ -225,3 +226,4 @@ def get_module_name(model, module_to_find):
if module is module_to_find:
return name
return None

1 change: 1 addition & 0 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,4 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym):
else:
from auto_round_extension.cuda.qlinear_tritonv2 import QuantLinear
return QuantLinear

8 changes: 4 additions & 4 deletions examples/multimodal-modeling/Llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This document presents step-by-step instructions for auto-round.

In this example, we introduce an straight-forward way to execute quantization on some popular multimodal models such as LLaVA.

Please note that LLAVA quantization is currently an **experimental feature** and does not yet support inference on various devices after export.
Please note that LLAVA quantized model is currently only support inference with **auto_round** format.

## Install
If you are not using Linux, do NOT proceed, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md).
Expand Down Expand Up @@ -68,12 +68,12 @@ bash run_autoround.sh
```

## 4. Results
Using [COCO 2017](https://cocodataset.org/) and [LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) datasets for quantization calibration, and TextVQA dataset for evaluation. When the vision components are not involved in quantization, it is able to achieve accuracy loss within 1%. The results for fake quantized LLava-7b are as follows:
Using [COCO 2017](https://cocodataset.org/) and [LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) datasets for quantization calibration, and TextVQA dataset for evaluation. When the vision components are not involved in quantization, it is able to achieve accuracy loss within 1%. The results for LLava-7b are as follows:
| Model | Config | Precision | Hyperparameter | Accuracy% | Relative drop |
| :----: | :----: | :----: | :----: | :----: | :----: |
| liuhaotian/llava-v1.5-7b | - | FP16 | - | 58.21 | - |
| liuhaotian/llava-v1.5-7b | W4G128 | FP16 | with vision | 56.39 | -3.13% |
| liuhaotian/llava-v1.5-7b | W4G128 | FP16 | w/o vision | 58.08 | -0.22% |
| liuhaotian/llava-v1.5-7b | W4G128 | FP16 | with vision | 56.11 | -3.60% |
| liuhaotian/llava-v1.5-7b | W4G128 | FP16 | w/o vision | 57.97 | -0.41% |


## 5. Known Issues
Expand Down
53 changes: 44 additions & 9 deletions examples/multimodal-modeling/Llava/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import sys

import copy
sys.path.insert(0, '../../..')
parser = argparse.ArgumentParser()
import torch
Expand Down Expand Up @@ -61,13 +61,46 @@ def __getitem__(self, index):

def __len__(self):
return len(self.list_data_dict)


def create_data_loader(dataset, batch_size=1, data_collator=None):
assert batch_size == 1, "batch_size must be 1"
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
return data_loader

def save_tower(model, save_path, quant_vision: bool = False, max_shard_size: str = "5GB", safe_serialization: bool = True):
if not quant_vision:
print("Won't save vision_tower since this part was not quantized.")
return
ori_path = save_path
ori_tower_name = model.get_vision_tower().vision_tower_name
vision_tower = model.get_vision_tower().vision_tower
save_path = f'{save_path}-vision_tower'
os.makedirs(save_path, exist_ok=True)
quantization_config = model.config.quantization_config
redundant_prefix = "model.vision_tower.vision_tower."
org_block_list = copy.deepcopy(quantization_config['quant_block_list'])
# prepare vision_tower quantize list
quant_block_list = [element.split(redundant_prefix)[1] if redundant_prefix in element else "" \
for sublist in org_block_list for element in sublist]
quant_block_list = [[element for element in quant_block_list if element != ""]]
quantization_config['quant_block_list'] = quant_block_list
if hasattr(vision_tower, "config"):
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(ori_tower_name)
processor.save_pretrained(save_path)
vision_tower.config.quantization_config = quantization_config
vision_tower.config.save_pretrained(save_path)
vision_tower.save_pretrained(save_path, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
# prepare llava model quantize list
quant_block_list = [element if redundant_prefix not in element else "" \
for sublist in org_block_list for element in sublist]
quant_block_list = [[element for element in quant_block_list if element != ""]]
quantization_config['quant_block_list'] = quant_block_list
model.config.mm_vision_tower = save_path
model.config.save_pretrained(ori_path)


if __name__ == '__main__':

parser.add_argument(
Expand Down Expand Up @@ -345,20 +378,24 @@ def create_data_loader(dataset, batch_size=1, data_collator=None):
for gpu_format in gpu_formats:
if "round" in gpu_format:
eval_folder = f'{export_dir}-round'
autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace)
compressed_model = autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace)
save_tower(compressed_model, eval_folder, quant_vision=args.quant_vision)
elif "gptq" in gpu_format:
eval_folder = f'{export_dir}-gpu'
autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace)

compressed_model = autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace)
save_tower(compressed_model, eval_folder, quant_vision=args.quant_vision)
if 'xpu' in deployment_device:
autoround.save_quantized(f'{export_dir}-xpu', format="itrex_xpu", use_triton=True, inplace=inplace,
compressed_model = autoround.save_quantized(f'{export_dir}-xpu', format="itrex_xpu", use_triton=True, inplace=inplace,
compression_dtype=torch.int8, compression_dim=0, use_optimum_format=False,
device="xpu")
save_tower(compressed_model, eval_folder, quant_vision=args.quant_vision)
if "cpu" in deployment_device:
autoround.save_quantized(output_dir=f'{export_dir}-cpu', format='itrex', inplace=inplace)
compressed_model = autoround.save_quantized(output_dir=f'{export_dir}-cpu', format='itrex', inplace=inplace)
save_tower(compressed_model, eval_folder, quant_vision=args.quant_vision)
if "fake" in deployment_device:
model = model.to("cpu")
model.save_pretrained(output_dir)
save_tower(model, output_dir, quant_vision=args.quant_vision)
tokenizer.save_pretrained(output_dir)
if eval_folder is None:
eval_folder = output_dir
Expand All @@ -380,5 +417,3 @@ def create_data_loader(dataset, batch_size=1, data_collator=None):
evaluator.calculate_accuracy(result_file = args.eval_result_file)




75 changes: 69 additions & 6 deletions examples/multimodal-modeling/Llava/mm_evaluation/textvqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,6 @@ def calculate_accuracy(self, result_file = None):
evaluator = TextVQAAccuracyEvaluator()
print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))



# results



# def eval_single(annotation_file, result_file):
# experiment_name = os.path.splitext(os.path.basename(result_file))[0]
# print(experiment_name)
Expand All @@ -199,3 +193,72 @@ def calculate_accuracy(self, result_file = None):
# print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))


if __name__ == "__main__":
import sys
import time
import argparse
from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
from transformers import AutoConfig
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name", default="liuhaotian/llava-v1.5-7b"
)
parser.add_argument(
"--base_model", default=None, type=float
)
parser.add_argument(
"--dataset_name", default="textvqa_val"
)
parser.add_argument(
"--eval_bs", default=4,
)
parser.add_argument(
"--trust_remote_code", action='store_true',
help="Whether to enable trust_remote_code"
)
parser.add_argument(
"--eval_question_file", type=str,
default="tables/question.jsonl"
)

parser.add_argument(
"--eval_image_folder", type=str
)

parser.add_argument(
"--eval_result_file", type=str
)

parser.add_argument(
"--eval_annotation_file", type=str
)
args = parser.parse_args()
s = time.time()
config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code)
if hasattr(config, "quantization_config"):
quantization_config = config.quantization_config
if "quant_method" in quantization_config and ("auto-round" in quantization_config["quant_method"] or
("gptq" in quantization_config["quant_method"] and args.device == "hpu")):
try:
from auto_round import AutoRoundConfig
except:
from auto_round.auto_quantizer import AutoHfQuantizer

model_path = args.model_name
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, _ = load_pretrained_model(model_path, model_base=args.base_model, model_name=model_name,
torch_dtype="auto")

evaluator = TextVQAEvaluator(
model,
tokenizer,
image_processor,
args.eval_image_folder,
args.eval_question_file,
args.eval_annotation_file,
model_name = model_name
)
evaluator.run_evaluate(result_file = args.eval_result_file)
evaluator.calculate_accuracy(result_file = args.eval_result_file)
print("cost time: ", time.time() - s)
12 changes: 0 additions & 12 deletions examples/multimodal-modeling/Llava/run_autoround_on_gaudi.sh

This file was deleted.

4 changes: 3 additions & 1 deletion examples/multimodal-modeling/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
transformers
torch
tiktoken
torchvision
transformers_stream_generator
peft
sentencepiece
Expand All @@ -11,4 +12,5 @@ protobuf
auto-gptq
openpyxl
wandb
py-cpuinfo
py-cpuinfo

15 changes: 0 additions & 15 deletions examples/multimodal-modeling/run_autoround.sh

This file was deleted.

0 comments on commit b698213

Please sign in to comment.