Skip to content

Commit

Permalink
Disabled test for android, added few more tests and addressed more PR…
Browse files Browse the repository at this point in the history
… comments.
  • Loading branch information
pranavsharma committed Mar 18, 2020
1 parent 9ec36cf commit 754ab24
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 9 deletions.
6 changes: 5 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,9 @@ set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest"
set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src})
if(NOT TARGET onnxruntime)
list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC})
list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC})
if (NOT onnxruntime_USE_NNAPI)
list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC})
endif()
endif()
set(all_dependencies ${onnxruntime_test_providers_dependencies} )

Expand Down Expand Up @@ -679,12 +681,14 @@ if (onnxruntime_BUILD_SHARED_LIB)
)

# test inference using global threadpools
if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android")
AddTest(DYN
TARGET onnxruntime_global_thread_pools_test
SOURCES ${onnxruntime_global_thread_pools_test_SRC}
LIBS ${onnxruntime_shared_lib_test_LIBS}
DEPENDS ${all_dependencies}
)
endif()
endif()

#some ETW tools
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
session_state_ = onnxruntime::make_unique<SessionState>(execution_providers_,
session_options_.enable_mem_pattern &&
session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL,
use_per_session_threads_ ? thread_pool_.get() : intra_op_thread_pool_from_env_,
use_per_session_threads_ ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_);
GetIntraOpThreadPoolToUse(),
GetInterOpThreadPoolToUse());

session_state_->SetLogger(*session_logger_);
session_state_->SetDataTransferMgr(&data_transfer_mgr_);
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,17 @@ class InferenceSession {
// It has a dependency on execution_providers_.
std::unique_ptr<SessionState> session_state_;

// Use these 2 threadpool methods to get access to the threadpools since they rely on
// specific flags in session options
// These methods assume that session options have been finalized before the call.
onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const {
return session_options_.use_per_session_threads ? thread_pool_.get() : intra_op_thread_pool_from_env_;
}

onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const {
return session_options_.use_per_session_threads ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_;
}

private:
// Threadpools per session. These are initialized and used for the entire duration of the session
// when use_per_session_threads is true.
Expand Down
171 changes: 171 additions & 0 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "core/common/logging/logging.h"
#include "core/common/logging/sinks/clog_sink.h"
#include "core/common/profiler.h"
#include "core/framework/compute_capability.h"
#include "core/framework/data_transfer_manager.h"
Expand Down Expand Up @@ -1801,5 +1802,175 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) {
#endif
}

// Global threadpool related tests
// We test for 4 combinations
class InferenceSessionTestGlobalThreadPools : public InferenceSession {
public:
InferenceSessionTestGlobalThreadPools(const SessionOptions& session_options,
const Environment& env) : InferenceSession(session_options, env) {
}

onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const {
return InferenceSession::GetIntraOpThreadPoolToUse();
}

onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const {
return InferenceSession::GetInterOpThreadPoolToUse();
}

const SessionState& GetSessionState() {
return *session_state_;
}
};

// Test 1: env created WITHOUT global tp / use per session tp (default case): in this case per session tps should be in use
TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed) {
SessionOptions so;
so.use_per_session_threads = true;

so.session_logid = "CheckIfPerSessionThreadPoolsAreBeingUsed";
auto logging_manager = onnxruntime::make_unique<logging::LoggingManager>(
std::unique_ptr<ISink>(new CLogSink()), logging::Severity::kVERBOSE, false,
LoggingManager::InstanceType::Temporal);

std::unique_ptr<Environment> env;
auto st = Environment::Create(std::move(logging_manager), env);
ASSERT_TRUE(st.IsOK());

InferenceSessionTestGlobalThreadPools session_object{so, *env.get()};
ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK());
ASSERT_TRUE(session_object.Initialize().IsOK());

// make sure we're using the per session threadpools
auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse();
auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool();
auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse();
auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool();
auto intra_tp_from_env = env->GetIntraOpThreadPool();
auto inter_tp_from_env = env->GetInterOpThreadPool();

// ensure threadpools were set correctly in the session state
ASSERT_TRUE(intra_tp_from_session == intra_tp_from_session_state);
ASSERT_TRUE(inter_tp_from_session == inter_tp_from_session_state);

ASSERT_TRUE(intra_tp_from_env == nullptr);
ASSERT_TRUE(inter_tp_from_env == nullptr);

RunOptions run_options;
run_options.run_tag = "RunTag";
run_options.run_log_severity_level = static_cast<int>(Severity::kVERBOSE);
RunModel(session_object, run_options);
}

// Test 2: env created with global tp / DONT use per session tp: in this case global tps should be in use
TEST(InferenceSessionTests, CheckIfGlobalThreadPoolsAreBeingUsed) {
SessionOptions so;
so.use_per_session_threads = false;

so.session_logid = "CheckIfGlobalThreadPoolsAreBeingUsed";
auto logging_manager = onnxruntime::make_unique<logging::LoggingManager>(
std::unique_ptr<ISink>(new CLogSink()), logging::Severity::kVERBOSE, false,
LoggingManager::InstanceType::Temporal);

std::unique_ptr<Environment> env;
ThreadingOptions tp_options{0, 0};
auto st = Environment::Create(std::move(logging_manager), env, &tp_options, true /*create_global_thread_pools*/);
ASSERT_TRUE(st.IsOK());

InferenceSessionTestGlobalThreadPools session_object{so, *env.get()};
ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK());
ASSERT_TRUE(session_object.Initialize().IsOK());

// make sure we're using the global threadpools in both session and session state
auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse();
auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool();
auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse();
auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool();
auto intra_tp_from_env = env->GetIntraOpThreadPool();
auto inter_tp_from_env = env->GetInterOpThreadPool();

ASSERT_TRUE(intra_tp_from_session == intra_tp_from_env);
ASSERT_TRUE(inter_tp_from_session == inter_tp_from_env);
ASSERT_TRUE(intra_tp_from_session_state == intra_tp_from_env);
ASSERT_TRUE(inter_tp_from_session_state == inter_tp_from_env);

RunOptions run_options;
run_options.run_tag = "RunTag";
run_options.run_log_severity_level = static_cast<int>(Severity::kVERBOSE);
RunModel(session_object, run_options);
}

// Test 3: env created with global tp / use per session tp: in this case per session tps should be in use
TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed2) {
SessionOptions so;
so.use_per_session_threads = true;

so.session_logid = "CheckIfPerSessionThreadPoolsAreBeingUsed2";
auto logging_manager = onnxruntime::make_unique<logging::LoggingManager>(
std::unique_ptr<ISink>(new CLogSink()), logging::Severity::kVERBOSE, false,
LoggingManager::InstanceType::Temporal);

std::unique_ptr<Environment> env;
ThreadingOptions tp_options{0, 0};
auto st = Environment::Create(std::move(logging_manager), env, &tp_options, true /*create_global_thread_pools*/);
ASSERT_TRUE(st.IsOK());

InferenceSessionTestGlobalThreadPools session_object{so, *env.get()};
ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK());
ASSERT_TRUE(session_object.Initialize().IsOK());

// make sure we're using the per session threadpools
auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse();
auto intra_tp_from_session_state = session_object.GetSessionState().GetThreadPool();
auto inter_tp_from_session = session_object.GetInterOpThreadPoolToUse();
auto inter_tp_from_session_state = session_object.GetSessionState().GetInterOpThreadPool();
auto intra_tp_from_env = env->GetIntraOpThreadPool();
auto inter_tp_from_env = env->GetInterOpThreadPool();

// ensure threadpools were set correctly in the session state
ASSERT_TRUE(intra_tp_from_session == intra_tp_from_session_state);
ASSERT_TRUE(inter_tp_from_session == inter_tp_from_session_state);

// ensure per session thread pools in use are different from the
// env threadpools
if (intra_tp_from_session && intra_tp_from_env) { // both tps could be null on 1 core machines
ASSERT_FALSE(intra_tp_from_session == intra_tp_from_env);
}

if (inter_tp_from_session && inter_tp_from_env) { // both tps could be null on 1 core machines
ASSERT_FALSE(inter_tp_from_session == inter_tp_from_env);
}

RunOptions run_options;
run_options.run_tag = "RunTag";
run_options.run_log_severity_level = static_cast<int>(Severity::kVERBOSE);
RunModel(session_object, run_options);
}

// Test 4: env created WITHOUT global tp / DONT use per session tp --> this should throw an exception
TEST(InferenceSessionTests, InvalidSessionEnvCombination) {
SessionOptions so;
so.use_per_session_threads = false;

so.session_logid = "InvalidSessionEnvCombination";
auto logging_manager = onnxruntime::make_unique<logging::LoggingManager>(
std::unique_ptr<ISink>(new CLogSink()), logging::Severity::kVERBOSE, false,
LoggingManager::InstanceType::Temporal);

std::unique_ptr<Environment> env;
auto st = Environment::Create(std::move(logging_manager), env);
ASSERT_TRUE(st.IsOK());

try {
InferenceSessionTestGlobalThreadPools session_object{so, *env.get()};
} catch (const std::exception& e) {
std::string e_message(std::string(e.what()));
ASSERT_TRUE(e_message.find(
"When the session is not configured to use per session"
" threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API") !=
std::string::npos);
}
}

} // namespace test
} // namespace onnxruntime
8 changes: 4 additions & 4 deletions onnxruntime/test/global_thread_pools/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
**Tests for global threadpools**

These tests here test the usage of the global threadpools. The reason we need to create a separate exe here is because
we need to create a separate environment that enables the creation of global threadpools. The test environment
used in the other exes create the env without global threadpools and since this env is process wide (a singleton) we
cannot use it for this kind of test.
These tests here test the usage of the global threadpools using the C API data flow. The reason we need to create a
separate exe here is because we need to create a separate environment that enables the creation of global threadpools.
The test environment used in the other exes create the env without global threadpools and since this env is process
wide (a singleton) we cannot use it for this kind of test.
2 changes: 0 additions & 2 deletions onnxruntime/test/global_thread_pools/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ static void TestInference(Ort::Session& session,
const char* output_name,
const std::vector<int64_t>& expected_dims_y,
const std::vector<OutT>& expected_values_y) {
Ort::SessionOptions session_options;
session_options.DisablePerSessionThreads();
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size());

Expand Down

0 comments on commit 754ab24

Please sign in to comment.