From b79ea7481930ae09980f61982a8b7b19303a0eca Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Thu, 26 Oct 2023 21:54:23 -0700 Subject: [PATCH] Add updates to LLaMA scripts (#18076) ### Description This PR adds a few updates to scripts in the LLaMA folder: - Fixes the precision re-naming in the LLaMA export - Adds a "prerequisites" section in the README - Adds IO binding synchronizations during benchmarking for other EPs ### Motivation and Context - With precision re-naming, the LLaMA parity check does not produce errors when creating the FP32 CPU model - The "prerequisites" section shows that there are specific package versions needed - This allows for benchmarking with other EPs besides CPU and CUDA --- .../tools/transformers/convert_generation.py | 2 +- .../tools/transformers/models/llama/README.md | 48 +++++++++---------- .../transformers/models/llama/benchmark.py | 36 ++++++++++---- .../models/llama/convert_to_onnx.py | 3 +- .../transformers/models/llama/llama_parity.py | 4 +- .../models/llama/requirements-cpu.txt | 2 +- .../models/llama/requirements-cuda.txt | 2 +- 7 files changed, 59 insertions(+), 38 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 4228c892d03ae..b32ae64c5b0c0 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1275,7 +1275,7 @@ def find_past_seq_len_usage(subg: GraphProto): def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): past_seq_len = past_seq_len_input if past_seq_len not in model.get_graphs_input_names(): - # Replace model input for past sequence length + # Add model input for past sequence length new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1]) model.model.graph.input.append(new_input) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 6057b46667fe6..9619e6cb52a91 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -1,5 +1,18 @@ # LLaMA-2 +## Prerequisites + +Please note the package versions needed for using LLaMA-2 in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running LLaMA-2 on CPU +- `requirements-cuda.txt` + - For running LLaMA-2 on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements-quant.txt` + - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) +- `requirements.txt` + - Package versions needed in each of the above files + ## Exporting LLaMA-2 There are several ways to export LLaMA-2 models (using LLaMA-2 7B as an example). @@ -40,7 +53,7 @@ Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onn ### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum) -Note that this will produce two ONNX models whereas the above two options produce one ONNX model. +Note that this may produce two ONNX models with older Optimum versions. The above two options produce one ONNX model and installing Optimum from source will now produce one ONNX model. First, log into the Hugging Face CLI in your terminal: @@ -81,7 +94,7 @@ Export for FP32 CUDA $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cuda +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda ``` Export for FP32 CPU @@ -90,7 +103,7 @@ Export for FP32 CPU $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu ``` Export for FP16 CUDA @@ -105,10 +118,10 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama Export for INT8 CPU (SmoothQuant) ``` # From source: -$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged ``` Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models. @@ -128,7 +141,7 @@ Export for INT4 CUDA $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cuda +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda ``` Export for INT4 CPU @@ -137,7 +150,7 @@ Export for INT4 CPU $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu ``` ## Benchmark LLaMA-2 @@ -183,20 +196,7 @@ python3 -m models.llama.benchmark \ --auth ``` -4. Optimum + ONNX Runtime, FP16, export via convert_to_onnx -``` -python3 -m models.llama.benchmark \ - --benchmark-type hf-ort \ - --hf-ort-dir-path ./llama2-7b-fp16/ \ - --model-name meta-llama/Llama-2-7b-hf \ - --precision fp16 \ - --batch-sizes "1 2" \ - --sequence-lengths "8 16" \ - --device cuda \ - --auth -``` - -5. ONNX Runtime, FP32, Microsoft custom export +4. ONNX Runtime, FP32, Microsoft custom export ``` python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ @@ -208,7 +208,7 @@ python3 -m models.llama.benchmark \ --device cpu ``` -6. ONNX Runtime, FP16, Microsoft custom export +5. ONNX Runtime, FP16, Microsoft custom export ``` python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ @@ -220,7 +220,7 @@ python3 -m models.llama.benchmark \ --device cuda ``` -7. ONNX Runtime, FP32, convert_to_onnx +6. ONNX Runtime, FP32, convert_to_onnx ``` python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ @@ -232,7 +232,7 @@ python3 -m models.llama.benchmark \ --device cpu ``` -8. ONNX Runtime, FP16, convert_to_onnx +7. ONNX Runtime, FP16, convert_to_onnx ``` python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 976de2abc7c57..a721979eb0bcb 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -286,31 +286,50 @@ def time_fn(args, fn, inputs): outputs = fn(inputs) logger.info(outputs) + input_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_inputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + + output_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_outputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + for _ in warmup_range: + input_sync() fn(inputs) + output_sync() # Benchmark - if args.device != "cpu": - torch.cuda.synchronize() - start_time = time.time() - + total_time = 0 bench_range = ( range(args.num_runs) if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: + input_sync() + start_time = time.time() + fn(inputs) - if args.device != "cpu": - torch.cuda.synchronize() - end_time = time.time() + output_sync() + end_time = time.time() + + total_time += end_time - start_time # Newline print after trange in order to print metrics on new lines without progress bar on same line if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: logger.info("") - latency = (end_time - start_time) / args.num_runs + latency = total_time / args.num_runs throughput = args.batch_size / latency logger.info(f"Batch Size: {args.batch_size}") @@ -467,6 +486,7 @@ def prepare_ort_inputs(inputs): else: io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) + setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding return inputs diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 61d71bc38f4e9..69603fd3ed488 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -817,7 +817,8 @@ def main(): # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models args.precision = ( "fp32" - if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu") + if args.precision in {Precision.INT8, Precision.FLOAT32} + or (args.precision == Precision.INT4 and args.execution_provider == "cpu") else "fp16" ) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 6bfcb9b4f290d..4353d0606803d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -113,10 +113,10 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama if args.execution_provider != "cpu": io_binding = add_io_bindings(args, ort_model, inputs) - torch.cuda.synchronize() + io_binding.synchronize_inputs() start_time = time.time() ort_model.run_with_iobinding(io_binding) - torch.cuda.synchronize() + io_binding.synchronize_outputs() end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt index e06c3ada834b0..3d707fa13e3c8 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt @@ -1,2 +1,2 @@ -r requirements.txt -onnxruntime>=1.17.0 \ No newline at end of file +onnxruntime>=1.16.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index 773680937bd21..b634bcc50f6e4 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -1,4 +1,4 @@ -r requirements.txt # Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ -onnxruntime-gpu>=1.17.0 \ No newline at end of file +onnxruntime-gpu>=1.16.2 \ No newline at end of file