diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0a02f578d7fe2..28aab36631bc0 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -360,7 +360,9 @@ set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest" set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src}) if(NOT TARGET onnxruntime) list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) - list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC}) + if (NOT onnxruntime_USE_NNAPI) + list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC}) + endif() endif() set(all_dependencies ${onnxruntime_test_providers_dependencies} ) @@ -679,12 +681,14 @@ if (onnxruntime_BUILD_SHARED_LIB) ) # test inference using global threadpools + if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android") AddTest(DYN TARGET onnxruntime_global_thread_pools_test SOURCES ${onnxruntime_global_thread_pools_test_SRC} LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) + endif() endif() #some ETW tools diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c465533c5b58e..b254d60d03c1d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -196,8 +196,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, session_state_ = onnxruntime::make_unique(execution_providers_, session_options_.enable_mem_pattern && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL, - use_per_session_threads_ ? thread_pool_.get() : intra_op_thread_pool_from_env_, - use_per_session_threads_ ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_); + GetIntraOpThreadPoolToUse(), + GetInterOpThreadPoolToUse()); session_state_->SetLogger(*session_logger_); session_state_->SetDataTransferMgr(&data_transfer_mgr_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 0af3d6fc6e966..f379d780658e6 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -454,6 +454,17 @@ class InferenceSession { // It has a dependency on execution_providers_. std::unique_ptr session_state_; + // Use these 2 threadpool methods to get access to the threadpools since they rely on + // specific flags in session options + // These methods assume that session options have been finalized before the call. + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const { + return session_options_.use_per_session_threads ? thread_pool_.get() : intra_op_thread_pool_from_env_; + } + + onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const { + return session_options_.use_per_session_threads ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_; + } + private: // Threadpools per session. These are initialized and used for the entire duration of the session // when use_per_session_threads is true. diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index a2170ae4abe35..45aade53bf729 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -12,6 +12,7 @@ #include #include "core/common/logging/logging.h" +#include "core/common/logging/sinks/clog_sink.h" #include "core/common/profiler.h" #include "core/framework/compute_capability.h" #include "core/framework/data_transfer_manager.h" @@ -1801,5 +1802,175 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { #endif } +// Global threadpool related tests +// We test for 4 combinations +class InferenceSessionTestGlobalThreadPools : public InferenceSession { + public: + InferenceSessionTestGlobalThreadPools(const SessionOptions& session_options, + const Environment& env) : InferenceSession(session_options, env) { + } + + onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const { + return InferenceSession::GetIntraOpThreadPoolToUse(); + } + + onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const { + return InferenceSession::GetInterOpThreadPoolToUse(); + } + + const SessionState& GetSessionState() { + return *session_state_; + } +}; + +// Test 1: env created WITHOUT global tp / use per session tp (default case): in this case per session tps should be in use +TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed) { + SessionOptions so; + so.use_per_session_threads = true; + + so.session_logid = "CheckIfPerSessionThreadPoolsAreBeingUsed"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the per session threadpools + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + // ensure threadpools were set correctly in the session state + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_session_state); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_session_state); + + ASSERT_TRUE(intra_tp_from_env == nullptr); + ASSERT_TRUE(inter_tp_from_env == nullptr); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 2: env created with global tp / DONT use per session tp: in this case global tps should be in use +TEST(InferenceSessionTests, CheckIfGlobalThreadPoolsAreBeingUsed) { + SessionOptions so; + so.use_per_session_threads = false; + + so.session_logid = "CheckIfGlobalThreadPoolsAreBeingUsed"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + ThreadingOptions tp_options{0, 0}; + auto st = Environment::Create(std::move(logging_manager), env, &tp_options, true /*create_global_thread_pools*/); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the global threadpools in both session and session state + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_env); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_env); + ASSERT_TRUE(intra_tp_from_session_state == intra_tp_from_env); + ASSERT_TRUE(inter_tp_from_session_state == inter_tp_from_env); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 3: env created with global tp / use per session tp: in this case per session tps should be in use +TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed2) { + SessionOptions so; + so.use_per_session_threads = true; + + so.session_logid = "CheckIfPerSessionThreadPoolsAreBeingUsed2"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + ThreadingOptions tp_options{0, 0}; + auto st = Environment::Create(std::move(logging_manager), env, &tp_options, true /*create_global_thread_pools*/); + ASSERT_TRUE(st.IsOK()); + + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // make sure we're using the per session threadpools + auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); + auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool(); + auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse(); + auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool(); + auto intra_tp_from_env = env->GetIntraOpThreadPool(); + auto inter_tp_from_env = env->GetInterOpThreadPool(); + + // ensure threadpools were set correctly in the session state + ASSERT_TRUE(intra_tp_from_session == intra_tp_from_session_state); + ASSERT_TRUE(inter_tp_from_session == inter_tp_from_session_state); + + // ensure per session thread pools in use are different from the + // env threadpools + if (intra_tp_from_session && intra_tp_from_env) { // both tps could be null on 1 core machines + ASSERT_FALSE(intra_tp_from_session == intra_tp_from_env); + } + + if (inter_tp_from_session && inter_tp_from_env) { // both tps could be null on 1 core machines + ASSERT_FALSE(inter_tp_from_session == inter_tp_from_env); + } + + RunOptions run_options; + run_options.run_tag = "RunTag"; + run_options.run_log_severity_level = static_cast(Severity::kVERBOSE); + RunModel(session_object, run_options); +} + +// Test 4: env created WITHOUT global tp / DONT use per session tp --> this should throw an exception +TEST(InferenceSessionTests, InvalidSessionEnvCombination) { + SessionOptions so; + so.use_per_session_threads = false; + + so.session_logid = "InvalidSessionEnvCombination"; + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + ASSERT_TRUE(st.IsOK()); + + try { + InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; + } catch (const std::exception& e) { + std::string e_message(std::string(e.what())); + ASSERT_TRUE(e_message.find( + "When the session is not configured to use per session" + " threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API") != + std::string::npos); + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/global_thread_pools/README.md b/onnxruntime/test/global_thread_pools/README.md index 6677bc580318a..d5ada6edef9fb 100644 --- a/onnxruntime/test/global_thread_pools/README.md +++ b/onnxruntime/test/global_thread_pools/README.md @@ -1,6 +1,6 @@ **Tests for global threadpools** -These tests here test the usage of the global threadpools. The reason we need to create a separate exe here is because -we need to create a separate environment that enables the creation of global threadpools. The test environment -used in the other exes create the env without global threadpools and since this env is process wide (a singleton) we -cannot use it for this kind of test. \ No newline at end of file +These tests here test the usage of the global threadpools using the C API data flow. The reason we need to create a +separate exe here is because we need to create a separate environment that enables the creation of global threadpools. +The test environment used in the other exes create the env without global threadpools and since this env is process +wide (a singleton) we cannot use it for this kind of test. \ No newline at end of file diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 2dda85ff3902f..5294c770e8ccc 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -102,8 +102,6 @@ static void TestInference(Ort::Session& session, const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y) { - Ort::SessionOptions session_options; - session_options.DisablePerSessionThreads(); auto default_allocator = onnxruntime::make_unique(); Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size());