Skip to content

Commit

Permalink
Add model class AutoModelForCausalLM and other minor updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyaobit committed Feb 2, 2022
1 parent 3573ad1 commit 3dd16d0
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 3 deletions.
6 changes: 4 additions & 2 deletions onnxruntime/python/tools/transformers/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def create_onnxruntime_session(onnx_model_path,
execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider']
elif provider == 'migraphx':
execution_providers = ['MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'CPUExecutionProvider']
elif provider == 'cuda':
execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
elif provider == 'tensorrt':
execution_providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
else:
Expand All @@ -95,7 +97,7 @@ def setup_logger(verbose=True):
logging.getLogger("transformers").setLevel(logging.WARNING)


def prepare_environment(cache_dir, output_dir, use_gpu, use_dml=False):
def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
if cache_dir and not os.path.exists(cache_dir):
os.makedirs(cache_dir)

Expand All @@ -104,7 +106,7 @@ def prepare_environment(cache_dir, output_dir, use_gpu, use_dml=False):

import onnxruntime
if use_gpu:
if use_dml:
if provider == 'dml':
assert 'DmlExecutionProvider' in onnxruntime.get_available_providers(
), "Please install onnxruntime-directml package to test GPU inference."

Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/python/tools/transformers/bert_perf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_op
execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider']
elif provider == 'migraphx':
execution_providers = ['MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'CPUExecutionProvider']
elif provider == 'cuda':
execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
elif provider == 'tensorrt':
execution_providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
else:
Expand Down Expand Up @@ -105,6 +107,8 @@ def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_op
elif provider == 'migraphx':
assert 'MIGraphXExecutionProvider' in session.get_providers()
assert 'ROCMExecutionProvider' in session.get_providers()
elif provider == 'cuda':
assert 'CUDAExecutionProvider' in session.get_providers()
elif provider == 'tensorrt':
assert 'TensorrtExecutionProvider' in session.get_providers()
assert 'CUDAExecutionProvider' in session.get_providers()
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/transformers/huggingface_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

# Maps model class name to a tuple of model class
MODEL_CLASSES = [
'AutoModel', 'AutoModelWithLMHead', 'AutoModelForSequenceClassification', 'AutoModelForQuestionAnswering'
'AutoModel', 'AutoModelWithLMHead', 'AutoModelForSequenceClassification', 'AutoModelForQuestionAnswering',
'AutoModelForCausalLM',
]

# List of pretrained models: https://huggingface.co/transformers/pretrained_models.html
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_
model_class_name = 'TF' + model_class_name

transformers_module = __import__("transformers", fromlist=[model_class_name])
logger.info(f"Model class name: {model_class_name}")
model_class = getattr(transformers_module, model_class_name)

return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)
Expand Down

0 comments on commit 3dd16d0

Please sign in to comment.