From dd7ed165d43217af427eb849b830c764db7c9684 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Mon, 9 Mar 2020 19:08:46 -0700 Subject: [PATCH 1/9] Add support for sessions to share a global threadpool. --- .../onnxruntime/core/session/environment.h | 28 +++++- .../core/session/onnxruntime_c_api.h | 24 +++++ .../core/session/onnxruntime_cxx_api.h | 7 +- .../core/session/onnxruntime_cxx_inline.h | 8 ++ onnxruntime/core/framework/session_options.h | 4 + .../core/session/abi_session_options.cc | 5 + onnxruntime/core/session/environment.cc | 22 ++++- onnxruntime/core/session/inference_session.cc | 58 +++++++----- onnxruntime/core/session/inference_session.h | 36 ++++---- onnxruntime/core/session/onnxruntime_c_api.cc | 24 ++++- onnxruntime/core/session/ort_apis.h | 5 + onnxruntime/core/session/ort_env.cc | 63 ++++++------- onnxruntime/core/session/ort_env.h | 15 ++- .../test/contrib_ops/layer_norm_op_test.cc | 6 +- .../test/framework/execution_frame_test.cc | 2 +- onnxruntime/test/framework/float_16_test.cc | 2 +- .../test/framework/inference_session_test.cc | 92 ++++++++++--------- .../framework/local_kernel_registry_test.cc | 6 +- .../test/framework/opaque_kernels_test.cc | 2 +- .../test/framework/sparse_kernels_test.cc | 2 +- .../test/optimizer/graph_transform_test.cc | 20 ++-- .../test/optimizer/nchwc_optimizer_test.cc | 5 +- .../providers/cpu/controlflow/loop_test.cc | 2 +- .../test/providers/provider_test_utils.cc | 18 ++-- onnxruntime/test/shared_lib/test_inference.cc | 4 +- .../test/util/include/test/test_environment.h | 4 + onnxruntime/test/util/test_environment.cc | 7 +- 27 files changed, 307 insertions(+), 164 deletions(-) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 5792e70b6ea39..632c03ac9b053 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -7,7 +7,10 @@ #include #include "core/common/common.h" #include "core/common/status.h" +#include "core/platform/threadpool.h" +#include "core/common/logging/logging.h" +struct ThreadingOptions; namespace onnxruntime { /** TODO: remove this class Provides the runtime environment for onnxruntime. @@ -18,12 +21,33 @@ class Environment { /** Create and initialize the runtime environment. */ - static Status Create(std::unique_ptr& environment); + static Status Create(std::unique_ptr logging_manager, + std::unique_ptr& environment, + const ThreadingOptions* tp_options = nullptr, + bool create_thread_pool = false); + + logging::LoggingManager* GetLoggingManager() const { + return logging_manager_.get(); + } + + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const { + return intra_op_thread_pool_.get(); + } + + onnxruntime::concurrency::ThreadPool* GetInterOpThreadPool() const { + return inter_op_thread_pool_.get(); + } private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); Environment() = default; - Status Initialize(); + Status Initialize(std::unique_ptr logging_manager, + const ThreadingOptions* tp_options = nullptr, + bool create_global_thread_pool = false); + + std::unique_ptr logging_manager_; + std::unique_ptr intra_op_thread_pool_; + std::unique_ptr inter_op_thread_pool_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 341fc9d7782b9..6a186f05ff3eb 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -212,6 +212,14 @@ typedef enum OrtMemType { OrtMemTypeDefault = 0, // the default allocator for execution provider } OrtMemType; +typedef struct ThreadingOptions { + // number of threads used to parallelize execution of an op + int intra_op_num_threads = 0; // default value + + // number of threads used to parallelize execution across ops + int inter_op_num_threads = 0; // default value +} ThreadingOptions; + struct OrtApi; typedef struct OrtApi OrtApi; @@ -747,6 +755,22 @@ struct OrtApi { OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION; ORT_CLASS_RELEASE(ModelMetadata); + + /* + * Creates an environment with global threadpools that will be shared across sessions. + * Use this in conjunction with DisablePerSessionThreads API or else by default the session will use + * its own thread pools. + */ + OrtStatus*(ORT_API_CALL* CreateEnvWithGlobalThreadPools)(OrtLoggingLevel default_logging_level, _In_ const char* logid, + _In_ ThreadingOptions t_options, _Outptr_ OrtEnv** out) + NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + // TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function? + + /* +* Calling this API will make the session use the global threadpools shared across sessions. +*/ + OrtStatus*(ORT_API_CALL* DisablePerSessionThreads)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5e1c91b916cf6..e562d3b3c5324 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -82,7 +82,7 @@ struct Base { ~Base() { OrtRelease(p_); } operator T*() { return p_; } - operator const T *() const { return p_; } + operator const T*() const { return p_; } T* release() { T* p = p_; @@ -123,6 +123,7 @@ struct ModelMetadata; struct Env : Base { Env(std::nullptr_t) {} Env(OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + Env(ThreadingOptions tp_options, OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); Env(OrtLoggingLevel default_logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); explicit Env(OrtEnv* p) : Base{p} {} @@ -185,6 +186,8 @@ struct SessionOptions : Base { SessionOptions& SetLogId(const char* logid); SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); + + SessionOptions& DisablePerSessionThreads(); }; struct ModelMetadata : Base { @@ -289,7 +292,7 @@ struct AllocatorWithDefaultOptions { AllocatorWithDefaultOptions(); operator OrtAllocator*() { return p_; } - operator const OrtAllocator *() const { return p_; } + operator const OrtAllocator*() const { return p_; } void* Alloc(size_t size); void Free(void* p); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index b10f0cebe50f1..a63508797ad90 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -84,6 +84,10 @@ inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLog ThrowOnError(Global::api_.CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_)); } +inline Env::Env(ThreadingOptions tp_options, OrtLoggingLevel default_warning_level, const char* logid) { + ThrowOnError(Global::api_.CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_)); +} + inline Env& Env::EnableTelemetryEvents() { ThrowOnError(Global::api_.EnableTelemetryEvents(p_)); return *this; @@ -601,4 +605,8 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, return out; } +inline SessionOptions& SessionOptions::DisablePerSessionThreads() { + ThrowOnError(Global::api_.DisablePerSessionThreads(p_)); + return *this; +} } // namespace Ort \ No newline at end of file diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 1c02f9f0ebbc2..3d79a851c8cdc 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -64,5 +64,9 @@ struct SessionOptions { // For models with free input dimensions (most commonly batch size), specifies a set of values to override those // free dimensions with, keyed by dimension denotation. std::vector free_dimension_overrides; + + // By default the session uses it's own set of threadpools, unless this is set to false. + // Use this in conjunction with the CreateEnvWithGlobalThreadPools API. + bool use_per_session_threads = true; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index c15252b71d5e5..a6b08364c46a2 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -155,3 +155,8 @@ ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverride, _Inout_ OrtSessionOptions options->value.free_dimension_overrides.push_back(onnxruntime::FreeDimensionOverride{symbolic_dim, dim_override}); return nullptr; } + +ORT_API_STATUS_IMPL(OrtApis::DisablePerSessionThreads, _In_ OrtSessionOptions* options) { + options->value.use_per_session_threads = false; + return nullptr; +} diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 744c347491dc3..e7e61bd46a6ee 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -18,6 +18,7 @@ #endif #include "core/platform/env.h" +#include "core/util/thread_utils.h" #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT #include "core/platform/tracing.h" @@ -29,15 +30,30 @@ using namespace ONNX_NAMESPACE; std::once_flag schemaRegistrationOnceFlag; -Status Environment::Create(std::unique_ptr& environment) { +Status Environment::Create(std::unique_ptr logging_manager, + std::unique_ptr& environment, + const ThreadingOptions* tp_options, + bool create_global_thread_pool) { environment = std::unique_ptr(new Environment()); - auto status = environment->Initialize(); + auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pool); return status; } -Status Environment::Initialize() { +Status Environment::Initialize(std::unique_ptr logging_manager, + const ThreadingOptions* tp_options, + bool create_global_thread_pool) { auto status = Status::OK(); + logging_manager_ = std::move(logging_manager); + + // create thread pools + if (create_global_thread_pool) { + intra_op_thread_pool_ = concurrency::CreateThreadPool("env_global_intra_op_thread_pool", + tp_options->intra_op_num_threads); + inter_op_thread_pool_ = concurrency::CreateThreadPool("env_global_inter_op_thread_pool", + tp_options->inter_op_num_threads); + } + try { // Register Microsoft domain with min/max op_set version as 1/1. std::call_once(schemaRegistrationOnceFlag, []() { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b1d2305594eab..be98c23100f5d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -160,30 +160,46 @@ static Status FinalizeSessionOptions(const SessionOptions& user_provided_session } void InferenceSession::ConstructorCommon(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) { + const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_.get(), session_options_); ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); graph_transformation_mgr_ = onnxruntime::make_unique( session_options_.max_num_graph_transformation_steps); - logging_manager_ = logging_manager; - thread_pool_ = concurrency::CreateThreadPool("intra_op_thread_pool", - session_options_.intra_op_num_threads); + use_per_session_threads_ = session_options.use_per_session_threads; - inter_op_thread_pool_ = session_options_.execution_mode == ExecutionMode::ORT_PARALLEL - ? concurrency::CreateThreadPool("inter_op_thread_pool", - session_options_.inter_op_num_threads) - : nullptr; + if (use_per_session_threads_) { + thread_pool_ = concurrency::CreateThreadPool("intra_op_thread_pool", + session_options_.intra_op_num_threads); + + inter_op_thread_pool_ = session_options_.execution_mode == ExecutionMode::ORT_PARALLEL + ? concurrency::CreateThreadPool("inter_op_thread_pool", + session_options_.inter_op_num_threads) + : nullptr; + } else { + intra_op_thread_pool_from_env_ = session_env.GetIntraOpThreadPool(); + inter_op_thread_pool_from_env_ = session_env.GetInterOpThreadPool(); + + ORT_ENFORCE(intra_op_thread_pool_from_env_, + "Since use_per_session_threads is false, this must be non-nullptr" + " You probably didn't create the env using the CreateEnvWithGlobalThreadPools API"); + ORT_ENFORCE(inter_op_thread_pool_from_env_, + "Since use_per_session_threads is false, this must be non-nullptr" + " You probably didn't create the env using the CreateEnvWithGlobalThreadPools API"); + ORT_ENFORCE(thread_pool_ == nullptr, "Since use_per_session_threads is false per session threadpools should be nullptr"); + ORT_ENFORCE(inter_op_thread_pool_ == nullptr, "Since use_per_session_threads is false per session threadpools should be nullptr"); + } session_state_ = onnxruntime::make_unique(execution_providers_, session_options_.enable_mem_pattern && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL, - thread_pool_.get(), - inter_op_thread_pool_.get()); + use_per_session_threads_ ? thread_pool_.get() : intra_op_thread_pool_from_env_, + use_per_session_threads_ ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_); - InitLogger(logging_manager); + logging_manager_ = session_env.GetLoggingManager(); + InitLogger(logging_manager_); session_state_->SetDataTransferMgr(&data_transfer_mgr_); session_profiler_.Initialize(session_logger_); @@ -198,15 +214,15 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } InferenceSession::InferenceSession(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) + const Environment& session_env) : insert_cast_transformer_("CastFloat16Transformer") { // Initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } InferenceSession::InferenceSession(const SessionOptions& session_options, const std::string& model_uri, - logging::LoggingManager* logging_manager) + const Environment& session_env) : insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); model_proto_ = onnxruntime::make_unique(); @@ -215,13 +231,13 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, status.ErrorMessage()); // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } #ifdef _WIN32 InferenceSession::InferenceSession(const SessionOptions& session_options, const std::wstring& model_uri, - logging::LoggingManager* logging_manager) + const Environment& session_env) : insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); model_proto_ = onnxruntime::make_unique(); @@ -230,13 +246,13 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, status.ErrorMessage()); // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } #endif InferenceSession::InferenceSession(const SessionOptions& session_options, std::istream& model_istream, - logging::LoggingManager* logging_manager) + const Environment& session_env) : insert_cast_transformer_("CastFloat16Transformer") { google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); model_proto_ = onnxruntime::make_unique(); @@ -244,20 +260,20 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } InferenceSession::InferenceSession(const SessionOptions& session_options, const void* model_data, int model_data_len, - logging::LoggingManager* logging_manager) + const Environment& session_env) : insert_cast_transformer_("CastFloat16Transformer") { model_proto_ = onnxruntime::make_unique(); const bool result = model_proto_->ParseFromArray(model_data, model_data_len); ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } InferenceSession::~InferenceSession() { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index b4977a9de5398..93c1e1b26bf59 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -31,6 +31,7 @@ namespace onnxruntime { // forward declarations class GraphTransformer; +class Environment; } // namespace onnxruntime namespace ONNX_NAMESPACE { @@ -104,8 +105,8 @@ class InferenceSession { for logging. This will use the default logger id in messages. See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. */ - explicit InferenceSession(const SessionOptions& session_options, - logging::LoggingManager* logging_manager = nullptr); + InferenceSession(const SessionOptions& session_options, + const Environment& session_env); /** Create a new InferenceSession @@ -121,11 +122,11 @@ class InferenceSession { */ InferenceSession(const SessionOptions& session_options, const std::string& model_uri, - logging::LoggingManager* logging_manager = nullptr); + const Environment& session_env); #ifdef _WIN32 InferenceSession(const SessionOptions& session_options, const std::wstring& model_uri, - logging::LoggingManager* logging_manager = nullptr); + const Environment& session_env); #endif /** @@ -142,7 +143,7 @@ class InferenceSession { */ InferenceSession(const SessionOptions& session_options, std::istream& model_istream, - logging::LoggingManager* logging_manager = nullptr); + const Environment& session_env); /** Create a new InferenceSession @@ -160,7 +161,7 @@ class InferenceSession { InferenceSession(const SessionOptions& session_options, const void* model_data, int model_data_len, - logging::LoggingManager* logging_manager = nullptr); + const Environment& session_env); virtual ~InferenceSession(); @@ -388,7 +389,7 @@ class InferenceSession { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); void ConstructorCommon(const SessionOptions& session_options, - logging::LoggingManager* logging_manager); + const Environment& session_env); bool HasLocalSchema() const { return !custom_schema_registries_.empty(); @@ -469,6 +470,9 @@ class InferenceSession { // Threadpool for this session std::unique_ptr thread_pool_; std::unique_ptr inter_op_thread_pool_; + onnxruntime::concurrency::ThreadPool* intra_op_thread_pool_from_env_{}; + onnxruntime::concurrency::ThreadPool* inter_op_thread_pool_from_env_{}; + bool use_per_session_threads_; // initialized from session options KernelRegistryManager kernel_registry_manager_; std::list> custom_schema_registries_; @@ -510,20 +514,20 @@ class InferenceSession { InterOpDomains interop_domains_; #endif // used to support platform telemetry - static std::atomic global_session_id_; // a monotonically increasing session id - uint32_t session_id_; // the current session's id + static std::atomic global_session_id_; // a monotonically increasing session id + uint32_t session_id_; // the current session's id struct Telemetry { Telemetry() : time_sent_last_(), time_sent_last_evalutation_start_() {} - uint32_t total_runs_since_last_ = 0; // the total number of Run() calls since the last report - long long total_run_duration_since_last_ = 0; // the total duration (us) of Run() calls since the last report - std::string event_name_; // where the model is loaded from: ["model_loading_uri", "model_loading_proto", "model_loading_istream"] + uint32_t total_runs_since_last_ = 0; // the total number of Run() calls since the last report + long long total_run_duration_since_last_ = 0; // the total duration (us) of Run() calls since the last report + std::string event_name_; // where the model is loaded from: ["model_loading_uri", "model_loading_proto", "model_loading_istream"] - TimePoint time_sent_last_; // the TimePoint of the last report + TimePoint time_sent_last_; // the TimePoint of the last report TimePoint time_sent_last_evalutation_start_; - // Event Rate per provider < 20 peak events per second - constexpr static long long kDurationBetweenSending = 1000 * 1000 * 60 * 10; // duration in (us). send a report every 10 mins - constexpr static long long kDurationBetweenSendingEvaluationStart = 1000 * 50; // duration in (us). send a EvaluationStop Event every 50 ms; + // Event Rate per provider < 20 peak events per second + constexpr static long long kDurationBetweenSending = 1000 * 1000 * 60 * 10; // duration in (us). send a report every 10 mins + constexpr static long long kDurationBetweenSendingEvaluationStart = 1000 * 50; // duration in (us). send a EvaluationStop Event every 50 ms; bool isEvaluationStart = false; } telemetry_; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ae328b70b26e1..03dc991ffae58 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -80,6 +80,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateEnv, OrtLoggingLevel default_warning_level, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithGlobalThreadPools, OrtLoggingLevel default_warning_level, + _In_ const char* logid, _In_ ThreadingOptions tp_options, _Outptr_ OrtEnv** out) { + API_IMPL_BEGIN + OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, default_warning_level, logid}; + Status status; + *out = OrtEnv::GetInstance(lm_info, status, &tp_options); + return ToOrtStatus(status); + API_IMPL_END +} + // enable platform telemetry ORT_API_STATUS_IMPL(OrtApis::EnableTelemetryEvents, _In_ const OrtEnv* ort_env) { API_IMPL_BEGIN @@ -414,7 +424,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O try { sess = onnxruntime::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - model_path, env->GetLoggingManager()); + model_path, env->GetEnvironment()); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -429,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In try { sess = onnxruntime::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - model_data, static_cast(model_data_length), env->GetLoggingManager()); + model_data, static_cast(model_data_length), env->GetEnvironment()); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -1362,7 +1372,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_2 = { +static constexpr OrtApi ort_api_1_to_3 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -1500,6 +1510,10 @@ static constexpr OrtApi ort_api_1_to_2 = { &OrtApis::ModelMetadataLookupCustomMetadataMap, &OrtApis::ModelMetadataGetVersion, &OrtApis::ReleaseModelMetadata, + + // Version 3 + &OrtApis::CreateEnvWithGlobalThreadPools, + &OrtApis::DisablePerSessionThreads, }; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) @@ -1507,8 +1521,8 @@ static constexpr OrtApi ort_api_1_to_2 = { static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change"); ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { - if (version >= 1 && version <= 2) - return &ort_api_1_to_2; + if (version >= 1 && version <= 3) + return &ort_api_1_to_3; return nullptr; // Unsupported version } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 8ebed4a71b930..54de7ee65040f 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -178,4 +178,9 @@ ORT_API_STATUS_IMPL(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _ // OrtSequenceTypeInfo Accessors ORT_API_STATUS_IMPL(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateEnvWithGlobalThreadPools, OrtLoggingLevel default_logging_level, _In_ const char* logid, + _In_ ThreadingOptions t_options, _Outptr_ OrtEnv** out) +ORT_ALL_ARGS_NONNULL; + +ORT_API_STATUS_IMPL(DisablePerSessionThreads, _In_ OrtSessionOptions* options); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index a74de2242d510..662cf47bce7d0 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -10,7 +10,6 @@ #include "core/session/environment.h" #include "core/common/logging/sinks/clog_sink.h" #include "core/common/logging/logging.h" -#include "core/session/environment.h" using namespace onnxruntime; using namespace onnxruntime::logging; @@ -30,38 +29,43 @@ void LoggingWrapper::SendImpl(const onnxruntime::logging::Timestamp& /*timestamp logger_id.c_str(), s.c_str(), message.Message().c_str()); } -OrtEnv::OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager) - : value_(std::move(value1)), logging_manager_(std::move(logging_manager)) { +OrtEnv::OrtEnv(std::unique_ptr value1) + : value_(std::move(value1)) { } -OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status) { +OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, + onnxruntime::common::Status& status, + const ThreadingOptions* tp_options) { std::lock_guard lock(m_); + std::unique_ptr lmgr; + std::string name = lm_info.logid; + if (lm_info.logging_function) { + std::unique_ptr logger = onnxruntime::make_unique(lm_info.logging_function, + lm_info.logger_param); + lmgr.reset(new LoggingManager(std::move(logger), + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &name)); + } else { + lmgr.reset(new LoggingManager(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &name)); + } + if (!p_instance_) { std::unique_ptr env; - status = onnxruntime::Environment::Create(env); + if (!tp_options) { + status = onnxruntime::Environment::Create(std::move(lmgr), env); + } else { + status = onnxruntime::Environment::Create(std::move(lmgr), env, tp_options, true); + } if (!status.IsOK()) { return nullptr; } - - std::unique_ptr lmgr; - std::string name = lm_info.logid; - if (lm_info.logging_function) { - std::unique_ptr logger = onnxruntime::make_unique(lm_info.logging_function, - lm_info.logger_param); - lmgr.reset(new LoggingManager(std::move(logger), - static_cast(lm_info.default_warning_level), - false, - LoggingManager::InstanceType::Default, - &name)); - } else { - lmgr.reset(new LoggingManager(std::unique_ptr{new CLogSink{}}, - static_cast(lm_info.default_warning_level), - false, - LoggingManager::InstanceType::Default, - &name)); - } - - p_instance_ = new OrtEnv(std::move(env), std::move(lmgr)); + p_instance_ = new OrtEnv(std::move(env)); } ++ref_count_; return p_instance_; @@ -78,13 +82,4 @@ void OrtEnv::Release(OrtEnv* env_ptr) { delete p_instance_; p_instance_ = nullptr; } -} - -LoggingManager* OrtEnv::GetLoggingManager() const { - return logging_manager_.get(); -} - -void OrtEnv::SetLoggingManager(std::unique_ptr logging_manager) { - std::lock_guard lock(m_); - logging_manager_ = std::move(logging_manager); } \ No newline at end of file diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index c93d2937c7a7b..c383639cc814b 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -42,13 +42,19 @@ struct OrtEnv { const char* logid{}; }; - static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status); + static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, + onnxruntime::common::Status& status, + const ThreadingOptions* tp_options = nullptr); static void Release(OrtEnv* env_ptr); - onnxruntime::logging::LoggingManager* GetLoggingManager() const; + const onnxruntime::Environment& GetEnvironment() const { + return *(value_.get()); + } - void SetLoggingManager(std::unique_ptr logging_manager); + // onnxruntime::logging::LoggingManager* GetLoggingManager() const; + + // void SetLoggingManager(std::unique_ptr logging_manager); private: static OrtEnv* p_instance_; @@ -56,9 +62,8 @@ struct OrtEnv { static int ref_count_; std::unique_ptr value_; - std::unique_ptr logging_manager_; - OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager); + OrtEnv(std::unique_ptr value1); ~OrtEnv() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv); diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index a19a8754e6ded..e9999eb2f567c 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -159,7 +159,7 @@ void LayerNormOpTester::ComputeWithCPU(std::vector& cpu_fetches) { run_options.run_log_verbosity_level = 1; // run with LayerNormalization - InferenceSession layernorm_session_object{so}; + InferenceSession layernorm_session_object{so, GetEnvironment()}; std::string s1; p_model->ToProto().SerializeToString(&s1); std::istringstream str(s1); @@ -193,7 +193,7 @@ void LayerNormOpTester::ComputeWithCUDA(std::vector& cuda_fetches) { run_options.run_log_verbosity_level = 1; auto cuda_execution_provider = DefaultCudaExecutionProvider(); - InferenceSession cuda_session_object{so}; + InferenceSession cuda_session_object{so, GetEnvironment()}; EXPECT_TRUE(cuda_session_object.RegisterExecutionProvider(std::move(cuda_execution_provider)).IsOK()); std::string s; @@ -226,7 +226,7 @@ void LayerNormOpTester::ComputeOriSubgraphWithCPU(std::vector& subgraph run_options.run_log_verbosity_level = 1; Status status; - InferenceSession subgraph_session_object{so, &DefaultLoggingManager()}; + InferenceSession subgraph_session_object{so, GetEnvironment()}; ASSERT_TRUE((status = subgraph_session_object.Load("testdata/layernorm.onnx")).IsOK()) << status; ASSERT_TRUE((status = subgraph_session_object.Initialize()).IsOK()) << status; ASSERT_TRUE((status = subgraph_session_object.Run(run_options, feeds, output_names, &subgraph_fetches)).IsOK()) << status; diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index c3ca950c81389..bd49cf071f44e 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -277,7 +277,7 @@ TEST(ExecutionFrameTestWithoutSessionState, BadModelInvalidDimParamUsage) { SessionOptions so; so.session_logid = "BadModelInvalidDimParamUsage"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status st; ASSERT_TRUE((st = session_object.Load("testdata/invalid_dim_param_value_repetition.onnx")).IsOK()) << st; ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; diff --git a/onnxruntime/test/framework/float_16_test.cc b/onnxruntime/test/framework/float_16_test.cc index f233e6862b2db..2a84fcf048167 100644 --- a/onnxruntime/test/framework/float_16_test.cc +++ b/onnxruntime/test/framework/float_16_test.cc @@ -137,7 +137,7 @@ TEST(Float16_Tests, Mul_16_Test) { so.session_logid = "InferenceSessionTests.NoTimeout"; std::shared_ptr registry = std::make_shared(); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); auto mulfp16_schema = GetMulFP16Schema(); std::vector schemas = {mulfp16_schema}; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index baa15a8a8a53a..d26be7a40ed21 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -38,6 +38,7 @@ #include "test/optimizer/dummy_graph_transformer.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "gtest/gtest.h" +#include "core/session/environment.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -120,7 +121,7 @@ class FuseExecutionProvider : public IExecutionProvider { class InferenceSessionGetGraphWrapper : public InferenceSession { public: explicit InferenceSessionGetGraphWrapper(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) { + const Environment& env) : InferenceSession(session_options, env) { } const Graph& GetGraph() { @@ -337,7 +338,7 @@ TEST(InferenceSessionTests, NoTimeout) { so.session_logid = "InferenceSessionTests.NoTimeout"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status st; ASSERT_TRUE((st = session_object.Load(MODEL_URI)).IsOK()) << st.ErrorMessage(); ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st.ErrorMessage(); @@ -353,7 +354,7 @@ TEST(InferenceSessionTests, DisableCPUArena) { so.session_logid = "InferenceSessionTests.DisableCPUArena"; so.enable_cpu_mem_arena = false; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -369,7 +370,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { const string test_model = "testdata/transform/abs-id-max.onnx"; so.session_logid = "InferenceSessionTests.TestModelSerialization"; so.graph_optimization_level = TransformerLevel::Default; - InferenceSessionGetGraphWrapper session_object_noopt{so, &DefaultLoggingManager()}; + InferenceSessionGetGraphWrapper session_object_noopt{so, GetEnvironment()}; ASSERT_TRUE(session_object_noopt.Load(test_model).IsOK()); ASSERT_TRUE(session_object_noopt.Initialize().IsOK()); @@ -381,7 +382,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { // Load model with level 1 transform level. so.graph_optimization_level = TransformerLevel::Level1; so.optimized_model_filepath = ToWideString(test_model + "-TransformLevel-" + std::to_string(static_cast(so.graph_optimization_level))); - InferenceSessionGetGraphWrapper session_object{so, &DefaultLoggingManager()}; + InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(test_model).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -391,7 +392,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(op_to_count["Identity"] == 0); // Serialize model to the same file path again to make sure that rewrite doesn't fail. - InferenceSession overwrite_session_object{so, &DefaultLoggingManager()}; + InferenceSession overwrite_session_object{so, GetEnvironment()}; ASSERT_TRUE(overwrite_session_object.Load(test_model).IsOK()); ASSERT_TRUE(overwrite_session_object.Initialize().IsOK()); @@ -400,7 +401,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { so_opt.session_logid = "InferenceSessionTests.TestModelSerialization"; so_opt.graph_optimization_level = TransformerLevel::Default; so_opt.optimized_model_filepath = ToWideString(so.optimized_model_filepath) + ToWideString("-TransformLevel-" + std::to_string(static_cast(so_opt.graph_optimization_level))); - InferenceSession session_object_opt{so_opt, &DefaultLoggingManager()}; + InferenceSession session_object_opt{so_opt, GetEnvironment()}; ASSERT_TRUE(session_object_opt.Load(so.optimized_model_filepath).IsOK()); ASSERT_TRUE(session_object_opt.Initialize().IsOK()); @@ -419,7 +420,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { // Assert that empty optimized model file-path doesn't fail loading. so_opt.optimized_model_filepath = ToWideString(""); - InferenceSession session_object_emptyValidation{so_opt, &DefaultLoggingManager()}; + InferenceSession session_object_emptyValidation{so_opt, GetEnvironment()}; ASSERT_TRUE(session_object_emptyValidation.Load(test_model).IsOK()); ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } @@ -455,7 +456,7 @@ TEST(InferenceSessionTests, ModelMetadata) { SessionOptions so; so.session_logid = "InferenceSessionTests.ModelMetadata"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; auto model_uri = ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"); ASSERT_TRUE(session_object.Load(model_uri).IsOK()); @@ -525,7 +526,9 @@ TEST(InferenceSessionTests, CheckRunLogger) { std::unique_ptr(capturing_sink), logging::Severity::kVERBOSE, false, LoggingManager::InstanceType::Temporal); - InferenceSession session_object{so, logging_manager.get()}; + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + InferenceSession session_object{so, *env.get()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -555,7 +558,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { so.enable_profiling = true; so.profile_file_prefix = ORT_TSTR("onnxprofile_profile_test"); - InferenceSession session_object(so); + InferenceSession session_object(so, GetEnvironment()); ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -594,7 +597,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) { so.session_logid = "CheckRunProfiler"; - InferenceSession session_object(so); + InferenceSession session_object(so, GetEnvironment()); ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -632,7 +635,7 @@ TEST(InferenceSessionTests, MultipleSessionsNoTimeout) { SessionOptions session_options; session_options.session_logid = "InferenceSessionTests.MultipleSessionsNoTimeout"; - InferenceSession session_object{session_options, &DefaultLoggingManager()}; + InferenceSession session_object{session_options, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -657,7 +660,7 @@ TEST(InferenceSessionTests, PreAllocateOutputVector) { so.session_logid = "InferenceSessionTests.PreAllocateOutputVector"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -684,7 +687,9 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { false, LoggingManager::InstanceType::Temporal); - InferenceSession session_object{so, logging_manager.get()}; + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + InferenceSession session_object{so, *env.get()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -719,7 +724,7 @@ TEST(InferenceSessionTests, TestWithIstream) { so.session_logid = "InferenceSessionTests.TestWithIstream"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; std::ifstream model_file_stream(MODEL_URI, ios::in | ios::binary); ASSERT_TRUE(model_file_stream.good()); @@ -736,7 +741,7 @@ TEST(InferenceSessionTests, TestRegisterExecutionProvider) { so.session_logid = "InferenceSessionTests.TestWithIstream"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; CPUExecutionProviderInfo epi; ASSERT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique(epi)).IsOK()); @@ -760,7 +765,7 @@ static void TestBindHelper(const std::string& log_str, so.session_logid = "InferenceSessionTests." + log_str; so.session_log_verbosity_level = 1; // change to 1 for detailed logging - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; if (bind_provider_type == kCudaExecutionProvider || run_provider_type == kCudaExecutionProvider) { #ifdef USE_CUDA @@ -798,7 +803,7 @@ TEST(InferenceSessionTests, TestBindCpu) { TEST(InferenceSessionTests, TestIOBindingReuse) { SessionOptions so; - InferenceSession session_object(so); + InferenceSession session_object(so, GetEnvironment()); std::unique_ptr p_model; CreateMatMulModel(p_model, kCpuExecutionProvider); @@ -839,7 +844,7 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) { so.session_logid = "InferenceSessionTests.InvalidInputTypeOfTensorElement"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -912,7 +917,7 @@ TEST(InferenceSessionTests, ModelWithoutOpset) { so.session_logid = "InferenceSessionTests.ModelWithoutOpset"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status retval = session_object.Load(MODEL_URI_NO_OPSET); ASSERT_FALSE(retval.IsOK()); if (!retval.IsOK()) { @@ -923,11 +928,11 @@ TEST(InferenceSessionTests, ModelWithoutOpset) { static common::Status RunOptionalInputTest(bool add_required_input, bool add_optional_input, bool add_invalid_input, - int model_ir_version) { + int model_ir_version, + const Environment& sess_env) { SessionOptions so; so.session_logid = "RunOptionalInputTest"; - - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, sess_env}; Status status; std::string model_path = "testdata/optional_inputs_ir" + std::to_string(model_ir_version) + ".onnx"; @@ -1001,25 +1006,26 @@ static common::Status RunOptionalInputTest(bool add_required_input, // for V4 allow it TEST(InferenceSessionTests, TestOptionalInputs) { std::vector ir_versions{3, 4}; + const auto& sess_env = GetEnvironment(); for (auto version : ir_versions) { // required input only - auto status = RunOptionalInputTest(true, false, false, version); + auto status = RunOptionalInputTest(true, false, false, version, sess_env); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); // required and optional input - status = RunOptionalInputTest(true, true, false, version); + status = RunOptionalInputTest(true, true, false, version, sess_env); if (version == 3) { ASSERT_FALSE(status.IsOK()) << status.ErrorMessage(); } else { ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); } // required, optional and invalid input - status = RunOptionalInputTest(true, true, true, version); + status = RunOptionalInputTest(true, true, true, version, sess_env); ASSERT_FALSE(status.IsOK()); EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name")); // missing required - status = RunOptionalInputTest(false, true, false, version); + status = RunOptionalInputTest(false, true, false, version, sess_env); ASSERT_FALSE(status.IsOK()); if (version == 3) { EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name")); @@ -1065,7 +1071,7 @@ TEST(ExecutionProviderTest, FunctionTest) { SessionOptions so; so.session_logid = "ExecutionProviderTest.FunctionTest"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; status = session_object.Load(model_file_name); ASSERT_TRUE(status.IsOK()); status = session_object.Initialize(); @@ -1104,7 +1110,7 @@ TEST(ExecutionProviderTest, FunctionTest) { ASSERT_TRUE(status.IsOK()); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); - InferenceSession session_object_2{so}; + InferenceSession session_object_2{so, GetEnvironment()}; session_object_2.RegisterExecutionProvider(std::move(testCPUExecutionProvider)); session_object_2.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::FuseExecutionProvider>()); status = session_object_2.Load(model_file_name); @@ -1172,7 +1178,7 @@ TEST(ExecutionProviderTest, FunctionInlineTest) { SessionOptions so; so.session_logid = "ExecutionProviderTest.FunctionInlineTest"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; status = session_object.Load(model_file_name); ASSERT_TRUE(status.IsOK()); status = session_object.Initialize(); @@ -1263,7 +1269,7 @@ TEST(InferenceSessionTests, TestTruncatedSequence) { // now run the truncated model SessionOptions so; - InferenceSession session_object(so); + InferenceSession session_object(so, GetEnvironment()); ASSERT_TRUE(session_object.Load(LSTM_MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -1373,7 +1379,7 @@ TEST(InferenceSessionTests, TestTruncatedSequence) { TEST(InferenceSessionTests, TestCopyToFromDevices) { SessionOptions so; so.session_logid = "InferenceSessionTests.TestCopyToFromDevices"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -1435,7 +1441,7 @@ TEST(InferenceSessionTests, TestRegisterTransformers) { SessionOptions so; so.session_logid = "InferenceSessionTests.TestL1AndL2Transformers"; so.graph_optimization_level = static_cast(i); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; // Create and register dummy graph transformer auto dummy_transformer_unique_ptr = onnxruntime::make_unique("DummyTransformer"); @@ -1465,7 +1471,7 @@ TEST(InferenceSessionTests, TestL1AndL2Transformers) { SessionOptions so; so.session_logid = "InferenceSessionTests.TestL1AndL2Transformers"; so.graph_optimization_level = TransformerLevel::Level2; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(model_uri).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); } @@ -1553,7 +1559,7 @@ TEST(InferenceSessionTests, ModelThatTriggersAllocationPlannerToReuseDoubleTenso so.session_logid = "InferenceSessionTests.ModelThatTriggersAllocationPlannerBug"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status st; ASSERT_TRUE((st = session_object.Load("testdata/test_cast_back_to_back_non_const_mixed_types_origin.onnx")).IsOK()) << st.ErrorMessage(); @@ -1618,7 +1624,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { std::string model_path = "testdata/model_with_valid_ort_config_json.onnx"; // Create session - InferenceSession session_object_1{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_1{so, model_path, GetEnvironment()}; // Load() and Initialize() the session Status st; @@ -1656,7 +1662,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { so.intra_op_num_threads = 2; // Create session - InferenceSession session_object_2{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_2{so, model_path, GetEnvironment()}; // Load() and Initialize() the session ASSERT_TRUE((st = session_object_2.Load()).IsOK()) << st.ErrorMessage(); @@ -1685,7 +1691,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { // Create session (should throw as the json within the model is invalid/improperly formed) try { - InferenceSession session_object_1{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_1{so, model_path, GetEnvironment()}; } catch (const std::exception& e) { std::string e_message(std::string(e.what())); ASSERT_TRUE(e_message.find("Could not finalize session options while constructing the inference session. Error Message:") != std::string::npos); @@ -1704,7 +1710,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { so.intra_op_num_threads = 2; // Create session - InferenceSession session_object_2{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_2{so, model_path, GetEnvironment()}; // Load() and Initialize() the session Status st; @@ -1734,7 +1740,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { std::string model_path = "testdata/transform/abs-id-max.onnx"; // Create session - InferenceSession session_object_1{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_1{so, model_path, GetEnvironment()}; // Load() and Initialize() the session Status st; @@ -1755,7 +1761,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { #endif // Create session - InferenceSession session_object_2{so, model_path, &DefaultLoggingManager()}; // so has inter_op_num_threads set to 2 + InferenceSession session_object_2{so, model_path, GetEnvironment()}; // so has inter_op_num_threads set to 2 // Load() and Initialize() the session ASSERT_TRUE((st = session_object_2.Load()).IsOK()) << st.ErrorMessage(); @@ -1779,7 +1785,7 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { // Create session (should throw because of the unsupported value for the env var - ORT_LOAD_CONFIG_FROM_MODEL) try { - InferenceSession session_object_1{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_1{so, model_path, GetEnvironment()}; } catch (const std::exception& e) { std::string e_message(std::string(e.what())); ASSERT_TRUE(e_message.find("Could not finalize session options while constructing the inference session. Error Message:") != std::string::npos); diff --git a/onnxruntime/test/framework/local_kernel_registry_test.cc b/onnxruntime/test/framework/local_kernel_registry_test.cc index 924d8c756b64a..68b1b40dc9993 100644 --- a/onnxruntime/test/framework/local_kernel_registry_test.cc +++ b/onnxruntime/test/framework/local_kernel_registry_test.cc @@ -229,7 +229,7 @@ TEST(CustomKernelTests, CustomKernelWithBuildInSchema) { // Register a foo kernel which is doing Add, but bind to Mul. std::shared_ptr registry = std::make_shared(); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); auto def = FooKernelDef("Mul"); @@ -261,7 +261,7 @@ TEST(CustomKernelTests, CustomKernelWithCustomSchema) { std::shared_ptr registry = std::make_shared(); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); //register foo schema @@ -307,7 +307,7 @@ TEST(CustomKernelTests, CustomKernelWithOptionalOutput) { //Register a foo kernel which is doing Add, but bind to Mul. EXPECT_TRUE(registry->RegisterCustomKernel(def, CreateOptionalOpKernel).IsOK()); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); EXPECT_TRUE(session_object.Load(OPTIONAL_MODEL1_URI).IsOK()); EXPECT_TRUE(session_object.Initialize().IsOK()); diff --git a/onnxruntime/test/framework/opaque_kernels_test.cc b/onnxruntime/test/framework/opaque_kernels_test.cc index ae13e2bd699d8..632f68deb064b 100644 --- a/onnxruntime/test/framework/opaque_kernels_test.cc +++ b/onnxruntime/test/framework/opaque_kernels_test.cc @@ -282,7 +282,7 @@ TEST_F(OpaqueTypeTests, RunModel) { // Both the session and the model need custom registries // so we construct it here before the model std::shared_ptr registry = std::make_shared(); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); auto ops_schema = GetConstructSparseTensorSchema(); diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 307ace0fe3caa..e016149cefdc1 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -279,7 +279,7 @@ class SparseTensorTests : public testing::Test { std::vector types; public: - SparseTensorTests() : session_object(SessionOptions(), &DefaultLoggingManager()), + SparseTensorTests() : session_object(SessionOptions(), GetEnvironment()), registry(std::make_shared()), custom_schema_registries_{registry->GetOpschemaRegistry()}, domain_to_version{{onnxruntime::kMLDomain, 10}}, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 39c9eed51dfa5..ebb4f63d36b1a 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -263,7 +263,7 @@ TEST(GraphTransformationTests, SubgraphWithConstantInputs) { SessionOptions so; so.graph_optimization_level = TransformerLevel::Level2; so.session_logid = "GraphTransformationTests.LoadModelToTransform"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(model_uri).IsOK()); std::shared_ptr p_model; @@ -635,7 +635,7 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { SessionOptions so; so.session_logid = "GraphTransformationTests.LoadModelToTransform"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(model_uri).IsOK()); std::shared_ptr p_model; @@ -1502,22 +1502,22 @@ static void TestSkipLayerNormFusion(const std::basic_string& file_pat std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Div"] == 0); - ASSERT_TRUE(op_to_count["Add"] == add_count ); + ASSERT_TRUE(op_to_count["Add"] == add_count); ASSERT_TRUE(op_to_count["Sub"] == 0); ASSERT_TRUE(op_to_count["ReduceMean"] == 0); ASSERT_TRUE(op_to_count["Pow"] == 0); ASSERT_TRUE(op_to_count["Sqrt"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count ); - ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == skip_ln_count ); + ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count); + ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == skip_ln_count); } TEST(GraphTransformationTests, SkipLayerNormFusionTest) { TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1 ); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1 ); - TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1 ); - TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1 ); - TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0 ); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0); } TEST(GraphTransformationTests, EmbedLayerNormFusionFormat1) { diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 348a46d7751c6..f51cc59205014 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -9,6 +9,7 @@ #include "test/compare_ortvalue.h" #include "gtest/gtest.h" #include "core/mlas/inc/mlas.h" +#include "core/session/environment.h" namespace onnxruntime { namespace test { @@ -17,7 +18,7 @@ namespace test { class NchwcInferenceSession : public InferenceSession { public: explicit NchwcInferenceSession(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) { + const Environment& env) : InferenceSession(session_options, env) { } std::unordered_map CountOpsInGraph() { @@ -208,7 +209,7 @@ void NchwcOptimizerTester(const std::function& bu SessionOptions session_options; session_options.graph_optimization_level = level; session_options.session_logid = "NchwcOptimizerTests"; - NchwcInferenceSession session{session_options, &DefaultLoggingManager()}; + NchwcInferenceSession session{session_options, GetEnvironment()}; ASSERT_TRUE(session.Load(model_data.data(), static_cast(model_data.size())).IsOK()); ASSERT_TRUE(session.Initialize().IsOK()); diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 0b9e688a079cb..0c7b53fef9946 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -580,7 +580,7 @@ TEST(Loop, SubgraphInputShadowsOuterScopeValue) { SessionOptions so; so.session_logid = "SubgraphInputShadowsOuterScopeValue"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status st; ASSERT_TRUE((st = session_object.Load("testdata/subgraph_input_shadows_outer_scope_value.onnx")).IsOK()) << st; ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index fe69db8a90028..57c3d3acf4c88 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -239,7 +239,7 @@ void Check(const OpTester::Data& expected_data, else { // the default for existing tests const float max_value = fmax(fabs(f_expected[i]), fabs(f_output[i])); - if (max_value != 0) { // max_value = 0 means output and expected are 0s. + if (max_value != 0) { // max_value = 0 means output and expected are 0s. const float rel_error = fabs(f_expected[i] - f_output[i]) / max_value; EXPECT_NEAR(0, rel_error, threshold) << "provider_type: " << provider_type; @@ -687,10 +687,14 @@ void OpTester::Run( FillFeedsAndOutputNames(feeds, output_names); // Run the model static const std::string all_provider_types[] = { - kCpuExecutionProvider, kCudaExecutionProvider, - kDnnlExecutionProvider, kNGraphExecutionProvider, - kNupharExecutionProvider, kTensorrtExecutionProvider, - kOpenVINOExecutionProvider, kDmlExecutionProvider, + kCpuExecutionProvider, + kCudaExecutionProvider, + kDnnlExecutionProvider, + kNGraphExecutionProvider, + kNupharExecutionProvider, + kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, + kDmlExecutionProvider, kAclExecutionProvider, }; @@ -705,7 +709,7 @@ void OpTester::Run( } } - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(!execution_providers->empty()) << "Empty execution providers vector."; @@ -733,7 +737,7 @@ void OpTester::Run( so.enable_mem_pattern = false; so.execution_mode = ExecutionMode::ORT_SEQUENTIAL; } - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; for (auto& custom_session_registry : custom_session_registries_) session_object.RegisterCustomRegistry(custom_session_registry); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 4b5b6731c061b..2fc69a2bd9871 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -185,8 +185,8 @@ TEST(CApiTest, dim_param) { } INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, - CApiTestWithProvider, - ::testing::Values(0, 1, 2, 3, 4)); + CApiTestWithProvider, + ::testing::Values(0, 1, 2, 3, 4)); struct OrtTensorDimensions : std::vector { OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) { diff --git a/onnxruntime/test/util/include/test/test_environment.h b/onnxruntime/test/util/include/test/test_environment.h index 91e8a2440e8ad..1e0319fcf1104 100644 --- a/onnxruntime/test/util/include/test/test_environment.h +++ b/onnxruntime/test/util/include/test/test_environment.h @@ -10,8 +10,12 @@ #endif namespace onnxruntime { +class Environment; + namespace test { +const ::onnxruntime::Environment& GetEnvironment(); + /** Static logging manager with a CLog based sink so logging macros that use the default logger will work */ diff --git a/onnxruntime/test/util/test_environment.cc b/onnxruntime/test/util/test_environment.cc index 3c59d577bd4f6..3f77d8e474229 100644 --- a/onnxruntime/test/util/test_environment.cc +++ b/onnxruntime/test/util/test_environment.cc @@ -12,6 +12,7 @@ #include "core/common/logging/logging.h" #include "core/common/logging/sinks/clog_sink.h" #include "core/session/ort_env.h" +#include "core/session/environment.h" using namespace ::onnxruntime::logging; extern std::unique_ptr ort_env; @@ -21,8 +22,12 @@ namespace test { static std::unique_ptr<::onnxruntime::logging::LoggingManager> s_default_logging_manager; +const ::onnxruntime::Environment& GetEnvironment() { + return ((OrtEnv*)*ort_env.get())->GetEnvironment(); +} + ::onnxruntime::logging::LoggingManager& DefaultLoggingManager() { - return *((OrtEnv*)*ort_env.get())->GetLoggingManager(); + return *((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager(); } } // namespace test From 8dec4b7c01992f23913be856c75a303a726a10c3 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Thu, 12 Mar 2020 18:50:51 -0700 Subject: [PATCH 2/9] Fix build issues --- .../onnxruntime/core/session/environment.h | 4 + .../core/session/onnxruntime_c_api.h | 4 +- onnxruntime/core/session/inference_session.cc | 16 +-- onnxruntime/core/session/inference_session.h | 16 +-- onnxruntime/core/session/onnxruntime_c_api.cc | 4 +- onnxruntime/core/session/ort_env.cc | 8 ++ onnxruntime/core/session/ort_env.h | 5 +- .../python/onnxruntime_pybind_state.cc | 135 +++++++++--------- .../test/framework/cuda/fence_cuda_test.cc | 4 +- .../test/framework/inference_session_test.cc | 16 +-- .../providers/tensorrt/tensorrt_basic_test.cc | 7 +- winml/adapter/winml_adapter_environment.cpp | 2 +- winml/adapter/winml_adapter_session.cpp | 6 +- 13 files changed, 120 insertions(+), 107 deletions(-) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 632c03ac9b053..12d4e931188f2 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -30,6 +30,10 @@ class Environment { return logging_manager_.get(); } + void SetLoggingManager(std::unique_ptr logging_manager) { + logging_manager_ = std::move(logging_manager); + } + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const { return intra_op_thread_pool_.get(); } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6a186f05ff3eb..61860a0c68c57 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -214,10 +214,10 @@ typedef enum OrtMemType { typedef struct ThreadingOptions { // number of threads used to parallelize execution of an op - int intra_op_num_threads = 0; // default value + int intra_op_num_threads; // use 0 if you want onnxruntime to choose a value for you // number of threads used to parallelize execution across ops - int inter_op_num_threads = 0; // default value + int inter_op_num_threads; // use 0 if you want onnxruntime to choose a value for you } ThreadingOptions; struct OrtApi; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index be98c23100f5d..bed40d404f272 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -221,8 +221,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, } InferenceSession::InferenceSession(const SessionOptions& session_options, - const std::string& model_uri, - const Environment& session_env) + const Environment& session_env, + const std::string& model_uri) : insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); model_proto_ = onnxruntime::make_unique(); @@ -236,8 +236,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, #ifdef _WIN32 InferenceSession::InferenceSession(const SessionOptions& session_options, - const std::wstring& model_uri, - const Environment& session_env) + const Environment& session_env, + const std::wstring& model_uri) : insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); model_proto_ = onnxruntime::make_unique(); @@ -251,8 +251,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, #endif InferenceSession::InferenceSession(const SessionOptions& session_options, - std::istream& model_istream, - const Environment& session_env) + const Environment& session_env, + std::istream& model_istream) : insert_cast_transformer_("CastFloat16Transformer") { google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); model_proto_ = onnxruntime::make_unique(); @@ -264,9 +264,9 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, } InferenceSession::InferenceSession(const SessionOptions& session_options, + const Environment& session_env, const void* model_data, - int model_data_len, - const Environment& session_env) + int model_data_len) : insert_cast_transformer_("CastFloat16Transformer") { model_proto_ = onnxruntime::make_unique(); const bool result = model_proto_->ParseFromArray(model_data, model_data_len); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 93c1e1b26bf59..ec23697b4ce49 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -121,12 +121,12 @@ class InferenceSession { This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, - const std::string& model_uri, - const Environment& session_env); + const Environment& session_env, + const std::string& model_uri); #ifdef _WIN32 InferenceSession(const SessionOptions& session_options, - const std::wstring& model_uri, - const Environment& session_env); + const Environment& session_env, + const std::wstring& model_uri); #endif /** @@ -142,8 +142,8 @@ class InferenceSession { This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, - std::istream& model_istream, - const Environment& session_env); + const Environment& session_env, + std::istream& model_istream); /** Create a new InferenceSession @@ -159,9 +159,9 @@ class InferenceSession { This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, + const Environment& session_env, const void* model_data, - int model_data_len, - const Environment& session_env); + int model_data_len); virtual ~InferenceSession(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 03dc991ffae58..d80d41e1265f5 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -424,7 +424,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O try { sess = onnxruntime::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - model_path, env->GetEnvironment()); + env->GetEnvironment(), model_path); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -439,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In try { sess = onnxruntime::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - model_data, static_cast(model_data_length), env->GetEnvironment()); + env->GetEnvironment(), model_data, static_cast(model_data_length)); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 662cf47bce7d0..2b027404cd3b1 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -82,4 +82,12 @@ void OrtEnv::Release(OrtEnv* env_ptr) { delete p_instance_; p_instance_ = nullptr; } +} + +onnxruntime::logging::LoggingManager* OrtEnv::GetLoggingManager() const { + return value_->GetLoggingManager(); +} + +void OrtEnv::SetLoggingManager(std::unique_ptr logging_manager) { + value_->SetLoggingManager(std::move(logging_manager)); } \ No newline at end of file diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index c383639cc814b..81055874a82b6 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -52,9 +52,8 @@ struct OrtEnv { return *(value_.get()); } - // onnxruntime::logging::LoggingManager* GetLoggingManager() const; - - // void SetLoggingManager(std::unique_ptr logging_manager); + onnxruntime::logging::LoggingManager* GetLoggingManager() const; + void SetLoggingManager(std::unique_ptr logging_manager); private: static OrtEnv* p_instance_; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d4baca48da303..ead3ba637a5a1 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -234,28 +234,29 @@ void AddTensorAsPyObj(OrtValue& val, std::vector& pyobjs) { GetPyObjFromTensor(rtensor, obj); pyobjs.push_back(obj); } - class SessionObjectInitializer { public: typedef const SessionOptions& Arg1; - typedef logging::LoggingManager* Arg2; + // typedef logging::LoggingManager* Arg2; + static std::string default_logger_id; operator Arg1() { return GetDefaultCPUSessionOptions(); } - operator Arg2() { - static std::string default_logger_id{"Default"}; - static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, - Severity::kWARNING, false, LoggingManager::InstanceType::Default, - &default_logger_id}; - return &default_logging_manager; - } + // operator Arg2() { + // static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, + // Severity::kWARNING, false, LoggingManager::InstanceType::Default, + // &default_logger_id}; + // return &default_logging_manager; + // } static SessionObjectInitializer Get() { return SessionObjectInitializer(); } }; +std::string SessionObjectInitializer::default_logger_id = "Default"; + inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) { auto p = f.CreateProvider(); OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p))); @@ -350,17 +351,17 @@ void InitializeSession(InferenceSession* sess, const std::vector& p OrtPybindThrowIfError(sess->Initialize()); } -void addGlobalMethods(py::module& m) { +void addGlobalMethods(py::module& m, const Environment& env) { m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance."); m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer."); m.def( "get_device", []() -> std::string { return BACKEND_DEVICE; }, "Return the device used to compute the prediction (CPU, MKL, ...)"); m.def( - "set_default_logger_severity", [](int severity) { + "set_default_logger_severity", [&env](int severity) { ORT_ENFORCE(severity >= 0 && severity <= 4, "Invalid logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal"); - logging::LoggingManager* default_logging_manager = SessionObjectInitializer::Get(); + logging::LoggingManager* default_logging_manager = env.GetLoggingManager(); default_logging_manager->SetDefaultLoggerSeverity(static_cast(severity)); }, "Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal"); @@ -546,7 +547,7 @@ void addOpSchemaSubmodule(py::module& m) { #endif //onnxruntime_PYBIND_EXPORT_OPSCHEMA -void addObjectMethods(py::module& m) { +void addObjectMethods(py::module& m, Environment& env) { py::enum_(m, "GraphOptimizationLevel") .value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL) .value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC) @@ -658,72 +659,70 @@ including arg name, arg type (contains both type and shape).)pbdoc") return *(na.Type()); }, "node type") - .def( - "__str__", [](const onnxruntime::NodeArg& na) -> std::string { - std::ostringstream res; - res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; - auto shape = na.Shape(); - std::vector arr; - if (shape == nullptr || shape->dim_size() == 0) { - res << "[]"; + .def("__str__", [](const onnxruntime::NodeArg& na) -> std::string { + std::ostringstream res; + res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; + auto shape = na.Shape(); + std::vector arr; + if (shape == nullptr || shape->dim_size() == 0) { + res << "[]"; + } else { + res << "["; + for (int i = 0; i < shape->dim_size(); ++i) { + if (utils::HasDimValue(shape->dim(i))) { + res << shape->dim(i).dim_value(); + } else if (utils::HasDimParam(shape->dim(i))) { + res << "'" << shape->dim(i).dim_param() << "'"; } else { - res << "["; - for (int i = 0; i < shape->dim_size(); ++i) { - if (utils::HasDimValue(shape->dim(i))) { - res << shape->dim(i).dim_value(); - } else if (utils::HasDimParam(shape->dim(i))) { - res << "'" << shape->dim(i).dim_param() << "'"; - } else { - res << "None"; - } - - if (i < shape->dim_size() - 1) { - res << ", "; - } - } - res << "]"; + res << "None"; } - res << ")"; - return std::string(res.str()); - }, - "converts the node into a readable string") - .def_property_readonly( - "shape", [](const onnxruntime::NodeArg& na) -> std::vector { - auto shape = na.Shape(); - std::vector arr; - if (shape == nullptr || shape->dim_size() == 0) { - return arr; + if (i < shape->dim_size() - 1) { + res << ", "; } + } + res << "]"; + } + res << ")"; - arr.resize(shape->dim_size()); - for (int i = 0; i < shape->dim_size(); ++i) { - if (utils::HasDimValue(shape->dim(i))) { - arr[i] = py::cast(shape->dim(i).dim_value()); - } else if (utils::HasDimParam(shape->dim(i))) { - arr[i] = py::cast(shape->dim(i).dim_param()); - } else { - arr[i] = py::none(); - } - } - return arr; - }, - "node shape (assuming the node holds a tensor)"); + return std::string(res.str()); + }, + "converts the node into a readable string") + .def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector { + auto shape = na.Shape(); + std::vector arr; + if (shape == nullptr || shape->dim_size() == 0) { + return arr; + } + + arr.resize(shape->dim_size()); + for (int i = 0; i < shape->dim_size(); ++i) { + if (utils::HasDimValue(shape->dim(i))) { + arr[i] = py::cast(shape->dim(i).dim_value()); + } else if (utils::HasDimParam(shape->dim(i))) { + arr[i] = py::cast(shape->dim(i).dim_param()); + } else { + arr[i] = py::none(); + } + } + return arr; + }, + "node shape (assuming the node holds a tensor)"); py::class_(m, "SessionObjectInitializer"); py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") // In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char* // without any conversion. So this init method can be used for model file path (string) // and model content (bytes) - .def(py::init([](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) { + .def(py::init([&env](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) { // Given arg is the file path. Invoke the corresponding ctor(). if (is_arg_file_name) { - return onnxruntime::make_unique(so, arg, SessionObjectInitializer::Get()); + return onnxruntime::make_unique(so, env, arg); } // Given arg is the model content as bytes. Invoke the corresponding ctor(). std::istringstream buffer(arg); - return onnxruntime::make_unique(so, buffer, SessionObjectInitializer::Get()); + return onnxruntime::make_unique(so, env, buffer); })) .def( "load_model", [](InferenceSession* sess, std::vector& provider_types) { @@ -867,6 +866,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { #endif + static std::unique_ptr env; auto initialize = [&]() { // Initialization of the module ([]() -> void { @@ -874,8 +874,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { import_array1(); })(); - static std::unique_ptr env; - OrtPybindThrowIfError(Environment::Create(env)); + OrtPybindThrowIfError(Environment::Create(std::make_unique( + std::unique_ptr{new CErrSink{}}, + Severity::kWARNING, false, LoggingManager::InstanceType::Default, + &SessionObjectInitializer::default_logger_id), + env)); static bool initialized = false; if (initialized) { @@ -885,8 +888,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { }; initialize(); - addGlobalMethods(m); - addObjectMethods(m); + addGlobalMethods(m, *env); + addObjectMethods(m, *env); #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA addOpSchemaSubmodule(m); diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index 15e4ce84eefbc..288cf1cb690d2 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -37,7 +37,7 @@ typedef std::vector ArgMap; class FenceCudaTestInferenceSession : public InferenceSession { public: - FenceCudaTestInferenceSession(const SessionOptions& so) : InferenceSession(so) {} + FenceCudaTestInferenceSession(const SessionOptions& so, const Environment& env) : InferenceSession(so, env) {} Status LoadModel(onnxruntime::Model& model) { auto model_proto = model.ToProto(); auto st = Load(model_proto); @@ -117,7 +117,7 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) { DataTypeImpl::GetType()->GetDeleteFunc()); SessionOptions so; - FenceCudaTestInferenceSession session(so); + FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index d26be7a40ed21..a2170ae4abe35 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1536,7 +1536,7 @@ TEST(InferenceSessionTests, TestParallelExecutionWithCudaProvider) { SessionOptions so; so.execution_mode = ExecutionMode::ORT_PARALLEL; so.session_logid = "InferenceSessionTests.TestParallelExecutionWithCudaProvider"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; CUDAExecutionProviderInfo epi; epi.device_id = 0; @@ -1624,7 +1624,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { std::string model_path = "testdata/model_with_valid_ort_config_json.onnx"; // Create session - InferenceSession session_object_1{so, model_path, GetEnvironment()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1662,7 +1662,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { so.intra_op_num_threads = 2; // Create session - InferenceSession session_object_2{so, model_path, GetEnvironment()}; + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // Load() and Initialize() the session ASSERT_TRUE((st = session_object_2.Load()).IsOK()) << st.ErrorMessage(); @@ -1691,7 +1691,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { // Create session (should throw as the json within the model is invalid/improperly formed) try { - InferenceSession session_object_1{so, model_path, GetEnvironment()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; } catch (const std::exception& e) { std::string e_message(std::string(e.what())); ASSERT_TRUE(e_message.find("Could not finalize session options while constructing the inference session. Error Message:") != std::string::npos); @@ -1710,7 +1710,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { so.intra_op_num_threads = 2; // Create session - InferenceSession session_object_2{so, model_path, GetEnvironment()}; + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1740,7 +1740,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { std::string model_path = "testdata/transform/abs-id-max.onnx"; // Create session - InferenceSession session_object_1{so, model_path, GetEnvironment()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1761,7 +1761,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { #endif // Create session - InferenceSession session_object_2{so, model_path, GetEnvironment()}; // so has inter_op_num_threads set to 2 + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // so has inter_op_num_threads set to 2 // Load() and Initialize() the session ASSERT_TRUE((st = session_object_2.Load()).IsOK()) << st.ErrorMessage(); @@ -1785,7 +1785,7 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { // Create session (should throw because of the unsupported value for the env var - ORT_LOAD_CONFIG_FROM_MODEL) try { - InferenceSession session_object_1{so, model_path, GetEnvironment()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; } catch (const std::exception& e) { std::string e_message(std::string(e.what())); ASSERT_TRUE(e_message.find("Could not finalize session options while constructing the inference session. Error Message:") != std::string::npos); diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 015c4efed514f..3bcf1263d2a59 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -87,7 +87,7 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; TensorrtExecutionProviderInfo epi; epi.device_id = 0; @@ -201,7 +201,7 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; TensorrtExecutionProviderInfo epi; epi.device_id = 0; @@ -215,11 +215,10 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) { // Now run status = session_object.Run(run_options, feeds, output_names, &fetches); ASSERT_TRUE(status.IsOK()); - std::vector fetche {fetches.back()}; + std::vector fetche{fetches.back()}; VerifyOutputs(fetche, expected_dims_mul_n, expected_values_mul_n); } - TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { onnxruntime::Model model("graph_removecycleTest", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp index d74c35aa3344f..51b93a8cdb6bb 100644 --- a/winml/adapter/winml_adapter_environment.cpp +++ b/winml/adapter/winml_adapter_environment.cpp @@ -14,8 +14,8 @@ #include "abi_custom_registry_impl.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" -#endif USE_DML +#endif USE_DML namespace winmla = Windows::AI::MachineLearning::Adapter; class WinmlAdapterLoggingWrapper : public LoggingWrapper { diff --git a/winml/adapter/winml_adapter_session.cpp b/winml/adapter/winml_adapter_session.cpp index 329d713752f70..b95b2e6c8d762 100644 --- a/winml/adapter/winml_adapter_session.cpp +++ b/winml/adapter/winml_adapter_session.cpp @@ -42,7 +42,7 @@ ORT_API_STATUS_IMPL(winmla::CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ co std::unique_ptr inference_session; try { // Create the inference session - inference_session = std::make_unique(options->value, env->GetLoggingManager()); + inference_session = std::make_unique(options->value, env->GetEnvironment()); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -171,7 +171,7 @@ GetLotusCustomRegistries(IMLOperatorRegistry* registry) { // Get the ORT registry return abi_custom_registry->GetRegistries(); -#endif // USE_DML +#endif // USE_DML } return {}; } @@ -195,7 +195,7 @@ ORT_API_STATUS_IMPL(winmla::CreateCustomRegistry, _Out_ IMLOperatorRegistry** re #ifdef USE_DML auto impl = wil::MakeOrThrow(); *registry = impl.Detach(); -#endif // USE_DML +#endif // USE_DML return nullptr; API_IMPL_END } From 8783fff74b0722257443d146a133ab26efd93834 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Sat, 14 Mar 2020 01:45:05 -0700 Subject: [PATCH 3/9] Add tests, fix build issues. --- cmake/onnxruntime_unittests.cmake | 16 +- .../ACL-ExecutionProvider.md | 10 +- .../DNNL-ExecutionProvider.md | 10 +- .../NNAPI-ExecutionProvider.md | 10 +- .../TensorRT-ExecutionProvider.md | 10 +- .../nGraph-ExecutionProvider.md | 10 +- .../onnxruntime/core/session/environment.h | 12 +- onnxruntime/core/common/logging/logging.cc | 6 +- onnxruntime/core/session/environment.cc | 8 +- onnxruntime/core/session/inference_session.h | 50 ++-- .../python/onnxruntime_pybind_state.cc | 6 +- .../test/framework/cuda/fence_cuda_test.cc | 4 +- .../test/global_thread_pools/test_fixture.h | 17 ++ .../global_thread_pools/test_inference.cc | 223 ++++++++++++++++++ .../test/global_thread_pools/test_main.cc | 61 +++++ .../ngraph/ngraph_execution_provider_test.cc | 18 +- .../test/providers/nnapi/nnapi_basic_test.cc | 3 +- .../providers/tensorrt/tensorrt_basic_test.cc | 2 +- onnxruntime/test/tvm/tvm_basic_test.cc | 2 +- tools/ci_build/build.py | 1 + 20 files changed, 419 insertions(+), 60 deletions(-) create mode 100644 onnxruntime/test/global_thread_pools/test_fixture.h create mode 100644 onnxruntime/test/global_thread_pools/test_inference.cc create mode 100644 onnxruntime/test/global_thread_pools/test_main.cc diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index c5afa1690cc7f..0a02f578d7fe2 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -164,7 +164,7 @@ if (onnxruntime_USE_NNAPI) endif() set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib") - +set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/global_thread_pools") set (onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h @@ -178,6 +178,11 @@ if(onnxruntime_RUN_ONNX_TESTS) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_io_types.cc) endif() +set (onnxruntime_global_thread_pools_test_SRC + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h + ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc + ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc) + # tests from lowest level library up. # the order of libraries should be maintained, with higher libraries being added first in the list @@ -355,6 +360,7 @@ set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest" set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src}) if(NOT TARGET onnxruntime) list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) + list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC}) endif() set(all_dependencies ${onnxruntime_test_providers_dependencies} ) @@ -671,6 +677,14 @@ if (onnxruntime_BUILD_SHARED_LIB) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) + + # test inference using global threadpools + AddTest(DYN + TARGET onnxruntime_global_thread_pools_test + SOURCES ${onnxruntime_global_thread_pools_test_SRC} + LIBS ${onnxruntime_shared_lib_test_LIBS} + DEPENDS ${all_dependencies} + ) endif() #some ETW tools diff --git a/docs/execution_providers/ACL-ExecutionProvider.md b/docs/execution_providers/ACL-ExecutionProvider.md index dbabac8e33470..390731591283b 100644 --- a/docs/execution_providers/ACL-ExecutionProvider.md +++ b/docs/execution_providers/ACL-ExecutionProvider.md @@ -9,7 +9,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#ARM-Compute-L #### C/C++ To use ACL as execution provider for inferencing, please register it as below. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so, env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::ACLExecutionProvider>()); status = session_object.Load(model_file_name); ``` diff --git a/docs/execution_providers/DNNL-ExecutionProvider.md b/docs/execution_providers/DNNL-ExecutionProvider.md index 8e3f8616890c7..fb40679635585 100644 --- a/docs/execution_providers/DNNL-ExecutionProvider.md +++ b/docs/execution_providers/DNNL-ExecutionProvider.md @@ -21,7 +21,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#dnnl-and-mklm ### C/C++ The DNNLExecutionProvider execution provider needs to be registered with ONNX Runtime to enable in the inference session. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so,env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime:: DNNLExecutionProvider >()); status = session_object.Load(model_file_name); ``` diff --git a/docs/execution_providers/NNAPI-ExecutionProvider.md b/docs/execution_providers/NNAPI-ExecutionProvider.md index ed2fdf8c19af2..e6dd8452c8bf6 100644 --- a/docs/execution_providers/NNAPI-ExecutionProvider.md +++ b/docs/execution_providers/NNAPI-ExecutionProvider.md @@ -14,7 +14,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#Android-NNAPI To use NNAPI EP for inferencing, please register it as below. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so,env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::NnapiExecutionProvider>()); status = session_object.Load(model_file_name); ``` diff --git a/docs/execution_providers/TensorRT-ExecutionProvider.md b/docs/execution_providers/TensorRT-ExecutionProvider.md index 24cb76e1143d5..2fc30ee559e5a 100644 --- a/docs/execution_providers/TensorRT-ExecutionProvider.md +++ b/docs/execution_providers/TensorRT-ExecutionProvider.md @@ -13,7 +13,15 @@ The TensorRT execution provider for ONNX Runtime is built and tested with Tensor ### C/C++ The TensorRT execution provider needs to be registered with ONNX Runtime to enable in the inference session. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so,env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::TensorrtExecutionProvider>()); status = session_object.Load(model_file_name); ``` diff --git a/docs/execution_providers/nGraph-ExecutionProvider.md b/docs/execution_providers/nGraph-ExecutionProvider.md index 3c53e5d461ae5..11d403786b0df 100644 --- a/docs/execution_providers/nGraph-ExecutionProvider.md +++ b/docs/execution_providers/nGraph-ExecutionProvider.md @@ -23,7 +23,15 @@ While the nGraph Compiler stack supports various operating systems and backends ### C/C++ To use nGraph as execution provider for inferencing, please register it as below. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so,env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::NGRAPHExecutionProvider>()); status = session_object.Load(model_file_name); ``` diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 12d4e931188f2..106d645c6dd62 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -20,11 +20,19 @@ class Environment { public: /** Create and initialize the runtime environment. + @param logging manager instance that will enable per session logger output using + session_options.session_logid as the logger id in messages. + If nullptr, the default LoggingManager MUST have been created previously as it will be used + for logging. This will use the default logger id in messages. + See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param tp_options optional set of parameters controlling the number of intra and inter op threads for the global + threadpools. + @param create_global_thread_pools determine if this function will create the global threadpools or not. */ static Status Create(std::unique_ptr logging_manager, std::unique_ptr& environment, const ThreadingOptions* tp_options = nullptr, - bool create_thread_pool = false); + bool create_global_thread_pools = false); logging::LoggingManager* GetLoggingManager() const { return logging_manager_.get(); @@ -48,7 +56,7 @@ class Environment { Environment() = default; Status Initialize(std::unique_ptr logging_manager, const ThreadingOptions* tp_options = nullptr, - bool create_global_thread_pool = false); + bool create_global_thread_pools = false); std::unique_ptr logging_manager_; std::unique_ptr intra_op_thread_pool_; diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc index 18f926076c412..82243e5cf79c2 100644 --- a/onnxruntime/core/common/logging/logging.cc +++ b/onnxruntime/core/common/logging/logging.cc @@ -22,7 +22,7 @@ #include "core/platform/ort_mutex.h" #if __FreeBSD__ -#include // Use thr_self() syscall under FreeBSD to get thread id +#include // Use thr_self() syscall under FreeBSD to get thread id #endif namespace onnxruntime { @@ -189,8 +189,8 @@ std::exception LoggingManager::LogFatalAndCreateException(const char* category, // create Capture in separate scope so it gets destructed (leading to log output) before we throw. { ::onnxruntime::logging::Capture c{::onnxruntime::logging::LoggingManager::DefaultLogger(), - ::onnxruntime::logging::Severity::kFATAL, category, - ::onnxruntime::logging::DataType::SYSTEM, location}; + ::onnxruntime::logging::Severity::kFATAL, category, + ::onnxruntime::logging::DataType::SYSTEM, location}; va_list args; va_start(args, format_str); diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index e7e61bd46a6ee..191fb2c147db7 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -33,21 +33,21 @@ std::once_flag schemaRegistrationOnceFlag; Status Environment::Create(std::unique_ptr logging_manager, std::unique_ptr& environment, const ThreadingOptions* tp_options, - bool create_global_thread_pool) { + bool create_global_thread_pools) { environment = std::unique_ptr(new Environment()); - auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pool); + auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pools); return status; } Status Environment::Initialize(std::unique_ptr logging_manager, const ThreadingOptions* tp_options, - bool create_global_thread_pool) { + bool create_global_thread_pools) { auto status = Status::OK(); logging_manager_ = std::move(logging_manager); // create thread pools - if (create_global_thread_pool) { + if (create_global_thread_pools) { intra_op_thread_pool_ = concurrency::CreateThreadPool("env_global_intra_op_thread_pool", tp_options->intra_op_num_threads); inter_op_thread_pool_ = concurrency::CreateThreadPool("env_global_inter_op_thread_pool", diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index bbff806dc2272..0af3d6fc6e966 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -78,7 +78,15 @@ struct ModelMetadata { * CPUExecutionProviderInfo epi; * ProviderOption po{"CPUExecutionProvider", epi}; * SessionOptions so(vector{po}); - * InferenceSession session_object{so}; + * string log_id = "Foo"; + * auto logging_manager = std::make_unique + (std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) + * Environment::Create(std::move(logging_manager), env) + * InferenceSession session_object{so,env}; * common::Status status = session_object.Load(MODEL_URI); * common::Status status = session_object.Initialize(); * @@ -98,12 +106,7 @@ class InferenceSession { /** Create a new InferenceSession @param session_options Session options. - @param logging_manager - Optional logging manager instance that will enable per session logger output using - session_options.session_logid as the logger id in messages. - If nullptr, the default LoggingManager MUST have been created previously as it will be used - for logging. This will use the default logger id in messages. - See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param session_env This represents the context for the session and contains the logger and the global threadpools. */ explicit InferenceSession(const SessionOptions& session_options, const Environment& session_env); @@ -112,12 +115,7 @@ class InferenceSession { Create a new InferenceSession @param session_options Session options. @param model_uri absolute path of the model file. - @param logging_manager - Optional logging manager instance that will enable per session logger output using - session_options.session_logid as the logger id in messages. - If nullptr, the default LoggingManager MUST have been created previously as it will be used - for logging. This will use the default logger id in messages. - See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param session_env This represents the context for the session and contains the logger and the global threadpools. This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, @@ -133,12 +131,7 @@ class InferenceSession { Create a new InferenceSession @param session_options Session options. @param istream object of the model. - @param logging_manager - Optional logging manager instance that will enable per session logger output using - session_options.session_logid as the logger id in messages. - If nullptr, the default LoggingManager MUST have been created previously as it will be used - for logging. This will use the default logger id in messages. - See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param session_env This represents the context for the session and contains the logger and the global threadpools. This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, @@ -150,12 +143,7 @@ class InferenceSession { @param session_options Session options. @param model_data Model data buffer. @param model_data_len Model data buffer size. - @param logging_manager - Optional logging manager instance that will enable per session logger output using - session_options.session_logid as the logger id in messages. - If nullptr, the default LoggingManager MUST have been created previously as it will be used - for logging. This will use the default logger id in messages. - See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param session_env This represents the context for the session and contains the logger and the global threadpools. This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, @@ -467,12 +455,20 @@ class InferenceSession { std::unique_ptr session_state_; private: - // Threadpool for this session + // Threadpools per session. These are initialized and used for the entire duration of the session + // when use_per_session_threads is true. std::unique_ptr thread_pool_; std::unique_ptr inter_op_thread_pool_; + + // Global threadpools. These are intialized and used when use_per_session_threads is false *and* + // the environment is created with create_global_thread_pools = true. onnxruntime::concurrency::ThreadPool* intra_op_thread_pool_from_env_{}; onnxruntime::concurrency::ThreadPool* inter_op_thread_pool_from_env_{}; - bool use_per_session_threads_; // initialized from session options + + // initialized from session options + // Determines which threadpools will be intialized and used for the duration of this session. + // If true, use the per session ones, or else the global threadpools. + bool use_per_session_threads_; KernelRegistryManager kernel_registry_manager_; std::list> custom_schema_registries_; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ead3ba637a5a1..ff51894e85ba4 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -238,7 +238,7 @@ class SessionObjectInitializer { public: typedef const SessionOptions& Arg1; // typedef logging::LoggingManager* Arg2; - static std::string default_logger_id; + static const std::string default_logger_id; operator Arg1() { return GetDefaultCPUSessionOptions(); } @@ -255,7 +255,7 @@ class SessionObjectInitializer { } }; -std::string SessionObjectInitializer::default_logger_id = "Default"; +const std::string SessionObjectInitializer::default_logger_id = "Default"; inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) { auto p = f.CreateProvider(); @@ -874,7 +874,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { import_array1(); })(); - OrtPybindThrowIfError(Environment::Create(std::make_unique( + OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique( std::unique_ptr{new CErrSink{}}, Severity::kWARNING, false, LoggingManager::InstanceType::Default, &SessionObjectInitializer::default_logger_id), diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index 288cf1cb690d2..2367393769057 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -171,7 +171,7 @@ TEST(CUDAFenceTests, TileWithInitializer) { DataTypeImpl::GetType()->GetDeleteFunc()); SessionOptions so; - FenceCudaTestInferenceSession session(so); + FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); @@ -236,7 +236,7 @@ TEST(CUDAFenceTests, TileWithComputedInput) { DataTypeImpl::GetType()->GetDeleteFunc()); SessionOptions so; - FenceCudaTestInferenceSession session(so); + FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); diff --git a/onnxruntime/test/global_thread_pools/test_fixture.h b/onnxruntime/test/global_thread_pools/test_fixture.h new file mode 100644 index 0000000000000..c9f533c433f8c --- /dev/null +++ b/onnxruntime/test/global_thread_pools/test_fixture.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/onnxruntime_cxx_api.h" +#include + +#ifdef _WIN32 +typedef const wchar_t* PATH_TYPE; +#define TSTR(X) L##X +#else +#define TSTR(X) (X) +typedef const char* PATH_TYPE; +#endif + +//empty +static inline void ORT_API_CALL MyLoggingFunction(void*, OrtLoggingLevel, const char*, const char*, const char*, const char*) { +} \ No newline at end of file diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc new file mode 100644 index 0000000000000..0f52c5978c016 --- /dev/null +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/constants.h" +#include "providers.h" +#include +#include +#include +#include +#include +#include +#include +#include "test_allocator.h" +#include "../shared_lib/onnx_protobuf.h" +#include "../shared_lib/test_fixture.h" +#include + +struct Input { + const char* name = nullptr; + std::vector dims; + std::vector values; +}; + +extern std::unique_ptr ort_env; +static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/squeezenet/model.onnx"); +class CApiTestGlobalThreadPoolsWithProvider : public testing::Test, public ::testing::WithParamInterface { +}; + +template +static void RunSession(OrtAllocator* allocator, Ort::Session& session_object, + const std::vector& inputs, + const char* output_name, + const std::vector& dims_y, + const std::vector& values_y, + Ort::Value* output_tensor) { + std::vector ort_inputs; + std::vector input_names; + for (size_t i = 0; i < inputs.size(); i++) { + input_names.emplace_back(inputs[i].name); + ort_inputs.emplace_back(Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + } + + std::vector ort_outputs; + if (output_tensor) + session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, output_tensor, 1); + else { + ort_outputs = session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1); + ASSERT_EQ(ort_outputs.size(), 1u); + output_tensor = &ort_outputs[0]; + } + + auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), dims_y); + //size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(values_y.size(), 5); + + OutT* f = output_tensor->GetTensorMutableData(); + for (size_t i = 0; i != 5; ++i) { + ASSERT_TRUE(abs(values_y[i] - f[i]) < 1e-6); + } +} + +template +static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) { + Ort::SessionOptions session_options; + session_options.DisablePerSessionThreads(); + + if (provider_type == 1) { +#ifdef USE_CUDA + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + std::cout << "Running simple inference with cuda provider" << std::endl; +#else + return Ort::Session(nullptr); +#endif + } else if (provider_type == 2) { +#ifdef USE_DNNL + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Dnnl(session_options, 1)); + std::cout << "Running simple inference with dnnl provider" << std::endl; +#else + return Ort::Session(nullptr); +#endif + } else if (provider_type == 3) { +#ifdef USE_NUPHAR + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, /*allow_unaligned_buffers*/ 1, "")); + std::cout << "Running simple inference with nuphar provider" << std::endl; +#else + return Ort::Session(nullptr); +#endif + } else { + std::cout << "Running simple inference with default provider" << std::endl; + } + + // if session creation passes, model loads fine + Ort::Session foo(env, model_uri, session_options); + return std::move(foo); +} + +template +static void TestInference(Ort::Session& session, + const std::vector& inputs, + const char* output_name, + const std::vector& expected_dims_y, + const std::vector& expected_values_y) { + Ort::SessionOptions session_options; + session_options.DisablePerSessionThreads(); + auto default_allocator = onnxruntime::make_unique(); + Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size()); + + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + &value_y); +} + +static void GetInputsAndExpectedOutputs(std::vector& inputs, + std::vector& expected_dims_y, + std::vector& expected_values_y, + std::string& output_name) { + inputs.resize(1); + Input& input = inputs.back(); + input.name = "data_0"; + input.dims = {1, 3, 224, 224}; + size_t input_tensor_size = 224 * 224 * 3; + input.values.resize(input_tensor_size); + auto& input_tensor_values = input.values; + for (unsigned int i = 0; i < input_tensor_size; i++) + input_tensor_values[i] = (float)i / (input_tensor_size + 1); + + // prepare expected inputs and outputs + expected_dims_y = {1, 1000, 1, 1}; + // For this test I'm checking for the first 5 values only since the global thread pool change + // doesn't affect the core op functionality + expected_values_y = {0.000045f, 0.003846f, 0.000125f, 0.001180f, 0.001317f}; + + output_name = "softmaxout_1"; +} + +// All tests below use global threadpools + +// Test 1 +// run inference on a model using just 1 session +TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple) { + // prepare inputs/outputs + std::vector inputs; + std::vector expected_dims_y; + std::vector expected_values_y; + std::string output_name; + GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name); + + // create session + Ort::Session session = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + + // run session + if (session) { + TestInference(session, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + } +} + +// Test 2 +// run inference on the same model using 2 sessions +// destruct the 2 sessions only at the end +TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple2) { + // prepare inputs/outputs + std::vector inputs; + std::vector expected_dims_y; + std::vector expected_values_y; + std::string output_name; + GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name); + + // create sessions + Ort::Session session1 = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + Ort::Session session2 = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + + // run session + if (session1 && session2) { + TestInference(session1, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + TestInference(session2, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + } +} + +// Test 3 +// run inference on the same model using 2 sessions +// one after another destructing first session first +// followed by second session +TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple3) { + // prepare inputs/outputs + std::vector inputs; + std::vector expected_dims_y; + std::vector expected_values_y; + std::string output_name; + GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name); + + // first session + { + // create session + Ort::Session session = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + + // run session + if (session) { + TestInference(session, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + } + } + + // second session + { + // create session + Ort::Session session = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + + // run session + if (session) { + TestInference(session, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + } + } +} + +INSTANTIATE_TEST_SUITE_P(CApiTestGlobalThreadPoolsWithProviders, + CApiTestGlobalThreadPoolsWithProvider, + ::testing::Values(0, 1, 2, 3, 4)); \ No newline at end of file diff --git a/onnxruntime/test/global_thread_pools/test_main.cc b/onnxruntime/test/global_thread_pools/test_main.cc new file mode 100644 index 0000000000000..9b08aa62b800d --- /dev/null +++ b/onnxruntime/test/global_thread_pools/test_main.cc @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef USE_ONNXRUNTIME_DLL +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-qualifiers" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#else +#pragma warning(push) +#pragma warning(disable : 4018) /*'expression' : signed/unsigned mismatch */ +#pragma warning(disable : 4065) /*switch statement contains 'default' but no 'case' labels*/ +#pragma warning(disable : 4100) +#pragma warning(disable : 4146) /*unary minus operator applied to unsigned type, result still unsigned*/ +#pragma warning(disable : 4127) +#pragma warning(disable : 4244) /*'conversion' conversion from 'type1' to 'type2', possible loss of data*/ +#pragma warning(disable : 4251) /*'identifier' : class 'type' needs to have dll-interface to be used by clients of class 'type2'*/ +#pragma warning(disable : 4267) /*'var' : conversion from 'size_t' to 'type', possible loss of data*/ +#pragma warning(disable : 4305) /*'identifier' : truncation from 'type1' to 'type2'*/ +#pragma warning(disable : 4307) /*'operator' : integral constant overflow*/ +#pragma warning(disable : 4309) /*'conversion' : truncation of constant value*/ +#pragma warning(disable : 4334) /*'operator' : result of 32-bit shift implicitly converted to 64 bits (was 64-bit shift intended?)*/ +#pragma warning(disable : 4355) /*'this' : used in base member initializer list*/ +#pragma warning(disable : 4506) /*no definition for inline function 'function'*/ +#pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/ +#pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/ +#endif +#include +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#else +#pragma warning(pop) +#endif +#endif + +#include "core/session/onnxruntime_cxx_api.h" +#include "gtest/gtest.h" +#include "test/test_environment.h" + +std::unique_ptr ort_env; + +int main(int argc, char** argv) { + int status = 0; + try { + ::testing::InitGoogleTest(&argc, argv); + ThreadingOptions tp_options{0, 0}; + ort_env.reset(new Ort::Env(tp_options, ORT_LOGGING_LEVEL_WARNING, "Default")); // this is the only change from test/providers/test_main.cc + status = RUN_ALL_TESTS(); + } catch (const std::exception& ex) { + std::cerr << ex.what(); + status = -1; + } + //TODO: Fix the C API issue + ort_env.reset(); //If we don't do this, it will crash + +#ifndef USE_ONNXRUNTIME_DLL + //make memory leak checker happy + ::google::protobuf::ShutdownProtobufLibrary(); +#endif + return status; +} diff --git a/onnxruntime/test/providers/ngraph/ngraph_execution_provider_test.cc b/onnxruntime/test/providers/ngraph/ngraph_execution_provider_test.cc index 6951ff40f0f45..d33bba95014b0 100644 --- a/onnxruntime/test/providers/ngraph/ngraph_execution_provider_test.cc +++ b/onnxruntime/test/providers/ngraph/ngraph_execution_provider_test.cc @@ -70,9 +70,9 @@ void add_feeds(NameMLValMap& feeds, std::string name, std::vector dims, } //TODO:(nivas) Refractor to use existing code -void RunTest(const std::string& model_path, const NameMLValMap& feeds, const std::vector& output_names, const std::vector>& expected_shapes, const std::vector>& expected_values) { +void RunTest(const std::string& model_path, const NameMLValMap& feeds, const std::vector& output_names, const std::vector>& expected_shapes, const std::vector>& expected_values, const Environment& env) { SessionOptions so; - InferenceSession session_object(so, &DefaultLoggingManager()); + InferenceSession session_object(so, env); EXPECT_TRUE(session_object.RegisterExecutionProvider(DefaultNGraphExecutionProvider()).IsOK()); @@ -148,7 +148,7 @@ TEST(NGraphExecutionProviderTest, Basic_Test) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Basic_Test.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Basic_Test.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -177,7 +177,7 @@ TEST(NGraphExecutionProviderTest, Graph_with_UnSupportedOp) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Graph_with_UnSupportedOp.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Graph_with_UnSupportedOp.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -210,7 +210,7 @@ TEST(NGraphExecutionProviderTest, Two_Subgraphs) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Two_Subgraphs.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Two_Subgraphs.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -246,7 +246,7 @@ TEST(NGraphExecutionProviderTest, ClusterOut_isAlso_GraphOut) { {4}, {4}}; - RunTest("testdata/ngraph/ClusterOut_isAlso_GraphOut.onnx", feeds, {"Y", "Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/ClusterOut_isAlso_GraphOut.onnx", feeds, {"Y", "Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -282,7 +282,7 @@ TEST(NGraphExecutionProviderTest, InOut_isAlso_GraphOut) { {4}, {4}}; - RunTest("testdata/ngraph/InOut_isAlso_GraphOut.onnx", feeds, {"Y", "Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/InOut_isAlso_GraphOut.onnx", feeds, {"Y", "Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -314,7 +314,7 @@ TEST(NGraphExecutionProviderTest, Op_with_Optional_or_Unused_Outputs) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Op_with_Optional_or_Unused_Outputs.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Op_with_Optional_or_Unused_Outputs.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -350,7 +350,7 @@ TEST(NGraphExecutionProviderTest, Independent_SubGraphs) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Independent_SubGraphs.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Independent_SubGraphs.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } } // namespace test diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index dc497bad25418..9db96b42a3bca 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -85,7 +85,7 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; status = session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>()); ASSERT_TRUE(status.IsOK()); status = session_object.Load(model_file_name); @@ -100,4 +100,3 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { } } // namespace test } // namespace onnxruntime - diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 3bcf1263d2a59..0497d99a3d9da 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -315,7 +315,7 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; TensorrtExecutionProviderInfo epi; epi.device_id = 0; diff --git a/onnxruntime/test/tvm/tvm_basic_test.cc b/onnxruntime/test/tvm/tvm_basic_test.cc index 70399f2f03679..9448d92663e93 100644 --- a/onnxruntime/test/tvm/tvm_basic_test.cc +++ b/onnxruntime/test/tvm/tvm_basic_test.cc @@ -310,7 +310,7 @@ TEST(TVMTest, CodeGen_Demo_for_Fuse_Mul) { so.session_logid = "InferenceSessionTests.NoTimeout"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; CPUExecutionProviderInfo info; auto tvm_xp = onnxruntime::make_unique(info); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(tvm_xp)).IsOK()); diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 7751425544386..73c8d71bc1f14 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -615,6 +615,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs, enab executables = ['onnxruntime_test_all.exe'] if args.build_shared_lib: executables.append('onnxruntime_shared_lib_test.exe') + executables.append('onnxruntime_global_thread_pools_test.exe') run_subprocess(['vstest.console.exe', '--parallel', '--TestAdapterPath:..\\googletestadapter.0.17.1\\build\\_common', '/Logger:trx','/Enablecodecoverage','/Platform:x64',"/Settings:%s" % os.path.join(source_dir, 'cmake\\codeconv.runsettings')] + executables, cwd=cwd2, dll_path=dll_path) else: From 98821b48360dc300d91e208d0316d1f3ff51ccd0 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Sat, 14 Mar 2020 20:06:39 -0700 Subject: [PATCH 4/9] Added some documentation --- docs/C_API.md | 9 +++++++++ include/onnxruntime/core/session/onnxruntime_c_api.h | 9 +++++---- onnxruntime/test/providers/provider_test_utils.cc | 1 - 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/C_API.md b/docs/C_API.md index aa393735826a5..07e32990df25b 100644 --- a/docs/C_API.md +++ b/docs/C_API.md @@ -12,6 +12,15 @@ * Setting graph optimization level for each session. * Dynamically loading custom ops. [Instructions](/docs/AddingCustomOp.md) * Ability to load a model from a byte array. See ```OrtCreateSessionFromArray``` in [onnxruntime_c_api.h](/include/onnxruntime/core/session/onnxruntime_c_api.h). +* **Global/shared threadpools:** By default each session creates its own set of threadpools. In situations where multiple +sessions need to be created (to infer different models) in the same process, you end up with several threadpools created +by each session. In order to address this inefficiency we introduce a new feature called global/shared threadpools. +The basic idea here is to share a set of global threadpools across multiple sessions. Typical usage of this feature +is as follows + * Populate ```ThreadingOptions```. Use the value of 0 for ORT to pick the defaults. + * Create env using ```CreateEnvWithGlobalThreadPools()``` + * Create session and call ```DisablePerSessionThreads()``` on the session + * Call ```Run()``` as usual ## Usage Overview diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 61860a0c68c57..0003c9c2fd252 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -758,18 +758,19 @@ struct OrtApi { /* * Creates an environment with global threadpools that will be shared across sessions. - * Use this in conjunction with DisablePerSessionThreads API or else by default the session will use + * Use this in conjunction with DisablePerSessionThreads API or else the session will use * its own thread pools. */ OrtStatus*(ORT_API_CALL* CreateEnvWithGlobalThreadPools)(OrtLoggingLevel default_logging_level, _In_ const char* logid, _In_ ThreadingOptions t_options, _Outptr_ OrtEnv** out) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - // TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function? + /* TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function? */ /* -* Calling this API will make the session use the global threadpools shared across sessions. -*/ + * Calling this API will make the session use the global threadpools shared across sessions. + * This API should be used in conjunction with CreateEnvWithGlobalThreadPools API. + */ OrtStatus*(ORT_API_CALL* DisablePerSessionThreads)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; }; diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index dfbeb1c492f1e..4faa3ea922785 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -800,7 +800,6 @@ void OpTester::Run( if (!valid) continue; - InferenceSession session_object{so}; for (auto& custom_session_registry : custom_session_registries_) session_object.RegisterCustomRegistry(custom_session_registry); From d8b36eecbffb4c8ed75d111f9749b14b3bf491ca Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Mon, 16 Mar 2020 00:01:26 -0700 Subject: [PATCH 5/9] Fix centos issue when threadpools become nullptr due to 1 core. --- .../onnxruntime/core/session/environment.h | 5 ++++ onnxruntime/core/session/environment.cc | 1 + onnxruntime/core/session/inference_session.cc | 25 ++++++++----------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 106d645c6dd62..34f3b81e05335 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -50,6 +50,10 @@ class Environment { return inter_op_thread_pool_.get(); } + bool EnvCreatedWithGlobalThreadPools() const { + return create_global_thread_pools_; + } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); @@ -61,5 +65,6 @@ class Environment { std::unique_ptr logging_manager_; std::unique_ptr intra_op_thread_pool_; std::unique_ptr inter_op_thread_pool_; + bool create_global_thread_pools_{false}; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 191fb2c147db7..114b8985d8d10 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -48,6 +48,7 @@ Status Environment::Initialize(std::unique_ptr logging_ // create thread pools if (create_global_thread_pools) { + create_global_thread_pools_ = true; intra_op_thread_pool_ = concurrency::CreateThreadPool("env_global_intra_op_thread_pool", tp_options->intra_op_num_threads); inter_op_thread_pool_ = concurrency::CreateThreadPool("env_global_inter_op_thread_pool", diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index d687ee4ac8b2a..79aa25d842776 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -166,11 +166,17 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked + // after the invocation of FinalizeSessionOptions. + logging_manager_ = session_env.GetLoggingManager(); + InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. + // Update the number of steps for the graph transformer manager using the "finalized" session options graph_transformation_mgr_.SetSteps(session_options_.max_num_graph_transformation_steps); use_per_session_threads_ = session_options.use_per_session_threads; if (use_per_session_threads_) { + LOGS(*session_logger_, INFO) << "Creating and using per session threadpools since use_per_session_threads_ is true"; thread_pool_ = concurrency::CreateThreadPool("intra_op_thread_pool", session_options_.intra_op_num_threads); @@ -181,15 +187,10 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } else { intra_op_thread_pool_from_env_ = session_env.GetIntraOpThreadPool(); inter_op_thread_pool_from_env_ = session_env.GetInterOpThreadPool(); - - ORT_ENFORCE(intra_op_thread_pool_from_env_, - "Since use_per_session_threads is false, this must be non-nullptr" - " You probably didn't create the env using the CreateEnvWithGlobalThreadPools API"); - ORT_ENFORCE(inter_op_thread_pool_from_env_, - "Since use_per_session_threads is false, this must be non-nullptr" - " You probably didn't create the env using the CreateEnvWithGlobalThreadPools API"); - ORT_ENFORCE(thread_pool_ == nullptr, "Since use_per_session_threads is false per session threadpools should be nullptr"); - ORT_ENFORCE(inter_op_thread_pool_ == nullptr, "Since use_per_session_threads is false per session threadpools should be nullptr"); + ORT_ENFORCE(session_env.EnvCreatedWithGlobalThreadPools() == true, + "When the session is not configured to use per session" + "threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API."); + LOGS(*session_logger_, INFO) << "Using global/env threadpools since use_per_session_threads_ is false"; } session_state_ = onnxruntime::make_unique(execution_providers_, @@ -198,9 +199,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, use_per_session_threads_ ? thread_pool_.get() : intra_op_thread_pool_from_env_, use_per_session_threads_ ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_); - logging_manager_ = session_env.GetLoggingManager(); - InitLogger(logging_manager_); - + session_state_->SetLogger(*session_logger_); session_state_->SetDataTransferMgr(&data_transfer_mgr_); session_profiler_.Initialize(session_logger_); session_state_->SetProfiler(session_profiler_); @@ -1422,8 +1421,6 @@ void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { } else { session_logger_ = &logging::LoggingManager::DefaultLogger(); } - - session_state_->SetLogger(*session_logger_); } // Registers all the predefined transformers with transformer manager From c0e64c7bfde8856f99dee0c5fc48b9c643086d3c Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Mon, 16 Mar 2020 14:02:50 -0700 Subject: [PATCH 6/9] Fix mac and x86 build issues --- onnxruntime/core/session/inference_session.cc | 4 ++-- onnxruntime/test/global_thread_pools/test_inference.cc | 7 +++---- onnxruntime/test/global_thread_pools/test_main.cc | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 79aa25d842776..c465533c5b58e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -185,12 +185,12 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, session_options_.inter_op_num_threads) : nullptr; } else { + LOGS(*session_logger_, INFO) << "Using global/env threadpools since use_per_session_threads_ is false"; intra_op_thread_pool_from_env_ = session_env.GetIntraOpThreadPool(); inter_op_thread_pool_from_env_ = session_env.GetInterOpThreadPool(); ORT_ENFORCE(session_env.EnvCreatedWithGlobalThreadPools() == true, "When the session is not configured to use per session" - "threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API."); - LOGS(*session_logger_, INFO) << "Using global/env threadpools since use_per_session_threads_ is false"; + " threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API."); } session_state_ = onnxruntime::make_unique(execution_providers_, diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 0f52c5978c016..e1c2072e2f4f7 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -54,10 +54,10 @@ static void RunSession(OrtAllocator* allocator, Ort::Session& session_object, auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); ASSERT_EQ(type_info.GetShape(), dims_y); //size_t total_len = type_info.GetElementCount(); - ASSERT_EQ(values_y.size(), 5); + ASSERT_EQ(values_y.size(), static_cast(5)); OutT* f = output_tensor->GetTensorMutableData(); - for (size_t i = 0; i != 5; ++i) { + for (size_t i = 0; i != static_cast(5); ++i) { ASSERT_TRUE(abs(values_y[i] - f[i]) < 1e-6); } } @@ -93,8 +93,7 @@ static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) } // if session creation passes, model loads fine - Ort::Session foo(env, model_uri, session_options); - return std::move(foo); + return Ort::Session(env, model_uri, session_options); } template diff --git a/onnxruntime/test/global_thread_pools/test_main.cc b/onnxruntime/test/global_thread_pools/test_main.cc index 9b08aa62b800d..6b25c716df1b1 100644 --- a/onnxruntime/test/global_thread_pools/test_main.cc +++ b/onnxruntime/test/global_thread_pools/test_main.cc @@ -44,7 +44,7 @@ int main(int argc, char** argv) { try { ::testing::InitGoogleTest(&argc, argv); ThreadingOptions tp_options{0, 0}; - ort_env.reset(new Ort::Env(tp_options, ORT_LOGGING_LEVEL_WARNING, "Default")); // this is the only change from test/providers/test_main.cc + ort_env.reset(new Ort::Env(tp_options, ORT_LOGGING_LEVEL_VERBOSE, "Default")); // this is the only change from test/providers/test_main.cc status = RUN_ALL_TESTS(); } catch (const std::exception& ex) { std::cerr << ex.what(); From 9ec36cf7a267d37cd567fc4617adcedb1082b368 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Tue, 17 Mar 2020 01:01:31 -0700 Subject: [PATCH 7/9] Address some PR comments --- docs/C_API.md | 2 +- onnxruntime/core/framework/session_options.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 4 +- .../python/onnxruntime_pybind_state.cc | 92 ++++++++++--------- .../test/global_thread_pools/README.md | 6 ++ .../test/global_thread_pools/test_fixture.h | 17 ---- .../global_thread_pools/test_inference.cc | 6 +- .../test/util/include/test/test_environment.h | 2 +- 8 files changed, 61 insertions(+), 70 deletions(-) create mode 100644 onnxruntime/test/global_thread_pools/README.md delete mode 100644 onnxruntime/test/global_thread_pools/test_fixture.h diff --git a/docs/C_API.md b/docs/C_API.md index 07e32990df25b..c8671cdf4d804 100644 --- a/docs/C_API.md +++ b/docs/C_API.md @@ -19,7 +19,7 @@ The basic idea here is to share a set of global threadpools across multiple sess is as follows * Populate ```ThreadingOptions```. Use the value of 0 for ORT to pick the defaults. * Create env using ```CreateEnvWithGlobalThreadPools()``` - * Create session and call ```DisablePerSessionThreads()``` on the session + * Create session and call ```DisablePerSessionThreads()``` on the session options object * Call ```Run()``` as usual ## Usage Overview diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 3d79a851c8cdc..dbfd12ca8b1ca 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -65,7 +65,7 @@ struct SessionOptions { // free dimensions with, keyed by dimension denotation. std::vector free_dimension_overrides; - // By default the session uses it's own set of threadpools, unless this is set to false. + // By default the session uses its own set of threadpools, unless this is set to false. // Use this in conjunction with the CreateEnvWithGlobalThreadPools API. bool use_per_session_threads = true; }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d80d41e1265f5..d419c2ad2cb84 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1492,7 +1492,6 @@ static constexpr OrtApi ort_api_1_to_3 = { &OrtApis::ReleaseCustomOpDomain, // End of Version 1 - DO NOT MODIFY ABOVE (see above text for more information) - // Version 2 - In development, feel free to add/remove/rearrange here &OrtApis::GetDenotationFromTypeInfo, &OrtApis::CastTypeInfoToMapTypeInfo, &OrtApis::CastTypeInfoToSequenceTypeInfo, @@ -1510,8 +1509,9 @@ static constexpr OrtApi ort_api_1_to_3 = { &OrtApis::ModelMetadataLookupCustomMetadataMap, &OrtApis::ModelMetadataGetVersion, &OrtApis::ReleaseModelMetadata, + // End of Version 2 - DO NOT MODIFY ABOVE (see above text for more information) - // Version 3 + // Version 3 - In development, feel free to add/remove/rearrange here &OrtApis::CreateEnvWithGlobalThreadPools, &OrtApis::DisablePerSessionThreads, }; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ff51894e85ba4..c50b762cf4974 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -659,55 +659,57 @@ including arg name, arg type (contains both type and shape).)pbdoc") return *(na.Type()); }, "node type") - .def("__str__", [](const onnxruntime::NodeArg& na) -> std::string { - std::ostringstream res; - res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; - auto shape = na.Shape(); - std::vector arr; - if (shape == nullptr || shape->dim_size() == 0) { - res << "[]"; - } else { - res << "["; - for (int i = 0; i < shape->dim_size(); ++i) { - if (utils::HasDimValue(shape->dim(i))) { - res << shape->dim(i).dim_value(); - } else if (utils::HasDimParam(shape->dim(i))) { - res << "'" << shape->dim(i).dim_param() << "'"; + .def( + "__str__", [](const onnxruntime::NodeArg& na) -> std::string { + std::ostringstream res; + res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; + auto shape = na.Shape(); + std::vector arr; + if (shape == nullptr || shape->dim_size() == 0) { + res << "[]"; } else { - res << "None"; + res << "["; + for (int i = 0; i < shape->dim_size(); ++i) { + if (utils::HasDimValue(shape->dim(i))) { + res << shape->dim(i).dim_value(); + } else if (utils::HasDimParam(shape->dim(i))) { + res << "'" << shape->dim(i).dim_param() << "'"; + } else { + res << "None"; + } + + if (i < shape->dim_size() - 1) { + res << ", "; + } + } + res << "]"; } + res << ")"; - if (i < shape->dim_size() - 1) { - res << ", "; + return std::string(res.str()); + }, + "converts the node into a readable string") + .def_property_readonly( + "shape", [](const onnxruntime::NodeArg& na) -> std::vector { + auto shape = na.Shape(); + std::vector arr; + if (shape == nullptr || shape->dim_size() == 0) { + return arr; } - } - res << "]"; - } - res << ")"; - - return std::string(res.str()); - }, - "converts the node into a readable string") - .def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector { - auto shape = na.Shape(); - std::vector arr; - if (shape == nullptr || shape->dim_size() == 0) { - return arr; - } - arr.resize(shape->dim_size()); - for (int i = 0; i < shape->dim_size(); ++i) { - if (utils::HasDimValue(shape->dim(i))) { - arr[i] = py::cast(shape->dim(i).dim_value()); - } else if (utils::HasDimParam(shape->dim(i))) { - arr[i] = py::cast(shape->dim(i).dim_param()); - } else { - arr[i] = py::none(); - } - } - return arr; - }, - "node shape (assuming the node holds a tensor)"); + arr.resize(shape->dim_size()); + for (int i = 0; i < shape->dim_size(); ++i) { + if (utils::HasDimValue(shape->dim(i))) { + arr[i] = py::cast(shape->dim(i).dim_value()); + } else if (utils::HasDimParam(shape->dim(i))) { + arr[i] = py::cast(shape->dim(i).dim_param()); + } else { + arr[i] = py::none(); + } + } + return arr; + }, + "node shape (assuming the node holds a tensor)"); py::class_(m, "SessionObjectInitializer"); py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") @@ -875,7 +877,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { })(); OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique( - std::unique_ptr{new CErrSink{}}, + std::unique_ptr{new CLogSink{}}, Severity::kWARNING, false, LoggingManager::InstanceType::Default, &SessionObjectInitializer::default_logger_id), env)); diff --git a/onnxruntime/test/global_thread_pools/README.md b/onnxruntime/test/global_thread_pools/README.md new file mode 100644 index 0000000000000..6677bc580318a --- /dev/null +++ b/onnxruntime/test/global_thread_pools/README.md @@ -0,0 +1,6 @@ +**Tests for global threadpools** + +These tests here test the usage of the global threadpools. The reason we need to create a separate exe here is because +we need to create a separate environment that enables the creation of global threadpools. The test environment +used in the other exes create the env without global threadpools and since this env is process wide (a singleton) we +cannot use it for this kind of test. \ No newline at end of file diff --git a/onnxruntime/test/global_thread_pools/test_fixture.h b/onnxruntime/test/global_thread_pools/test_fixture.h deleted file mode 100644 index c9f533c433f8c..0000000000000 --- a/onnxruntime/test/global_thread_pools/test_fixture.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/onnxruntime_cxx_api.h" -#include - -#ifdef _WIN32 -typedef const wchar_t* PATH_TYPE; -#define TSTR(X) L##X -#else -#define TSTR(X) (X) -typedef const char* PATH_TYPE; -#endif - -//empty -static inline void ORT_API_CALL MyLoggingFunction(void*, OrtLoggingLevel, const char*, const char*, const char*, const char*) { -} \ No newline at end of file diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index e1c2072e2f4f7..2dda85ff3902f 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -29,7 +29,7 @@ class CApiTestGlobalThreadPoolsWithProvider : public testing::Test, public ::tes }; template -static void RunSession(OrtAllocator* allocator, Ort::Session& session_object, +static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, const std::vector& inputs, const char* output_name, const std::vector& dims_y, @@ -39,7 +39,7 @@ static void RunSession(OrtAllocator* allocator, Ort::Session& session_object, std::vector input_names; for (size_t i = 0; i < inputs.size(); i++) { input_names.emplace_back(inputs[i].name); - ort_inputs.emplace_back(Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor(allocator.Info(&allocator), const_cast(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); } std::vector ort_outputs; @@ -107,7 +107,7 @@ static void TestInference(Ort::Session& session, auto default_allocator = onnxruntime::make_unique(); Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size()); - RunSession(default_allocator.get(), + RunSession(*default_allocator, session, inputs, output_name, diff --git a/onnxruntime/test/util/include/test/test_environment.h b/onnxruntime/test/util/include/test/test_environment.h index 1e0319fcf1104..549804595b93d 100644 --- a/onnxruntime/test/util/include/test/test_environment.h +++ b/onnxruntime/test/util/include/test/test_environment.h @@ -14,7 +14,7 @@ class Environment; namespace test { -const ::onnxruntime::Environment& GetEnvironment(); +const onnxruntime::Environment& GetEnvironment(); /** Static logging manager with a CLog based sink so logging macros that use the default logger will work From 754ab243770902a8b9fe8810f3e46d036f4092a2 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Tue, 17 Mar 2020 18:44:14 -0700 Subject: [PATCH 8/9] Disabled test for android, added few more tests and addressed more PR comments. --- cmake/onnxruntime_unittests.cmake | 6 +- onnxruntime/core/session/inference_session.cc | 4 +- onnxruntime/core/session/inference_session.h | 11 ++ .../test/framework/inference_session_test.cc | 171 ++++++++++++++++++ .../test/global_thread_pools/README.md | 8 +- .../global_thread_pools/test_inference.cc | 2 - 6 files changed, 193 insertions(+), 9 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0a02f578d7fe2..28aab36631bc0 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -360,7 +360,9 @@ set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest" set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src}) if(NOT TARGET onnxruntime) list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) - list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC}) + if (NOT onnxruntime_USE_NNAPI) + list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC}) + endif() endif() set(all_dependencies ${onnxruntime_test_providers_dependencies} ) @@ -679,12 +681,14 @@ if (onnxruntime_BUILD_SHARED_LIB) ) # test inference using global threadpools + if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android") AddTest(DYN TARGET onnxruntime_global_thread_pools_test SOURCES ${onnxruntime_global_thread_pools_test_SRC} LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) + endif() endif() #some ETW tools diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c465533c5b58e..b254d60d03c1d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -196,8 +196,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, session_state_ = onnxruntime::make_unique(execution_providers_, session_options_.enable_mem_pattern && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL, - use_per_session_threads_ ? thread_pool_.get() : intra_op_thread_pool_from_env_, - use_per_session_threads_ ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_); + GetIntraOpThreadPoolToUse(), + GetInterOpThreadPoolToUse()); session_state_->SetLogger(*session_logger_); session_state_->SetDataTransferMgr(&data_transfer_mgr_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 0af3d6fc6e966..f379d780658e6 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -454,6 +454,17 @@ class InferenceSession { // It has a dependency on execution_providers_. std::unique_ptr session_state_; + // Use these 2 threadpool methods to get access to the threadpools since they rely on + // specific flags in session options + // These methods assume that session options have been finalized before the call. + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const { + return session_options_.use_per_session_threads ? thread_pool_.get() : intra_op_thread_pool_from_env_; + } + + onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const { + return session_options_.use_per_session_threads ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_; + } + private: // Threadpools per session. These are initialized and used for the entire duration of the session // when use_per_session_threads is true. diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index a2170ae4abe35..45aade53bf729 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -12,6 +12,7 @@ #include #include "core/common/logging/logging.h" +#include "core/common/logging/sinks/clog_sink.h" #include "core/common/profiler.h" #include "core/framework/compute_capability.h" #include "core/framework/data_transfer_manager.h" @@ -1801,5 +1802,175 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { #endif } +// Global threadpool related tests +// We test for 4 combinations +class InferenceSessionTestGlobalThreadPools : public InferenceSession { + public: + InferenceSessionTestGlobalThreadPools(const SessionOptions& session_options, + const Environment& env) : InferenceSession(session_options, env) { + } + + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const { + return InferenceSession::GetIntraOpThreadPoolToUse(); + } + + onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const { + return InferenceSession::GetInterOpThreadPoolToUse(); + } + + const SessionState& GetSessionState() { + return *session_state_; + } +}; + +// Test 1: env created WITHOUT global tp / use per session tp (default case): in this case per session tps should be in use +TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed) { + SessionOptions so; + so.use_per_session_threads = true; + + so.session_logid = "CheckIfPerSessionThreadPoolsAreBeingUsed"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the per session threadpools + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + // ensure threadpools were set correctly in the session state + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_session_state); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_session_state); + + ASSERT_TRUE(intra_tp_from_env == nullptr); + ASSERT_TRUE(inter_tp_from_env == nullptr); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 2: env created with global tp / DONT use per session tp: in this case global tps should be in use +TEST(InferenceSessionTests, CheckIfGlobalThreadPoolsAreBeingUsed) { + SessionOptions so; + so.use_per_session_threads = false; + + so.session_logid = "CheckIfGlobalThreadPoolsAreBeingUsed"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + ThreadingOptions tp_options{0, 0}; + auto st = Environment::Create(std::move(logging_manager), env, &tp_options, true /*create_global_thread_pools*/); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the global threadpools in both session and session state + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_env); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_env); + ASSERT_TRUE(intra_tp_from_session_state == intra_tp_from_env); + ASSERT_TRUE(inter_tp_from_session_state == inter_tp_from_env); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 3: env created with global tp / use per session tp: in this case per session tps should be in use +TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed2) { + SessionOptions so; + so.use_per_session_threads = true; + + so.session_logid = "CheckIfPerSessionThreadPoolsAreBeingUsed2"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + ThreadingOptions tp_options{0, 0}; + auto st = Environment::Create(std::move(logging_manager), env, &tp_options, true /*create_global_thread_pools*/); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the per session threadpools + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + // ensure threadpools were set correctly in the session state + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_session_state); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_session_state); + + // ensure per session thread pools in use are different from the + // env threadpools + if (intra_tp_from_session && intra_tp_from_env) { // both tps could be null on 1 core machines + ASSERT_FALSE(intra_tp_from_session == intra_tp_from_env); + } + + if (inter_tp_from_session && inter_tp_from_env) { // both tps could be null on 1 core machines + ASSERT_FALSE(inter_tp_from_session == inter_tp_from_env); + } + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 4: env created WITHOUT global tp / DONT use per session tp --> this should throw an exception +TEST(InferenceSessionTests, InvalidSessionEnvCombination) { + SessionOptions so; + so.use_per_session_threads = false; + + so.session_logid = "InvalidSessionEnvCombination"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + ASSERT_TRUE(st.IsOK()); + + try { + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + } catch (const std::exception& e) { + std::string e_message(std::string(e.what())); + ASSERT_TRUE(e_message.find( + "When the session is not configured to use per session" + " threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API") != + std::string::npos); + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/global_thread_pools/README.md b/onnxruntime/test/global_thread_pools/README.md index 6677bc580318a..d5ada6edef9fb 100644 --- a/onnxruntime/test/global_thread_pools/README.md +++ b/onnxruntime/test/global_thread_pools/README.md @@ -1,6 +1,6 @@ **Tests for global threadpools** -These tests here test the usage of the global threadpools. The reason we need to create a separate exe here is because -we need to create a separate environment that enables the creation of global threadpools. The test environment -used in the other exes create the env without global threadpools and since this env is process wide (a singleton) we -cannot use it for this kind of test. \ No newline at end of file +These tests here test the usage of the global threadpools using the C API data flow. The reason we need to create a +separate exe here is because we need to create a separate environment that enables the creation of global threadpools. +The test environment used in the other exes create the env without global threadpools and since this env is process +wide (a singleton) we cannot use it for this kind of test. \ No newline at end of file diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 2dda85ff3902f..5294c770e8ccc 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -102,8 +102,6 @@ static void TestInference(Ort::Session& session, const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y) { - Ort::SessionOptions session_options; - session_options.DisablePerSessionThreads(); auto default_allocator = onnxruntime::make_unique(); Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size()); From 8d9d0aecccc7f327eb9f6db4f78160a3e2c49566 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Wed, 18 Mar 2020 14:17:38 -0700 Subject: [PATCH 9/9] const_cast --- onnxruntime/test/global_thread_pools/test_inference.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 5294c770e8ccc..daeabcb373017 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -30,7 +30,7 @@ class CApiTestGlobalThreadPoolsWithProvider : public testing::Test, public ::tes template static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, - const std::vector& inputs, + std::vector& inputs, const char* output_name, const std::vector& dims_y, const std::vector& values_y, @@ -39,7 +39,7 @@ static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, std::vector input_names; for (size_t i = 0; i < inputs.size(); i++) { input_names.emplace_back(inputs[i].name); - ort_inputs.emplace_back(Ort::Value::CreateTensor(allocator.Info(&allocator), const_cast(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor(allocator.Info(&allocator), inputs[i].values.data(), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); } std::vector ort_outputs; @@ -98,7 +98,7 @@ static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) template static void TestInference(Ort::Session& session, - const std::vector& inputs, + std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y) {