From 8fc3bb45193279302f61fa2621137156e9caee68 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Tue, 1 Feb 2022 23:01:22 -0800 Subject: [PATCH 1/6] Support specifying execution providers. --- .../python/tools/transformers/benchmark.py | 16 ++++++++++++---- .../tools/transformers/benchmark_helper.py | 12 +++++++++--- .../python/tools/transformers/profiler.py | 15 +++++++++------ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 6e5d5b98ef651..ca02754201abe 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -68,13 +68,14 @@ from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model, LxmertConfig) -def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, batch_sizes, sequence_lengths, +def run_onnxruntime(use_gpu, provider, model_names, model_class, precision, num_threads, batch_sizes, sequence_lengths, repeat_times, input_counts, optimize_onnx, validate_onnx, cache_dir, onnx_dir, verbose, overwrite, disable_ort_io_binding, use_raw_attention_mask, model_fusion_statistics, model_source): import onnxruntime results = [] - if use_gpu and ('CUDAExecutionProvider' not in onnxruntime.get_available_providers()): + if (use_gpu and ('CUDAExecutionProvider' not in onnxruntime.get_available_providers()) and + ('ROCMExecutionProvider' not in onnxruntime.get_available_providers())): logger.error( "Please install onnxruntime-gpu package instead of onnxruntime, and use a machine with GPU for testing gpu performance." ) @@ -105,6 +106,7 @@ def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, b ort_session = create_onnxruntime_session(onnx_model_file, use_gpu, + provider, enable_all_optimization=True, num_threads=num_threads, verbose=verbose) @@ -425,7 +427,13 @@ def parse_arguments(): default=os.path.join('.', 'onnx_models'), help="Directory to store onnx models") - parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="Run on cuda device") + parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="Run on gpu device") + + parser.add_argument("--provider", + required=False, + type=str, + default='cuda', + help="Execution provider to use") parser.add_argument( "-p", @@ -545,7 +553,7 @@ def main(): if enable_onnxruntime: try: use_raw_attention_mask = True - results += run_onnxruntime(args.use_gpu, args.models, args.model_class, args.precision, num_threads, + results += run_onnxruntime(args.use_gpu, args.provider, args.models, args.model_class, args.precision, num_threads, args.batch_sizes, args.sequence_lengths, args.test_times, args.input_counts, args.optimize_onnx, args.validate_onnx, args.cache_dir, args.onnx_dir, args.verbose, args.overwrite, args.disable_ort_io_binding, diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 8c9afbe6561e0..2d3d2612e4e6e 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -39,11 +39,11 @@ def __str__(self): def create_onnxruntime_session(onnx_model_path, use_gpu, + provider='cpu', enable_all_optimization=True, num_threads=-1, enable_profiling=False, - verbose=False, - use_dml=False): + verbose=False): session = None try: from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel, __version__ as onnxruntime_version @@ -68,8 +68,14 @@ def create_onnxruntime_session(onnx_model_path, logger.debug(f"Create session for onnx model: {onnx_model_path}") if use_gpu: - if use_dml: + if provider == 'dml': execution_providers = ['DmlExecutionProvider', 'CPUExecutionProvider'] + elif provider == 'rocm': + execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] + elif provider == 'migraphx': + execution_providers = ['MIGraphXExecutionProvider', 'CPUExecutionProvider'] + elif provider == 'tensorrt': + execution_providers = ['TensorrtExecutionProvider', 'CPUExecutionProvider'] else: execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] else: diff --git a/onnxruntime/python/tools/transformers/profiler.py b/onnxruntime/python/tools/transformers/profiler.py index 753b976758af1..ddd00e94ebb7a 100644 --- a/onnxruntime/python/tools/transformers/profiler.py +++ b/onnxruntime/python/tools/transformers/profiler.py @@ -86,8 +86,11 @@ def parse_arguments(argv=None): parser.add_argument('-g', '--use_gpu', required=False, action='store_true', help="use GPU") parser.set_defaults(use_gpu=False) - parser.add_argument('-d', '--use_dml', required=False, action='store_true', help="use DML") - parser.set_defaults(use_dml=False) + parser.add_argument('--provider', + required=False, + type=str, + default='cuda', + help="Execution provider to use") parser.add_argument( '--basic_optimization', @@ -108,15 +111,15 @@ def parse_arguments(argv=None): return parser.parse_args(argv) -def run_profile(onnx_model_path, use_gpu, basic_optimization, thread_num, all_inputs, use_dml): +def run_profile(onnx_model_path, use_gpu, provider, basic_optimization, thread_num, all_inputs): from benchmark_helper import create_onnxruntime_session session = create_onnxruntime_session(onnx_model_path, use_gpu, + provider, enable_all_optimization=not basic_optimization, num_threads=thread_num, - enable_profiling=True, - use_dml=use_dml) + enable_profiling=True) for inputs in all_inputs: _ = session.run(None, inputs) @@ -604,7 +607,7 @@ def run(args): else: # default all_inputs = create_dummy_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples) - profile_file = run_profile(args.model, args.use_gpu, args.basic_optimization, args.thread_num, all_inputs, args.use_dml) + profile_file = run_profile(args.model, args.use_gpu, args.provider, args.basic_optimization, args.thread_num, all_inputs) return profile_file From dea288a53ef1a950eed75b7bbf5128ef44c96607 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Tue, 1 Feb 2022 23:25:30 -0800 Subject: [PATCH 2/6] Change default provider setting to None. --- onnxruntime/python/tools/transformers/benchmark.py | 2 +- onnxruntime/python/tools/transformers/benchmark_helper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index ca02754201abe..abaeaf47c984e 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -432,7 +432,7 @@ def parse_arguments(): parser.add_argument("--provider", required=False, type=str, - default='cuda', + default=None, help="Execution provider to use") parser.add_argument( diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 2d3d2612e4e6e..5decc4863fc2a 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -39,7 +39,7 @@ def __str__(self): def create_onnxruntime_session(onnx_model_path, use_gpu, - provider='cpu', + provider=None, enable_all_optimization=True, num_threads=-1, enable_profiling=False, From cf589b14680f1cf2e34535cad84347dbe65893be Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 2 Feb 2022 10:24:26 -0800 Subject: [PATCH 3/6] Add support for bert_perf_test script. --- .../tools/transformers/bert_perf_test.py | 42 ++++++++++++++++--- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py index 45682eb18ab8a..9c48bf0845018 100644 --- a/onnxruntime/python/tools/transformers/bert_perf_test.py +++ b/onnxruntime/python/tools/transformers/bert_perf_test.py @@ -36,6 +36,7 @@ class TestSetting: test_cases: int test_times: int use_gpu: bool + provider: str intra_op_num_threads: int seed: int verbose: bool @@ -50,7 +51,7 @@ class ModelSetting: opt_level: int -def create_session(model_path, use_gpu, intra_op_num_threads, graph_optimization_level=None): +def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_optimization_level=None): import onnxruntime if use_gpu and ('CUDAExecutionProvider' not in onnxruntime.get_available_providers()): @@ -61,8 +62,19 @@ def create_session(model_path, use_gpu, intra_op_num_threads, graph_optimization if intra_op_num_threads is None and graph_optimization_level is None: session = onnxruntime.InferenceSession(model_path) else: - execution_providers = ['CPUExecutionProvider' - ] if not use_gpu else ['CUDAExecutionProvider', 'CPUExecutionProvider'] + if use_gpu: + if provider == 'dml': + execution_providers = ['DmlExecutionProvider', 'CPUExecutionProvider'] + elif provider == 'rocm': + execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] + elif provider == 'migraphx': + execution_providers = ['MIGraphXExecutionProvider', 'CPUExecutionProvider'] + elif provider == 'tensorrt': + execution_providers = ['TensorrtExecutionProvider', 'CPUExecutionProvider'] + else: + execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + execution_providers = ['CPUExecutionProvider'] sess_options = onnxruntime.SessionOptions() sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL @@ -86,7 +98,19 @@ def create_session(model_path, use_gpu, intra_op_num_threads, graph_optimization session = onnxruntime.InferenceSession(model_path, sess_options, providers=execution_providers) if use_gpu: - assert 'CUDAExecutionProvider' in session.get_providers() + if provider == 'dml': + assert 'DmlExecutionProvider' in session.get_providers() + elif provider == 'rocm': + assert 'ROCMExecutionProvider' in session.get_providers() + elif provider == 'migraphx': + assert 'MIGraphXExecutionProvider' in session.get_providers() + elif provider == 'tensorrt': + assert 'TensorrtExecutionProvider' in session.get_providers() + else: + assert 'CUDAExecutionProvider' in session.get_providers() + else: + assert 'CPUExecutionProvider' in session.get_providers() + return session @@ -117,7 +141,7 @@ def to_string(model_path, session, test_setting): def run_one_test(model_setting, test_setting, perf_results, all_inputs, intra_op_num_threads): - session = create_session(model_setting.model_path, test_setting.use_gpu, intra_op_num_threads, + session = create_session(model_setting.model_path, test_setting.use_gpu, test_setting.provider, intra_op_num_threads, model_setting.opt_level) output_names = [output.name for output in session.get_outputs()] @@ -239,6 +263,12 @@ def parse_arguments(): parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU") parser.set_defaults(use_gpu=False) + parser.add_argument("--provider", + required=False, + type=str, + default=None, + help="Execution provider to use") + parser.add_argument('-n', '--intra_op_num_threads', required=False, @@ -276,7 +306,7 @@ def main(): for batch_size in batch_size_set: test_setting = TestSetting(batch_size, args.sequence_length, args.samples, args.test_times, args.use_gpu, - args.intra_op_num_threads, args.seed, args.verbose) + args.provider, args.intra_op_num_threads, args.seed, args.verbose) print("test setting", test_setting) run_performance(model_setting, test_setting, perf_results) From ed43d91eb0855791d82dd8f5f48da87f74b0efd3 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 2 Feb 2022 10:54:16 -0800 Subject: [PATCH 4/6] Fall back to ROCM/CUDA EP for MIGraphX/Tensorrt EP. --- onnxruntime/python/tools/transformers/benchmark_helper.py | 4 ++-- onnxruntime/python/tools/transformers/bert_perf_test.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 5decc4863fc2a..41bdc1966a0c7 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -73,9 +73,9 @@ def create_onnxruntime_session(onnx_model_path, elif provider == 'rocm': execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] elif provider == 'migraphx': - execution_providers = ['MIGraphXExecutionProvider', 'CPUExecutionProvider'] + execution_providers = ['MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'CPUExecutionProvider'] elif provider == 'tensorrt': - execution_providers = ['TensorrtExecutionProvider', 'CPUExecutionProvider'] + execution_providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] else: execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] else: diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py index 9c48bf0845018..fed379f762f02 100644 --- a/onnxruntime/python/tools/transformers/bert_perf_test.py +++ b/onnxruntime/python/tools/transformers/bert_perf_test.py @@ -68,9 +68,9 @@ def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_op elif provider == 'rocm': execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] elif provider == 'migraphx': - execution_providers = ['MIGraphXExecutionProvider', 'CPUExecutionProvider'] + execution_providers = ['MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'CPUExecutionProvider'] elif provider == 'tensorrt': - execution_providers = ['TensorrtExecutionProvider', 'CPUExecutionProvider'] + execution_providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] else: execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] else: From 3573ad15fbe3f46b3055a03a4024e539a7bf6163 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 2 Feb 2022 10:57:49 -0800 Subject: [PATCH 5/6] Assert fall back EPs are included. --- onnxruntime/python/tools/transformers/bert_perf_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py index fed379f762f02..21924e1e089bc 100644 --- a/onnxruntime/python/tools/transformers/bert_perf_test.py +++ b/onnxruntime/python/tools/transformers/bert_perf_test.py @@ -104,8 +104,10 @@ def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_op assert 'ROCMExecutionProvider' in session.get_providers() elif provider == 'migraphx': assert 'MIGraphXExecutionProvider' in session.get_providers() + assert 'ROCMExecutionProvider' in session.get_providers() elif provider == 'tensorrt': assert 'TensorrtExecutionProvider' in session.get_providers() + assert 'CUDAExecutionProvider' in session.get_providers() else: assert 'CUDAExecutionProvider' in session.get_providers() else: From 3dd16d084f81b78325e824b889fbb8e8b550b8ae Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 2 Feb 2022 14:24:28 -0800 Subject: [PATCH 6/6] Add model class AutoModelForCausalLM and other minor updates. --- onnxruntime/python/tools/transformers/benchmark_helper.py | 6 ++++-- onnxruntime/python/tools/transformers/bert_perf_test.py | 4 ++++ onnxruntime/python/tools/transformers/huggingface_models.py | 3 ++- onnxruntime/python/tools/transformers/onnx_exporter.py | 1 + 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 41bdc1966a0c7..8ef162a1991c2 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -74,6 +74,8 @@ def create_onnxruntime_session(onnx_model_path, execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] elif provider == 'migraphx': execution_providers = ['MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'CPUExecutionProvider'] + elif provider == 'cuda': + execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] elif provider == 'tensorrt': execution_providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] else: @@ -95,7 +97,7 @@ def setup_logger(verbose=True): logging.getLogger("transformers").setLevel(logging.WARNING) -def prepare_environment(cache_dir, output_dir, use_gpu, use_dml=False): +def prepare_environment(cache_dir, output_dir, use_gpu, provider=None): if cache_dir and not os.path.exists(cache_dir): os.makedirs(cache_dir) @@ -104,7 +106,7 @@ def prepare_environment(cache_dir, output_dir, use_gpu, use_dml=False): import onnxruntime if use_gpu: - if use_dml: + if provider == 'dml': assert 'DmlExecutionProvider' in onnxruntime.get_available_providers( ), "Please install onnxruntime-directml package to test GPU inference." diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py index 21924e1e089bc..6b621492b2ec2 100644 --- a/onnxruntime/python/tools/transformers/bert_perf_test.py +++ b/onnxruntime/python/tools/transformers/bert_perf_test.py @@ -69,6 +69,8 @@ def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_op execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] elif provider == 'migraphx': execution_providers = ['MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'CPUExecutionProvider'] + elif provider == 'cuda': + execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] elif provider == 'tensorrt': execution_providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] else: @@ -105,6 +107,8 @@ def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_op elif provider == 'migraphx': assert 'MIGraphXExecutionProvider' in session.get_providers() assert 'ROCMExecutionProvider' in session.get_providers() + elif provider == 'cuda': + assert 'CUDAExecutionProvider' in session.get_providers() elif provider == 'tensorrt': assert 'TensorrtExecutionProvider' in session.get_providers() assert 'CUDAExecutionProvider' in session.get_providers() diff --git a/onnxruntime/python/tools/transformers/huggingface_models.py b/onnxruntime/python/tools/transformers/huggingface_models.py index 051480ebb0ade..642669156cbb8 100644 --- a/onnxruntime/python/tools/transformers/huggingface_models.py +++ b/onnxruntime/python/tools/transformers/huggingface_models.py @@ -6,7 +6,8 @@ # Maps model class name to a tuple of model class MODEL_CLASSES = [ - 'AutoModel', 'AutoModelWithLMHead', 'AutoModelForSequenceClassification', 'AutoModelForQuestionAnswering' + 'AutoModel', 'AutoModelWithLMHead', 'AutoModelForSequenceClassification', 'AutoModelForQuestionAnswering', + 'AutoModelForCausalLM', ] # List of pretrained models: https://huggingface.co/transformers/pretrained_models.html diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index d12c1d13070ae..6abf72c237703 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -255,6 +255,7 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_ model_class_name = 'TF' + model_class_name transformers_module = __import__("transformers", fromlist=[model_class_name]) + logger.info(f"Model class name: {model_class_name}") model_class = getattr(transformers_module, model_class_name) return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)