Skip to content

Commit

Permalink
Add support for sessions to share a global threadpool. (#3177)
Browse files Browse the repository at this point in the history
* Add support for sessions to share a global threadpool.

* Fix build issues

* Add tests, fix build issues.

* Added some documentation

* Fix centos issue when threadpools become nullptr due to 1 core.

* Fix mac and x86 build issues

* Address some PR comments

* Disabled test for android, added few more tests and addressed more PR comments.

* const_cast
  • Loading branch information
pranavsharma authored Mar 18, 2020
1 parent e03b8a1 commit 435f014
Show file tree
Hide file tree
Showing 46 changed files with 936 additions and 233 deletions.
20 changes: 19 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ if (onnxruntime_USE_NNAPI)
endif()

set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib")

set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/global_thread_pools")

set (onnxruntime_shared_lib_test_SRC
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
Expand All @@ -178,6 +178,11 @@ if(onnxruntime_RUN_ONNX_TESTS)
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_io_types.cc)
endif()

set (onnxruntime_global_thread_pools_test_SRC
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc
${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc)

# tests from lowest level library up.
# the order of libraries should be maintained, with higher libraries being added first in the list

Expand Down Expand Up @@ -355,6 +360,9 @@ set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest"
set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src})
if(NOT TARGET onnxruntime)
list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC})
if (NOT onnxruntime_USE_NNAPI)
list(APPEND all_tests ${onnxruntime_global_thread_pools_test_SRC})
endif()
endif()
set(all_dependencies ${onnxruntime_test_providers_dependencies} )

Expand Down Expand Up @@ -671,6 +679,16 @@ if (onnxruntime_BUILD_SHARED_LIB)
LIBS ${onnxruntime_shared_lib_test_LIBS}
DEPENDS ${all_dependencies}
)

# test inference using global threadpools
if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android")
AddTest(DYN
TARGET onnxruntime_global_thread_pools_test
SOURCES ${onnxruntime_global_thread_pools_test_SRC}
LIBS ${onnxruntime_shared_lib_test_LIBS}
DEPENDS ${all_dependencies}
)
endif()
endif()

#some ETW tools
Expand Down
9 changes: 9 additions & 0 deletions docs/C_API.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
* Setting graph optimization level for each session.
* Dynamically loading custom ops. [Instructions](/docs/AddingCustomOp.md)
* Ability to load a model from a byte array. See ```OrtCreateSessionFromArray``` in [onnxruntime_c_api.h](/include/onnxruntime/core/session/onnxruntime_c_api.h).
* **Global/shared threadpools:** By default each session creates its own set of threadpools. In situations where multiple
sessions need to be created (to infer different models) in the same process, you end up with several threadpools created
by each session. In order to address this inefficiency we introduce a new feature called global/shared threadpools.
The basic idea here is to share a set of global threadpools across multiple sessions. Typical usage of this feature
is as follows
* Populate ```ThreadingOptions```. Use the value of 0 for ORT to pick the defaults.
* Create env using ```CreateEnvWithGlobalThreadPools()```
* Create session and call ```DisablePerSessionThreads()``` on the session options object
* Call ```Run()``` as usual

## Usage Overview

Expand Down
10 changes: 9 additions & 1 deletion docs/execution_providers/ACL-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#ARM-Compute-L
#### C/C++
To use ACL as execution provider for inferencing, please register it as below.
```
InferenceSession session_object{so};
string log_id = "Foo";
auto logging_manager = std::make_unique<LoggingManager>
(std::unique_ptr<ISink>{new CLogSink{}},
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&log_id)
Environment::Create(std::move(logging_manager), env)
InferenceSession session_object{so, env};
session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::ACLExecutionProvider>());
status = session_object.Load(model_file_name);
```
Expand Down
10 changes: 9 additions & 1 deletion docs/execution_providers/DNNL-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#dnnl-and-mklm
### C/C++
The DNNLExecutionProvider execution provider needs to be registered with ONNX Runtime to enable in the inference session.
```
InferenceSession session_object{so};
string log_id = "Foo";
auto logging_manager = std::make_unique<LoggingManager>
(std::unique_ptr<ISink>{new CLogSink{}},
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&log_id)
Environment::Create(std::move(logging_manager), env)
InferenceSession session_object{so,env};
session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime:: DNNLExecutionProvider >());
status = session_object.Load(model_file_name);
```
Expand Down
10 changes: 9 additions & 1 deletion docs/execution_providers/NNAPI-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ For build instructions, please see the [BUILD page](../../BUILD.md#Android-NNAPI

To use NNAPI EP for inferencing, please register it as below.
```
InferenceSession session_object{so};
string log_id = "Foo";
auto logging_manager = std::make_unique<LoggingManager>
(std::unique_ptr<ISink>{new CLogSink{}},
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&log_id)
Environment::Create(std::move(logging_manager), env)
InferenceSession session_object{so,env};
session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::NnapiExecutionProvider>());
status = session_object.Load(model_file_name);
```
Expand Down
10 changes: 9 additions & 1 deletion docs/execution_providers/TensorRT-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ The TensorRT execution provider for ONNX Runtime is built and tested with Tensor
### C/C++
The TensorRT execution provider needs to be registered with ONNX Runtime to enable in the inference session.
```
InferenceSession session_object{so};
string log_id = "Foo";
auto logging_manager = std::make_unique<LoggingManager>
(std::unique_ptr<ISink>{new CLogSink{}},
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&log_id)
Environment::Create(std::move(logging_manager), env)
InferenceSession session_object{so,env};
session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::TensorrtExecutionProvider>());
status = session_object.Load(model_file_name);
```
Expand Down
10 changes: 9 additions & 1 deletion docs/execution_providers/nGraph-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ While the nGraph Compiler stack supports various operating systems and backends
### C/C++
To use nGraph as execution provider for inferencing, please register it as below.
```
InferenceSession session_object{so};
string log_id = "Foo";
auto logging_manager = std::make_unique<LoggingManager>
(std::unique_ptr<ISink>{new CLogSink{}},
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&log_id)
Environment::Create(std::move(logging_manager), env)
InferenceSession session_object{so,env};
session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::NGRAPHExecutionProvider>());
status = session_object.Load(model_file_name);
```
Expand Down
45 changes: 43 additions & 2 deletions include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
#include <memory>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/platform/threadpool.h"
#include "core/common/logging/logging.h"

struct ThreadingOptions;
namespace onnxruntime {
/** TODO: remove this class
Provides the runtime environment for onnxruntime.
Expand All @@ -17,13 +20,51 @@ class Environment {
public:
/**
Create and initialize the runtime environment.
@param logging manager instance that will enable per session logger output using
session_options.session_logid as the logger id in messages.
If nullptr, the default LoggingManager MUST have been created previously as it will be used
for logging. This will use the default logger id in messages.
See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works.
@param tp_options optional set of parameters controlling the number of intra and inter op threads for the global
threadpools.
@param create_global_thread_pools determine if this function will create the global threadpools or not.
*/
static Status Create(std::unique_ptr<Environment>& environment);
static Status Create(std::unique_ptr<logging::LoggingManager> logging_manager,
std::unique_ptr<Environment>& environment,
const ThreadingOptions* tp_options = nullptr,
bool create_global_thread_pools = false);

logging::LoggingManager* GetLoggingManager() const {
return logging_manager_.get();
}

void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager) {
logging_manager_ = std::move(logging_manager);
}

onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const {
return intra_op_thread_pool_.get();
}

onnxruntime::concurrency::ThreadPool* GetInterOpThreadPool() const {
return inter_op_thread_pool_.get();
}

bool EnvCreatedWithGlobalThreadPools() const {
return create_global_thread_pools_;
}

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);

Environment() = default;
Status Initialize();
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
const ThreadingOptions* tp_options = nullptr,
bool create_global_thread_pools = false);

std::unique_ptr<logging::LoggingManager> logging_manager_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> intra_op_thread_pool_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};
};
} // namespace onnxruntime
25 changes: 25 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ typedef enum OrtMemType {
OrtMemTypeDefault = 0, // the default allocator for execution provider
} OrtMemType;

typedef struct ThreadingOptions {
// number of threads used to parallelize execution of an op
int intra_op_num_threads; // use 0 if you want onnxruntime to choose a value for you

// number of threads used to parallelize execution across ops
int inter_op_num_threads; // use 0 if you want onnxruntime to choose a value for you
} ThreadingOptions;

struct OrtApi;
typedef struct OrtApi OrtApi;

Expand Down Expand Up @@ -747,6 +755,23 @@ struct OrtApi {
OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION;

ORT_CLASS_RELEASE(ModelMetadata);

/*
* Creates an environment with global threadpools that will be shared across sessions.
* Use this in conjunction with DisablePerSessionThreads API or else the session will use
* its own thread pools.
*/
OrtStatus*(ORT_API_CALL* CreateEnvWithGlobalThreadPools)(OrtLoggingLevel default_logging_level, _In_ const char* logid,
_In_ ThreadingOptions t_options, _Outptr_ OrtEnv** out)
NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

/* TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function? */

/*
* Calling this API will make the session use the global threadpools shared across sessions.
* This API should be used in conjunction with CreateEnvWithGlobalThreadPools API.
*/
OrtStatus*(ORT_API_CALL* DisablePerSessionThreads)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION;
};

/*
Expand Down
7 changes: 5 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct Base {
~Base() { OrtRelease(p_); }

operator T*() { return p_; }
operator const T *() const { return p_; }
operator const T*() const { return p_; }

T* release() {
T* p = p_;
Expand Down Expand Up @@ -123,6 +123,7 @@ struct ModelMetadata;
struct Env : Base<OrtEnv> {
Env(std::nullptr_t) {}
Env(OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
Env(ThreadingOptions tp_options, OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
Env(OrtLoggingLevel default_logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}

Expand Down Expand Up @@ -185,6 +186,8 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& SetLogId(const char* logid);

SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);

SessionOptions& DisablePerSessionThreads();
};

struct ModelMetadata : Base<OrtModelMetadata> {
Expand Down Expand Up @@ -289,7 +292,7 @@ struct AllocatorWithDefaultOptions {
AllocatorWithDefaultOptions();

operator OrtAllocator*() { return p_; }
operator const OrtAllocator *() const { return p_; }
operator const OrtAllocator*() const { return p_; }

void* Alloc(size_t size);
void Free(void* p);
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLog
ThrowOnError(Global<void>::api_.CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_));
}

inline Env::Env(ThreadingOptions tp_options, OrtLoggingLevel default_warning_level, const char* logid) {
ThrowOnError(Global<void>::api_.CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_));
}

inline Env& Env::EnableTelemetryEvents() {
ThrowOnError(Global<void>::api_.EnableTelemetryEvents(p_));
return *this;
Expand Down Expand Up @@ -601,4 +605,8 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
return out;
}

inline SessionOptions& SessionOptions::DisablePerSessionThreads() {
ThrowOnError(Global<void>::api_.DisablePerSessionThreads(p_));
return *this;
}
} // namespace Ort
6 changes: 3 additions & 3 deletions onnxruntime/core/common/logging/logging.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "core/platform/ort_mutex.h"

#if __FreeBSD__
#include <sys/thr.h> // Use thr_self() syscall under FreeBSD to get thread id
#include <sys/thr.h> // Use thr_self() syscall under FreeBSD to get thread id
#endif

namespace onnxruntime {
Expand Down Expand Up @@ -189,8 +189,8 @@ std::exception LoggingManager::LogFatalAndCreateException(const char* category,
// create Capture in separate scope so it gets destructed (leading to log output) before we throw.
{
::onnxruntime::logging::Capture c{::onnxruntime::logging::LoggingManager::DefaultLogger(),
::onnxruntime::logging::Severity::kFATAL, category,
::onnxruntime::logging::DataType::SYSTEM, location};
::onnxruntime::logging::Severity::kFATAL, category,
::onnxruntime::logging::DataType::SYSTEM, location};
va_list args;
va_start(args, format_str);

Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,9 @@ struct SessionOptions {
// For models with free input dimensions (most commonly batch size), specifies a set of values to override those
// free dimensions with, keyed by dimension denotation.
std::vector<FreeDimensionOverride> free_dimension_overrides;

// By default the session uses its own set of threadpools, unless this is set to false.
// Use this in conjunction with the CreateEnvWithGlobalThreadPools API.
bool use_per_session_threads = true;
};
} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/core/session/abi_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,8 @@ ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverride, _Inout_ OrtSessionOptions
options->value.free_dimension_overrides.push_back(onnxruntime::FreeDimensionOverride{symbolic_dim, dim_override});
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::DisablePerSessionThreads, _In_ OrtSessionOptions* options) {
options->value.use_per_session_threads = false;
return nullptr;
}
23 changes: 20 additions & 3 deletions onnxruntime/core/session/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#endif

#include "core/platform/env.h"
#include "core/util/thread_utils.h"

#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
#include "core/platform/tracing.h"
Expand All @@ -29,15 +30,31 @@ using namespace ONNX_NAMESPACE;

std::once_flag schemaRegistrationOnceFlag;

Status Environment::Create(std::unique_ptr<Environment>& environment) {
Status Environment::Create(std::unique_ptr<logging::LoggingManager> logging_manager,
std::unique_ptr<Environment>& environment,
const ThreadingOptions* tp_options,
bool create_global_thread_pools) {
environment = std::unique_ptr<Environment>(new Environment());
auto status = environment->Initialize();
auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pools);
return status;
}

Status Environment::Initialize() {
Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
const ThreadingOptions* tp_options,
bool create_global_thread_pools) {
auto status = Status::OK();

logging_manager_ = std::move(logging_manager);

// create thread pools
if (create_global_thread_pools) {
create_global_thread_pools_ = true;
intra_op_thread_pool_ = concurrency::CreateThreadPool("env_global_intra_op_thread_pool",
tp_options->intra_op_num_threads);
inter_op_thread_pool_ = concurrency::CreateThreadPool("env_global_inter_op_thread_pool",
tp_options->inter_op_num_threads);
}

try {
// Register Microsoft domain with min/max op_set version as 1/1.
std::call_once(schemaRegistrationOnceFlag, []() {
Expand Down
Loading

0 comments on commit 435f014

Please sign in to comment.