Skip to content

Commit

Permalink
Add GPTJ-6B onnx model example (#724)
Browse files Browse the repository at this point in the history
Signed-off-by: mengniwa <[email protected]>
Co-authored-by: chensuyue <[email protected]>
  • Loading branch information
mengniwang95 and chensuyue authored Apr 20, 2023
1 parent f248cdc commit ac5a671
Show file tree
Hide file tree
Showing 13 changed files with 703 additions and 8 deletions.
14 changes: 14 additions & 0 deletions examples/.config/model_params_onnxrt.json
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,20 @@
"main_script": "main.py",
"batch_size": 1
},
"gpt-j-6B": {
"model_src_dir": "nlp/huggingface_model/text_generation/quantization/ptq_static",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/gpt-j-6b/model.onnx",
"main_script": "main.py",
"batch_size": 1
},
"gpt-j-6B_dynamic": {
"model_src_dir": "nlp/huggingface_model/text_generation/quantization/ptq_dynamic",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/gpt-j-6b/model.onnx",
"main_script": "main.py",
"batch_size": 1
},
"hf_roberta-large": {
"model_src_dir": "nlp/huggingface_model/question_answering/quantization/ptq_static",
"dataset_location": "/tf_dataset2/datasets/squad",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
Step-by-Step
============

This example load a gpt-j-6B model and confirm its accuracy and speed based on [lambada](https://huggingface.co/datasets/lambada).

# Prerequisite

## 1. Environment
```shell
pip install neural-compressor
pip install -r requirements.txt
```
> Note: Validated ONNX Runtime [Version](/docs/source/installation_guide.md#validated-software-environment).
## 2. Prepare Model

```bash
python -m transformers.onnx --model=EleutherAI/gpt-j-6B model/ --framework pt --opset 13 --feature=causal-lm-with-past
```

# Run

## 1. Quantization

Static quantization:

```bash
bash run_tuning.sh --input_model=/path/to/model \ # model path as *.onnx
--output_model=/path/to/model_tune \
--batch_size=batch_size # optional
```

## 2. Benchmark

```bash
bash run_benchmark.sh --input_model=path/to/model \ # model path as *.onnx
--mode=performance # or accuracy \
--batch_size=batch_size # optional
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name,logging-format-interpolation

import os
import onnx
import torch
import logging
import argparse
import numpy as np
from transformers import AutoTokenizer
from datasets import load_dataset
import onnxruntime as ort
from torch.nn.functional import pad
from torch.utils.data import DataLoader

logger = logging.getLogger(__name__)
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.WARN)

parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--model_path',
type=str,
help="Pre-trained resnet50 model on onnx file"
)
parser.add_argument(
'--benchmark',
action='store_true', \
default=False
)
parser.add_argument(
'--tune',
action='store_true', \
default=False,
help="whether quantize the model"
)
parser.add_argument(
'--output_model',
type=str,
default=None,
help="output model path"
)
parser.add_argument(
'--mode',
type=str,
help="benchmark mode of performance or accuracy"
)
parser.add_argument(
'--batch_size',
default=1,
type=int,
)
parser.add_argument(
'--model_name_or_path',
type=str,
help="pretrained model name or path",
default="EleutherAI/gpt-j-6B"
)
parser.add_argument(
'--workspace',
type=str,
help="workspace to save intermediate files",
default="nc_workspace"
)
parser.add_argument(
'--pad_max',
default=196,
type=int,
)
args = parser.parse_args()

# load model
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

def tokenize_function(examples):
example = tokenizer(examples['text'])
return example

def eval_func(onnx_model, dataloader, workspace, pad_max):
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
if isinstance(onnx_model, str):
model_path = onnx_model
else:
onnx.save(onnx_model, os.path.join(workspace, 'eval.onnx'), save_as_external_data=True)
model_path = os.path.join(workspace, 'eval.onnx')

session = ort.InferenceSession(model_path, options, providers=ort.get_available_providers())
inputs_names = [i.name for i in session.get_inputs()]

total, hit = 0, 0
pad_len = 0

for idx, (batch, last_ind) in enumerate(dataloader):
ort_inputs = dict(zip(inputs_names, batch))
label = torch.from_numpy(batch[0][torch.arange(len(last_ind)), last_ind])
pad_len = pad_max - last_ind - 1

predictions = session.run(None, ort_inputs)
outputs = torch.from_numpy(predictions[0])

last_token_logits = outputs[torch.arange(len(last_ind)), -2 - pad_len, :]
pred = last_token_logits.argmax(dim=-1)
total += len(label)
hit += (pred == label).sum().item()

acc = hit / total
return acc

class Dataloader:
def __init__(self, pad_max=196, batch_size=1):
self.pad_max = pad_max
self.batch_size=batch_size
dataset = load_dataset('lambada', split='validation')
dataset = dataset.map(tokenize_function, batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
self.dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=self.collate_batch,
)

def collate_batch(self, batch):

input_ids_padded = []
attention_mask_padded = []
last_ind = []

for text in batch:
input_ids = text["input_ids"]
pad_len = self.pad_max - input_ids.shape[0]
last_ind.append(input_ids.shape[0] - 1)
attention_mask = torch.ones(len(input_ids) + 1)
attention_mask[0] = 0
input_ids = pad(input_ids, (0, pad_len), value=1)
input_ids_padded.append(input_ids)
attention_mask = pad(attention_mask, (0, pad_len), value=0)
attention_mask_padded.append(attention_mask)

return (torch.vstack(input_ids_padded), torch.vstack(attention_mask_padded)), torch.tensor(last_ind)


def __iter__(self):
try:
for (input_ids, attention_mask), last_ind in self.dataloader:
data = [input_ids.detach().cpu().numpy().astype('int64')]
for i in range(28):
data.append(np.zeros((input_ids.shape[0],16,1,256), dtype='float32'))
data.append(np.zeros((input_ids.shape[0],16,1,256), dtype='float32'))
data.append(attention_mask.detach().cpu().numpy().astype('int64'))
yield data, last_ind.detach().cpu().numpy()
except StopIteration:
return

if __name__ == "__main__":
from neural_compressor import set_workspace
set_workspace(args.workspace)

dataloader = Dataloader(pad_max=args.pad_max, batch_size=args.batch_size)
def eval(model):
return eval_func(model, dataloader, args.workspace, args.pad_max)

if args.benchmark:
if args.mode == 'performance':
from neural_compressor.benchmark import fit
from neural_compressor.config import BenchmarkConfig
conf = BenchmarkConfig(iteration=100,
cores_per_instance=28,
num_of_instance=1)
fit(args.model_path, conf, b_dataloader=dataloader)
elif args.mode == 'accuracy':
acc_result = eval(args.model_path)
print("Batch size = %d" % args.batch_size)
print("Accuracy: %.5f" % acc_result)

if args.tune:
from neural_compressor import quantization, PostTrainingQuantConfig
config = PostTrainingQuantConfig(
approach='dynamic',
op_type_dict={'^((?!(MatMul|Gather)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
q_model = quantization.fit(args.model_path, config, eval_func=eval)
q_model.save(args.output_model)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torch
transformers
onnx
onnxruntime
onnxruntime-extensions; python_version < '3.10'
datasets
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash
set -x

function main {

init_params "$@"
run_benchmark

}

# init params
function init_params {
for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--mode=*)
mode=$(echo $var |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_benchmark
function run_benchmark {

python main.py \
--model_path ${input_model} \
--mode=${mode} \
--batch_size=${batch_size-1} \
--benchmark

}

main "$@"

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash
set -x

function main {
init_params "$@"
run_tuning
}

# init params
function init_params {
for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--output_model=*)
output_model=$(echo $var |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_tuning
function run_tuning {

python main.py \
--model_path ${input_model} \
--output_model ${output_model} \
--batch_size ${batch_size-1} \
--tune
}

main "$@"



Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Step-by-Step
============

This example load a gpt-j-6B model and confirm its accuracy and speed based on [lambada](https://huggingface.co/datasets/lambada).

# Prerequisite

## 1. Environment
```shell
pip install neural-compressor
pip install -r requirements.txt
```
> Note: Validated ONNX Runtime [Version](/docs/source/installation_guide.md#validated-software-environment).
## 2. Prepare Model

```bash
python -m transformers.onnx --model=EleutherAI/gpt-j-6B model/ --framework pt --opset 13 --feature=causal-lm-with-past
```

# Run

## 1. Quantization

Static quantization:

```bash
bash run_tuning.sh --input_model=/path/to/model \ # model path as *.onnx
--output_model=/path/to/model_tune \
--batch_size=batch_size # optional \
--quant_format="QOperator" # or QDQ, optional
```

## 2. Benchmark

```bash
bash run_benchmark.sh --input_model=path/to/model \ # model path as *.onnx
--batch_size=batch_size # optional \
--mode=performance # or accuracy
```
Loading

0 comments on commit ac5a671

Please sign in to comment.