Skip to content

Commit

Permalink
Add updates to LLaMA scripts (#18076)
Browse files Browse the repository at this point in the history
### Description
This PR adds a few updates to scripts in the LLaMA folder:
- Fixes the precision re-naming in the LLaMA export
- Adds a "prerequisites" section in the README
- Adds IO binding synchronizations during benchmarking for other EPs



### Motivation and Context
- With precision re-naming, the LLaMA parity check does not produce
errors when creating the FP32 CPU model
- The "prerequisites" section shows that there are specific package
versions needed
- This allows for benchmarking with other EPs besides CPU and CUDA
  • Loading branch information
kunal-vaishnavi authored Oct 27, 2023
1 parent 0f3a067 commit b79ea74
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ def find_past_seq_len_usage(subg: GraphProto):
def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0):
past_seq_len = past_seq_len_input
if past_seq_len not in model.get_graphs_input_names():
# Replace model input for past sequence length
# Add model input for past sequence length
new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1])
model.model.graph.input.append(new_input)

Expand Down
48 changes: 24 additions & 24 deletions onnxruntime/python/tools/transformers/models/llama/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# LLaMA-2

## Prerequisites

Please note the package versions needed for using LLaMA-2 in the `requirements.txt` file that fits your scenario.
- `requirements-cpu.txt`
- For running LLaMA-2 on CPU
- `requirements-cuda.txt`
- For running LLaMA-2 on CUDA
- Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file.
- `requirements-quant.txt`
- For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor)
- `requirements.txt`
- Package versions needed in each of the above files

## Exporting LLaMA-2

There are several ways to export LLaMA-2 models (using LLaMA-2 7B as an example).
Expand Down Expand Up @@ -40,7 +53,7 @@ Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onn

### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum)

Note that this will produce two ONNX models whereas the above two options produce one ONNX model.
Note that this may produce two ONNX models with older Optimum versions. The above two options produce one ONNX model and installing Optimum from source will now produce one ONNX model.

First, log into the Hugging Face CLI in your terminal:

Expand Down Expand Up @@ -81,7 +94,7 @@ Export for FP32 CUDA
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cuda
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda
```

Export for FP32 CPU
Expand All @@ -90,7 +103,7 @@ Export for FP32 CPU
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cpu
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu
```

Export for FP16 CUDA
Expand All @@ -105,10 +118,10 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama
Export for INT8 CPU (SmoothQuant)
```
# From source:
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged
```

Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models.
Expand All @@ -128,7 +141,7 @@ Export for INT4 CUDA
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cuda
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda
```

Export for INT4 CPU
Expand All @@ -137,7 +150,7 @@ Export for INT4 CPU
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cpu
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu
```

## Benchmark LLaMA-2
Expand Down Expand Up @@ -183,20 +196,7 @@ python3 -m models.llama.benchmark \
--auth
```

4. Optimum + ONNX Runtime, FP16, export via convert_to_onnx
```
python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./llama2-7b-fp16/ \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cuda \
--auth
```

5. ONNX Runtime, FP32, Microsoft custom export
4. ONNX Runtime, FP32, Microsoft custom export
```
python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
Expand All @@ -208,7 +208,7 @@ python3 -m models.llama.benchmark \
--device cpu
```

6. ONNX Runtime, FP16, Microsoft custom export
5. ONNX Runtime, FP16, Microsoft custom export
```
python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
Expand All @@ -220,7 +220,7 @@ python3 -m models.llama.benchmark \
--device cuda
```

7. ONNX Runtime, FP32, convert_to_onnx
6. ONNX Runtime, FP32, convert_to_onnx
```
python3 -m models.llama.benchmark \
--benchmark-type ort-convert-to-onnx \
Expand All @@ -232,7 +232,7 @@ python3 -m models.llama.benchmark \
--device cpu
```

8. ONNX Runtime, FP16, convert_to_onnx
7. ONNX Runtime, FP16, convert_to_onnx
```
python3 -m models.llama.benchmark \
--benchmark-type ort-convert-to-onnx \
Expand Down
36 changes: 28 additions & 8 deletions onnxruntime/python/tools/transformers/models/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,31 +286,50 @@ def time_fn(args, fn, inputs):
outputs = fn(inputs)
logger.info(outputs)

input_sync = ( # noqa: E731
lambda *kwargs: args.io_binding.synchronize_inputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None # no-op function
)

output_sync = ( # noqa: E731
lambda *kwargs: args.io_binding.synchronize_outputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None # no-op function
)

for _ in warmup_range:
input_sync()
fn(inputs)
output_sync()

# Benchmark
if args.device != "cpu":
torch.cuda.synchronize()
start_time = time.time()

total_time = 0
bench_range = (
range(args.num_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
input_sync()
start_time = time.time()

fn(inputs)

if args.device != "cpu":
torch.cuda.synchronize()
end_time = time.time()
output_sync()
end_time = time.time()

total_time += end_time - start_time

# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
logger.info("")

latency = (end_time - start_time) / args.num_runs
latency = total_time / args.num_runs
throughput = args.batch_size / latency

logger.info(f"Batch Size: {args.batch_size}")
Expand Down Expand Up @@ -467,6 +486,7 @@ def prepare_ort_inputs(inputs):
else:
io_binding.bind_output(name, device_type=args.device, device_id=args.device_id)

setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding

return inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,8 @@ def main():
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
args.precision = (
"fp32"
if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu")
if args.precision in {Precision.INT8, Precision.FLOAT32}
or (args.precision == Precision.INT4 and args.execution_provider == "cpu")
else "fp16"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
if args.execution_provider != "cpu":
io_binding = add_io_bindings(args, ort_model, inputs)

torch.cuda.synchronize()
io_binding.synchronize_inputs()
start_time = time.time()
ort_model.run_with_iobinding(io_binding)
torch.cuda.synchronize()
io_binding.synchronize_outputs()
end_time = time.time()

ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
-r requirements.txt
onnxruntime>=1.17.0
onnxruntime>=1.16.2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-r requirements.txt
# Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system.
# Instructions can be found here: https://pytorch.org/get-started/locally/
onnxruntime-gpu>=1.17.0
onnxruntime-gpu>=1.16.2

0 comments on commit b79ea74

Please sign in to comment.