From 4d6d4dfb9de610fbb09e23b36379473b4005a99d Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Thu, 10 Feb 2022 08:51:01 -0800 Subject: [PATCH] Add TRT ep perf benchmark (#10470) --- .../python/tools/transformers/benchmark.py | 62 +++++++++++-------- .../tools/transformers/benchmark_helper.py | 28 ++++++--- .../tools/transformers/dev_benchmark.cmd | 7 ++- .../tools/transformers/onnx_exporter.py | 24 +++---- .../tools/transformers/run_benchmark.sh | 25 +++++++- .../python/transformers/test_optimizer.py | 6 +- .../transformers/test_shape_infer_helper.py | 6 +- 7 files changed, 105 insertions(+), 53 deletions(-) diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index abaeaf47c984e..c6e042d4947a0 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -48,9 +48,9 @@ import psutil import onnx from enum import Enum -from benchmark_helper import (create_onnxruntime_session, Precision, setup_logger, get_latency_result, output_details, - output_summary, output_fusion_statistics, inference_ort, inference_ort_with_io_binding, - allocateOutputBuffers) +from benchmark_helper import (OptimizerInfo, create_onnxruntime_session, Precision, setup_logger, get_latency_result, + output_details, output_summary, output_fusion_statistics, inference_ort, + inference_ort_with_io_binding, allocateOutputBuffers) from quantize_helper import QuantizeHelper from onnx_exporter import create_onnxruntime_input, load_pretrained_model, export_onnx_model_from_pt, export_onnx_model_from_tf @@ -69,18 +69,28 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, precision, num_threads, batch_sizes, sequence_lengths, - repeat_times, input_counts, optimize_onnx, validate_onnx, cache_dir, onnx_dir, verbose, overwrite, + repeat_times, input_counts, optimizer_info, validate_onnx, cache_dir, onnx_dir, verbose, overwrite, disable_ort_io_binding, use_raw_attention_mask, model_fusion_statistics, model_source): import onnxruntime results = [] - if (use_gpu and ('CUDAExecutionProvider' not in onnxruntime.get_available_providers()) and - ('ROCMExecutionProvider' not in onnxruntime.get_available_providers())): + if (use_gpu and ('CUDAExecutionProvider' not in onnxruntime.get_available_providers()) + and ('ROCMExecutionProvider' not in onnxruntime.get_available_providers())): logger.error( "Please install onnxruntime-gpu package instead of onnxruntime, and use a machine with GPU for testing gpu performance." ) return results + warm_up_repeat = 0 + if provider == 'tensorrt': + optimizer_info = OptimizerInfo.NOOPT + warm_up_repeat = 5 + if 'TensorrtExecutionProvider' not in onnxruntime.get_available_providers(): + logger.error( + "Please install onnxruntime-gpu-tensorrt package, and use a machine with GPU for testing gpu performance." + ) + return results + for model_name in model_names: all_input_names = MODELS[model_name][0] for num_inputs in input_counts: @@ -93,12 +103,12 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, precision, num_ with torch.no_grad(): onnx_model_file, is_valid_onnx_model, vocab_size, max_sequence_length = export_onnx_model_from_pt( model_name, MODELS[model_name][1], MODELS[model_name][2], MODELS[model_name][3], model_class, - cache_dir, onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx, + cache_dir, onnx_dir, input_names, use_gpu, precision, optimizer_info, validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics) if 'tf' in model_source: onnx_model_file, is_valid_onnx_model, vocab_size, max_sequence_length = export_onnx_model_from_tf( model_name, MODELS[model_name][1], MODELS[model_name][2], MODELS[model_name][3], model_class, - cache_dir, onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx, + cache_dir, onnx_dir, input_names, use_gpu, precision, optimizer_info, validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics) if not is_valid_onnx_model: @@ -134,8 +144,9 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, precision, num_ result_template = { "engine": "onnxruntime", "version": onnxruntime.__version__, + "providers": provider, "device": device, - "optimizer": optimize_onnx, + "optimizer": optimizer_info, "precision": precision, "io_binding": not disable_ort_io_binding, "model_name": model_name, @@ -150,7 +161,8 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, precision, num_ [batch_size, sequence_length])) if disable_ort_io_binding: - result = inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size) + result = inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, + warm_up_repeat) else: # Get output sizes from a dummy ort run ort_outputs = ort_session.run(ort_output_names, ort_inputs) @@ -165,7 +177,8 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, precision, num_ data_type = numpy.longlong if 'pt' in model_source else numpy.intc result = inference_ort_with_io_binding(ort_session, ort_inputs, result_template, repeat_times, ort_output_names, ort_outputs, output_buffers, - output_buffer_max_sizes, batch_size, device, data_type) + output_buffer_max_sizes, batch_size, device, data_type, + warm_up_repeat) logger.info(result) results.append(result) @@ -429,11 +442,7 @@ def parse_arguments(): parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="Run on gpu device") - parser.add_argument("--provider", - required=False, - type=str, - default=None, - help="Execution provider to use") + parser.add_argument("--provider", required=False, type=str, default=None, help="Execution provider to use") parser.add_argument( "-p", @@ -447,11 +456,14 @@ def parse_arguments(): parser.add_argument("--overwrite", required=False, action="store_true", help="Overwrite existing models") - parser.add_argument("-o", - "--optimize_onnx", - required=False, - action="store_true", - help="Use optimizer.py to optimize onnx model") + parser.add_argument( + "-o", + "--optimizer_info", + type=OptimizerInfo, + default=OptimizerInfo.BYSCRIPT, + choices=list(OptimizerInfo), + help="Optimizer info: Use optimizer.py to optimize onnx model as default. Can also choose from by_ort and no_opt" + ) parser.add_argument("-v", "--validate_onnx", required=False, action="store_true", help="Validate ONNX model") @@ -553,10 +565,10 @@ def main(): if enable_onnxruntime: try: use_raw_attention_mask = True - results += run_onnxruntime(args.use_gpu, args.provider, args.models, args.model_class, args.precision, num_threads, - args.batch_sizes, args.sequence_lengths, args.test_times, args.input_counts, - args.optimize_onnx, args.validate_onnx, args.cache_dir, args.onnx_dir, - args.verbose, args.overwrite, args.disable_ort_io_binding, + results += run_onnxruntime(args.use_gpu, args.provider, args.models, args.model_class, args.precision, + num_threads, args.batch_sizes, args.sequence_lengths, args.test_times, + args.input_counts, args.optimizer_info, args.validate_onnx, args.cache_dir, + args.onnx_dir, args.verbose, args.overwrite, args.disable_ort_io_binding, use_raw_attention_mask, model_fusion_statistics, args.model_source) except: logger.error(f"Exception", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 8ef162a1991c2..03296eb451394 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -31,6 +31,17 @@ def __str__(self): return self.value +class OptimizerInfo(Enum): + # no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as + # graph optimization level is not 0 (disable all). + NOOPT = 'no_opt' + BYORT = 'by_ort' + BYSCRIPT = 'by_script' + + def __str__(self): + return self.value + + IO_BINDING_DATA_TYPE_MAP = { "float32": numpy.float32, # TODO: Add more. @@ -114,7 +125,6 @@ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None): assert 'CUDAExecutionProvider' in onnxruntime.get_available_providers( ), "Please install onnxruntime-gpu package to test GPU inference." - import transformers logger.info(f'PyTorch Version:{torch.__version__}') logger.info(f'Transformers Version:{transformers.__version__}') @@ -146,9 +156,9 @@ def get_latency_result(runtimes, batch_size): def output_details(results, csv_filename): with open(csv_filename, mode="a", newline='') as csv_file: column_names = [ - "engine", "version", "device", "precision", "optimizer", "io_binding", "model_name", "inputs", "threads", - "batch_size", "sequence_length", "datetime", "test_times", "QPS", "average_latency_ms", "latency_variance", - "latency_90_percentile", "latency_95_percentile", "latency_99_percentile" + "engine", "version", "providers", "device", "precision", "optimizer", "io_binding", "model_name", "inputs", + "threads", "batch_size", "sequence_length", "datetime", "test_times", "QPS", "average_latency_ms", + "latency_variance", "latency_90_percentile", "latency_95_percentile", "latency_99_percentile" ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) @@ -162,7 +172,8 @@ def output_details(results, csv_filename): def output_summary(results, csv_filename, args): with open(csv_filename, mode="a", newline='') as csv_file: header_names = [ - "model_name", "inputs", "engine", "version", "device", "precision", "optimizer", "io_binding", "threads" + "model_name", "inputs", "engine", "version", "providers", "device", "precision", "optimizer", "io_binding", + "threads" ] data_names = [] for batch_size in args.batch_sizes: @@ -213,8 +224,9 @@ def output_fusion_statistics(model_fusion_statistics, csv_filename): logger.info(f"Fusion statistics is saved to csv file: {csv_filename}") -def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size): +def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0): result = {} + timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run runtimes = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times) result.update(result_template) result.update({"io_binding": False}) @@ -232,7 +244,8 @@ def inference_ort_with_io_binding(ort_session, output_buffer_max_sizes, batch_size, device, - data_type=numpy.longlong): + data_type=numpy.longlong, + warm_up_repeat=0): result = {} # Bind inputs and outputs to onnxruntime session @@ -250,6 +263,7 @@ def inference_ort_with_io_binding(ort_session, for i in range(len(ort_output_names)): io_binding.bind_output(ort_output_names[i], output_buffers[i].device.type, 0, numpy.float32, ort_outputs[i].shape, output_buffers[i].data_ptr()) + timeit.repeat(lambda: ort_session.run_with_iobinding(io_binding), number=1, repeat=warm_up_repeat) # Dry run runtimes = timeit.repeat(lambda: ort_session.run_with_iobinding(io_binding), number=1, repeat=repeat_times) result.update(result_template) result.update({"io_binding": True}) diff --git a/onnxruntime/python/tools/transformers/dev_benchmark.cmd b/onnxruntime/python/tools/transformers/dev_benchmark.cmd index 3f0b397a14eb2..7a9b3254a1708 100644 --- a/onnxruntime/python/tools/transformers/dev_benchmark.cmd +++ b/onnxruntime/python/tools/transformers/dev_benchmark.cmd @@ -86,8 +86,11 @@ set onnx_export_options=-i %input_counts% -v -b 0 -f fusion.csv --overwrite set benchmark_options=-b %batch_sizes% -s %sequence_length% -t %average_over% -f fusion.csv -r result.csv -d detail.csv if %use_optimizer% == true ( - set onnx_export_options=%onnx_export_options% -o - set benchmark_options=%benchmark_options% -o + set onnx_export_options=%onnx_export_options% -o by_script + set benchmark_options=%benchmark_options% -o by_script +) else ( + set onnx_export_options=%onnx_export_options% -o by_ort + set benchmark_options=%benchmark_options% -o by_ort ) if %run_gpu_fp32% == true ( diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 6abf72c237703..3029bb8f416c8 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -11,7 +11,7 @@ from pathlib import Path from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig from affinity_helper import AffinitySetting -from benchmark_helper import create_onnxruntime_session, Precision +from benchmark_helper import create_onnxruntime_session, Precision, OptimizerInfo from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState from quantize_helper import QuantizeHelper from huggingface_models import MODEL_CLASSES @@ -305,7 +305,7 @@ def validate_and_optimize_onnx(model_name, input_names, use_gpu, precision, - optimize_onnx, + optimize_info, validate_onnx, use_raw_attention_mask, overwrite, @@ -319,8 +319,10 @@ def validate_and_optimize_onnx(model_name, if validate_onnx: is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, False, output_names) + if optimize_info == OptimizerInfo.NOOPT: + return onnx_model_path, is_valid_onnx_model, config.vocab_size - if optimize_onnx or precision == Precision.FLOAT16 or precision == Precision.INT8: # Use script (optimizer.py) to optimize + if optimize_info == OptimizerInfo.BYSCRIPT or precision == Precision.FLOAT16 or precision == Precision.INT8: # Use script (optimizer.py) to optimize optimized_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), True, use_gpu, precision, False, use_external_data_format) optimize_onnx_model(model_name, onnx_model_path, optimized_model_path, model_type, config.num_attention_heads, @@ -337,7 +339,7 @@ def validate_and_optimize_onnx(model_name, QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_path, use_external_data_format) logger.info(f"Finished quantizing model: {onnx_model_path}") - else: # Use OnnxRuntime to optimize + if optimize_info == OptimizerInfo.BYORT: # Use OnnxRuntime to optimize if is_valid_onnx_model: ort_model_path = add_filename_suffix(onnx_model_path, '_ort') optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics) @@ -346,7 +348,7 @@ def validate_and_optimize_onnx(model_name, def export_onnx_model_from_pt(model_name, opset_version, use_external_data_format, model_type, model_class, cache_dir, - onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx, + onnx_dir, input_names, use_gpu, precision, optimizer_info, validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics): config, model = load_pt_model(model_name, model_class, cache_dir) @@ -394,15 +396,15 @@ def export_onnx_model_from_pt(model_name, opset_version, use_external_data_forma logger.info(f"Skip export since model existed: {onnx_model_path}") onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx( - model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx, + model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimizer_info, validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path, - example_inputs, example_outputs_flatten) + example_inputs, example_outputs_flatten, None) return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size def export_onnx_model_from_tf(model_name, opset_version, use_external_data_format, model_type, model_class, cache_dir, - onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx, + onnx_dir, input_names, use_gpu, precision, optimizer_info, validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics): # Use CPU to export import tensorflow as tf @@ -487,9 +489,9 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma logger.info(f"Skip export since model existed: {onnx_model_path}") model_type = model_type + '_tf' - onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx( - model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx, + opt_onnx_model_file, onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx( + model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimizer_info, validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path, example_inputs, example_outputs_flatten, output_names) - return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size + return opt_onnx_model_file, onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh index 1fe18be104b37..5962a4df27e49 100644 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -16,7 +16,9 @@ use_package=true run_install=true # Engines to test. +# To run ort_trt, you need to build and install the onnxruntime-gpu-tensorrt package on your own run_ort=true +run_ort_trt=false run_torch=false run_torchscript=true run_tensorflow=false @@ -107,8 +109,11 @@ if [ "$export_onnx_from_tf" = true ] ; then fi if [ "$use_optimizer" = true ] ; then - onnx_export_options="$onnx_export_options -o" - benchmark_options="$benchmark_options -o" + onnx_export_options="$onnx_export_options -o by_script" + benchmark_options="$benchmark_options -o by_script" +else + onnx_export_options="$onnx_export_options -o by_ort" + benchmark_options="$benchmark_options -o by_ort" fi # ------------------------------------------- @@ -122,6 +127,16 @@ run_one_test() { fi fi + if [ "$run_ort_trt" = true ] ; then + trt_options="--provider tensorrt --disable_ort_io_binding" + echo python $benchmark_script -m $1 $onnx_export_options $trt_options $2 $3 $4 >> benchmark.log + echo python $benchmark_script -m $1 $benchmark_options $trt_options $2 $3 $4 -i $input_counts >> benchmark.log + if [ "$run_tests" = true ] ; then + python $benchmark_script -m $1 $onnx_export_options $trt_options $2 $3 $4 + python $benchmark_script -m $1 $benchmark_options $trt_options $2 $3 $4 -i $input_counts + fi + fi + if [ "$run_torch" = true ] ; then echo python $benchmark_script -e torch -m $1 $benchmark_options $2 $3 $4 >> benchmark.log if [ "$run_tests" = true ] ; then @@ -146,6 +161,9 @@ run_one_test() { # ------------------------------------------- if [ "$run_gpu_fp32" = true ] ; then + if [ "$run_ort_trt" = true ] ; then + export ORT_TENSORRT_FP16_ENABLE=0 + fi for m in $models_to_test do echo Run GPU FP32 Benchmark on model ${m} @@ -154,6 +172,9 @@ if [ "$run_gpu_fp32" = true ] ; then fi if [ "$run_gpu_fp16" = true ] ; then + if [ "$run_ort_trt" = true ] ; then + export ORT_TENSORRT_FP16_ENABLE=1 + fi for m in $models_to_test do echo Run GPU FP16 Benchmark on model ${m} diff --git a/onnxruntime/test/python/transformers/test_optimizer.py b/onnxruntime/test/python/transformers/test_optimizer.py index b149edbb145f4..b4d944caf998f 100644 --- a/onnxruntime/test/python/transformers/test_optimizer.py +++ b/onnxruntime/test/python/transformers/test_optimizer.py @@ -20,13 +20,13 @@ from onnx_model import OnnxModel from onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt from huggingface_models import MODELS - from benchmark_helper import Precision + from benchmark_helper import Precision, OptimizerInfo else: from onnxruntime.transformers.optimizer import optimize_model from onnxruntime.transformers.onnx_model import OnnxModel from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt from onnxruntime.transformers.huggingface_models import MODELS - from onnxruntime.transformers.benchmark_helper import Precision + from onnxruntime.transformers.benchmark_helper import Precision, OptimizerInfo BERT_TEST_MODELS = { "bert_keras_0": ('models', 'TFBertForSequenceClassification_1.onnx'), # bert_mrpc_tensorflow2.1_opset10 @@ -78,7 +78,7 @@ def _test_optimizer_on_huggingface_model(self, MODELS[model_name][2], MODELS[model_name][3], None, './cache_models', './onnx_models', input_names[:inputs_count], False, - Precision.FLOAT32, True, True, True, True, + Precision.FLOAT32, OptimizerInfo.BYSCRIPT, True, True, True, model_fusion_statistics) onnx_model = list(model_fusion_statistics.keys())[0] diff --git a/onnxruntime/test/python/transformers/test_shape_infer_helper.py b/onnxruntime/test/python/transformers/test_shape_infer_helper.py index c38f249a0e907..1a9f22d1477cd 100644 --- a/onnxruntime/test/python/transformers/test_shape_infer_helper.py +++ b/onnxruntime/test/python/transformers/test_shape_infer_helper.py @@ -5,12 +5,12 @@ if find_transformers_source(): from onnx_exporter import export_onnx_model_from_pt from huggingface_models import MODELS - from benchmark_helper import Precision + from benchmark_helper import Precision, OptimizerInfo from shape_infer_helper import SymbolicShapeInferenceHelper else: from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt from onnxruntime.transformers.huggingface_models import MODELS - from onnxruntime.transformers.benchmark_helper import Precision + from onnxruntime.transformers.benchmark_helper import Precision, OptimizerInfo from onnxruntime.transformers.shape_infer_helper import SymbolicShapeInferenceHelper @@ -22,7 +22,7 @@ def _load_onnx(self, model_name): with torch.no_grad(): export_onnx_model_from_pt(model_name, MODELS[model_name][1], MODELS[model_name][2], MODELS[model_name][3], None, '../cache_models', base_path, input_names[:1], False, Precision.FLOAT32, - True, True, True, False, {}) + OptimizerInfo.BYSCRIPT, True, True, False, {}) model_path = base_path + model_name.replace('-', '_') + "_1.onnx" import onnx return onnx.load_model(model_path)