Skip to content

Commit

Permalink
🦙 llama2 optimization (#641)
Browse files Browse the repository at this point in the history
## Describe your changes

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Format your code by running `pre-commit run --all-files`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
  • Loading branch information
trajepl authored Oct 28, 2023
1 parent 819c25a commit 649e314
Show file tree
Hide file tree
Showing 14 changed files with 873 additions and 19 deletions.
2 changes: 2 additions & 0 deletions examples/llama2/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
llama2_cpu*
llama2_gpu*
96 changes: 96 additions & 0 deletions examples/llama2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Llama2 optimization using ORT toolchain
This folder contains a sample use case of Olive to optimize a [Llama2](https://huggingface.co/meta-llama/Llama-2-7b-hf) model using ONNXRuntime tools.

Performs optimization pipeline:
- CPU, FP32: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp32*
- CPU, INT8: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp32 -> Onnx Dynamic Quantization*
- CPU, INT4: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp32 -> Onnx Block wise int4 Quantization*
- GPU, FP32: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp32*
- GPU, FP16: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp16 + Grouped Query Attention*
- GPU, INT4: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp16 + Grouped Query Attention -> Onnx Block wise int4 Quantization*

**Note that**: Currently, grouped query attention is only supported on GPU with fp16 and it requires the cuda architecture >= 80. You can just set `use_gqa` to `false` in the config file to disable it.
```json
"transformers_optimization_fp16": {
"type": "OrtTransformersOptimization",
"disable_search": true,
"evaluator": "gqa_evaluator",
"config": {
"save_as_external_data": true,
"all_tensors_to_one_file": true,
"model_type": "gpt2",
"opt_level": 0,
"only_onnxruntime": false,
"keep_io_types": false,
"float16": true,
"use_gqa": false // <----------- disable gqa
}
}
```

## Prerequisites
### Clone the repository and install Olive

Refer to the instructions in the [examples README](../README.md) to clone the repository and install Olive.

### Install onnxruntime
Also we need latest version of onnxruntime which provides the support of int4 quantization/grouped query attention. Please install the latest version of onnxruntime:

1. From source:
```bash
git clone https://github.com/microsoft/onnxruntime
# compile ort with cuda support, which requires the image with cuda and cudnn installed
bash ./build.sh \
--config=Release \
--build_dir="./test_build" \
--cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ \
--cuda_version=11.7 \
--use_cuda --update --build \
--build_wheel \
--parallel \
--skip_tests --cmake_extra_defines ONNXRUNTIME_VERSION=(cat ./VERSION_NUMBER) \CMAKE_CUDA_ARCHITECTURES="70;75;80" \
--use_mpi=false
```
Then you can find the wheel file under folder of `build_dir`(`test_build/Release/dist/` in this case).

2. From nightly-build:

Installation package table: https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages

After installation, you can run the following command to check if the onnxruntime is installed successfully:
```python
import onnxruntime as ort
ort.get_available_providers() # should contain 'CUDAExecutionProvider'
```

### Install extra dependencies
Install the necessary python packages:
```
python -m pip install -r requirements.txt
```

## Run the config to optimize the model
You can only generate the optimized config file by running the following command for double checking before running the optimization pipeline:
```bash
python llama2.py --model_name meta-llama/Llama-2-7b-hf --only_flag
```

Or you can run the following command to directly optimize the model:

CPU:
```bash
# run to optimize the model: FP32/INT8/INT4
python llama2.py --model_name meta-llama/Llama-2-7b-hf
```

GPU:
```bash
# run to optimize the model: FP32/INT8/INT4
python llama2.py --model_name meta-llama/Llama-2-7b-hf --gpu
# use gqa instead of mha
python llama2.py --model_name meta-llama/Llama-2-7b-hf --gpu --use_gqa
```

## TODO
- [ ] Add generation example of the optimized model.
- [ ] Attach the benchmark results.
120 changes: 120 additions & 0 deletions examples/llama2/llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import argparse
import json
import re
from pathlib import Path

from onnxruntime import __version__ as OrtVersion
from packaging import version

import olive.workflows.run as olive_run

SUPPORTED_WORKFLOWS = {
"cpu": [
["conversion_merged", "transformers_optimization_fp32"],
["conversion_merged", "transformers_optimization_fp32", "onnx_dynamic_quant_int8"],
["conversion_merged", "transformers_optimization_fp32", "blockwise_quant_int4"],
],
"gpu": [
["conversion_merged", "transformers_optimization_fp16"],
["conversion_merged", "transformers_optimization_fp16", "blockwise_quant_int4"],
],
}
DEVICE_TO_EP = {
"cpu": "CPUExecutionProvider",
"gpu": "CUDAExecutionProvider",
}


def get_args(raw_args):
parser = argparse.ArgumentParser(description="Llama2 optimization")
parser.add_argument(
"--model_name",
type=str,
default="meta-llama/Llama-2-7b-hf",
help="Model name, currently only supports llama2 7B/13B",
)
parser.add_argument("--gpu", action="store_true", required=False, help="Whether to use gpu for optimization.")
parser.add_argument(
"--use_gqa",
action="store_true",
required=False,
help="Whether to use GQA(grouped query attention) instead of MHA(multi-head attention).",
)
parser.add_argument(
"--only_config",
action="store_true",
required=False,
help="Whether to use GQA(grouped query attention) instead of MHA(multi-head attention).",
)

return parser.parse_args(raw_args)


def main(raw_args=None):
args = get_args(raw_args)

# version check
version_1_17 = version.parse(OrtVersion) >= version.parse("1.17.0")

if not version_1_17:
raise ValueError("Please use onnxruntime>=1.17.0 for llama2 optimization")

json_file_template = "llama2_template.json"
with open(json_file_template) as f: # noqa: PTH123
template_json = json.load(f)

model_name = args.model_name
# update model name
template_json["input_model"]["config"]["hf_config"]["model_name"] = model_name

# update ep
device = "cpu" if not args.gpu else "gpu"
template_json["pass_flows"] = SUPPORTED_WORKFLOWS[device]

template_json["engine"]["execution_providers"] = [DEVICE_TO_EP[device]]
template_json["engine"]["output_dir"] = f"llama2_{device}/{model_name}"

if not args.use_gqa and args.gpu:
template_json["passes"]["transformers_optimization_fp16"]["config"]["use_gqa"] = False
# after applying GQA, the model's input will be changed, we need to remove the special dataloader implementation

del template_json["passes"]["transformers_optimization_fp16"]["evaluator"]
del template_json["passes"]["blockwise_quant_int4"]["evaluator"]
del template_json["evaluators"]["gqa_evaluator"]

# update user script
user_script_path = Path(__file__).parent / "user_script.py"
update_user_script(user_script_path, model_name)

device = "gpu" if args.gpu else "cpu"
gqa = "gqa" if args.use_gqa else "mha"
# dump config
with open(f"llama2_{device}_{gqa}.json", "w") as f: # noqa: PTH123
json.dump(template_json, f, indent=4)

if not args.only_config:
olive_run(template_json) # pylint: disable=not-callable


def update_user_script(file_path, model_name):
with open(file_path) as file: # noqa: PTH123
lines = file.readlines()

new_lines = []
for line in lines:
updated_line = line
if "meta-llama/Llama-2" in line:
updated_line = re.sub(r"meta-llama/Llama-2-(\d+)b-hf", model_name, line)
new_lines.append(updated_line)

with open(file_path, "w") as file: # noqa: PTH123
file.writelines(new_lines)


if __name__ == "__main__":
main()
134 changes: 134 additions & 0 deletions examples/llama2/llama2_template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
{
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"io_config": "get_merged_decoder_with_past_io_config",
"dummy_inputs_func": "get_merged_decoder_with_past_kv_inputs",
"hf_config": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_class": "LlamaForCausalLM"
}
}
},
"evaluators": {
"merged_evaluator": {
"metrics": [
{
"name": "onnx_merged_latency",
"type": "latency",
"sub_types": [
{
"name": "avg",
"priority": 1
}
],
"user_config": {
"user_script": "user_script.py",
"dataloader_func": "dataloader_func_for_merged",
"batch_size": 1,
"io_bind": true
}
}
]
},
"gqa_evaluator": {
"metrics": [
{
"name": "onnx_merged_latency",
"type": "latency",
"sub_types": [
{
"name": "avg",
"priority": 1
}
],
"user_config": {
"user_script": "user_script.py",
"dataloader_func": "dataloader_func_for_merged_gqa",
"batch_size": 1,
"io_bind": true
}
}
]
}
},
"passes": {
"conversion_merged": {
"type": "OnnxConversion",
"config": {
"target_opset": 13,
"save_as_external_data": true,
"all_tensors_to_one_file": true
}
},
"transformers_optimization_fp16": {
"type": "OrtTransformersOptimization",
"disable_search": true,
"evaluator": "gqa_evaluator",
"config": {
"save_as_external_data": true,
"all_tensors_to_one_file": true,
"model_type": "gpt2",
"opt_level": 0,
"only_onnxruntime": false,
"keep_io_types": false,
"float16": true,
"use_gqa": true
}
},
"transformers_optimization_fp32": {
"type": "OrtTransformersOptimization",
"disable_search": true,
"config": {
"save_as_external_data": true,
"all_tensors_to_one_file": true,
"model_type": "gpt2",
"opt_level": 0,
"only_onnxruntime": false,
"keep_io_types": false,
"float16": false,
"use_gqa": false
}
},
"onnx_dynamic_quant_int8": {
"type": "OnnxDynamicQuantization",
"disable_search": true,
"config": {
"save_as_external_data": true,
"all_tensors_to_one_file": true,
"op_types_to_quantize": [
"MatMul",
"Gemm"
],
"per_channel": false,
"reduce_range": false,
"MatMulConstBOnly": true
}
},
"blockwise_quant_int4": {
"type": "OnnxMatMul4Quantizer",
"disable_search": true,
"evaluator": "gqa_evaluator",
"config": {
"save_as_external_data": true,
"all_tensors_to_one_file": true,
"block_size": 32,
"is_symmetric": true
}
}
},
"engine": {
"search_strategy": {
"execution_order": "pass-by-pass",
"search_algorithm": "tpe",
"search_algorithm_config": {
"num_samples": 3,
"seed": 0
}
},
"evaluator": "merged_evaluator",
"cache_dir": "cache",
"output_dir": "models/llama2-7b"
}
}
6 changes: 6 additions & 0 deletions examples/llama2/requirement.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
git+https://github.com/huggingface/optimum.git
transformers>=4.33.2
onnx>=1.14.0
datasets>=2.8.0
protobuf==3.20.2
torch -i https://download.pytorch.org/whl/nightly/cu118
Loading

0 comments on commit 649e314

Please sign in to comment.