From 1a250eb39c70d980e262949057811b57dd6d4f11 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 4 Dec 2024 11:53:46 +0800 Subject: [PATCH] add gpu test (#367) --- README.md | 4 +- auto_round/script/llm.py | 4 +- auto_round/script/mllm.py | 4 +- auto_round/utils.py | 2 +- test/test_cuda_before_release.py | 121 +++++++++++++++++++++++++++++++ 5 files changed, 128 insertions(+), 7 deletions(-) create mode 100644 test/test_cuda_before_release.py diff --git a/README.md b/README.md index e9afcffa..c482fdc7 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,8 @@ more accuracy data and recipes across various models.
## What's New - -* [2024/11] We provide experimental support for VLLM quantization, please check out +* [2024/12] Many quantized LLMs/VLMs AutoRound are released in [OPEA Space](https://huggingface.co/OPEA) +* [2024/11] We provide experimental support for VLM quantization, please check out the [README](./auto_round/mllm/README.md) * [2024/11] We provide some tips and tricks for LLM&VLM quantization, please check out [this blog](https://medium.com/@NeuralCompressor/10-tips-for-quantizing-llms-and-vlms-with-autoround-923e733879a7) diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index cf373b79..f7cd48b0 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -27,7 +27,7 @@ # limitations under the License. import argparse -from auto_round.utils import detect_device, set_layer_config_by_fp_layers +from auto_round.utils import detect_device, get_fp_layer_names class BasicArgumentParser(argparse.ArgumentParser): @@ -380,7 +380,7 @@ def tune(args): " resulting in an exporting issue to autogptq") layer_config = {} - not_quantize_layer_names = set_layer_config_by_fp_layers(model, args.fp_layers) + not_quantize_layer_names = get_fp_layer_names(model, args.fp_layers) for name in not_quantize_layer_names: layer_config[name] = {"bits": 16} if len(not_quantize_layer_names) > 0: diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 54d35d47..3c88a5b4 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -24,7 +24,7 @@ torch.use_deterministic_algorithms(True, warn_only=True) from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoProcessor -from auto_round.utils import detect_device, set_layer_config_by_fp_layers +from auto_round.utils import detect_device, get_fp_layer_names from auto_round.utils import logger @@ -327,7 +327,7 @@ def tune(args): round = AutoRoundMLLM layer_config = {} - not_quantize_layer_names = set_layer_config_by_fp_layers(model, args.fp_layers) + not_quantize_layer_names = get_fp_layer_names(model, args.fp_layers) for name in not_quantize_layer_names: layer_config[name] = {"bits": 16} if len(not_quantize_layer_names) > 0: diff --git a/auto_round/utils.py b/auto_round/utils.py index fc4189ef..b5c8236d 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -1045,7 +1045,7 @@ def can_pack_with_numba(): # pragma: no cover return True -def set_layer_config_by_fp_layers(model, fp_layers): +def get_fp_layer_names(model, fp_layers): """Identifies and returns layers in the model to exclude from quantization. This function processes a comma-separated list of fully precision (FP) layers, diff --git a/test/test_cuda_before_release.py b/test/test_cuda_before_release.py new file mode 100644 index 00000000..daddf720 --- /dev/null +++ b/test/test_cuda_before_release.py @@ -0,0 +1,121 @@ +import copy +import shutil +import sys +import unittest +import re + +sys.path.insert(0, "..") +import torch +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round import AutoRound +from auto_round.eval.evaluation import simple_evaluate +from lm_eval.utils import make_table # pylint: disable=E0401 + + +def get_accuracy(data): + match = re.search(r'\|acc\s+\|[↑↓]\s+\|\s+([\d.]+)\|', data) + + if match: + accuracy = float(match.group(1)) + return accuracy + else: + return 0.0 + + +class TestAutoRound(unittest.TestCase): + @classmethod + def setUpClass(self): + self.save_dir = "./saved" + self.tasks = "lambada_openai" + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + @unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda") + def test_backend(self): + model_name = "/models/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_name) + autoround = AutoRound(model, tokenizer, bits=4, group_size=128) + autoround.quantize() + + ##test auto_round format + autoround.save_quantized(self.save_dir, format="auto_round", inplace=False) + model_args = f"pretrained={self.save_dir}" + res = simple_evaluate(model="hf", model_args=model_args, + tasks=self.tasks, + batch_size="auto") + res = make_table(res) + accuracy = get_accuracy(res) + assert accuracy > 0.35 + shutil.rmtree("./saved", ignore_errors=True) + + ##test auto_round format + autoround.save_quantized(self.save_dir, format="auto_gptq", inplace=False) + model_args = f"pretrained={self.save_dir}" + res = simple_evaluate(model="hf", model_args=model_args, + tasks=self.tasks, + batch_size="auto") + res = make_table(res) + accuracy = get_accuracy(res) + assert accuracy > 0.35 + shutil.rmtree("./saved", ignore_errors=True) + + ##test auto_round format + autoround.save_quantized(self.save_dir, format="auto_awq", inplace=False) + model_args = f"pretrained={self.save_dir}" + res = simple_evaluate(model="hf", model_args=model_args, + tasks=self.tasks, + batch_size="auto") + res = make_table(res) + accuracy = get_accuracy(res) + assert accuracy > 0.35 + shutil.rmtree("./saved", ignore_errors=True) + + @unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda") + def test_fp_layers(self): + model_name = "/models/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_name) + from auto_round.utils import get_fp_layer_names + layer_names = get_fp_layer_names(model, "model.decoder.layers.0,model.decoder.layers.1") + layer_configs = {} + for name in layer_names: + layer_configs[name] = {"bits": 16} + autoround = AutoRound(model, tokenizer, bits=4, group_size=128) + autoround.quantize() + + ##test auto_round format + autoround.save_quantized(self.save_dir, format="auto_round", inplace=False) + model_args = f"pretrained={self.save_dir}" + res = simple_evaluate(model="hf", model_args=model_args, + tasks=self.tasks, + batch_size="auto") + res = make_table(res) + accuracy = get_accuracy(res) + assert accuracy > 0.35 + shutil.rmtree("./saved", ignore_errors=True) + + ##test auto_awq format + autoround.save_quantized(self.save_dir, format="auto_awq", inplace=False) + model_args = f"pretrained={self.save_dir}" + res = simple_evaluate(model="hf", model_args=model_args, + tasks=self.tasks, + batch_size="auto") + res = make_table(res) + accuracy = get_accuracy(res) + assert accuracy > 0.35 + shutil.rmtree("./saved", ignore_errors=True) + + @unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda") + def test_undivided_group_size_tuning(self): + model_name = "/models/falcon-7b" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_name) + + autoround = AutoRound(model, tokenizer, bits=4, group_size=128, nsamples=1, iters=1) + autoround.quantize()