Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make user capable of adding new field in OrtTensorRTProviderOptionsV2 as new provider option #10450

Merged
merged 7 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

/// <summary>
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V2.
/// Please note that this struct is identical to OrtTensorRTProviderOptions but only to be used internally.
/// Please note that this struct is *similar* to OrtTensorRTProviderOptions but only to be used internally.
/// Going forward, new trt provider options are to be supported via this struct and usage of the publicly defined
/// OrtTensorRTProviderOptions will be deprecated over time.
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ typedef struct OrtTensorRTProviderOptions {
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
// This is the legacy struct and don't add new fields here.
// For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
// For non-string field, need to create a new separate api to handle it.
} OrtTensorRTProviderOptions;

/** \brief MIGraphX Provider Options
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX

SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or
return *this;
}

inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(p_, &provider_options));
return *this;
}

inline SessionOptions& SessionOptions::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(p_, &provider_options));
return *this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ constexpr const char* kCachePath = "trt_engine_cache_path";
constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable";
constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path";
constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build";
// add new provider option name here.
} // namespace provider_option_names
} // namespace tensorrt

Expand Down Expand Up @@ -63,7 +64,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path)
.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build)
.Parse(options));
.Parse(options)); // add new provider option here.

return info;
}
Expand All @@ -87,6 +88,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)},
{tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)},
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)},
// add new provider option here.
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <atomic>
#include "tensorrt_execution_provider.h"
#include "core/framework/provider_options.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include <string.h>

using namespace onnxruntime;
Expand Down Expand Up @@ -48,7 +49,7 @@ struct Tensorrt_Provider : Provider {
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* provider_options) override {
auto& options = *reinterpret_cast<const OrtTensorRTProviderOptions*>(provider_options);
auto& options = *reinterpret_cast<const OrtTensorRTProviderOptionsV2*>(provider_options);
TensorrtExecutionProviderInfo info;
info.device_id = options.device_id;
info.has_user_compute_stream = options.has_user_compute_stream != 0;
Expand All @@ -74,7 +75,7 @@ struct Tensorrt_Provider : Provider {

void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
auto internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options);
auto& trt_options = *reinterpret_cast<OrtTensorRTProviderOptions*>(provider_options);
auto& trt_options = *reinterpret_cast<OrtTensorRTProviderOptionsV2*>(provider_options);
trt_options.device_id = internal_options.device_id;
trt_options.trt_max_partition_iterations = internal_options.max_partition_iterations;
trt_options.trt_min_subgraph_size = internal_options.min_subgraph_size;
Expand Down
46 changes: 45 additions & 1 deletion onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,43 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGrap
return nullptr;
}

// Adapter to convert the legacy OrtTensorRTProviderOptions to the latest OrtTensorRTProviderOptionsV2
OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(const OrtTensorRTProviderOptions* legacy_trt_options) {
OrtTensorRTProviderOptionsV2 trt_options_converted;

trt_options_converted.device_id = legacy_trt_options->device_id;
trt_options_converted.has_user_compute_stream = legacy_trt_options->has_user_compute_stream;
trt_options_converted.user_compute_stream = legacy_trt_options->user_compute_stream;
trt_options_converted.trt_max_partition_iterations = legacy_trt_options->trt_max_partition_iterations;
trt_options_converted.trt_min_subgraph_size = legacy_trt_options->trt_min_subgraph_size;
trt_options_converted.trt_max_workspace_size = legacy_trt_options->trt_max_workspace_size;
trt_options_converted.trt_fp16_enable = legacy_trt_options->trt_fp16_enable;
trt_options_converted.trt_int8_enable = legacy_trt_options->trt_int8_enable;
trt_options_converted.trt_int8_calibration_table_name = legacy_trt_options->trt_int8_calibration_table_name;
trt_options_converted.trt_int8_use_native_calibration_table = legacy_trt_options->trt_int8_use_native_calibration_table;
trt_options_converted.trt_dla_enable = legacy_trt_options->trt_dla_enable;
trt_options_converted.trt_dla_core = legacy_trt_options->trt_dla_core;
trt_options_converted.trt_dump_subgraphs = legacy_trt_options->trt_dump_subgraphs;
trt_options_converted.trt_engine_cache_enable = legacy_trt_options->trt_engine_cache_enable;
trt_options_converted.trt_engine_cache_path = legacy_trt_options->trt_engine_cache_path;
trt_options_converted.trt_engine_decryption_enable = legacy_trt_options->trt_engine_decryption_enable;
trt_options_converted.trt_engine_decryption_lib_path = legacy_trt_options->trt_engine_decryption_lib_path;
trt_options_converted.trt_force_sequential_engine_build = legacy_trt_options->trt_force_sequential_engine_build;
// Add new provider option below
// Use default value as this field is not available in OrtTensorRTProviderOptionsV

return trt_options_converted;
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* provider_options) {
OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(provider_options);
if (auto* provider = s_library_tensorrt.Get())
return provider->CreateExecutionProviderFactory(&trt_options_converted);

return nullptr;
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* provider_options) {
if (auto* provider = s_library_tensorrt.Get())
return provider->CreateExecutionProviderFactory(provider_options);

Expand Down Expand Up @@ -1420,7 +1456,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ Or
}

ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options) {
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT(options, reinterpret_cast<const OrtTensorRTProviderOptions*>(tensorrt_options));
API_IMPL_BEGIN
auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(tensorrt_options);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_TensorRT: Failed to load shared library");
}

options->provider_factories.push_back(factory);
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out) {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/provider_bridge_ort.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"

// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
Expand Down Expand Up @@ -374,7 +375,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
std::string calibration_table, cache_path, lib_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptions params{
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct OrtStatus {
#include "core/providers/providers.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/providers/cpu/cpu_provider_factory_creator.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"

#if defined(USE_CUDA) || defined(USE_ROCM)
#define BACKEND_PROC "GPU"
Expand Down Expand Up @@ -474,6 +475,7 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor);
} // namespace python

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGraphX(int device_id);
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "ort_test_session.h"
#include <core/session/onnxruntime_cxx_api.h>
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include <assert.h>
#include "providers.h"
#include "TestCase.h"
Expand Down Expand Up @@ -209,7 +210,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build'] \n");
}
}
OrtTensorRTProviderOptions tensorrt_options;
OrtTensorRTProviderOptionsV2 tensorrt_options;
tensorrt_options.device_id = device_id;
tensorrt_options.has_user_compute_stream = 0;
tensorrt_options.user_compute_stream = nullptr;
Expand All @@ -228,7 +229,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable;
tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str();
tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build;
session_options.AppendExecutionProvider_TensorRT(tensorrt_options);
session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options);

OrtCUDAProviderOptions cuda_options;
cuda_options.device_id=device_id;
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/providers/cpu/model_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/inference_session.h"
#include "core/session/ort_env.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "asserts.h"
#include <core/platform/path_lib.h>
#include "default_providers.h"
Expand Down Expand Up @@ -591,7 +592,7 @@ TEST_P(ModelTest, Run) {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultNupharExecutionProvider()));
} else if (provider_name == "tensorrt") {
if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) {
OrtTensorRTProviderOptions params{
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const O
return nullptr;
}

std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params) {
#ifdef USE_TENSORRT
if (auto factory = CreateExecutionProviderFactory_Tensorrt(params))
return factory->CreateProvider();
#else
ORT_UNUSED_PARAMETER(params);
#endif
return nullptr;
}

std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
#ifdef USE_MIGRAPHX
OrtMIGraphXProviderOptions params{
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/util/include/default_providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVI
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params);

// EP for internal testing
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_InternalTesting(const std::unordered_set<std::string>& supported_ops);
Expand All @@ -38,6 +39,7 @@ std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider(bool allow_un
//std::unique_ptr<IExecutionProvider> DefaultStvmExecutionProvider();
std::unique_ptr<IExecutionProvider> DefaultTensorrtExecutionProvider();
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params);
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params);
std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider();
std::unique_ptr<IExecutionProvider> MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params);
std::unique_ptr<IExecutionProvider> DefaultOpenVINOExecutionProvider();
Expand Down