Skip to content

Commit

Permalink
[CI] Use lm-eval for model regression tests (#518)
Browse files Browse the repository at this point in the history
* add lm_eval in model test

* mod default value

* mod clean up

* add native lm_eval

* add bloom lmeval result

* add test cohere lm_eval

* add lm eval model

* update native model score

* remove general assert mod quant data

* mod task name

* mod clean up

* mod clean up

* mod max length 4096

* mod clean upg

* add yi

* modify tests/models file

* modify tests/models files

* clean up

* add native value

* mod clean up

* clean temp code

* eval need save path and delete

* check quant model path

* revert

* clean up code

* mod sub test

* mod clean pu

* mod clean up

* modify model unit test files

* merge code

* mod value

* mod diff_pct

* mod clean up

* Update model_test.py

* modify tests/models/test_llama3_1.py

* format code

* format code

* mod clean up

* format code

---------

Co-authored-by: root <[email protected]>
Co-authored-by: ZYC <[email protected]>
Co-authored-by: Qubitium-ModelCloud <[email protected]>
  • Loading branch information
4 people authored Nov 5, 2024
1 parent ed9a77d commit c4b8e83
Show file tree
Hide file tree
Showing 42 changed files with 302 additions and 282 deletions.
23 changes: 13 additions & 10 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def quantize(
remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, CUDA_0)

calibration_dataset = self._prepare_dataset_for_quantization(calibration_dataset, batch_size, tokenizer, )
calibration_dataset = self._prepare_dataset_for_quantization(calibration_dataset, batch_size, tokenizer,)

if isinstance(self.quantize_config, AutoRoundQuantizeConfig):
from auto_round import AutoRound
Expand Down Expand Up @@ -659,6 +659,8 @@ def save_pretrained(

def lm_eval(
self,
model: Optional[str] = None,
model_args: str = "",
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = 32,
Expand Down Expand Up @@ -688,22 +690,23 @@ def lm_eval(
wandb_project: Optional[str] = None,
wandb_name: Optional[str] = None,
show_config: bool = False,
trust_remote_code: bool = False,
):
LM = HFLM(
pretrained=self,
batch_size=batch_size,
max_batch_size=max_batch_size,
)
if model is None:
model = HFLM(
pretrained=self,
batch_size=batch_size,
max_batch_size=max_batch_size,
trust_remote_code=trust_remote_code,
)
# evaluation_tracker need model_args cannot be None
model_args = ""
if evaluation_tracker is None and output_path is not None:
evaluation_tracker = EvaluationTracker(output_path=output_path)

results = lm_eval.simple_evaluate(
model=LM,
model=model,
model_args=model_args,
tasks=tasks,
device=self.device,
device=str(self.device),
num_fewshot=num_fewshot,
batch_size=batch_size,
max_batch_size=max_batch_size,
Expand Down
1 change: 0 additions & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon
The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
"""
# post init for bitblas backend.
device_to_buffers_size = {}

model_uses_qbits = False

Expand Down
92 changes: 77 additions & 15 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,29 @@

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import tempfile
import unittest
import shutil # noqa: E402
import tempfile # noqa: E402
import unittest # noqa: E402

from datasets import load_dataset # noqa: E402
from gptqmodel import GPTQModel
from gptqmodel import GPTQModel # noqa: E402
from gptqmodel.quantization import FORMAT # noqa: E402
from gptqmodel.quantization.config import QuantizeConfig # noqa: E402
from lm_eval.utils import make_table # noqa: E402
from transformers import AutoTokenizer # noqa: E402


class ModelTest(unittest.TestCase):
GENERATE_EVAL_SIZE = 100
TASK_NAME = "arc_challenge"
# sub test can modify
QUANT_ARC_MAX_NEGATIVE_DELTA = 0.1 # -10%
QUANT_ARC_MAX_POSITIVE_DELTA = 0.2 # 20%
TRUST_REMOTE_CODE = False
APPLY_CHAT_TEMPLATE = False
TORCH_DTYPE = "auto"

def generate(self, model, tokenizer, prompt=None):
if prompt == None:
if prompt is None:
prompt = "I am in Paris and"
device = model.device
inp = tokenizer(prompt, return_tensors="pt").to(device)
Expand All @@ -27,7 +35,7 @@ def generate(self, model, tokenizer, prompt=None):
return output

def generateChat(self, model, tokenizer, prompt=None):
if prompt == None:
if prompt is None:
prompt = [
{"role": "system",
"content": "You are a helpful assistant."},
Expand All @@ -46,11 +54,13 @@ def load_tokenizer(self, model_name_or_path, trust_remote_code=False):
return tokenizer

def load_dataset(self, tokenizer):
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train").filter(lambda x: len(x['text']) >= 512)
calibration_dataset = [tokenizer(example["text"]) for example in traindata.select(range(1024))]
return calibration_dataset
max_length = 4096
traindata = load_dataset("allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz",
split="train").filter(
lambda x: len(x["text"]) >= max_length and len(x["text"]) <= (max_length * 1.5))
return [tokenizer(example["text"]) for example in traindata.select(range(1024))]

def quantModel(self, model_name_or_path, trust_remote_code=False, torch_dtype="auto"):
def quantModel(self, model_name_or_path, trust_remote_code=False, torch_dtype="auto", need_eval=True):
tokenizer = self.load_tokenizer(model_name_or_path, trust_remote_code=trust_remote_code)
calibration_dataset = self.load_dataset(tokenizer)
quantize_config = QuantizeConfig(
Expand All @@ -73,16 +83,24 @@ def quantModel(self, model_name_or_path, trust_remote_code=False, torch_dtype="a
model.config.eos_token_id = tokenizer.eos_token_id or 0

model.quantize(calibration_dataset, batch_size=64)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_quantized(tmpdirname)
q_model, q_tokenizer = self.loadQuantModel(tmpdirname, tokenizer_path=model_name_or_path)
if need_eval:
test_dir = os.path.dirname(os.path.abspath(__file__))
save_dir = os.path.join(test_dir, "test_quantized_model")
os.makedirs(save_dir, exist_ok=True)
model.save_quantized(save_dir)
tokenizer.save_pretrained(save_dir)
q_model, q_tokenizer = self.loadQuantModel(save_dir, trust_remote_code=trust_remote_code)
else:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_quantized(tmpdirname)
tokenizer.save_pretrained(tmpdirname)
q_model, q_tokenizer = self.loadQuantModel(tmpdirname, trust_remote_code=trust_remote_code)

return q_model, q_tokenizer


def loadQuantModel(self, model_name_or_path, trust_remote_code=False, tokenizer_path=None):
if tokenizer_path == None:
if tokenizer_path is None:
tokenizer_path = model_name_or_path
else:
trust_remote_code = True
Expand All @@ -94,3 +112,47 @@ def loadQuantModel(self, model_name_or_path, trust_remote_code=False, tokenizer_
)

return model, tokenizer

def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False):
with tempfile.TemporaryDirectory() as tmp_dir:
results = model.lm_eval(
model="vllm",
model_args=f"pretrained={model.model_name_or_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code}",
output_path=tmp_dir,
tasks=self.TASK_NAME,
apply_chat_template=apply_chat_template,
trust_remote_code=trust_remote_code
)
print('--------Eval Result---------')
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
print('--------Eval Result End---------')
task_results = {
metric: value for metric, value in results['results'].get(self.TASK_NAME, {}).items()
if metric != 'alias' and 'stderr' not in metric
}
print(task_results)
if os.path.exists(model.model_name_or_path):
shutil.rmtree(model.model_name_or_path)
return task_results

def calculatorPer(self, filter, value):
if "norm" in filter:
diff_pct = (value / self.NATIVE_ARC_CHALLENGE_ACC_NORM) * 100
print(f"{filter}: {value} diff {diff_pct:.2f}%")
else:
diff_pct = (value / self.NATIVE_ARC_CHALLENGE_ACC) * 100
print(f"{filter}: {value} diff {diff_pct:.2f}%")
return diff_pct

def quant_lm_eval(self):
self.model, self.tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, torch_dtype=self.TORCH_DTYPE)

task_results = self.lm_eval(self.model, trust_remote_code=self.TRUST_REMOTE_CODE, apply_chat_template=self.APPLY_CHAT_TEMPLATE)
for filter, value in task_results.items():
diff_pct = self.calculatorPer(filter=filter, value=value)
negative_pct = 100 * (1 - self.QUANT_ARC_MAX_NEGATIVE_DELTA)
positive_pct = 100 * (1 + self.QUANT_ARC_MAX_POSITIVE_DELTA)
self.assertTrue(negative_pct <= diff_pct <= positive_pct,
f"{filter}: {value} diff {diff_pct:.2f}% is out of the expected range [{negative_pct}-{positive_pct}%]")
12 changes: 5 additions & 7 deletions tests/models/test_baichuan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestBaiChuan(ModelTest):
NATIVE_MODEL_ID = "baichuan-inc/Baichuan2-7B-Chat"
NATIVE_ARC_CHALLENGE_ACC = 0.4104
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4317
TRUST_REMOTE_CODE = True

def test_baichuan(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=True)

reference_output = "I am in Paris and I need to go to the airport. How can I get to the airport from here?\nThere are several ways to get to the airport from Paris. The most common way is to take the RER (Regional Express Train). You can take the RER A line from Gare de l'Est or Gare du Nord stations. The other option is to take the Métro (subway). You can take the Métro Line 1 or Line 14 to"
result = self.generate(model, tokenizer)

self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
self.quant_lm_eval()
12 changes: 6 additions & 6 deletions tests/models/test_bloom.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
from model_test import ModelTest
import torch # noqa: E402
from model_test import ModelTest # noqa: E402


class TestBloom(ModelTest):
NATIVE_MODEL_ID = "bigscience/bloom-560m"
NATIVE_ARC_CHALLENGE_ACC = 0.2201
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2440
TORCH_DTYPE = torch.float16

def test_bloom(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, torch_dtype=torch.float16)
reference_output = "I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and"
result = self.generate(model, tokenizer)
self.quant_lm_eval()

self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
10 changes: 3 additions & 7 deletions tests/models/test_chatglm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestChatGlm(ModelTest):
NATIVE_MODEL_ID = "THUDM/chatglm3-6b"
TRUST_REMOTE_CODE = True

def test_chatglm(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=True)
reference_output = ""
result = self.generate(model, tokenizer)

self.assertTrue(len(result) > 0)
# self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
self.quant_lm_eval()
10 changes: 5 additions & 5 deletions tests/models/test_codegen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestCodeGen(ModelTest):
NATIVE_MODEL_ID = "Salesforce/codegen2-1B_P"
NATIVE_ARC_CHALLENGE_ACC = 0.1749
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2005
TRUST_REMOTE_CODE = True

def test_codegen(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=True)
reference_output = "I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and I am in Paris. I am in Paris and"
result = self.generate(model, tokenizer)
self.quant_lm_eval()

self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
11 changes: 5 additions & 6 deletions tests/models/test_cohere.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestCohere(ModelTest):
NATIVE_MODEL_ID = "CohereForAI/aya-expanse-8b"
NATIVE_ARC_CHALLENGE_ACC = 0.5401
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5640
QUANT_ARC_MAX_NEGATIVE_DELTA = 0.12

def test_cohere(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID)
reference_output = "<BOS_TOKEN>I am in Paris and I am in love. I am in love with the city, the people, the food, the art, the history, the architecture, the fashion, the music, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art, the art,"
result = self.generate(model, tokenizer)

self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
self.quant_lm_eval()
12 changes: 6 additions & 6 deletions tests/models/test_deci.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestDeci(ModelTest):
NATIVE_MODEL_ID = "Deci/DeciLM-7B-instruct"
NATIVE_ARC_CHALLENGE_ACC = 0.5239
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5222
QUANT_ARC_MAX_NEGATIVE_DELTA = 0.55
TRUST_REMOTE_CODE = True

def test_deci(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=True)
reference_output = "<s> I am in Paris and I am going to the Eiffel Tower.\n\nQuestion: Where is the narrator going?\n\nAnswer: The Eiffel Tower\n\nTitle: The Eiffel Tower\n\nBackground: The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower. Construction began on 28 January 1887"
result = self.generate(model, tokenizer)

self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
self.quant_lm_eval()
12 changes: 7 additions & 5 deletions tests/models/test_deepseekv2_lite.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestDeepseekV2Lite(ModelTest):
NATIVE_MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
NATIVE_ARC_CHALLENGE_ACC = 0.4753
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4855
APPLY_CHAT_TEMPLATE = True
TRUST_REMOTE_CODE = True

def test_deepseekv2lite(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=True)
reference_output = "<|begin▁of▁sentence|>I am in Paris and I am looking for a good place to eat. I am a vegetarian and I am looking for a place that has a good vegetarian menu. I am not looking for a fancy restaurant, just a good place to eat.\nI am looking for a place that has a good vegetarian menu and is not too expensive. I am not looking for a fancy restaurant, just a good place to eat.\nI am in Paris and I am looking for a good place to eat. I am a vegetarian and"
result = self.generate(model, tokenizer)
self.quant_lm_eval()


self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
17 changes: 5 additions & 12 deletions tests/models/test_exaone.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestExaone(ModelTest):
NATIVE_MODEL_ID = "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"
prompt = [
{"role": "system",
"content": "You are EXAONE model from LG AI Research, a helpful assistant."},
{"role": "user",
"content": "I am in Shanghai, preparing to visit the natural history museum. Can you tell me the best way to"}
]

NATIVE_ARC_CHALLENGE_ACC = 0.4232
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4164
TRUST_REMOTE_CODE = True
def test_exaone(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=True)
reference_output = "Certainly! Here's how you can get to the Shanghai Natural History Museum:\n\n1. **By Metro**: The museum is located near Line 10 of the Shanghai Metro. You can take the Line 10 train to the People's Park station. From there, it's a short walk to the museum.\n\n2. **By Bus**: Several bus lines pass near the museum. For example, bus routes 10, 11,"
self.quant_lm_eval()

result = self.generateChat(model, tokenizer, prompt=self.prompt)

self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
12 changes: 6 additions & 6 deletions tests/models/test_falcon.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestFalcon(ModelTest):
NATIVE_MODEL_ID = "tiiuae/falcon-7b-instruct"
NATIVE_ARC_CHALLENGE_ACC = 0.3993
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4292
APPLY_CHAT_TEMPLATE = True
TRUST_REMOTE_CODE = True

def test_falcon(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID)
reference_output = "I am in Paris and,.....\n,,,,,,,, ,,, and and,, ,, and and and,, ,, and and, and, and, and, and, and, and, and and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and the, and"
result = self.generate(model, tokenizer)

self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
self.quant_lm_eval()
10 changes: 5 additions & 5 deletions tests/models/test_gemma.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from model_test import ModelTest
from model_test import ModelTest # noqa: E402


class TestGemma(ModelTest):
NATIVE_MODEL_ID = "google/gemma-2-9b"
NATIVE_ARC_CHALLENGE_ACC = 0.6143
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.6553

def test_gemma(self):
model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID)
reference_output = "<bos>I am in Paris and I am going to the Louvre. I am going to see the Mona Lisa. I am going to see the Venus de Milo. I am going to see the Winged Victory of Samothrace. I am going to see the Coronation of Napoleon. I am going to see the Raft of the Medusa. I am going to see the Code of Hammurabi. I am going to see the Rosetta Stone. I am going to see the Venus de Milo. I am going to see the Winged"
result = self.generate(model, tokenizer)
self.quant_lm_eval()


self.assertEqual(result[:self.GENERATE_EVAL_SIZE], reference_output[:self.GENERATE_EVAL_SIZE])
Loading

0 comments on commit c4b8e83

Please sign in to comment.