Skip to content

Commit

Permalink
add gpu test (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Dec 4, 2024
1 parent 3acb119 commit 1a250eb
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ more accuracy data and recipes across various models.
<div align="left">

## What's New

* [2024/11] We provide experimental support for VLLM quantization, please check out
* [2024/12] Many quantized LLMs/VLMs AutoRound are released in [OPEA Space](https://huggingface.co/OPEA)
* [2024/11] We provide experimental support for VLM quantization, please check out
the [README](./auto_round/mllm/README.md)
* [2024/11] We provide some tips and tricks for LLM&VLM quantization, please check
out [this blog](https://medium.com/@NeuralCompressor/10-tips-for-quantizing-llms-and-vlms-with-autoround-923e733879a7)
Expand Down
4 changes: 2 additions & 2 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# limitations under the License.
import argparse

from auto_round.utils import detect_device, set_layer_config_by_fp_layers
from auto_round.utils import detect_device, get_fp_layer_names


class BasicArgumentParser(argparse.ArgumentParser):
Expand Down Expand Up @@ -380,7 +380,7 @@ def tune(args):
" resulting in an exporting issue to autogptq")

layer_config = {}
not_quantize_layer_names = set_layer_config_by_fp_layers(model, args.fp_layers)
not_quantize_layer_names = get_fp_layer_names(model, args.fp_layers)
for name in not_quantize_layer_names:
layer_config[name] = {"bits": 16}
if len(not_quantize_layer_names) > 0:
Expand Down
4 changes: 2 additions & 2 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
torch.use_deterministic_algorithms(True, warn_only=True)
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoProcessor

from auto_round.utils import detect_device, set_layer_config_by_fp_layers
from auto_round.utils import detect_device, get_fp_layer_names
from auto_round.utils import logger


Expand Down Expand Up @@ -327,7 +327,7 @@ def tune(args):
round = AutoRoundMLLM

layer_config = {}
not_quantize_layer_names = set_layer_config_by_fp_layers(model, args.fp_layers)
not_quantize_layer_names = get_fp_layer_names(model, args.fp_layers)
for name in not_quantize_layer_names:
layer_config[name] = {"bits": 16}
if len(not_quantize_layer_names) > 0:
Expand Down
2 changes: 1 addition & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def can_pack_with_numba(): # pragma: no cover
return True


def set_layer_config_by_fp_layers(model, fp_layers):
def get_fp_layer_names(model, fp_layers):
"""Identifies and returns layers in the model to exclude from quantization.
This function processes a comma-separated list of fully precision (FP) layers,
Expand Down
121 changes: 121 additions & 0 deletions test/test_cuda_before_release.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import copy
import shutil
import sys
import unittest
import re

sys.path.insert(0, "..")
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from auto_round import AutoRound
from auto_round.eval.evaluation import simple_evaluate
from lm_eval.utils import make_table # pylint: disable=E0401


def get_accuracy(data):
match = re.search(r'\|acc\s+\|[↑↓]\s+\|\s+([\d.]+)\|', data)

if match:
accuracy = float(match.group(1))
return accuracy
else:
return 0.0


class TestAutoRound(unittest.TestCase):
@classmethod
def setUpClass(self):
self.save_dir = "./saved"
self.tasks = "lambada_openai"

@classmethod
def tearDownClass(self):
shutil.rmtree("./saved", ignore_errors=True)
shutil.rmtree("runs", ignore_errors=True)

@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
def test_backend(self):
model_name = "/models/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
autoround = AutoRound(model, tokenizer, bits=4, group_size=128)
autoround.quantize()

##test auto_round format
autoround.save_quantized(self.save_dir, format="auto_round", inplace=False)
model_args = f"pretrained={self.save_dir}"
res = simple_evaluate(model="hf", model_args=model_args,
tasks=self.tasks,
batch_size="auto")
res = make_table(res)
accuracy = get_accuracy(res)
assert accuracy > 0.35
shutil.rmtree("./saved", ignore_errors=True)

##test auto_round format
autoround.save_quantized(self.save_dir, format="auto_gptq", inplace=False)
model_args = f"pretrained={self.save_dir}"
res = simple_evaluate(model="hf", model_args=model_args,
tasks=self.tasks,
batch_size="auto")
res = make_table(res)
accuracy = get_accuracy(res)
assert accuracy > 0.35
shutil.rmtree("./saved", ignore_errors=True)

##test auto_round format
autoround.save_quantized(self.save_dir, format="auto_awq", inplace=False)
model_args = f"pretrained={self.save_dir}"
res = simple_evaluate(model="hf", model_args=model_args,
tasks=self.tasks,
batch_size="auto")
res = make_table(res)
accuracy = get_accuracy(res)
assert accuracy > 0.35
shutil.rmtree("./saved", ignore_errors=True)

@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
def test_fp_layers(self):
model_name = "/models/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
from auto_round.utils import get_fp_layer_names
layer_names = get_fp_layer_names(model, "model.decoder.layers.0,model.decoder.layers.1")
layer_configs = {}
for name in layer_names:
layer_configs[name] = {"bits": 16}
autoround = AutoRound(model, tokenizer, bits=4, group_size=128)
autoround.quantize()

##test auto_round format
autoround.save_quantized(self.save_dir, format="auto_round", inplace=False)
model_args = f"pretrained={self.save_dir}"
res = simple_evaluate(model="hf", model_args=model_args,
tasks=self.tasks,
batch_size="auto")
res = make_table(res)
accuracy = get_accuracy(res)
assert accuracy > 0.35
shutil.rmtree("./saved", ignore_errors=True)

##test auto_awq format
autoround.save_quantized(self.save_dir, format="auto_awq", inplace=False)
model_args = f"pretrained={self.save_dir}"
res = simple_evaluate(model="hf", model_args=model_args,
tasks=self.tasks,
batch_size="auto")
res = make_table(res)
accuracy = get_accuracy(res)
assert accuracy > 0.35
shutil.rmtree("./saved", ignore_errors=True)

@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
def test_undivided_group_size_tuning(self):
model_name = "/models/falcon-7b"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

autoround = AutoRound(model, tokenizer, bits=4, group_size=128, nsamples=1, iters=1)
autoround.quantize()

0 comments on commit 1a250eb

Please sign in to comment.