Skip to content

Commit

Permalink
ONNXRT LLM examples support latest optimum version (#1578)
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho authored Feb 8, 2024
1 parent ac47d9b commit 26b260e
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 232 deletions.
56 changes: 49 additions & 7 deletions examples/.config/model_params_onnxrt.json
Original file line number Diff line number Diff line change
Expand Up @@ -756,45 +756,87 @@
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b": {
"llama-2-7b-sq": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-sq-with-past": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-lwq": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-with-past-lwq": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-rtn": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-rtn-with-past": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-awq": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-awq-with-past": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-gptq": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-gptq-with-past": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-woq_tune": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-woq_tune-with-past": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
"main_script": "main.py",
"batch_size": 1
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ Note that this README.md uses meta-llama/Llama-2-7b-hf as an example. There are

Export to ONNX model:
```bash
python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" --output_model="./llama-2-7b-hf"
python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" \
--output_model="./llama-2-7b-hf" \
--task=text-generation-with-past \ # or text-generation
```

# Run
Expand All @@ -41,7 +43,7 @@ bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
--output_model=/path/to/model_tune \ # folder path to save onnx model
--batch_size=batch_size # optional \
--dataset NeelNanda/pile-10k \
--alpha 0.6 \ # 0.6 for llama-7b, 0.8 for llama-13b
--alpha 0.75 \
--tokenizer=meta-llama/Llama-2-7b-hf \ # model name or folder path containing all relevant files for model's tokenizer
--quant_format="QOperator" # or QDQ, optional
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
parser.add_argument(
'--quant_format',
type=str,
default='QOperator',
default='QOperator',
choices=['QOperator', 'QDQ'],
help="quantization format"
)
Expand Down Expand Up @@ -124,8 +124,9 @@
)
args = parser.parse_args()

# load model
# load model tokenize and config
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer)
config = LlamaConfig.from_pretrained(args.model_path)

def tokenize_function(examples):
example = tokenizer(examples['text'])
Expand All @@ -134,29 +135,20 @@ def tokenize_function(examples):
def benchmark(model):
import json
import time
config = LlamaConfig.from_pretrained(args.model_path)
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = args.intra_op_num_threads

if os.path.exists(os.path.join(model, "decoder_with_past_model.onnx")):
sessions = ORTModelForCausalLM.load_model( # pylint: disable=E1123
os.path.join(model, "decoder_model.onnx"),
os.path.join(model, "decoder_with_past_model.onnx"),
session_options=sess_options)
model = ORTModelForCausalLM(sessions[0], # pylint: disable=E1121
config,
model,
sessions[1],
use_cache=True)
else:
sessions = ORTModelForCausalLM.load_model( # pylint: disable=E1123
os.path.join(model, "decoder_model.onnx"),
session_options=sess_options)
model = ORTModelForCausalLM(sessions[0], # pylint: disable=E1121
config,
model,
use_cache=False,
use_io_binding=False)

session = ORTModelForCausalLM.load_model( # pylint: disable=E1123
os.path.join(model, "model.onnx"),
session_options=sess_options)
inputs_names = session.get_inputs()
key_value_input_names = [key.name for key in inputs_names if (".key" in key.name) or (".value" in key.name)]
use_cache = len(key_value_input_names) > 0

model = ORTModelForCausalLM(session, # pylint: disable=E1121
config,
use_cache=True if use_cache else False,
use_io_binding=True if use_cache else False,)

input_tokens = '32'
max_new_tokens = 32
Expand Down Expand Up @@ -192,7 +184,7 @@ def benchmark(model):
print(args)
throughput = (num_iter - num_warmup) / total_time
print("Throughput: {} samples/s".format(throughput))


def replace_architectures(json_path):
# replace 'LLaMATokenizer' to lowercase 'LlamaTokenizer'
Expand All @@ -201,7 +193,7 @@ def replace_architectures(json_path):
with open(json_path, "r") as file:
data = json.load(file)
data["architectures"] = ["LlamaForCausalLM"]

with open(json_path, 'w') as file:
json.dump(data, file, indent=4)

Expand Down Expand Up @@ -234,6 +226,7 @@ def eval_func(model):

return eval_acc


class KVDataloader:
def __init__(self, model_path, pad_max=196, batch_size=1, sub_folder='train'):
self.pad_max = pad_max
Expand All @@ -247,10 +240,11 @@ def __init__(self, model_path, pad_max=196, batch_size=1, sub_folder='train'):
shuffle=False,
collate_fn=self.collate_batch,
)
self.sess = None
if not model_path.endswith('decoder_model.onnx'):
self.sess = ort.InferenceSession(os.path.join(os.path.dirname(model_path), 'decoder_model.onnx'))

session = ort.InferenceSession(model_path)
inputs_names = [input.name for input in session.get_inputs()]
self.key_value_input_names = [key for key in inputs_names if (".key" in key) or (".value" in key)]
self.use_cache = len(self.key_value_input_names) > 0
self.session = session if self.use_cache else None

def collate_batch(self, batch):

Expand All @@ -269,23 +263,26 @@ def collate_batch(self, batch):
attention_mask_padded.append(attention_mask)
return (torch.vstack(input_ids_padded), torch.vstack(attention_mask_padded)), torch.tensor(last_ind)


def __iter__(self):
try:
for (input_ids, attention_mask), last_ind in self.dataloader:
if self.sess is None:
yield {'input_ids': input_ids[:, :-1].detach().cpu().numpy().astype('int64'),
'attention_mask':attention_mask[:, :-1].detach().cpu().numpy().astype('int64')}, last_ind.detach().cpu().numpy()
else:
outputs = self.sess.run(None, {'input_ids': input_ids[:, :-1].detach().cpu().numpy().astype('int64'),
'attention_mask':attention_mask[:, :-1].detach().cpu().numpy().astype('int64')})
ort_input = {}
ort_input['input_ids'] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype('int64')
for i in range(int((len(outputs) - 1) / 2)):
ort_input['past_key_values.{}.key'.format(i)] = outputs[i*2+1]
ort_input['past_key_values.{}.value'.format(i)] = outputs[i*2+2]
ort_input['attention_mask'] = np.zeros([self.batch_size, ort_input['past_key_values.0.key'].shape[2]+1], dtype='int64')
yield ort_input, last_ind.detach().cpu().numpy()
ort_input = {}
ort_input["input_ids"] = input_ids[:, :-1].detach().cpu().numpy().astype("int64")
ort_input["attention_mask"] = attention_mask[:, :-1].detach().cpu().numpy().astype("int64")
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
ort_input["position_ids"] = position_ids[:,:-1].detach().cpu().numpy().astype("int64")
if self.use_cache:
# Create dummy past_key_values for decoder
num_attention_heads = config.num_key_value_heads
embed_size_per_head = config.hidden_size // config.num_attention_heads
shape = (self.batch_size, num_attention_heads, 0, embed_size_per_head)
key_or_value = np.zeros(shape, dtype=np.float32)
for key_value_input_name in self.key_value_input_names:
ort_input[key_value_input_name] = key_or_value

yield ort_input, last_ind.detach().cpu().numpy()

except StopIteration:
return

Expand All @@ -294,43 +291,38 @@ def __iter__(self):
set_workspace(args.workspace)

if args.benchmark:
if args.mode == 'performance':
if args.mode == 'performance':
benchmark(args.model_path)
elif args.mode == 'accuracy':
eval_func(args.model_path)

if args.tune:
from neural_compressor import quantization, PostTrainingQuantConfig

model_name = "model.onnx" # require optimum >= 1.14.0
model_path = os.path.join(args.model_path, model_name)

if args.layer_wise:
# layer-wise quantization for ONNX models is still under development and only support W8A8 quantization now
config = PostTrainingQuantConfig(
ptq_config = PostTrainingQuantConfig(
calibration_sampling_size=[8],
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
'layer_wise_quant': True},
'layer_wise_quant': True,
'graph_optimization_level': 'ENABLE_EXTENDED'},
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
for model in ['decoder_model.onnx']:
# only test decoder_model
q_model = quantization.fit(
os.path.join(args.model_path, model),
config,
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
q_model.save(os.path.join(args.output_model, model))

tokenizer.save_pretrained(args.output_model)

else:
config = PostTrainingQuantConfig(
ptq_config = PostTrainingQuantConfig(
calibration_sampling_size=[8],
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
'smooth_quant': True,
'smooth_quant_args': {'alpha': args.smooth_quant_alpha},
},
'smooth_quant': True,
'smooth_quant_args': {'alpha': args.smooth_quant_alpha},
'graph_optimization_level': 'ENABLE_EXTENDED'},
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
for model in ['decoder_model.onnx', 'decoder_with_past_model.onnx']:
q_model = quantization.fit(
os.path.join(args.model_path, model),
config,
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
q_model.save(os.path.join(args.output_model, model))
tokenizer.save_pretrained(args.output_model)

q_model = quantization.fit(
model_path,
ptq_config,
calib_dataloader=KVDataloader(model_path, pad_max=args.pad_max, batch_size=1))
q_model.save(os.path.join(args.output_model, model_name))

tokenizer.save_pretrained(args.output_model)
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,37 @@ def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--input_model", type=str, required=False, default="")
parser.add_argument("--output_model", type=str, required=True)
parser.add_argument("--task",
type=str,
required=False,
default="text-generation-with-past",
choices=["text-generation-with-past", "text-generation"])
return parser.parse_args()


def prepare_model(input_model, output_model):
def prepare_model(input_model, output_model, task):
print("\nexport model...")
if Version(optimum.version.__version__) >= OPTIMUM114_VERSION:
subprocess.run(
[
"optimum-cli",
"export",
"onnx",
"--model",
f"{input_model}",
"--task",
"text-generation-with-past",
"--legacy",
f"{output_model}",
],
stdout=subprocess.PIPE,
text=True,
)
else:
subprocess.run(
[
"optimum-cli",
"export",
"onnx",
"--model",
f"{input_model}",
"--task",
"text-generation-with-past",
f"{output_model}",
],
stdout=subprocess.PIPE,
text=True,
)
if Version(optimum.version.__version__) < OPTIMUM114_VERSION:
raise ImportError("Please upgrade optimum to >= 1.14.0")

subprocess.run(
[
"optimum-cli",
"export",
"onnx",
"--model",
f"{input_model}",
"--task",
task,
f"{output_model}",
],
stdout=subprocess.PIPE,
text=True,
)

assert os.path.exists(output_model), f"{output_model} doesn't exist!"


if __name__ == "__main__":
args = parse_arguments()
prepare_model(args.input_model, args.output_model)
prepare_model(args.input_model, args.output_model, args.task)
Loading

0 comments on commit 26b260e

Please sign in to comment.