From 649e314b0a1d8fcb20b320ba83e503e467403cbb Mon Sep 17 00:00:00 2001 From: trajep Date: Sat, 28 Oct 2023 13:30:40 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=A6=99=20llama2=20optimization=20(#641)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Describe your changes ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Format your code by running `pre-commit run --all-files` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link --- examples/llama2/.gitignore | 2 + examples/llama2/README.md | 96 +++++ examples/llama2/llama2.py | 120 ++++++ examples/llama2/llama2_template.json | 134 +++++++ examples/llama2/requirement.txt | 6 + examples/llama2/user_script.py | 354 ++++++++++++++++++ olive/model/__init__.py | 14 +- olive/model/hf_utils.py | 3 +- olive/passes/onnx/common.py | 1 + olive/passes/onnx/conversion.py | 17 +- olive/passes/onnx/inc_quantization.py | 14 +- olive/passes/onnx/optimum_conversion.py | 23 +- olive/passes/onnx/quantization.py | 49 +++ olive/passes/onnx/transformer_optimization.py | 59 ++- 14 files changed, 873 insertions(+), 19 deletions(-) create mode 100644 examples/llama2/.gitignore create mode 100644 examples/llama2/README.md create mode 100644 examples/llama2/llama2.py create mode 100644 examples/llama2/llama2_template.json create mode 100644 examples/llama2/requirement.txt create mode 100644 examples/llama2/user_script.py diff --git a/examples/llama2/.gitignore b/examples/llama2/.gitignore new file mode 100644 index 000000000..d3e82a03b --- /dev/null +++ b/examples/llama2/.gitignore @@ -0,0 +1,2 @@ +llama2_cpu* +llama2_gpu* diff --git a/examples/llama2/README.md b/examples/llama2/README.md new file mode 100644 index 000000000..db79ced00 --- /dev/null +++ b/examples/llama2/README.md @@ -0,0 +1,96 @@ +# Llama2 optimization using ORT toolchain +This folder contains a sample use case of Olive to optimize a [Llama2](https://huggingface.co/meta-llama/Llama-2-7b-hf) model using ONNXRuntime tools. + +Performs optimization pipeline: +- CPU, FP32: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp32* +- CPU, INT8: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp32 -> Onnx Dynamic Quantization* +- CPU, INT4: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp32 -> Onnx Block wise int4 Quantization* +- GPU, FP32: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp32* +- GPU, FP16: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp16 + Grouped Query Attention* +- GPU, INT4: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp16 + Grouped Query Attention -> Onnx Block wise int4 Quantization* + +**Note that**: Currently, grouped query attention is only supported on GPU with fp16 and it requires the cuda architecture >= 80. You can just set `use_gqa` to `false` in the config file to disable it. +```json +"transformers_optimization_fp16": { + "type": "OrtTransformersOptimization", + "disable_search": true, + "evaluator": "gqa_evaluator", + "config": { + "save_as_external_data": true, + "all_tensors_to_one_file": true, + "model_type": "gpt2", + "opt_level": 0, + "only_onnxruntime": false, + "keep_io_types": false, + "float16": true, + "use_gqa": false // <----------- disable gqa + } +} +``` + +## Prerequisites +### Clone the repository and install Olive + +Refer to the instructions in the [examples README](../README.md) to clone the repository and install Olive. + +### Install onnxruntime +Also we need latest version of onnxruntime which provides the support of int4 quantization/grouped query attention. Please install the latest version of onnxruntime: + +1. From source: + ```bash + git clone https://github.com/microsoft/onnxruntime + # compile ort with cuda support, which requires the image with cuda and cudnn installed + bash ./build.sh \ + --config=Release \ + --build_dir="./test_build" \ + --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ \ + --cuda_version=11.7 \ + --use_cuda --update --build \ + --build_wheel \ + --parallel \ + --skip_tests --cmake_extra_defines ONNXRUNTIME_VERSION=(cat ./VERSION_NUMBER) \CMAKE_CUDA_ARCHITECTURES="70;75;80" \ + --use_mpi=false + ``` +Then you can find the wheel file under folder of `build_dir`(`test_build/Release/dist/` in this case). + +2. From nightly-build: + + Installation package table: https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages + +After installation, you can run the following command to check if the onnxruntime is installed successfully: +```python +import onnxruntime as ort +ort.get_available_providers() # should contain 'CUDAExecutionProvider' +``` + +### Install extra dependencies +Install the necessary python packages: +``` +python -m pip install -r requirements.txt +``` + +## Run the config to optimize the model +You can only generate the optimized config file by running the following command for double checking before running the optimization pipeline: +```bash +python llama2.py --model_name meta-llama/Llama-2-7b-hf --only_flag +``` + +Or you can run the following command to directly optimize the model: + +CPU: +```bash +# run to optimize the model: FP32/INT8/INT4 +python llama2.py --model_name meta-llama/Llama-2-7b-hf +``` + +GPU: +```bash +# run to optimize the model: FP32/INT8/INT4 +python llama2.py --model_name meta-llama/Llama-2-7b-hf --gpu +# use gqa instead of mha +python llama2.py --model_name meta-llama/Llama-2-7b-hf --gpu --use_gqa +``` + +## TODO +- [ ] Add generation example of the optimized model. +- [ ] Attach the benchmark results. diff --git a/examples/llama2/llama2.py b/examples/llama2/llama2.py new file mode 100644 index 000000000..4d55a9334 --- /dev/null +++ b/examples/llama2/llama2.py @@ -0,0 +1,120 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import argparse +import json +import re +from pathlib import Path + +from onnxruntime import __version__ as OrtVersion +from packaging import version + +import olive.workflows.run as olive_run + +SUPPORTED_WORKFLOWS = { + "cpu": [ + ["conversion_merged", "transformers_optimization_fp32"], + ["conversion_merged", "transformers_optimization_fp32", "onnx_dynamic_quant_int8"], + ["conversion_merged", "transformers_optimization_fp32", "blockwise_quant_int4"], + ], + "gpu": [ + ["conversion_merged", "transformers_optimization_fp16"], + ["conversion_merged", "transformers_optimization_fp16", "blockwise_quant_int4"], + ], +} +DEVICE_TO_EP = { + "cpu": "CPUExecutionProvider", + "gpu": "CUDAExecutionProvider", +} + + +def get_args(raw_args): + parser = argparse.ArgumentParser(description="Llama2 optimization") + parser.add_argument( + "--model_name", + type=str, + default="meta-llama/Llama-2-7b-hf", + help="Model name, currently only supports llama2 7B/13B", + ) + parser.add_argument("--gpu", action="store_true", required=False, help="Whether to use gpu for optimization.") + parser.add_argument( + "--use_gqa", + action="store_true", + required=False, + help="Whether to use GQA(grouped query attention) instead of MHA(multi-head attention).", + ) + parser.add_argument( + "--only_config", + action="store_true", + required=False, + help="Whether to use GQA(grouped query attention) instead of MHA(multi-head attention).", + ) + + return parser.parse_args(raw_args) + + +def main(raw_args=None): + args = get_args(raw_args) + + # version check + version_1_17 = version.parse(OrtVersion) >= version.parse("1.17.0") + + if not version_1_17: + raise ValueError("Please use onnxruntime>=1.17.0 for llama2 optimization") + + json_file_template = "llama2_template.json" + with open(json_file_template) as f: # noqa: PTH123 + template_json = json.load(f) + + model_name = args.model_name + # update model name + template_json["input_model"]["config"]["hf_config"]["model_name"] = model_name + + # update ep + device = "cpu" if not args.gpu else "gpu" + template_json["pass_flows"] = SUPPORTED_WORKFLOWS[device] + + template_json["engine"]["execution_providers"] = [DEVICE_TO_EP[device]] + template_json["engine"]["output_dir"] = f"llama2_{device}/{model_name}" + + if not args.use_gqa and args.gpu: + template_json["passes"]["transformers_optimization_fp16"]["config"]["use_gqa"] = False + # after applying GQA, the model's input will be changed, we need to remove the special dataloader implementation + + del template_json["passes"]["transformers_optimization_fp16"]["evaluator"] + del template_json["passes"]["blockwise_quant_int4"]["evaluator"] + del template_json["evaluators"]["gqa_evaluator"] + + # update user script + user_script_path = Path(__file__).parent / "user_script.py" + update_user_script(user_script_path, model_name) + + device = "gpu" if args.gpu else "cpu" + gqa = "gqa" if args.use_gqa else "mha" + # dump config + with open(f"llama2_{device}_{gqa}.json", "w") as f: # noqa: PTH123 + json.dump(template_json, f, indent=4) + + if not args.only_config: + olive_run(template_json) # pylint: disable=not-callable + + +def update_user_script(file_path, model_name): + with open(file_path) as file: # noqa: PTH123 + lines = file.readlines() + + new_lines = [] + for line in lines: + updated_line = line + if "meta-llama/Llama-2" in line: + updated_line = re.sub(r"meta-llama/Llama-2-(\d+)b-hf", model_name, line) + new_lines.append(updated_line) + + with open(file_path, "w") as file: # noqa: PTH123 + file.writelines(new_lines) + + +if __name__ == "__main__": + main() diff --git a/examples/llama2/llama2_template.json b/examples/llama2/llama2_template.json new file mode 100644 index 000000000..158f38a88 --- /dev/null +++ b/examples/llama2/llama2_template.json @@ -0,0 +1,134 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_script": "user_script.py", + "io_config": "get_merged_decoder_with_past_io_config", + "dummy_inputs_func": "get_merged_decoder_with_past_kv_inputs", + "hf_config": { + "model_name": "meta-llama/Llama-2-7b-hf", + "model_class": "LlamaForCausalLM" + } + } + }, + "evaluators": { + "merged_evaluator": { + "metrics": [ + { + "name": "onnx_merged_latency", + "type": "latency", + "sub_types": [ + { + "name": "avg", + "priority": 1 + } + ], + "user_config": { + "user_script": "user_script.py", + "dataloader_func": "dataloader_func_for_merged", + "batch_size": 1, + "io_bind": true + } + } + ] + }, + "gqa_evaluator": { + "metrics": [ + { + "name": "onnx_merged_latency", + "type": "latency", + "sub_types": [ + { + "name": "avg", + "priority": 1 + } + ], + "user_config": { + "user_script": "user_script.py", + "dataloader_func": "dataloader_func_for_merged_gqa", + "batch_size": 1, + "io_bind": true + } + } + ] + } + }, + "passes": { + "conversion_merged": { + "type": "OnnxConversion", + "config": { + "target_opset": 13, + "save_as_external_data": true, + "all_tensors_to_one_file": true + } + }, + "transformers_optimization_fp16": { + "type": "OrtTransformersOptimization", + "disable_search": true, + "evaluator": "gqa_evaluator", + "config": { + "save_as_external_data": true, + "all_tensors_to_one_file": true, + "model_type": "gpt2", + "opt_level": 0, + "only_onnxruntime": false, + "keep_io_types": false, + "float16": true, + "use_gqa": true + } + }, + "transformers_optimization_fp32": { + "type": "OrtTransformersOptimization", + "disable_search": true, + "config": { + "save_as_external_data": true, + "all_tensors_to_one_file": true, + "model_type": "gpt2", + "opt_level": 0, + "only_onnxruntime": false, + "keep_io_types": false, + "float16": false, + "use_gqa": false + } + }, + "onnx_dynamic_quant_int8": { + "type": "OnnxDynamicQuantization", + "disable_search": true, + "config": { + "save_as_external_data": true, + "all_tensors_to_one_file": true, + "op_types_to_quantize": [ + "MatMul", + "Gemm" + ], + "per_channel": false, + "reduce_range": false, + "MatMulConstBOnly": true + } + }, + "blockwise_quant_int4": { + "type": "OnnxMatMul4Quantizer", + "disable_search": true, + "evaluator": "gqa_evaluator", + "config": { + "save_as_external_data": true, + "all_tensors_to_one_file": true, + "block_size": 32, + "is_symmetric": true + } + } + }, + "engine": { + "search_strategy": { + "execution_order": "pass-by-pass", + "search_algorithm": "tpe", + "search_algorithm_config": { + "num_samples": 3, + "seed": 0 + } + }, + "evaluator": "merged_evaluator", + "cache_dir": "cache", + "output_dir": "models/llama2-7b" + } +} diff --git a/examples/llama2/requirement.txt b/examples/llama2/requirement.txt new file mode 100644 index 000000000..e918f30bf --- /dev/null +++ b/examples/llama2/requirement.txt @@ -0,0 +1,6 @@ +git+https://github.com/huggingface/optimum.git +transformers>=4.33.2 +onnx>=1.14.0 +datasets>=2.8.0 +protobuf==3.20.2 +torch -i https://download.pytorch.org/whl/nightly/cu118 diff --git a/examples/llama2/user_script.py b/examples/llama2/user_script.py new file mode 100644 index 000000000..c3e03b592 --- /dev/null +++ b/examples/llama2/user_script.py @@ -0,0 +1,354 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from itertools import chain +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import LlamaConfig, LlamaTokenizer + +from olive.constants import Framework + + +def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if use_past_kv: + position_ids = position_ids[:, -1].unsqueeze(-1) + return position_ids + + +def get_decoder_inputs(model, batch_size=2, seq_len=100, model_id=""): + device = torch.device("cpu") + if model_id: + config = LlamaConfig.from_pretrained(model_id) + else: + config = LlamaConfig.from_pretrained(model.hf_config.model_name) + + input_ids = torch.randint( + low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 + ) + attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64) + # position_ids is of shape (batch_size, seq_len) + position_ids = get_position_ids(attention_mask, use_past_kv=False) + + return (input_ids, attention_mask, position_ids) + + +def get_decoder_with_past_kv_inputs(model, batch_size=2, seq_len=1, past_seq_len=100, use_fp16=False, model_id=""): + if model_id: + config = LlamaConfig.from_pretrained(model_id) + else: + config = LlamaConfig.from_pretrained(model.hf_config.model_name) + + device = torch.device("cpu") + + input_ids = torch.randint( + low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 + ) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) + # position_ids is of shape (batch_size, 1) + position_ids = get_position_ids(attention_mask, use_past_kv=True) + past_key_values = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16=use_fp16) + + return (input_ids, attention_mask, position_ids, past_key_values) + + +def get_merged_decoder_with_past_kv_inputs(model, batch_size=2, seq_len=8, past_seq_len=0, use_fp16=False, model_id=""): + input_ids, attention_mask, position_ids, past_key_values = get_decoder_with_past_kv_inputs( + model, batch_size, seq_len, past_seq_len, use_fp16, model_id + ) + # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation + position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) + + return input_ids, attention_mask, position_ids, past_key_values + + +def get_sample_past_kv_inputs( + config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool +): + num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads + torch_dtype = torch.float16 if use_fp16 else torch.float32 + return [ + ( + torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), + ) + for _ in range(config.num_hidden_layers) + ] + + +def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool): + past_kv = {} + np_dtype = np.float16 if use_fp16 else np.float32 + # Convert list of past_kv to dict of past_key and past_value + for i, (past_k, past_v) in enumerate(past_key_values): + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype) + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype) + return past_kv + + +def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in input_names: + # shape is (batch_size, sequence_length) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif name == "logits": + # shape is (batch_size, sequence_length, vocab_size) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif "present" in name: + # shape is (batch_size, num_heads, sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"} + else: + raise ValueError("Unknown input or output name found") + return dynamic_axes + + +def get_decoder_io_config(model_name, merged=False): + config = LlamaConfig.from_pretrained(model_name) + + input_names = ["input_ids", "attention_mask", "position_ids"] + output_names = [ + "logits", + *list(chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(config.num_hidden_layers))), + ] + dynamic_axes = get_model_dynamic_axes(input_names, output_names) + return { + "input_names": input_names, + "dynamic_axes": dynamic_axes, + "output_names": output_names, + } + + +def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in {"input_ids", "position_ids"}: + # shape is (batch_size, 1) + dynamic_axes[name] = {0: "batch_size"} + elif name == "attention_mask": + # shape is (batch_size, past_sequence_length + 1) + dynamic_axes[name] = {0: "batch_size", 1: "past_sequence_length + 1"} + elif "past" in name: + # shape is (batch_size, num_heads, past_sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"} + elif name == "logits": + # shape is (batch_size, 1, vocab_size) + dynamic_axes[name] = {0: "batch_size"} + elif "present" in name: + # shape is (batch_size, num_heads, past_sequence_length + 1, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length + 1"} + else: + raise ValueError("Unknown input or output name found") + return dynamic_axes + + +def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in {"input_ids", "position_ids"}: + # shape is (batch_size, sequence_length) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif name == "attention_mask": + # shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"} + elif "past" in name: + # shape is (batch_size, num_heads, past_sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"} + elif name == "logits": + # shape is (batch_size, sequence_length, vocab_size) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif "present" in name: + # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size) + # = (batch_size, num_heads, total_sequence_length, head_size) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"} + else: + raise ValueError("Unknown input or output name found") + return dynamic_axes + + +def get_decoder_with_past_io_config(model_name): + config = LlamaConfig.from_pretrained(model_name) + io_config = get_decoder_io_config(model_name) + + io_config["input_names"].extend( + list( + chain.from_iterable( + (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(config.num_hidden_layers) + ) + ) + ) + io_config["dynamic_axes"] = get_model_with_past_kv_dynamic_axes(io_config["input_names"], io_config["output_names"]) + return io_config + + +def get_merged_decoder_with_past_io_config(model_name): + io_config = get_decoder_with_past_io_config(model_name) + io_config["dynamic_axes"] = get_merged_model_dynamic_axes(io_config["input_names"], io_config["output_names"]) + return io_config + + +class RandomDataLoader: + def __init__( + self, + create_inputs_func, + batch_size, + torch_dtype, + model_framework=Framework.PYTORCH, + onnx_merged=False, + use_gqa=False, + ): + self.create_input_func = create_inputs_func + self.batch_size = batch_size + self.torch_dtype = torch_dtype + self.model_framework = model_framework + self.onnx_merged = onnx_merged + self.use_gqa = use_gqa + + def __getitem__(self, idx): + label = None + return ( + self.create_input_func( + self.batch_size, self.torch_dtype, self.model_framework, self.onnx_merged, self.use_gqa + ), + label, + ) + + +def dummy_inputs_for_latency( + batch_size, torch_dtype, model_framework=Framework.PYTORCH, onnx_merged=False, use_gqa=False +): + model_id = "meta-llama/Llama-2-7b-hf" + if onnx_merged: + input_ids, attention_mask, position_ids, pkv = get_merged_decoder_with_past_kv_inputs( + model=None, model_id=model_id + ) + else: + input_ids, attention_mask, position_ids, pkv = get_decoder_with_past_kv_inputs(model=None, model_id=model_id) + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": pkv, + } + if model_framework == Framework.ONNX: + inputs.update(flatten_past_kv_inputs(pkv, use_fp16=torch_dtype == torch.float16)) + inputs["use_cache_branch"] = torch.ones((1,), dtype=torch.bool) + del inputs["past_key_values"] + else: + inputs["use_cache"] = True + + if use_gqa: + past_seq_len = 0 + inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) + return inputs + + +def dataloader_func(data_dir, batch_size, *args, **kwargs): + model_framework = kwargs.get("model_framework", Framework.PYTORCH) + return RandomDataLoader(dummy_inputs_for_latency, batch_size, torch.float16, model_framework) + + +def dataloader_func_for_merged(data_dir, batch_size, *args, **kwargs): + # TODO(trajep): after optimization, the model's input will be different + model_framework = kwargs.get("model_framework", Framework.PYTORCH) + return RandomDataLoader(dummy_inputs_for_latency, batch_size, torch.float16, model_framework, True) + + +def dataloader_func_for_merged_gqa(data_dir, batch_size, *args, **kwargs): + model_framework = kwargs.get("model_framework", Framework.PYTORCH) + return RandomDataLoader( + dummy_inputs_for_latency, batch_size, torch.float16, model_framework, onnx_merged=True, use_gqa=True + ) + + +def inc_cali_dataloader_func(data_dir, batch_size, *args, **kwargs): + return QuantKVDataLoader( + hf_model_id="meta-llama/Llama-2-7b-hf", + dataset_name="NeelNanda/pile-10k", + ) + + +def inc_cali_merged_dataloader_func(data_dir, batch_size, *args, **kwargs): + return QuantKVDataLoader( + hf_model_id="meta-llama/Llama-2-7b-hf", + dataset_name="NeelNanda/pile-10k", + merged=True, + ) + + +class QuantKVDataLoader: + def __init__(self, hf_model_id: str = "", dataset_name: str = "", pad_max: int = 196, merged: bool = False): + self.batch_size = 1 + self.pad_max = pad_max + self.merged = merged + + tokenizer = LlamaTokenizer.from_pretrained(hf_model_id) + dataset = load_dataset(dataset_name, split="train") + dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True) + dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) + self.dataloader = DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=False, + collate_fn=self.collate_batch, + ) + + def collate_batch(self, batch): + input_ids_batched = [] + attention_mask_batched = [] + position_ids_batched = [] + labels = [] + + for text in batch: + # Set inputs for model + input_ids = text["input_ids"] + attention_mask = torch.ones(len(input_ids)) + position_ids = get_position_ids(attention_mask, use_past_kv=False) + label = len(input_ids) - 1 + + # Pad input data because all model inputs must have same shape + pad_len = self.pad_max - input_ids.shape[0] + # pylint: disable=not-callable + input_ids = F.pad(input_ids, (0, pad_len), value=1) + attention_mask = F.pad(attention_mask, (0, pad_len), value=0) + position_ids = F.pad(position_ids, (0, pad_len), value=0) + + input_ids_batched.append(input_ids) + attention_mask_batched.append(attention_mask) + position_ids_batched.append(position_ids) + labels.append(label) + + input_ids_batched = torch.vstack(input_ids_batched) + attention_mask_batched = torch.vstack(attention_mask_batched) + position_ids_batched = torch.vstack(position_ids_batched) + labels = torch.tensor(labels) + + return (input_ids_batched, attention_mask_batched, position_ids_batched), labels + + def __iter__(self): + for (input_ids, attention_mask, position_ids), labels in self.dataloader: + # Inputs for decoder_model.onnx + inputs = { + "input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64), + "attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64), + "position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64), + } + if self.merged: + inputs.pop("attention_mask", None) + label = labels.detach().cpu().numpy() + + # Yield (inputs, label) tuple for Intel's Neural Compressor: + # https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62 + yield (inputs, label) diff --git a/olive/model/__init__.py b/olive/model/__init__.py index 4a0052466..b9b6e751d 100644 --- a/olive/model/__init__.py +++ b/olive/model/__init__.py @@ -475,7 +475,7 @@ def __init__( model_loader: Union[str, Callable] = None, model_script: Union[str, Path] = None, script_dir: Union[str, Path] = None, - io_config: Union[Dict[str, Any], IOConfig] = None, + io_config: Union[Dict[str, Any], IOConfig, str] = None, dummy_inputs_func: Union[str, Callable] = None, hf_config: Union[Dict[str, Any], HFConfig] = None, adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None, @@ -524,6 +524,10 @@ def __init__( ), "model_script must be a local file or a string name." # io config for conversion to onnx + # TODO(trajep): support callable io_config + if isinstance(io_config, str): + user_module_loader = UserModuleLoader(self.model_script, self.script_dir) + io_config = user_module_loader.call_object(io_config, self.hf_config.model_name) self.io_config = validate_config(io_config, IOConfig).dict() if io_config else None self.dummy_inputs_func = dummy_inputs_func @@ -688,8 +692,12 @@ def get_component(self, component_name: str) -> "PyTorchModel": components_dict = {component.name: component for component in self.hf_config.components} hf_component = components_dict[component_name] - user_module_loader = UserModuleLoader(self.model_script, self.script_dir) - model_component = user_module_loader.call_object(hf_component.component_func, self.hf_config.model_name) + if hf_component.component_func is None: + logger.debug("component_func is not provided, using hf_config to get component") + model_component = self.hf_config.load_model(self.model_path) + else: + user_module_loader = UserModuleLoader(self.model_script, self.script_dir) + model_component = user_module_loader.call_object(hf_component.component_func, self.hf_config.model_name) io_config = hf_component.io_config if isinstance(io_config, str): diff --git a/olive/model/hf_utils.py b/olive/model/hf_utils.py index 7d3f67b31..e0aa2ff84 100644 --- a/olive/model/hf_utils.py +++ b/olive/model/hf_utils.py @@ -22,8 +22,9 @@ class HFComponent(ConfigBase): name: str + # TODO(trajep): support callable io_config io_config: Union[IOConfig, str, Dict[str, Any]] - component_func: Union[str, Callable] + component_func: Union[str, Callable] = None dummy_inputs_func: Union[str, Callable] diff --git a/olive/passes/onnx/common.py b/olive/passes/onnx/common.py index 5968c1e4b..32d0ce4e0 100644 --- a/olive/passes/onnx/common.py +++ b/olive/passes/onnx/common.py @@ -108,6 +108,7 @@ def model_proto_to_file( raise RuntimeError(f"Output directory ({output_dir}) for external data is not empty.") # save model + # TODO(trajep): complete the argument list onnx.save_model( model, str(output_path), diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index da3040e38..325c15e56 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -9,6 +9,7 @@ import onnx import torch +from packaging import version from olive.common.config_utils import validate_config from olive.common.utils import tensor_data_to_device @@ -98,13 +99,15 @@ def _convert_model_on_device( onnx_model = None if config["use_dynamo_exporter"]: - # TODO(xiaoyu): remove this import check once torch.onnx.dynamo_export is available in stable pytorch - try: - from torch.onnx import dynamo_export - except ImportError: - raise ImportError( - "torch.onnx.dynamo_export is not available. Please upgrade your pytorch version to nightly build." - ) from None + # available since torch==2.1.0 + torch_version = torch.__version__ + if version.parse(torch_version) < version.parse("2.1.0"): + raise RuntimeError( + f"torch.onnx.dynamo_export is not available for torch version {torch_version}. " + "Please upgrade your torch version to 2.1.0 or above." + ) + from torch.onnx import dynamo_export + exported = dynamo_export( pytorch_model, *dummy_inputs, diff --git a/olive/passes/onnx/inc_quantization.py b/olive/passes/onnx/inc_quantization.py index 6304300c5..86f285cc9 100644 --- a/olive/passes/onnx/inc_quantization.py +++ b/olive/passes/onnx/inc_quantization.py @@ -127,6 +127,13 @@ INC weight only quantization config. """, ), + "op_type_dict": PassConfigParam( + type_=dict, + default_value={}, + description=""" + INC weight only quantization config. + """, + ), } _inc_static_dataloader_config = { @@ -506,11 +513,16 @@ def _run_for_config( if key in run_config: del run_config[key] + run_config["op_type_dict"] = ( + run_config["op_type_dict"] or {".*": {"weight": weight_only_config}} + if run_config["approach"] == "weight_only" + else None + ) + ptq_config = PostTrainingQuantConfig( **run_config, accuracy_criterion=accuracy_criterion, tuning_criterion=tuning_criterion, - op_type_dict={".*": {"weight": weight_only_config}} if run_config["approach"] == "weight_only" else None, ) inc_calib_dataloader = None diff --git a/olive/passes/onnx/optimum_conversion.py b/olive/passes/onnx/optimum_conversion.py index 170dd9582..f0cdcb927 100644 --- a/olive/passes/onnx/optimum_conversion.py +++ b/olive/passes/onnx/optimum_conversion.py @@ -34,18 +34,29 @@ def _run_for_config( ) -> Union[ONNXModel, CompositeOnnxModel]: assert len(model.model_components) > 0 + from optimum import version as optimum_version from optimum.exporters.onnx import main_export as export_optimum_model + from packaging import version # TODO(jambayk): export into temp dir and then move to sub-dirs of output_model_path # so that we only keep the final model files in the output_model_path # and track external data if present hf_config = deepcopy(model.hf_config) or HFConfig() - export_optimum_model( - model.model_path or hf_config.model_name, - output_model_path, - opset=config["target_opset"], - no_post_process=True, - ) + if version.parse(optimum_version.__version__) < version.parse("1.14.0"): + export_optimum_model( + model.model_path or hf_config.model_name, + output_model_path, + opset=config["target_opset"], + no_post_process=True, + ) + else: + export_optimum_model( + model.model_path or hf_config.model_name, + output_model_path, + opset=config["target_opset"], + legacy=True, + no_post_process=True, + ) onnx_model_components = [ ONNXModel(str(Path(output_model_path) / model_component), model_attributes=model.model_attributes) diff --git a/olive/passes/onnx/quantization.py b/olive/passes/onnx/quantization.py index f2be1ddf9..3c00e6891 100644 --- a/olive/passes/onnx/quantization.py +++ b/olive/passes/onnx/quantization.py @@ -506,3 +506,52 @@ def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigPa # external data config config.update(get_external_data_config()) return config + + +class OnnxMatMul4Quantizer(Pass): + @staticmethod + def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: + config = { + "block_size": PassConfigParam( + type_=int, + default_value=32, + description="Block size for quantization. Default value is 32.", + ), + "is_symmetric": PassConfigParam( + type_=bool, + default_value=True, + description="Symmetric quantization. Default value is True.", + ), + "nodes_to_exclude": PassConfigParam( + type_=list, + default_value=[], + description="List of node names to exclude from quantization.", + ), + } + config.update(get_external_data_config()) + return config + + def _run_for_config( + self, model: ONNXModel, data_root: str, config: Dict[str, Any], output_model_path: str + ) -> ONNXModel: + from onnxruntime import __version__ as OrtVersion + + if version.parse(OrtVersion) < version.parse("1.17.0"): + raise OlivePassError("OnnxLlamaMatMulWeight4Quantizer is only supported in onnxruntime >= 1.17.0") + + from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer + + quant = MatMul4BitsQuantizer( + model.load_model(), config["block_size"], config["is_symmetric"], config["nodes_to_exclude"] + ) + quant.process() + + # TODO(trajep): add more options to save_model_to_file + new_tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp") + tmp_model_path = str(Path(new_tmp_dir.name) / Path(output_model_path).name) + quant.model.save_model_to_file(tmp_model_path, config["save_as_external_data"]) + + # load the model + onnx_model = onnx.load(tmp_model_path) + new_tmp_dir.cleanup() + return model_proto_to_olive_model(onnx_model, output_model_path, config) diff --git a/olive/passes/onnx/transformer_optimization.py b/olive/passes/onnx/transformer_optimization.py index c6d309925..55f12ee2c 100644 --- a/olive/passes/onnx/transformer_optimization.py +++ b/olive/passes/onnx/transformer_optimization.py @@ -5,7 +5,7 @@ import logging import os from copy import deepcopy -from typing import Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from olive.hardware.accelerator import AcceleratorSpec, Device from olive.model import ONNXModel @@ -15,6 +15,9 @@ from olive.passes.pass_config import PassConfigParam from olive.strategy.search_parameter import Boolean, Categorical, Conditional +if TYPE_CHECKING: + from onnxruntime.transformers.onnx_model import OnnxModel + logger = logging.getLogger(__name__) @@ -90,6 +93,11 @@ def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigPa "force_fp32_ops": PassConfigParam( type_=List[str], default_value=None, description="Operators that are forced to run in float32" ), + "use_gqa": PassConfigParam( + type_=bool, + default_value=False, + description="Replace MultiHeadAttention with GroupQueryAttention.", + ), } config.update(get_external_data_config()) return config @@ -109,6 +117,9 @@ def validate_search_point( if accelerator_spec.execution_provider == "CPUExecutionProvider": logger.info("CPUExecutionProvider does not support float16 very well, please avoid to use float16.") return False + if not search_point.get("float16") and search_point.get("use_gqa"): + logger.info("use_gqa is only supported when float16 is True.") + return False if search_point.get("use_gpu") and accelerator_spec.execution_provider == "CPUExecutionProvider": logger.info("CPUExecutionProvider does not support GPU inference, please avoid to use use_gpu.") return False @@ -151,6 +162,7 @@ def _run_for_config( run_config["input_int32"], run_config["keep_io_types"], run_config["force_fp32_ops"], + run_config["use_gqa"], ) for key in get_external_data_config(): del run_config[key] @@ -199,6 +211,18 @@ def _run_for_config( optimizer.convert_float_to_float16( keep_io_types=config["keep_io_types"], op_block_list=op_block_list, force_fp16_inputs=force_fp16_inputs ) + if config["use_gqa"]: + # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes + num_kv_heads = model.model_attributes.get("num_key_value_heads", None) + if num_kv_heads is None: + raise ValueError( + "num_key_value_heads is not specified in the model attributes. " + "Please specify it in the model attributes." + ) + optimizer = self._replace_mha_with_gqa(optimizer, kv_num_heads=num_kv_heads) + optimizer.prune_graph() + # add allow_remove_graph_inputs to pass config + optimizer.update_graph(allow_remove_graph_inputs=True) if config["input_int32"]: optimizer.change_graph_inputs_to_int32() @@ -208,3 +232,36 @@ def _run_for_config( # save the model to the output path and return the model return model_proto_to_olive_model(optimizer.model, output_model_path, config) + + @staticmethod + def _replace_mha_with_gqa(model: "OnnxModel", past_seq_len: str = "past_sequence_length", kv_num_heads: int = 0): + import onnx + + if past_seq_len not in model.get_graphs_input_names(): + # Replace model input for past sequence length + new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1]) + model.model.graph.input.append(new_input) + + # Replace MultiHeadAttention with GroupQueryAttention + for node in model.model.graph.node: + if node.op_type == "MultiHeadAttention": + gqa_node = onnx.helper.make_node( + "GroupQueryAttention", + inputs=[ + node.input[0], # query + node.input[1], # key + node.input[2], # value + node.input[6], # past_key + node.input[7], # past_value + past_seq_len, # past_sequence_length + ], + outputs=node.output, + name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), + domain="com.microsoft", + num_heads=node.attribute[0].i, + kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads, + is_past_bsnh=0, + ) + model.model.graph.node.remove(node) + model.model.graph.node.extend([gqa_node]) + return model