Skip to content

Commit

Permalink
update eval and fix example (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
n1ck-guo authored Sep 14, 2024
1 parent 3275df9 commit 7816eea
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 539 deletions.
35 changes: 6 additions & 29 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import argparse

import torch
import subprocess
import transformers
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True, warn_only=True)
Expand All @@ -26,7 +25,7 @@

from auto_round import AutoRoundConfig
from auto_round.eval.evaluation import simple_evaluate
from auto_round.utils import detect_device
from auto_round.utils import detect_device, get_library_version

def setup_parser():
parser = argparse.ArgumentParser()
Expand All @@ -36,7 +35,7 @@ def setup_parser():
)

parser.add_argument('--eval', action='store_true',
help="whether to use eval mode.")
help="whether to use eval only mode.")

parser.add_argument("--bits", default=4, type=int,
help="number of bits")
Expand Down Expand Up @@ -109,7 +108,7 @@ def setup_parser():
help="Where to store the final model.")

parser.add_argument("--disable_eval", action='store_true',
help="Whether to do lmeval evaluation.")
help="Whether to do lm-eval evaluation after tuning.")

parser.add_argument("--disable_amp", action='store_true',
help="disable amp")
Expand Down Expand Up @@ -177,7 +176,6 @@ def tune(args):
args.low_cpu_mem_tmp_dir = os.path.join(args.output_dir, "low_cpu_mem_tmp")
if args.low_cpu_mem_mode == 2:
from auto_round.low_cpu_mem.utils import load_model_with_hooks

model = load_model_with_hooks(
model_name,
model_cls,
Expand All @@ -189,7 +187,6 @@ def tune(args):
)
elif args.low_cpu_mem_mode == 1:
from auto_round.low_cpu_mem.utils import load_empty_model

low_cpu_mem_usage = True
model = load_empty_model(
model_name,
Expand All @@ -205,18 +202,16 @@ def tune(args):
trust_remote_code=not args.disable_trust_remote_code
)

from auto_round import (AutoRound,
AutoAdamRound)
from auto_round import AutoRound, AutoAdamRound

model = model.eval()
# align with GPTQ to eval ppl
seqlen = args.seqlen
if "opt" in model_name:
seqlen = model.config.max_position_embeddings
model.seqlen = model.config.max_position_embeddings
else:
seqlen = 2048
model.seqlen = seqlen
seqlen = args.seqlen

if args.model_dtype != None:
if args.model_dtype == "float16" or args.model_dtype == "fp16":
Expand Down Expand Up @@ -262,7 +257,6 @@ def tune(args):
lm_head_layer_name = n
if args.quant_lm_head:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if config.tie_word_embeddings and hasattr(model, "_tied_weights_keys"):
tied_keys = model._tied_weights_keys
Expand Down Expand Up @@ -299,7 +293,6 @@ def tune(args):
model_name = args.model.rstrip("/")
if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2:
import shutil

shutil.rmtree(args.low_cpu_mem_tmp_dir, ignore_errors=True)

model.eval()
Expand All @@ -308,22 +301,12 @@ def tune(args):

export_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-w{args.bits}g{args.group_size}"


format_list = args.format.replace(' ', '').split(',')
inplace = False if len(format_list) > 1 else True
for format_ in format_list:
eval_folder = f'{export_dir}-{format_}'
autoround.save_quantized(eval_folder, format=format_, inplace=inplace)


def get_library_version(library_name):
try:
version = subprocess.check_output(['pip', 'show', library_name]).decode().split('\n')[1].split(': ')[1]
return version
except subprocess.CalledProcessError:
return "Library not found"


lm_eval_version = get_library_version("lm-eval")

if isinstance(tasks, str):
Expand All @@ -343,28 +326,22 @@ def get_library_version(library_name):
tasks=tasks,
batch_size=args.eval_bs,
user_model=user_model)

print(make_table(res))


def eval(args):
quantization_config = AutoRoundConfig(backend=args.device)
device_str = detect_device(args.device)
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map=device_str, quantization_config=quantization_config)
model_args = f"pretrained={args.model},trust_remote_code={not args.disable_trust_remote_code}"
if isinstance(args.tasks, str):
tasks = args.tasks.split(',')
res = simple_evaluate(
model="hf",
model_args=model_args,
user_model=user_model,
tasks=tasks,
device=device_str,
batch_size=args.eval_bs)

from lm_eval.utils import make_table # pylint: disable=E0401

print(make_table(res))


Expand Down
Loading

0 comments on commit 7816eea

Please sign in to comment.