diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b32d55357afa5..6fce71661da1b 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -164,7 +164,7 @@ if (onnxruntime_USE_NNAPI) endif() set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib") - +set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/global_thread_pools") set (onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h @@ -178,6 +178,11 @@ if(onnxruntime_RUN_ONNX_TESTS) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_io_types.cc) endif() +set (onnxruntime_global_thread_pools_test_SRC + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h + ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc + ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc) + # tests from lowest level library up. # the order of libraries should be maintained, with higher libraries being added first in the list @@ -355,6 +360,9 @@ set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest" set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src}) if(NOT TARGET onnxruntime) list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) + if (NOT onnxruntime_USE_NNAPI) + list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC}) + endif() endif() set(all_dependencies ${onnxruntime_test_providers_dependencies} ) @@ -671,6 +679,16 @@ if (onnxruntime_BUILD_SHARED_LIB) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) + + # test inference using global threadpools + if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android") + AddTest(DYN + TARGET onnxruntime_global_thread_pools_test + SOURCES ${onnxruntime_global_thread_pools_test_SRC} + LIBS ${onnxruntime_shared_lib_test_LIBS} + DEPENDS ${all_dependencies} + ) + endif() endif() #some ETW tools diff --git a/docs/C_API.md b/docs/C_API.md index aa393735826a5..c8671cdf4d804 100644 --- a/docs/C_API.md +++ b/docs/C_API.md @@ -12,6 +12,15 @@ * Setting graph optimization level for each session. * Dynamically loading custom ops. [Instructions](/docs/AddingCustomOp.md) * Ability to load a model from a byte array. See ```OrtCreateSessionFromArray``` in [onnxruntime_c_api.h](/include/onnxruntime/core/session/onnxruntime_c_api.h). +* **Global/shared threadpools:** By default each session creates its own set of threadpools. In situations where multiple +sessions need to be created (to infer different models) in the same process, you end up with several threadpools created +by each session. In order to address this inefficiency we introduce a new feature called global/shared threadpools. +The basic idea here is to share a set of global threadpools across multiple sessions. Typical usage of this feature +is as follows + * Populate ```ThreadingOptions```. Use the value of 0 for ORT to pick the defaults. + * Create env using ```CreateEnvWithGlobalThreadPools()``` + * Create session and call ```DisablePerSessionThreads()``` on the session options object + * Call ```Run()``` as usual ## Usage Overview diff --git a/docs/execution_providers/ACL-ExecutionProvider.md b/docs/execution_providers/ACL-ExecutionProvider.md index dbabac8e33470..390731591283b 100644 --- a/docs/execution_providers/ACL-ExecutionProvider.md +++ b/docs/execution_providers/ACL-ExecutionProvider.md @@ -9,7 +9,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#ARM-Compute-L #### C/C++ To use ACL as execution provider for inferencing, please register it as below. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so, env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::ACLExecutionProvider>()); status = session_object.Load(model_file_name); ``` diff --git a/docs/execution_providers/DNNL-ExecutionProvider.md b/docs/execution_providers/DNNL-ExecutionProvider.md index 8e3f8616890c7..fb40679635585 100644 --- a/docs/execution_providers/DNNL-ExecutionProvider.md +++ b/docs/execution_providers/DNNL-ExecutionProvider.md @@ -21,7 +21,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#dnnl-and-mklm ### C/C++ The DNNLExecutionProvider execution provider needs to be registered with ONNX Runtime to enable in the inference session. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so,env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime:: DNNLExecutionProvider >()); status = session_object.Load(model_file_name); ``` diff --git a/docs/execution_providers/NNAPI-ExecutionProvider.md b/docs/execution_providers/NNAPI-ExecutionProvider.md index ed2fdf8c19af2..e6dd8452c8bf6 100644 --- a/docs/execution_providers/NNAPI-ExecutionProvider.md +++ b/docs/execution_providers/NNAPI-ExecutionProvider.md @@ -14,7 +14,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#Android-NNAPI To use NNAPI EP for inferencing, please register it as below. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so,env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::NnapiExecutionProvider>()); status = session_object.Load(model_file_name); ``` diff --git a/docs/execution_providers/TensorRT-ExecutionProvider.md b/docs/execution_providers/TensorRT-ExecutionProvider.md index 24cb76e1143d5..2fc30ee559e5a 100644 --- a/docs/execution_providers/TensorRT-ExecutionProvider.md +++ b/docs/execution_providers/TensorRT-ExecutionProvider.md @@ -13,7 +13,15 @@ The TensorRT execution provider for ONNX Runtime is built and tested with Tensor ### C/C++ The TensorRT execution provider needs to be registered with ONNX Runtime to enable in the inference session. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so,env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::TensorrtExecutionProvider>()); status = session_object.Load(model_file_name); ``` diff --git a/docs/execution_providers/nGraph-ExecutionProvider.md b/docs/execution_providers/nGraph-ExecutionProvider.md index 3c53e5d461ae5..11d403786b0df 100644 --- a/docs/execution_providers/nGraph-ExecutionProvider.md +++ b/docs/execution_providers/nGraph-ExecutionProvider.md @@ -23,7 +23,15 @@ While the nGraph Compiler stack supports various operating systems and backends ### C/C++ To use nGraph as execution provider for inferencing, please register it as below. ``` -InferenceSession session_object{so}; +string log_id = "Foo"; +auto logging_manager = std::make_unique +(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) +Environment::Create(std::move(logging_manager), env) +InferenceSession session_object{so,env}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::NGRAPHExecutionProvider>()); status = session_object.Load(model_file_name); ``` diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 5792e70b6ea39..34f3b81e05335 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -7,7 +7,10 @@ #include #include "core/common/common.h" #include "core/common/status.h" +#include "core/platform/threadpool.h" +#include "core/common/logging/logging.h" +struct ThreadingOptions; namespace onnxruntime { /** TODO: remove this class Provides the runtime environment for onnxruntime. @@ -17,13 +20,51 @@ class Environment { public: /** Create and initialize the runtime environment. + @param logging manager instance that will enable per session logger output using + session_options.session_logid as the logger id in messages. + If nullptr, the default LoggingManager MUST have been created previously as it will be used + for logging. This will use the default logger id in messages. + See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param tp_options optional set of parameters controlling the number of intra and inter op threads for the global + threadpools. + @param create_global_thread_pools determine if this function will create the global threadpools or not. */ - static Status Create(std::unique_ptr& environment); + static Status Create(std::unique_ptr logging_manager, + std::unique_ptr& environment, + const ThreadingOptions* tp_options = nullptr, + bool create_global_thread_pools = false); + + logging::LoggingManager* GetLoggingManager() const { + return logging_manager_.get(); + } + + void SetLoggingManager(std::unique_ptr logging_manager) { + logging_manager_ = std::move(logging_manager); + } + + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const { + return intra_op_thread_pool_.get(); + } + + onnxruntime::concurrency::ThreadPool* GetInterOpThreadPool() const { + return inter_op_thread_pool_.get(); + } + + bool EnvCreatedWithGlobalThreadPools() const { + return create_global_thread_pools_; + } private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); Environment() = default; - Status Initialize(); + Status Initialize(std::unique_ptr logging_manager, + const ThreadingOptions* tp_options = nullptr, + bool create_global_thread_pools = false); + + std::unique_ptr logging_manager_; + std::unique_ptr intra_op_thread_pool_; + std::unique_ptr inter_op_thread_pool_; + bool create_global_thread_pools_{false}; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 341fc9d7782b9..0003c9c2fd252 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -212,6 +212,14 @@ typedef enum OrtMemType { OrtMemTypeDefault = 0, // the default allocator for execution provider } OrtMemType; +typedef struct ThreadingOptions { + // number of threads used to parallelize execution of an op + int intra_op_num_threads; // use 0 if you want onnxruntime to choose a value for you + + // number of threads used to parallelize execution across ops + int inter_op_num_threads; // use 0 if you want onnxruntime to choose a value for you +} ThreadingOptions; + struct OrtApi; typedef struct OrtApi OrtApi; @@ -747,6 +755,23 @@ struct OrtApi { OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION; ORT_CLASS_RELEASE(ModelMetadata); + + /* + * Creates an environment with global threadpools that will be shared across sessions. + * Use this in conjunction with DisablePerSessionThreads API or else the session will use + * its own thread pools. + */ + OrtStatus*(ORT_API_CALL* CreateEnvWithGlobalThreadPools)(OrtLoggingLevel default_logging_level, _In_ const char* logid, + _In_ ThreadingOptions t_options, _Outptr_ OrtEnv** out) + NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /* TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function? */ + + /* + * Calling this API will make the session use the global threadpools shared across sessions. + * This API should be used in conjunction with CreateEnvWithGlobalThreadPools API. + */ + OrtStatus*(ORT_API_CALL* DisablePerSessionThreads)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5e1c91b916cf6..e562d3b3c5324 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -82,7 +82,7 @@ struct Base { ~Base() { OrtRelease(p_); } operator T*() { return p_; } - operator const T *() const { return p_; } + operator const T*() const { return p_; } T* release() { T* p = p_; @@ -123,6 +123,7 @@ struct ModelMetadata; struct Env : Base { Env(std::nullptr_t) {} Env(OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + Env(ThreadingOptions tp_options, OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); Env(OrtLoggingLevel default_logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); explicit Env(OrtEnv* p) : Base{p} {} @@ -185,6 +186,8 @@ struct SessionOptions : Base { SessionOptions& SetLogId(const char* logid); SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); + + SessionOptions& DisablePerSessionThreads(); }; struct ModelMetadata : Base { @@ -289,7 +292,7 @@ struct AllocatorWithDefaultOptions { AllocatorWithDefaultOptions(); operator OrtAllocator*() { return p_; } - operator const OrtAllocator *() const { return p_; } + operator const OrtAllocator*() const { return p_; } void* Alloc(size_t size); void Free(void* p); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index b10f0cebe50f1..a63508797ad90 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -84,6 +84,10 @@ inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLog ThrowOnError(Global::api_.CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_)); } +inline Env::Env(ThreadingOptions tp_options, OrtLoggingLevel default_warning_level, const char* logid) { + ThrowOnError(Global::api_.CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_)); +} + inline Env& Env::EnableTelemetryEvents() { ThrowOnError(Global::api_.EnableTelemetryEvents(p_)); return *this; @@ -601,4 +605,8 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, return out; } +inline SessionOptions& SessionOptions::DisablePerSessionThreads() { + ThrowOnError(Global::api_.DisablePerSessionThreads(p_)); + return *this; +} } // namespace Ort \ No newline at end of file diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc index 18f926076c412..82243e5cf79c2 100644 --- a/onnxruntime/core/common/logging/logging.cc +++ b/onnxruntime/core/common/logging/logging.cc @@ -22,7 +22,7 @@ #include "core/platform/ort_mutex.h" #if __FreeBSD__ -#include // Use thr_self() syscall under FreeBSD to get thread id +#include // Use thr_self() syscall under FreeBSD to get thread id #endif namespace onnxruntime { @@ -189,8 +189,8 @@ std::exception LoggingManager::LogFatalAndCreateException(const char* category, // create Capture in separate scope so it gets destructed (leading to log output) before we throw. { ::onnxruntime::logging::Capture c{::onnxruntime::logging::LoggingManager::DefaultLogger(), - ::onnxruntime::logging::Severity::kFATAL, category, - ::onnxruntime::logging::DataType::SYSTEM, location}; + ::onnxruntime::logging::Severity::kFATAL, category, + ::onnxruntime::logging::DataType::SYSTEM, location}; va_list args; va_start(args, format_str); diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 1c02f9f0ebbc2..dbfd12ca8b1ca 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -64,5 +64,9 @@ struct SessionOptions { // For models with free input dimensions (most commonly batch size), specifies a set of values to override those // free dimensions with, keyed by dimension denotation. std::vector free_dimension_overrides; + + // By default the session uses its own set of threadpools, unless this is set to false. + // Use this in conjunction with the CreateEnvWithGlobalThreadPools API. + bool use_per_session_threads = true; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index c15252b71d5e5..a6b08364c46a2 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -155,3 +155,8 @@ ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverride, _Inout_ OrtSessionOptions options->value.free_dimension_overrides.push_back(onnxruntime::FreeDimensionOverride{symbolic_dim, dim_override}); return nullptr; } + +ORT_API_STATUS_IMPL(OrtApis::DisablePerSessionThreads, _In_ OrtSessionOptions* options) { + options->value.use_per_session_threads = false; + return nullptr; +} diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 744c347491dc3..114b8985d8d10 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -18,6 +18,7 @@ #endif #include "core/platform/env.h" +#include "core/util/thread_utils.h" #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT #include "core/platform/tracing.h" @@ -29,15 +30,31 @@ using namespace ONNX_NAMESPACE; std::once_flag schemaRegistrationOnceFlag; -Status Environment::Create(std::unique_ptr& environment) { +Status Environment::Create(std::unique_ptr logging_manager, + std::unique_ptr& environment, + const ThreadingOptions* tp_options, + bool create_global_thread_pools) { environment = std::unique_ptr(new Environment()); - auto status = environment->Initialize(); + auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pools); return status; } -Status Environment::Initialize() { +Status Environment::Initialize(std::unique_ptr logging_manager, + const ThreadingOptions* tp_options, + bool create_global_thread_pools) { auto status = Status::OK(); + logging_manager_ = std::move(logging_manager); + + // create thread pools + if (create_global_thread_pools) { + create_global_thread_pools_ = true; + intra_op_thread_pool_ = concurrency::CreateThreadPool("env_global_intra_op_thread_pool", + tp_options->intra_op_num_threads); + inter_op_thread_pool_ = concurrency::CreateThreadPool("env_global_inter_op_thread_pool", + tp_options->inter_op_num_threads); + } + try { // Register Microsoft domain with min/max op_set version as 1/1. std::call_once(schemaRegistrationOnceFlag, []() { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 0de00bad3a308..b254d60d03c1d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -161,32 +161,45 @@ static Status FinalizeSessionOptions(const SessionOptions& user_provided_session } void InferenceSession::ConstructorCommon(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) { + const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, model_loaded_, session_options_); ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked + // after the invocation of FinalizeSessionOptions. + logging_manager_ = session_env.GetLoggingManager(); + InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. + // Update the number of steps for the graph transformer manager using the "finalized" session options graph_transformation_mgr_.SetSteps(session_options_.max_num_graph_transformation_steps); + use_per_session_threads_ = session_options.use_per_session_threads; - logging_manager_ = logging_manager; - - thread_pool_ = concurrency::CreateThreadPool("intra_op_thread_pool", - session_options_.intra_op_num_threads); + if (use_per_session_threads_) { + LOGS(*session_logger_, INFO) << "Creating and using per session threadpools since use_per_session_threads_ is true"; + thread_pool_ = concurrency::CreateThreadPool("intra_op_thread_pool", + session_options_.intra_op_num_threads); - inter_op_thread_pool_ = session_options_.execution_mode == ExecutionMode::ORT_PARALLEL - ? concurrency::CreateThreadPool("inter_op_thread_pool", - session_options_.inter_op_num_threads) - : nullptr; + inter_op_thread_pool_ = session_options_.execution_mode == ExecutionMode::ORT_PARALLEL + ? concurrency::CreateThreadPool("inter_op_thread_pool", + session_options_.inter_op_num_threads) + : nullptr; + } else { + LOGS(*session_logger_, INFO) << "Using global/env threadpools since use_per_session_threads_ is false"; + intra_op_thread_pool_from_env_ = session_env.GetIntraOpThreadPool(); + inter_op_thread_pool_from_env_ = session_env.GetInterOpThreadPool(); + ORT_ENFORCE(session_env.EnvCreatedWithGlobalThreadPools() == true, + "When the session is not configured to use per session" + " threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API."); + } session_state_ = onnxruntime::make_unique(execution_providers_, session_options_.enable_mem_pattern && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL, - thread_pool_.get(), - inter_op_thread_pool_.get()); - - InitLogger(logging_manager); + GetIntraOpThreadPoolToUse(), + GetInterOpThreadPoolToUse()); + session_state_->SetLogger(*session_logger_); session_state_->SetDataTransferMgr(&data_transfer_mgr_); session_profiler_.Initialize(session_logger_); session_state_->SetProfiler(session_profiler_); @@ -200,16 +213,16 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } InferenceSession::InferenceSession(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) + const Environment& session_env) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), insert_cast_transformer_("CastFloat16Transformer") { // Initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } InferenceSession::InferenceSession(const SessionOptions& session_options, - const std::string& model_uri, - logging::LoggingManager* logging_manager) + const Environment& session_env, + const std::string& model_uri) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); @@ -218,13 +231,13 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, status.ErrorMessage()); model_loaded_ = true; // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } #ifdef _WIN32 InferenceSession::InferenceSession(const SessionOptions& session_options, - const std::wstring& model_uri, - logging::LoggingManager* logging_manager) + const Environment& session_env, + const std::wstring& model_uri) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); @@ -233,13 +246,13 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, status.ErrorMessage()); model_loaded_ = true; // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } #endif InferenceSession::InferenceSession(const SessionOptions& session_options, - std::istream& model_istream, - logging::LoggingManager* logging_manager) + const Environment& session_env, + std::istream& model_istream) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), insert_cast_transformer_("CastFloat16Transformer") { google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); @@ -247,20 +260,20 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); model_loaded_ = true; // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } InferenceSession::InferenceSession(const SessionOptions& session_options, + const Environment& session_env, const void* model_data, - int model_data_len, - logging::LoggingManager* logging_manager) + int model_data_len) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), insert_cast_transformer_("CastFloat16Transformer") { const bool result = model_proto_.ParseFromArray(model_data, model_data_len); ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); model_loaded_ = true; // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, logging_manager); + ConstructorCommon(session_options, session_env); } InferenceSession::~InferenceSession() { @@ -1408,8 +1421,6 @@ void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { } else { session_logger_ = &logging::LoggingManager::DefaultLogger(); } - - session_state_->SetLogger(*session_logger_); } // Registers all the predefined transformers with transformer manager diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index a2f5fa3dc18e3..f379d780658e6 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -31,6 +31,7 @@ namespace onnxruntime { // forward declarations class GraphTransformer; +class Environment; } // namespace onnxruntime namespace ONNX_NAMESPACE { @@ -77,7 +78,15 @@ struct ModelMetadata { * CPUExecutionProviderInfo epi; * ProviderOption po{"CPUExecutionProvider", epi}; * SessionOptions so(vector{po}); - * InferenceSession session_object{so}; + * string log_id = "Foo"; + * auto logging_manager = std::make_unique + (std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &log_id) + * Environment::Create(std::move(logging_manager), env) + * InferenceSession session_object{so,env}; * common::Status status = session_object.Load(MODEL_URI); * common::Status status = session_object.Initialize(); * @@ -97,70 +106,50 @@ class InferenceSession { /** Create a new InferenceSession @param session_options Session options. - @param logging_manager - Optional logging manager instance that will enable per session logger output using - session_options.session_logid as the logger id in messages. - If nullptr, the default LoggingManager MUST have been created previously as it will be used - for logging. This will use the default logger id in messages. - See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param session_env This represents the context for the session and contains the logger and the global threadpools. */ explicit InferenceSession(const SessionOptions& session_options, - logging::LoggingManager* logging_manager = nullptr); + const Environment& session_env); /** Create a new InferenceSession @param session_options Session options. @param model_uri absolute path of the model file. - @param logging_manager - Optional logging manager instance that will enable per session logger output using - session_options.session_logid as the logger id in messages. - If nullptr, the default LoggingManager MUST have been created previously as it will be used - for logging. This will use the default logger id in messages. - See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param session_env This represents the context for the session and contains the logger and the global threadpools. This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, - const std::string& model_uri, - logging::LoggingManager* logging_manager = nullptr); + const Environment& session_env, + const std::string& model_uri); #ifdef _WIN32 InferenceSession(const SessionOptions& session_options, - const std::wstring& model_uri, - logging::LoggingManager* logging_manager = nullptr); + const Environment& session_env, + const std::wstring& model_uri); #endif /** Create a new InferenceSession @param session_options Session options. @param istream object of the model. - @param logging_manager - Optional logging manager instance that will enable per session logger output using - session_options.session_logid as the logger id in messages. - If nullptr, the default LoggingManager MUST have been created previously as it will be used - for logging. This will use the default logger id in messages. - See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param session_env This represents the context for the session and contains the logger and the global threadpools. This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, - std::istream& model_istream, - logging::LoggingManager* logging_manager = nullptr); + const Environment& session_env, + std::istream& model_istream); /** Create a new InferenceSession @param session_options Session options. @param model_data Model data buffer. @param model_data_len Model data buffer size. - @param logging_manager - Optional logging manager instance that will enable per session logger output using - session_options.session_logid as the logger id in messages. - If nullptr, the default LoggingManager MUST have been created previously as it will be used - for logging. This will use the default logger id in messages. - See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works. + @param session_env This represents the context for the session and contains the logger and the global threadpools. This ctor will throw on encountering model parsing issues. */ InferenceSession(const SessionOptions& session_options, + const Environment& session_env, const void* model_data, - int model_data_len, - logging::LoggingManager* logging_manager = nullptr); + int model_data_len); virtual ~InferenceSession(); @@ -388,7 +377,7 @@ class InferenceSession { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); void ConstructorCommon(const SessionOptions& session_options, - logging::LoggingManager* logging_manager); + const Environment& session_env); bool HasLocalSchema() const { return !custom_schema_registries_.empty(); @@ -465,11 +454,33 @@ class InferenceSession { // It has a dependency on execution_providers_. std::unique_ptr session_state_; + // Use these 2 threadpool methods to get access to the threadpools since they rely on + // specific flags in session options + // These methods assume that session options have been finalized before the call. + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const { + return session_options_.use_per_session_threads ? thread_pool_.get() : intra_op_thread_pool_from_env_; + } + + onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const { + return session_options_.use_per_session_threads ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_; + } + private: - // Threadpool for this session + // Threadpools per session. These are initialized and used for the entire duration of the session + // when use_per_session_threads is true. std::unique_ptr thread_pool_; std::unique_ptr inter_op_thread_pool_; + // Global threadpools. These are intialized and used when use_per_session_threads is false *and* + // the environment is created with create_global_thread_pools = true. + onnxruntime::concurrency::ThreadPool* intra_op_thread_pool_from_env_{}; + onnxruntime::concurrency::ThreadPool* inter_op_thread_pool_from_env_{}; + + // initialized from session options + // Determines which threadpools will be intialized and used for the duration of this session. + // If true, use the per session ones, or else the global threadpools. + bool use_per_session_threads_; + KernelRegistryManager kernel_registry_manager_; std::list> custom_schema_registries_; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ae328b70b26e1..d419c2ad2cb84 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -80,6 +80,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateEnv, OrtLoggingLevel default_warning_level, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithGlobalThreadPools, OrtLoggingLevel default_warning_level, + _In_ const char* logid, _In_ ThreadingOptions tp_options, _Outptr_ OrtEnv** out) { + API_IMPL_BEGIN + OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, default_warning_level, logid}; + Status status; + *out = OrtEnv::GetInstance(lm_info, status, &tp_options); + return ToOrtStatus(status); + API_IMPL_END +} + // enable platform telemetry ORT_API_STATUS_IMPL(OrtApis::EnableTelemetryEvents, _In_ const OrtEnv* ort_env) { API_IMPL_BEGIN @@ -414,7 +424,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O try { sess = onnxruntime::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - model_path, env->GetLoggingManager()); + env->GetEnvironment(), model_path); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -429,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In try { sess = onnxruntime::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - model_data, static_cast(model_data_length), env->GetLoggingManager()); + env->GetEnvironment(), model_data, static_cast(model_data_length)); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -1362,7 +1372,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_2 = { +static constexpr OrtApi ort_api_1_to_3 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -1482,7 +1492,6 @@ static constexpr OrtApi ort_api_1_to_2 = { &OrtApis::ReleaseCustomOpDomain, // End of Version 1 - DO NOT MODIFY ABOVE (see above text for more information) - // Version 2 - In development, feel free to add/remove/rearrange here &OrtApis::GetDenotationFromTypeInfo, &OrtApis::CastTypeInfoToMapTypeInfo, &OrtApis::CastTypeInfoToSequenceTypeInfo, @@ -1500,6 +1509,11 @@ static constexpr OrtApi ort_api_1_to_2 = { &OrtApis::ModelMetadataLookupCustomMetadataMap, &OrtApis::ModelMetadataGetVersion, &OrtApis::ReleaseModelMetadata, + // End of Version 2 - DO NOT MODIFY ABOVE (see above text for more information) + + // Version 3 - In development, feel free to add/remove/rearrange here + &OrtApis::CreateEnvWithGlobalThreadPools, + &OrtApis::DisablePerSessionThreads, }; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) @@ -1507,8 +1521,8 @@ static constexpr OrtApi ort_api_1_to_2 = { static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change"); ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { - if (version >= 1 && version <= 2) - return &ort_api_1_to_2; + if (version >= 1 && version <= 3) + return &ort_api_1_to_3; return nullptr; // Unsupported version } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 8ebed4a71b930..54de7ee65040f 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -178,4 +178,9 @@ ORT_API_STATUS_IMPL(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _ // OrtSequenceTypeInfo Accessors ORT_API_STATUS_IMPL(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateEnvWithGlobalThreadPools, OrtLoggingLevel default_logging_level, _In_ const char* logid, + _In_ ThreadingOptions t_options, _Outptr_ OrtEnv** out) +ORT_ALL_ARGS_NONNULL; + +ORT_API_STATUS_IMPL(DisablePerSessionThreads, _In_ OrtSessionOptions* options); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index a74de2242d510..2b027404cd3b1 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -10,7 +10,6 @@ #include "core/session/environment.h" #include "core/common/logging/sinks/clog_sink.h" #include "core/common/logging/logging.h" -#include "core/session/environment.h" using namespace onnxruntime; using namespace onnxruntime::logging; @@ -30,38 +29,43 @@ void LoggingWrapper::SendImpl(const onnxruntime::logging::Timestamp& /*timestamp logger_id.c_str(), s.c_str(), message.Message().c_str()); } -OrtEnv::OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager) - : value_(std::move(value1)), logging_manager_(std::move(logging_manager)) { +OrtEnv::OrtEnv(std::unique_ptr value1) + : value_(std::move(value1)) { } -OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status) { +OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, + onnxruntime::common::Status& status, + const ThreadingOptions* tp_options) { std::lock_guard lock(m_); + std::unique_ptr lmgr; + std::string name = lm_info.logid; + if (lm_info.logging_function) { + std::unique_ptr logger = onnxruntime::make_unique(lm_info.logging_function, + lm_info.logger_param); + lmgr.reset(new LoggingManager(std::move(logger), + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &name)); + } else { + lmgr.reset(new LoggingManager(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &name)); + } + if (!p_instance_) { std::unique_ptr env; - status = onnxruntime::Environment::Create(env); + if (!tp_options) { + status = onnxruntime::Environment::Create(std::move(lmgr), env); + } else { + status = onnxruntime::Environment::Create(std::move(lmgr), env, tp_options, true); + } if (!status.IsOK()) { return nullptr; } - - std::unique_ptr lmgr; - std::string name = lm_info.logid; - if (lm_info.logging_function) { - std::unique_ptr logger = onnxruntime::make_unique(lm_info.logging_function, - lm_info.logger_param); - lmgr.reset(new LoggingManager(std::move(logger), - static_cast(lm_info.default_warning_level), - false, - LoggingManager::InstanceType::Default, - &name)); - } else { - lmgr.reset(new LoggingManager(std::unique_ptr{new CLogSink{}}, - static_cast(lm_info.default_warning_level), - false, - LoggingManager::InstanceType::Default, - &name)); - } - - p_instance_ = new OrtEnv(std::move(env), std::move(lmgr)); + p_instance_ = new OrtEnv(std::move(env)); } ++ref_count_; return p_instance_; @@ -80,11 +84,10 @@ void OrtEnv::Release(OrtEnv* env_ptr) { } } -LoggingManager* OrtEnv::GetLoggingManager() const { - return logging_manager_.get(); +onnxruntime::logging::LoggingManager* OrtEnv::GetLoggingManager() const { + return value_->GetLoggingManager(); } -void OrtEnv::SetLoggingManager(std::unique_ptr logging_manager) { - std::lock_guard lock(m_); - logging_manager_ = std::move(logging_manager); +void OrtEnv::SetLoggingManager(std::unique_ptr logging_manager) { + value_->SetLoggingManager(std::move(logging_manager)); } \ No newline at end of file diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index c93d2937c7a7b..81055874a82b6 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -42,12 +42,17 @@ struct OrtEnv { const char* logid{}; }; - static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status); + static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, + onnxruntime::common::Status& status, + const ThreadingOptions* tp_options = nullptr); static void Release(OrtEnv* env_ptr); - onnxruntime::logging::LoggingManager* GetLoggingManager() const; + const onnxruntime::Environment& GetEnvironment() const { + return *(value_.get()); + } + onnxruntime::logging::LoggingManager* GetLoggingManager() const; void SetLoggingManager(std::unique_ptr logging_manager); private: @@ -56,9 +61,8 @@ struct OrtEnv { static int ref_count_; std::unique_ptr value_; - std::unique_ptr logging_manager_; - OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager); + OrtEnv(std::unique_ptr value1); ~OrtEnv() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d4baca48da303..c50b762cf4974 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -234,28 +234,29 @@ void AddTensorAsPyObj(OrtValue& val, std::vector& pyobjs) { GetPyObjFromTensor(rtensor, obj); pyobjs.push_back(obj); } - class SessionObjectInitializer { public: typedef const SessionOptions& Arg1; - typedef logging::LoggingManager* Arg2; + // typedef logging::LoggingManager* Arg2; + static const std::string default_logger_id; operator Arg1() { return GetDefaultCPUSessionOptions(); } - operator Arg2() { - static std::string default_logger_id{"Default"}; - static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, - Severity::kWARNING, false, LoggingManager::InstanceType::Default, - &default_logger_id}; - return &default_logging_manager; - } + // operator Arg2() { + // static LoggingManager default_logging_manager{std::unique_ptr{new CErrSink{}}, + // Severity::kWARNING, false, LoggingManager::InstanceType::Default, + // &default_logger_id}; + // return &default_logging_manager; + // } static SessionObjectInitializer Get() { return SessionObjectInitializer(); } }; +const std::string SessionObjectInitializer::default_logger_id = "Default"; + inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) { auto p = f.CreateProvider(); OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p))); @@ -350,17 +351,17 @@ void InitializeSession(InferenceSession* sess, const std::vector& p OrtPybindThrowIfError(sess->Initialize()); } -void addGlobalMethods(py::module& m) { +void addGlobalMethods(py::module& m, const Environment& env) { m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance."); m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer."); m.def( "get_device", []() -> std::string { return BACKEND_DEVICE; }, "Return the device used to compute the prediction (CPU, MKL, ...)"); m.def( - "set_default_logger_severity", [](int severity) { + "set_default_logger_severity", [&env](int severity) { ORT_ENFORCE(severity >= 0 && severity <= 4, "Invalid logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal"); - logging::LoggingManager* default_logging_manager = SessionObjectInitializer::Get(); + logging::LoggingManager* default_logging_manager = env.GetLoggingManager(); default_logging_manager->SetDefaultLoggerSeverity(static_cast(severity)); }, "Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal"); @@ -546,7 +547,7 @@ void addOpSchemaSubmodule(py::module& m) { #endif //onnxruntime_PYBIND_EXPORT_OPSCHEMA -void addObjectMethods(py::module& m) { +void addObjectMethods(py::module& m, Environment& env) { py::enum_(m, "GraphOptimizationLevel") .value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL) .value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC) @@ -715,15 +716,15 @@ including arg name, arg type (contains both type and shape).)pbdoc") // In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char* // without any conversion. So this init method can be used for model file path (string) // and model content (bytes) - .def(py::init([](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) { + .def(py::init([&env](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) { // Given arg is the file path. Invoke the corresponding ctor(). if (is_arg_file_name) { - return onnxruntime::make_unique(so, arg, SessionObjectInitializer::Get()); + return onnxruntime::make_unique(so, env, arg); } // Given arg is the model content as bytes. Invoke the corresponding ctor(). std::istringstream buffer(arg); - return onnxruntime::make_unique(so, buffer, SessionObjectInitializer::Get()); + return onnxruntime::make_unique(so, env, buffer); })) .def( "load_model", [](InferenceSession* sess, std::vector& provider_types) { @@ -867,6 +868,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { #endif + static std::unique_ptr env; auto initialize = [&]() { // Initialization of the module ([]() -> void { @@ -874,8 +876,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { import_array1(); })(); - static std::unique_ptr env; - OrtPybindThrowIfError(Environment::Create(env)); + OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique( + std::unique_ptr{new CLogSink{}}, + Severity::kWARNING, false, LoggingManager::InstanceType::Default, + &SessionObjectInitializer::default_logger_id), + env)); static bool initialized = false; if (initialized) { @@ -885,8 +890,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { }; initialize(); - addGlobalMethods(m); - addObjectMethods(m); + addGlobalMethods(m, *env); + addObjectMethods(m, *env); #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA addOpSchemaSubmodule(m); diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index a19a8754e6ded..e9999eb2f567c 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -159,7 +159,7 @@ void LayerNormOpTester::ComputeWithCPU(std::vector& cpu_fetches) { run_options.run_log_verbosity_level = 1; // run with LayerNormalization - InferenceSession layernorm_session_object{so}; + InferenceSession layernorm_session_object{so, GetEnvironment()}; std::string s1; p_model->ToProto().SerializeToString(&s1); std::istringstream str(s1); @@ -193,7 +193,7 @@ void LayerNormOpTester::ComputeWithCUDA(std::vector& cuda_fetches) { run_options.run_log_verbosity_level = 1; auto cuda_execution_provider = DefaultCudaExecutionProvider(); - InferenceSession cuda_session_object{so}; + InferenceSession cuda_session_object{so, GetEnvironment()}; EXPECT_TRUE(cuda_session_object.RegisterExecutionProvider(std::move(cuda_execution_provider)).IsOK()); std::string s; @@ -226,7 +226,7 @@ void LayerNormOpTester::ComputeOriSubgraphWithCPU(std::vector& subgraph run_options.run_log_verbosity_level = 1; Status status; - InferenceSession subgraph_session_object{so, &DefaultLoggingManager()}; + InferenceSession subgraph_session_object{so, GetEnvironment()}; ASSERT_TRUE((status = subgraph_session_object.Load("testdata/layernorm.onnx")).IsOK()) << status; ASSERT_TRUE((status = subgraph_session_object.Initialize()).IsOK()) << status; ASSERT_TRUE((status = subgraph_session_object.Run(run_options, feeds, output_names, &subgraph_fetches)).IsOK()) << status; diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index 15e4ce84eefbc..2367393769057 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -37,7 +37,7 @@ typedef std::vector ArgMap; class FenceCudaTestInferenceSession : public InferenceSession { public: - FenceCudaTestInferenceSession(const SessionOptions& so) : InferenceSession(so) {} + FenceCudaTestInferenceSession(const SessionOptions& so, const Environment& env) : InferenceSession(so, env) {} Status LoadModel(onnxruntime::Model& model) { auto model_proto = model.ToProto(); auto st = Load(model_proto); @@ -117,7 +117,7 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) { DataTypeImpl::GetType()->GetDeleteFunc()); SessionOptions so; - FenceCudaTestInferenceSession session(so); + FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); @@ -171,7 +171,7 @@ TEST(CUDAFenceTests, TileWithInitializer) { DataTypeImpl::GetType()->GetDeleteFunc()); SessionOptions so; - FenceCudaTestInferenceSession session(so); + FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); @@ -236,7 +236,7 @@ TEST(CUDAFenceTests, TileWithComputedInput) { DataTypeImpl::GetType()->GetDeleteFunc()); SessionOptions so; - FenceCudaTestInferenceSession session(so); + FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index c3ca950c81389..bd49cf071f44e 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -277,7 +277,7 @@ TEST(ExecutionFrameTestWithoutSessionState, BadModelInvalidDimParamUsage) { SessionOptions so; so.session_logid = "BadModelInvalidDimParamUsage"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status st; ASSERT_TRUE((st = session_object.Load("testdata/invalid_dim_param_value_repetition.onnx")).IsOK()) << st; ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; diff --git a/onnxruntime/test/framework/float_16_test.cc b/onnxruntime/test/framework/float_16_test.cc index f233e6862b2db..2a84fcf048167 100644 --- a/onnxruntime/test/framework/float_16_test.cc +++ b/onnxruntime/test/framework/float_16_test.cc @@ -137,7 +137,7 @@ TEST(Float16_Tests, Mul_16_Test) { so.session_logid = "InferenceSessionTests.NoTimeout"; std::shared_ptr registry = std::make_shared(); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); auto mulfp16_schema = GetMulFP16Schema(); std::vector schemas = {mulfp16_schema}; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index baa15a8a8a53a..45aade53bf729 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -12,6 +12,7 @@ #include #include "core/common/logging/logging.h" +#include "core/common/logging/sinks/clog_sink.h" #include "core/common/profiler.h" #include "core/framework/compute_capability.h" #include "core/framework/data_transfer_manager.h" @@ -38,6 +39,7 @@ #include "test/optimizer/dummy_graph_transformer.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "gtest/gtest.h" +#include "core/session/environment.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -120,7 +122,7 @@ class FuseExecutionProvider : public IExecutionProvider { class InferenceSessionGetGraphWrapper : public InferenceSession { public: explicit InferenceSessionGetGraphWrapper(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) { + const Environment& env) : InferenceSession(session_options, env) { } const Graph& GetGraph() { @@ -337,7 +339,7 @@ TEST(InferenceSessionTests, NoTimeout) { so.session_logid = "InferenceSessionTests.NoTimeout"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status st; ASSERT_TRUE((st = session_object.Load(MODEL_URI)).IsOK()) << st.ErrorMessage(); ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st.ErrorMessage(); @@ -353,7 +355,7 @@ TEST(InferenceSessionTests, DisableCPUArena) { so.session_logid = "InferenceSessionTests.DisableCPUArena"; so.enable_cpu_mem_arena = false; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -369,7 +371,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { const string test_model = "testdata/transform/abs-id-max.onnx"; so.session_logid = "InferenceSessionTests.TestModelSerialization"; so.graph_optimization_level = TransformerLevel::Default; - InferenceSessionGetGraphWrapper session_object_noopt{so, &DefaultLoggingManager()}; + InferenceSessionGetGraphWrapper session_object_noopt{so, GetEnvironment()}; ASSERT_TRUE(session_object_noopt.Load(test_model).IsOK()); ASSERT_TRUE(session_object_noopt.Initialize().IsOK()); @@ -381,7 +383,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { // Load model with level 1 transform level. so.graph_optimization_level = TransformerLevel::Level1; so.optimized_model_filepath = ToWideString(test_model + "-TransformLevel-" + std::to_string(static_cast(so.graph_optimization_level))); - InferenceSessionGetGraphWrapper session_object{so, &DefaultLoggingManager()}; + InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(test_model).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -391,7 +393,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(op_to_count["Identity"] == 0); // Serialize model to the same file path again to make sure that rewrite doesn't fail. - InferenceSession overwrite_session_object{so, &DefaultLoggingManager()}; + InferenceSession overwrite_session_object{so, GetEnvironment()}; ASSERT_TRUE(overwrite_session_object.Load(test_model).IsOK()); ASSERT_TRUE(overwrite_session_object.Initialize().IsOK()); @@ -400,7 +402,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { so_opt.session_logid = "InferenceSessionTests.TestModelSerialization"; so_opt.graph_optimization_level = TransformerLevel::Default; so_opt.optimized_model_filepath = ToWideString(so.optimized_model_filepath) + ToWideString("-TransformLevel-" + std::to_string(static_cast(so_opt.graph_optimization_level))); - InferenceSession session_object_opt{so_opt, &DefaultLoggingManager()}; + InferenceSession session_object_opt{so_opt, GetEnvironment()}; ASSERT_TRUE(session_object_opt.Load(so.optimized_model_filepath).IsOK()); ASSERT_TRUE(session_object_opt.Initialize().IsOK()); @@ -419,7 +421,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { // Assert that empty optimized model file-path doesn't fail loading. so_opt.optimized_model_filepath = ToWideString(""); - InferenceSession session_object_emptyValidation{so_opt, &DefaultLoggingManager()}; + InferenceSession session_object_emptyValidation{so_opt, GetEnvironment()}; ASSERT_TRUE(session_object_emptyValidation.Load(test_model).IsOK()); ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } @@ -455,7 +457,7 @@ TEST(InferenceSessionTests, ModelMetadata) { SessionOptions so; so.session_logid = "InferenceSessionTests.ModelMetadata"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; auto model_uri = ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"); ASSERT_TRUE(session_object.Load(model_uri).IsOK()); @@ -525,7 +527,9 @@ TEST(InferenceSessionTests, CheckRunLogger) { std::unique_ptr(capturing_sink), logging::Severity::kVERBOSE, false, LoggingManager::InstanceType::Temporal); - InferenceSession session_object{so, logging_manager.get()}; + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + InferenceSession session_object{so, *env.get()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -555,7 +559,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { so.enable_profiling = true; so.profile_file_prefix = ORT_TSTR("onnxprofile_profile_test"); - InferenceSession session_object(so); + InferenceSession session_object(so, GetEnvironment()); ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -594,7 +598,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) { so.session_logid = "CheckRunProfiler"; - InferenceSession session_object(so); + InferenceSession session_object(so, GetEnvironment()); ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -632,7 +636,7 @@ TEST(InferenceSessionTests, MultipleSessionsNoTimeout) { SessionOptions session_options; session_options.session_logid = "InferenceSessionTests.MultipleSessionsNoTimeout"; - InferenceSession session_object{session_options, &DefaultLoggingManager()}; + InferenceSession session_object{session_options, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -657,7 +661,7 @@ TEST(InferenceSessionTests, PreAllocateOutputVector) { so.session_logid = "InferenceSessionTests.PreAllocateOutputVector"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -684,7 +688,9 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { false, LoggingManager::InstanceType::Temporal); - InferenceSession session_object{so, logging_manager.get()}; + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + InferenceSession session_object{so, *env.get()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -719,7 +725,7 @@ TEST(InferenceSessionTests, TestWithIstream) { so.session_logid = "InferenceSessionTests.TestWithIstream"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; std::ifstream model_file_stream(MODEL_URI, ios::in | ios::binary); ASSERT_TRUE(model_file_stream.good()); @@ -736,7 +742,7 @@ TEST(InferenceSessionTests, TestRegisterExecutionProvider) { so.session_logid = "InferenceSessionTests.TestWithIstream"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; CPUExecutionProviderInfo epi; ASSERT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique(epi)).IsOK()); @@ -760,7 +766,7 @@ static void TestBindHelper(const std::string& log_str, so.session_logid = "InferenceSessionTests." + log_str; so.session_log_verbosity_level = 1; // change to 1 for detailed logging - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; if (bind_provider_type == kCudaExecutionProvider || run_provider_type == kCudaExecutionProvider) { #ifdef USE_CUDA @@ -798,7 +804,7 @@ TEST(InferenceSessionTests, TestBindCpu) { TEST(InferenceSessionTests, TestIOBindingReuse) { SessionOptions so; - InferenceSession session_object(so); + InferenceSession session_object(so, GetEnvironment()); std::unique_ptr p_model; CreateMatMulModel(p_model, kCpuExecutionProvider); @@ -839,7 +845,7 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) { so.session_logid = "InferenceSessionTests.InvalidInputTypeOfTensorElement"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -912,7 +918,7 @@ TEST(InferenceSessionTests, ModelWithoutOpset) { so.session_logid = "InferenceSessionTests.ModelWithoutOpset"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status retval = session_object.Load(MODEL_URI_NO_OPSET); ASSERT_FALSE(retval.IsOK()); if (!retval.IsOK()) { @@ -923,11 +929,11 @@ TEST(InferenceSessionTests, ModelWithoutOpset) { static common::Status RunOptionalInputTest(bool add_required_input, bool add_optional_input, bool add_invalid_input, - int model_ir_version) { + int model_ir_version, + const Environment& sess_env) { SessionOptions so; so.session_logid = "RunOptionalInputTest"; - - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, sess_env}; Status status; std::string model_path = "testdata/optional_inputs_ir" + std::to_string(model_ir_version) + ".onnx"; @@ -1001,25 +1007,26 @@ static common::Status RunOptionalInputTest(bool add_required_input, // for V4 allow it TEST(InferenceSessionTests, TestOptionalInputs) { std::vector ir_versions{3, 4}; + const auto& sess_env = GetEnvironment(); for (auto version : ir_versions) { // required input only - auto status = RunOptionalInputTest(true, false, false, version); + auto status = RunOptionalInputTest(true, false, false, version, sess_env); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); // required and optional input - status = RunOptionalInputTest(true, true, false, version); + status = RunOptionalInputTest(true, true, false, version, sess_env); if (version == 3) { ASSERT_FALSE(status.IsOK()) << status.ErrorMessage(); } else { ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); } // required, optional and invalid input - status = RunOptionalInputTest(true, true, true, version); + status = RunOptionalInputTest(true, true, true, version, sess_env); ASSERT_FALSE(status.IsOK()); EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name")); // missing required - status = RunOptionalInputTest(false, true, false, version); + status = RunOptionalInputTest(false, true, false, version, sess_env); ASSERT_FALSE(status.IsOK()); if (version == 3) { EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name")); @@ -1065,7 +1072,7 @@ TEST(ExecutionProviderTest, FunctionTest) { SessionOptions so; so.session_logid = "ExecutionProviderTest.FunctionTest"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; status = session_object.Load(model_file_name); ASSERT_TRUE(status.IsOK()); status = session_object.Initialize(); @@ -1104,7 +1111,7 @@ TEST(ExecutionProviderTest, FunctionTest) { ASSERT_TRUE(status.IsOK()); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); - InferenceSession session_object_2{so}; + InferenceSession session_object_2{so, GetEnvironment()}; session_object_2.RegisterExecutionProvider(std::move(testCPUExecutionProvider)); session_object_2.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::FuseExecutionProvider>()); status = session_object_2.Load(model_file_name); @@ -1172,7 +1179,7 @@ TEST(ExecutionProviderTest, FunctionInlineTest) { SessionOptions so; so.session_logid = "ExecutionProviderTest.FunctionInlineTest"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; status = session_object.Load(model_file_name); ASSERT_TRUE(status.IsOK()); status = session_object.Initialize(); @@ -1263,7 +1270,7 @@ TEST(InferenceSessionTests, TestTruncatedSequence) { // now run the truncated model SessionOptions so; - InferenceSession session_object(so); + InferenceSession session_object(so, GetEnvironment()); ASSERT_TRUE(session_object.Load(LSTM_MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -1373,7 +1380,7 @@ TEST(InferenceSessionTests, TestTruncatedSequence) { TEST(InferenceSessionTests, TestCopyToFromDevices) { SessionOptions so; so.session_logid = "InferenceSessionTests.TestCopyToFromDevices"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -1435,7 +1442,7 @@ TEST(InferenceSessionTests, TestRegisterTransformers) { SessionOptions so; so.session_logid = "InferenceSessionTests.TestL1AndL2Transformers"; so.graph_optimization_level = static_cast(i); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; // Create and register dummy graph transformer auto dummy_transformer_unique_ptr = onnxruntime::make_unique("DummyTransformer"); @@ -1465,7 +1472,7 @@ TEST(InferenceSessionTests, TestL1AndL2Transformers) { SessionOptions so; so.session_logid = "InferenceSessionTests.TestL1AndL2Transformers"; so.graph_optimization_level = TransformerLevel::Level2; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(model_uri).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); } @@ -1530,7 +1537,7 @@ TEST(InferenceSessionTests, TestParallelExecutionWithCudaProvider) { SessionOptions so; so.execution_mode = ExecutionMode::ORT_PARALLEL; so.session_logid = "InferenceSessionTests.TestParallelExecutionWithCudaProvider"; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; CUDAExecutionProviderInfo epi; epi.device_id = 0; @@ -1553,7 +1560,7 @@ TEST(InferenceSessionTests, ModelThatTriggersAllocationPlannerToReuseDoubleTenso so.session_logid = "InferenceSessionTests.ModelThatTriggersAllocationPlannerBug"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status st; ASSERT_TRUE((st = session_object.Load("testdata/test_cast_back_to_back_non_const_mixed_types_origin.onnx")).IsOK()) << st.ErrorMessage(); @@ -1618,7 +1625,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { std::string model_path = "testdata/model_with_valid_ort_config_json.onnx"; // Create session - InferenceSession session_object_1{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1656,7 +1663,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { so.intra_op_num_threads = 2; // Create session - InferenceSession session_object_2{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // Load() and Initialize() the session ASSERT_TRUE((st = session_object_2.Load()).IsOK()) << st.ErrorMessage(); @@ -1685,7 +1692,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { // Create session (should throw as the json within the model is invalid/improperly formed) try { - InferenceSession session_object_1{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; } catch (const std::exception& e) { std::string e_message(std::string(e.what())); ASSERT_TRUE(e_message.find("Could not finalize session options while constructing the inference session. Error Message:") != std::string::npos); @@ -1704,7 +1711,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { so.intra_op_num_threads = 2; // Create session - InferenceSession session_object_2{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1734,7 +1741,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { std::string model_path = "testdata/transform/abs-id-max.onnx"; // Create session - InferenceSession session_object_1{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; // Load() and Initialize() the session Status st; @@ -1755,7 +1762,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { #endif // Create session - InferenceSession session_object_2{so, model_path, &DefaultLoggingManager()}; // so has inter_op_num_threads set to 2 + InferenceSession session_object_2{so, GetEnvironment(), model_path}; // so has inter_op_num_threads set to 2 // Load() and Initialize() the session ASSERT_TRUE((st = session_object_2.Load()).IsOK()) << st.ErrorMessage(); @@ -1779,7 +1786,7 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { // Create session (should throw because of the unsupported value for the env var - ORT_LOAD_CONFIG_FROM_MODEL) try { - InferenceSession session_object_1{so, model_path, &DefaultLoggingManager()}; + InferenceSession session_object_1{so, GetEnvironment(), model_path}; } catch (const std::exception& e) { std::string e_message(std::string(e.what())); ASSERT_TRUE(e_message.find("Could not finalize session options while constructing the inference session. Error Message:") != std::string::npos); @@ -1795,5 +1802,175 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { #endif } +// Global threadpool related tests +// We test for 4 combinations +class InferenceSessionTestGlobalThreadPools : public InferenceSession { + public: + InferenceSessionTestGlobalThreadPools(const SessionOptions& session_options, + const Environment& env) : InferenceSession(session_options, env) { + } + + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const { + return InferenceSession::GetIntraOpThreadPoolToUse(); + } + + onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const { + return InferenceSession::GetInterOpThreadPoolToUse(); + } + + const SessionState& GetSessionState() { + return *session_state_; + } +}; + +// Test 1: env created WITHOUT global tp / use per session tp (default case): in this case per session tps should be in use +TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed) { + SessionOptions so; + so.use_per_session_threads = true; + + so.session_logid = "CheckIfPerSessionThreadPoolsAreBeingUsed"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the per session threadpools + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + // ensure threadpools were set correctly in the session state + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_session_state); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_session_state); + + ASSERT_TRUE(intra_tp_from_env == nullptr); + ASSERT_TRUE(inter_tp_from_env == nullptr); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 2: env created with global tp / DONT use per session tp: in this case global tps should be in use +TEST(InferenceSessionTests, CheckIfGlobalThreadPoolsAreBeingUsed) { + SessionOptions so; + so.use_per_session_threads = false; + + so.session_logid = "CheckIfGlobalThreadPoolsAreBeingUsed"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + ThreadingOptions tp_options{0, 0}; + auto st = Environment::Create(std::move(logging_manager), env, &tp_options, true /*create_global_thread_pools*/); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the global threadpools in both session and session state + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_env); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_env); + ASSERT_TRUE(intra_tp_from_session_state == intra_tp_from_env); + ASSERT_TRUE(inter_tp_from_session_state == inter_tp_from_env); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 3: env created with global tp / use per session tp: in this case per session tps should be in use +TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed2) { + SessionOptions so; + so.use_per_session_threads = true; + + so.session_logid = "CheckIfPerSessionThreadPoolsAreBeingUsed2"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + ThreadingOptions tp_options{0, 0}; + auto st = Environment::Create(std::move(logging_manager), env, &tp_options, true /*create_global_thread_pools*/); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the per session threadpools + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + // ensure threadpools were set correctly in the session state + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_session_state); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_session_state); + + // ensure per session thread pools in use are different from the + // env threadpools + if (intra_tp_from_session && intra_tp_from_env) { // both tps could be null on 1 core machines + ASSERT_FALSE(intra_tp_from_session == intra_tp_from_env); + } + + if (inter_tp_from_session && inter_tp_from_env) { // both tps could be null on 1 core machines + ASSERT_FALSE(inter_tp_from_session == inter_tp_from_env); + } + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 4: env created WITHOUT global tp / DONT use per session tp --> this should throw an exception +TEST(InferenceSessionTests, InvalidSessionEnvCombination) { + SessionOptions so; + so.use_per_session_threads = false; + + so.session_logid = "InvalidSessionEnvCombination"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + ASSERT_TRUE(st.IsOK()); + + try { + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + } catch (const std::exception& e) { + std::string e_message(std::string(e.what())); + ASSERT_TRUE(e_message.find( + "When the session is not configured to use per session" + " threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API") != + std::string::npos); + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/local_kernel_registry_test.cc b/onnxruntime/test/framework/local_kernel_registry_test.cc index 924d8c756b64a..68b1b40dc9993 100644 --- a/onnxruntime/test/framework/local_kernel_registry_test.cc +++ b/onnxruntime/test/framework/local_kernel_registry_test.cc @@ -229,7 +229,7 @@ TEST(CustomKernelTests, CustomKernelWithBuildInSchema) { // Register a foo kernel which is doing Add, but bind to Mul. std::shared_ptr registry = std::make_shared(); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); auto def = FooKernelDef("Mul"); @@ -261,7 +261,7 @@ TEST(CustomKernelTests, CustomKernelWithCustomSchema) { std::shared_ptr registry = std::make_shared(); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); //register foo schema @@ -307,7 +307,7 @@ TEST(CustomKernelTests, CustomKernelWithOptionalOutput) { //Register a foo kernel which is doing Add, but bind to Mul. EXPECT_TRUE(registry->RegisterCustomKernel(def, CreateOptionalOpKernel).IsOK()); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); EXPECT_TRUE(session_object.Load(OPTIONAL_MODEL1_URI).IsOK()); EXPECT_TRUE(session_object.Initialize().IsOK()); diff --git a/onnxruntime/test/framework/opaque_kernels_test.cc b/onnxruntime/test/framework/opaque_kernels_test.cc index ae13e2bd699d8..632f68deb064b 100644 --- a/onnxruntime/test/framework/opaque_kernels_test.cc +++ b/onnxruntime/test/framework/opaque_kernels_test.cc @@ -282,7 +282,7 @@ TEST_F(OpaqueTypeTests, RunModel) { // Both the session and the model need custom registries // so we construct it here before the model std::shared_ptr registry = std::make_shared(); - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); auto ops_schema = GetConstructSparseTensorSchema(); diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 307ace0fe3caa..e016149cefdc1 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -279,7 +279,7 @@ class SparseTensorTests : public testing::Test { std::vector types; public: - SparseTensorTests() : session_object(SessionOptions(), &DefaultLoggingManager()), + SparseTensorTests() : session_object(SessionOptions(), GetEnvironment()), registry(std::make_shared()), custom_schema_registries_{registry->GetOpschemaRegistry()}, domain_to_version{{onnxruntime::kMLDomain, 10}}, diff --git a/onnxruntime/test/global_thread_pools/README.md b/onnxruntime/test/global_thread_pools/README.md new file mode 100644 index 0000000000000..d5ada6edef9fb --- /dev/null +++ b/onnxruntime/test/global_thread_pools/README.md @@ -0,0 +1,6 @@ +**Tests for global threadpools** + +These tests here test the usage of the global threadpools using the C API data flow. The reason we need to create a +separate exe here is because we need to create a separate environment that enables the creation of global threadpools. +The test environment used in the other exes create the env without global threadpools and since this env is process +wide (a singleton) we cannot use it for this kind of test. \ No newline at end of file diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc new file mode 100644 index 0000000000000..daeabcb373017 --- /dev/null +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/constants.h" +#include "providers.h" +#include +#include +#include +#include +#include +#include +#include +#include "test_allocator.h" +#include "../shared_lib/onnx_protobuf.h" +#include "../shared_lib/test_fixture.h" +#include + +struct Input { + const char* name = nullptr; + std::vector dims; + std::vector values; +}; + +extern std::unique_ptr ort_env; +static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/squeezenet/model.onnx"); +class CApiTestGlobalThreadPoolsWithProvider : public testing::Test, public ::testing::WithParamInterface { +}; + +template +static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, + std::vector& inputs, + const char* output_name, + const std::vector& dims_y, + const std::vector& values_y, + Ort::Value* output_tensor) { + std::vector ort_inputs; + std::vector input_names; + for (size_t i = 0; i < inputs.size(); i++) { + input_names.emplace_back(inputs[i].name); + ort_inputs.emplace_back(Ort::Value::CreateTensor(allocator.Info(&allocator), inputs[i].values.data(), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + } + + std::vector ort_outputs; + if (output_tensor) + session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, output_tensor, 1); + else { + ort_outputs = session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1); + ASSERT_EQ(ort_outputs.size(), 1u); + output_tensor = &ort_outputs[0]; + } + + auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), dims_y); + //size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(values_y.size(), static_cast(5)); + + OutT* f = output_tensor->GetTensorMutableData(); + for (size_t i = 0; i != static_cast(5); ++i) { + ASSERT_TRUE(abs(values_y[i] - f[i]) < 1e-6); + } +} + +template +static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) { + Ort::SessionOptions session_options; + session_options.DisablePerSessionThreads(); + + if (provider_type == 1) { +#ifdef USE_CUDA + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + std::cout << "Running simple inference with cuda provider" << std::endl; +#else + return Ort::Session(nullptr); +#endif + } else if (provider_type == 2) { +#ifdef USE_DNNL + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Dnnl(session_options, 1)); + std::cout << "Running simple inference with dnnl provider" << std::endl; +#else + return Ort::Session(nullptr); +#endif + } else if (provider_type == 3) { +#ifdef USE_NUPHAR + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, /*allow_unaligned_buffers*/ 1, "")); + std::cout << "Running simple inference with nuphar provider" << std::endl; +#else + return Ort::Session(nullptr); +#endif + } else { + std::cout << "Running simple inference with default provider" << std::endl; + } + + // if session creation passes, model loads fine + return Ort::Session(env, model_uri, session_options); +} + +template +static void TestInference(Ort::Session& session, + std::vector& inputs, + const char* output_name, + const std::vector& expected_dims_y, + const std::vector& expected_values_y) { + auto default_allocator = onnxruntime::make_unique(); + Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size()); + + RunSession(*default_allocator, + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + &value_y); +} + +static void GetInputsAndExpectedOutputs(std::vector& inputs, + std::vector& expected_dims_y, + std::vector& expected_values_y, + std::string& output_name) { + inputs.resize(1); + Input& input = inputs.back(); + input.name = "data_0"; + input.dims = {1, 3, 224, 224}; + size_t input_tensor_size = 224 * 224 * 3; + input.values.resize(input_tensor_size); + auto& input_tensor_values = input.values; + for (unsigned int i = 0; i < input_tensor_size; i++) + input_tensor_values[i] = (float)i / (input_tensor_size + 1); + + // prepare expected inputs and outputs + expected_dims_y = {1, 1000, 1, 1}; + // For this test I'm checking for the first 5 values only since the global thread pool change + // doesn't affect the core op functionality + expected_values_y = {0.000045f, 0.003846f, 0.000125f, 0.001180f, 0.001317f}; + + output_name = "softmaxout_1"; +} + +// All tests below use global threadpools + +// Test 1 +// run inference on a model using just 1 session +TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple) { + // prepare inputs/outputs + std::vector inputs; + std::vector expected_dims_y; + std::vector expected_values_y; + std::string output_name; + GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name); + + // create session + Ort::Session session = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + + // run session + if (session) { + TestInference(session, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + } +} + +// Test 2 +// run inference on the same model using 2 sessions +// destruct the 2 sessions only at the end +TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple2) { + // prepare inputs/outputs + std::vector inputs; + std::vector expected_dims_y; + std::vector expected_values_y; + std::string output_name; + GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name); + + // create sessions + Ort::Session session1 = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + Ort::Session session2 = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + + // run session + if (session1 && session2) { + TestInference(session1, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + TestInference(session2, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + } +} + +// Test 3 +// run inference on the same model using 2 sessions +// one after another destructing first session first +// followed by second session +TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple3) { + // prepare inputs/outputs + std::vector inputs; + std::vector expected_dims_y; + std::vector expected_values_y; + std::string output_name; + GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name); + + // first session + { + // create session + Ort::Session session = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + + // run session + if (session) { + TestInference(session, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + } + } + + // second session + { + // create session + Ort::Session session = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + + // run session + if (session) { + TestInference(session, inputs, output_name.c_str(), expected_dims_y, expected_values_y); + } + } +} + +INSTANTIATE_TEST_SUITE_P(CApiTestGlobalThreadPoolsWithProviders, + CApiTestGlobalThreadPoolsWithProvider, + ::testing::Values(0, 1, 2, 3, 4)); \ No newline at end of file diff --git a/onnxruntime/test/global_thread_pools/test_main.cc b/onnxruntime/test/global_thread_pools/test_main.cc new file mode 100644 index 0000000000000..6b25c716df1b1 --- /dev/null +++ b/onnxruntime/test/global_thread_pools/test_main.cc @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef USE_ONNXRUNTIME_DLL +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-qualifiers" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#else +#pragma warning(push) +#pragma warning(disable : 4018) /*'expression' : signed/unsigned mismatch */ +#pragma warning(disable : 4065) /*switch statement contains 'default' but no 'case' labels*/ +#pragma warning(disable : 4100) +#pragma warning(disable : 4146) /*unary minus operator applied to unsigned type, result still unsigned*/ +#pragma warning(disable : 4127) +#pragma warning(disable : 4244) /*'conversion' conversion from 'type1' to 'type2', possible loss of data*/ +#pragma warning(disable : 4251) /*'identifier' : class 'type' needs to have dll-interface to be used by clients of class 'type2'*/ +#pragma warning(disable : 4267) /*'var' : conversion from 'size_t' to 'type', possible loss of data*/ +#pragma warning(disable : 4305) /*'identifier' : truncation from 'type1' to 'type2'*/ +#pragma warning(disable : 4307) /*'operator' : integral constant overflow*/ +#pragma warning(disable : 4309) /*'conversion' : truncation of constant value*/ +#pragma warning(disable : 4334) /*'operator' : result of 32-bit shift implicitly converted to 64 bits (was 64-bit shift intended?)*/ +#pragma warning(disable : 4355) /*'this' : used in base member initializer list*/ +#pragma warning(disable : 4506) /*no definition for inline function 'function'*/ +#pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/ +#pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/ +#endif +#include +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#else +#pragma warning(pop) +#endif +#endif + +#include "core/session/onnxruntime_cxx_api.h" +#include "gtest/gtest.h" +#include "test/test_environment.h" + +std::unique_ptr ort_env; + +int main(int argc, char** argv) { + int status = 0; + try { + ::testing::InitGoogleTest(&argc, argv); + ThreadingOptions tp_options{0, 0}; + ort_env.reset(new Ort::Env(tp_options, ORT_LOGGING_LEVEL_VERBOSE, "Default")); // this is the only change from test/providers/test_main.cc + status = RUN_ALL_TESTS(); + } catch (const std::exception& ex) { + std::cerr << ex.what(); + status = -1; + } + //TODO: Fix the C API issue + ort_env.reset(); //If we don't do this, it will crash + +#ifndef USE_ONNXRUNTIME_DLL + //make memory leak checker happy + ::google::protobuf::ShutdownProtobufLibrary(); +#endif + return status; +} diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 92e4632411e8c..7356d024d14f6 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -263,7 +263,7 @@ TEST(GraphTransformationTests, SubgraphWithConstantInputs) { SessionOptions so; so.graph_optimization_level = TransformerLevel::Level2; so.session_logid = "GraphTransformationTests.LoadModelToTransform"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(model_uri).IsOK()); std::shared_ptr p_model; @@ -635,7 +635,7 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { SessionOptions so; so.session_logid = "GraphTransformationTests.LoadModelToTransform"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(model_uri).IsOK()); std::shared_ptr p_model; diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 348a46d7751c6..f51cc59205014 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -9,6 +9,7 @@ #include "test/compare_ortvalue.h" #include "gtest/gtest.h" #include "core/mlas/inc/mlas.h" +#include "core/session/environment.h" namespace onnxruntime { namespace test { @@ -17,7 +18,7 @@ namespace test { class NchwcInferenceSession : public InferenceSession { public: explicit NchwcInferenceSession(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) { + const Environment& env) : InferenceSession(session_options, env) { } std::unordered_map CountOpsInGraph() { @@ -208,7 +209,7 @@ void NchwcOptimizerTester(const std::function& bu SessionOptions session_options; session_options.graph_optimization_level = level; session_options.session_logid = "NchwcOptimizerTests"; - NchwcInferenceSession session{session_options, &DefaultLoggingManager()}; + NchwcInferenceSession session{session_options, GetEnvironment()}; ASSERT_TRUE(session.Load(model_data.data(), static_cast(model_data.size())).IsOK()); ASSERT_TRUE(session.Initialize().IsOK()); diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 0b9e688a079cb..0c7b53fef9946 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -580,7 +580,7 @@ TEST(Loop, SubgraphInputShadowsOuterScopeValue) { SessionOptions so; so.session_logid = "SubgraphInputShadowsOuterScopeValue"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; Status st; ASSERT_TRUE((st = session_object.Load("testdata/subgraph_input_shadows_outer_scope_value.onnx")).IsOK()) << st; ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; diff --git a/onnxruntime/test/providers/ngraph/ngraph_execution_provider_test.cc b/onnxruntime/test/providers/ngraph/ngraph_execution_provider_test.cc index 6951ff40f0f45..d33bba95014b0 100644 --- a/onnxruntime/test/providers/ngraph/ngraph_execution_provider_test.cc +++ b/onnxruntime/test/providers/ngraph/ngraph_execution_provider_test.cc @@ -70,9 +70,9 @@ void add_feeds(NameMLValMap& feeds, std::string name, std::vector dims, } //TODO:(nivas) Refractor to use existing code -void RunTest(const std::string& model_path, const NameMLValMap& feeds, const std::vector& output_names, const std::vector>& expected_shapes, const std::vector>& expected_values) { +void RunTest(const std::string& model_path, const NameMLValMap& feeds, const std::vector& output_names, const std::vector>& expected_shapes, const std::vector>& expected_values, const Environment& env) { SessionOptions so; - InferenceSession session_object(so, &DefaultLoggingManager()); + InferenceSession session_object(so, env); EXPECT_TRUE(session_object.RegisterExecutionProvider(DefaultNGraphExecutionProvider()).IsOK()); @@ -148,7 +148,7 @@ TEST(NGraphExecutionProviderTest, Basic_Test) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Basic_Test.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Basic_Test.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -177,7 +177,7 @@ TEST(NGraphExecutionProviderTest, Graph_with_UnSupportedOp) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Graph_with_UnSupportedOp.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Graph_with_UnSupportedOp.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -210,7 +210,7 @@ TEST(NGraphExecutionProviderTest, Two_Subgraphs) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Two_Subgraphs.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Two_Subgraphs.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -246,7 +246,7 @@ TEST(NGraphExecutionProviderTest, ClusterOut_isAlso_GraphOut) { {4}, {4}}; - RunTest("testdata/ngraph/ClusterOut_isAlso_GraphOut.onnx", feeds, {"Y", "Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/ClusterOut_isAlso_GraphOut.onnx", feeds, {"Y", "Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -282,7 +282,7 @@ TEST(NGraphExecutionProviderTest, InOut_isAlso_GraphOut) { {4}, {4}}; - RunTest("testdata/ngraph/InOut_isAlso_GraphOut.onnx", feeds, {"Y", "Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/InOut_isAlso_GraphOut.onnx", feeds, {"Y", "Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -314,7 +314,7 @@ TEST(NGraphExecutionProviderTest, Op_with_Optional_or_Unused_Outputs) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Op_with_Optional_or_Unused_Outputs.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Op_with_Optional_or_Unused_Outputs.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } /* @@ -350,7 +350,7 @@ TEST(NGraphExecutionProviderTest, Independent_SubGraphs) { std::vector> expected_shapes = { {4}}; - RunTest("testdata/ngraph/Independent_SubGraphs.onnx", feeds, {"Z"}, expected_shapes, expected_values); + RunTest("testdata/ngraph/Independent_SubGraphs.onnx", feeds, {"Z"}, expected_shapes, expected_values, GetEnvironment()); } } // namespace test diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index dc497bad25418..9db96b42a3bca 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -85,7 +85,7 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; status = session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>()); ASSERT_TRUE(status.IsOK()); status = session_object.Load(model_file_name); @@ -100,4 +100,3 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { } } // namespace test } // namespace onnxruntime - diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 332826dbfb4df..4faa3ea922785 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -239,7 +239,7 @@ void Check(const OpTester::Data& expected_data, else { // the default for existing tests const float max_value = fmax(fabs(f_expected[i]), fabs(f_output[i])); - if (max_value != 0) { // max_value = 0 means output and expected are 0s. + if (max_value != 0) { // max_value = 0 means output and expected are 0s. const float rel_error = fabs(f_expected[i] - f_output[i]) / max_value; EXPECT_NEAR(0, rel_error, threshold) << "provider_type: " << provider_type; @@ -687,10 +687,14 @@ void OpTester::Run( FillFeedsAndOutputNames(feeds, output_names); // Run the model static const std::string all_provider_types[] = { - kCpuExecutionProvider, kCudaExecutionProvider, - kDnnlExecutionProvider, kNGraphExecutionProvider, - kNupharExecutionProvider, kTensorrtExecutionProvider, - kOpenVINOExecutionProvider, kDmlExecutionProvider, + kCpuExecutionProvider, + kCudaExecutionProvider, + kDnnlExecutionProvider, + kNGraphExecutionProvider, + kNupharExecutionProvider, + kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, + kDmlExecutionProvider, kAclExecutionProvider, }; @@ -705,7 +709,7 @@ void OpTester::Run( } } - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(!execution_providers->empty()) << "Empty execution providers vector."; @@ -733,6 +737,10 @@ void OpTester::Run( so.enable_mem_pattern = false; so.execution_mode = ExecutionMode::ORT_SEQUENTIAL; } + InferenceSession session_object{so, GetEnvironment()}; + + for (auto& custom_session_registry : custom_session_registries_) + session_object.RegisterCustomRegistry(custom_session_registry); std::unique_ptr execution_provider; if (provider_type == onnxruntime::kCpuExecutionProvider) @@ -792,7 +800,6 @@ void OpTester::Run( if (!valid) continue; - InferenceSession session_object{so}; for (auto& custom_session_registry : custom_session_registries_) session_object.RegisterCustomRegistry(custom_session_registry); diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 015c4efed514f..0497d99a3d9da 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -87,7 +87,7 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; TensorrtExecutionProviderInfo epi; epi.device_id = 0; @@ -201,7 +201,7 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; TensorrtExecutionProviderInfo epi; epi.device_id = 0; @@ -215,11 +215,10 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) { // Now run status = session_object.Run(run_options, feeds, output_names, &fetches); ASSERT_TRUE(status.IsOK()); - std::vector fetche {fetches.back()}; + std::vector fetche{fetches.back()}; VerifyOutputs(fetche, expected_dims_mul_n, expected_values_mul_n); } - TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { onnxruntime::Model model("graph_removecycleTest", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); @@ -316,7 +315,7 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { RunOptions run_options; run_options.run_tag = so.session_logid; - InferenceSession session_object{so}; + InferenceSession session_object{so, GetEnvironment()}; TensorrtExecutionProviderInfo epi; epi.device_id = 0; diff --git a/onnxruntime/test/tvm/tvm_basic_test.cc b/onnxruntime/test/tvm/tvm_basic_test.cc index 70399f2f03679..9448d92663e93 100644 --- a/onnxruntime/test/tvm/tvm_basic_test.cc +++ b/onnxruntime/test/tvm/tvm_basic_test.cc @@ -310,7 +310,7 @@ TEST(TVMTest, CodeGen_Demo_for_Fuse_Mul) { so.session_logid = "InferenceSessionTests.NoTimeout"; - InferenceSession session_object{so, &DefaultLoggingManager()}; + InferenceSession session_object{so, GetEnvironment()}; CPUExecutionProviderInfo info; auto tvm_xp = onnxruntime::make_unique(info); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(tvm_xp)).IsOK()); diff --git a/onnxruntime/test/util/include/test/test_environment.h b/onnxruntime/test/util/include/test/test_environment.h index 91e8a2440e8ad..549804595b93d 100644 --- a/onnxruntime/test/util/include/test/test_environment.h +++ b/onnxruntime/test/util/include/test/test_environment.h @@ -10,8 +10,12 @@ #endif namespace onnxruntime { +class Environment; + namespace test { +const onnxruntime::Environment& GetEnvironment(); + /** Static logging manager with a CLog based sink so logging macros that use the default logger will work */ diff --git a/onnxruntime/test/util/test_environment.cc b/onnxruntime/test/util/test_environment.cc index 3c59d577bd4f6..3f77d8e474229 100644 --- a/onnxruntime/test/util/test_environment.cc +++ b/onnxruntime/test/util/test_environment.cc @@ -12,6 +12,7 @@ #include "core/common/logging/logging.h" #include "core/common/logging/sinks/clog_sink.h" #include "core/session/ort_env.h" +#include "core/session/environment.h" using namespace ::onnxruntime::logging; extern std::unique_ptr ort_env; @@ -21,8 +22,12 @@ namespace test { static std::unique_ptr<::onnxruntime::logging::LoggingManager> s_default_logging_manager; +const ::onnxruntime::Environment& GetEnvironment() { + return ((OrtEnv*)*ort_env.get())->GetEnvironment(); +} + ::onnxruntime::logging::LoggingManager& DefaultLoggingManager() { - return *((OrtEnv*)*ort_env.get())->GetLoggingManager(); + return *((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager(); } } // namespace test diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 4cb237a45d0e1..ad869f324c0c0 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -612,6 +612,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs, enab executables = ['onnxruntime_test_all.exe'] if args.build_shared_lib: executables.append('onnxruntime_shared_lib_test.exe') + executables.append('onnxruntime_global_thread_pools_test.exe') run_subprocess(['vstest.console.exe', '--parallel', '--TestAdapterPath:..\\googletestadapter.0.17.1\\build\\_common', '/Logger:trx','/Enablecodecoverage','/Platform:x64',"/Settings:%s" % os.path.join(source_dir, 'cmake\\codeconv.runsettings')] + executables, cwd=cwd2, dll_path=dll_path) else: diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp index d74c35aa3344f..51b93a8cdb6bb 100644 --- a/winml/adapter/winml_adapter_environment.cpp +++ b/winml/adapter/winml_adapter_environment.cpp @@ -14,8 +14,8 @@ #include "abi_custom_registry_impl.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" -#endif USE_DML +#endif USE_DML namespace winmla = Windows::AI::MachineLearning::Adapter; class WinmlAdapterLoggingWrapper : public LoggingWrapper { diff --git a/winml/adapter/winml_adapter_session.cpp b/winml/adapter/winml_adapter_session.cpp index 329d713752f70..b95b2e6c8d762 100644 --- a/winml/adapter/winml_adapter_session.cpp +++ b/winml/adapter/winml_adapter_session.cpp @@ -42,7 +42,7 @@ ORT_API_STATUS_IMPL(winmla::CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ co std::unique_ptr inference_session; try { // Create the inference session - inference_session = std::make_unique(options->value, env->GetLoggingManager()); + inference_session = std::make_unique(options->value, env->GetEnvironment()); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } @@ -171,7 +171,7 @@ GetLotusCustomRegistries(IMLOperatorRegistry* registry) { // Get the ORT registry return abi_custom_registry->GetRegistries(); -#endif // USE_DML +#endif // USE_DML } return {}; } @@ -195,7 +195,7 @@ ORT_API_STATUS_IMPL(winmla::CreateCustomRegistry, _Out_ IMLOperatorRegistry** re #ifdef USE_DML auto impl = wil::MakeOrThrow(); *registry = impl.Detach(); -#endif // USE_DML +#endif // USE_DML return nullptr; API_IMPL_END }