Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sessions to share a global threadpool. #3177

Merged
merged 11 commits into from
Mar 18, 2020
32 changes: 30 additions & 2 deletions include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
#include <memory>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/platform/threadpool.h"
#include "core/common/logging/logging.h"

struct ThreadingOptions;
namespace onnxruntime {
/** TODO: remove this class
Provides the runtime environment for onnxruntime.
Expand All @@ -18,12 +21,37 @@ class Environment {
/**
Create and initialize the runtime environment.
*/
static Status Create(std::unique_ptr<Environment>& environment);
static Status Create(std::unique_ptr<logging::LoggingManager> logging_manager,
std::unique_ptr<Environment>& environment,
const ThreadingOptions* tp_options = nullptr,
bool create_thread_pool = false);

logging::LoggingManager* GetLoggingManager() const {
return logging_manager_.get();
}

void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager) {
logging_manager_ = std::move(logging_manager);
}

onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const {
return intra_op_thread_pool_.get();
}

onnxruntime::concurrency::ThreadPool* GetInterOpThreadPool() const {
return inter_op_thread_pool_.get();
}

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);

Environment() = default;
Status Initialize();
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
const ThreadingOptions* tp_options = nullptr,
bool create_global_thread_pool = false);

std::unique_ptr<logging::LoggingManager> logging_manager_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> intra_op_thread_pool_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
};
} // namespace onnxruntime
24 changes: 24 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ typedef enum OrtMemType {
OrtMemTypeDefault = 0, // the default allocator for execution provider
} OrtMemType;

typedef struct ThreadingOptions {
// number of threads used to parallelize execution of an op
int intra_op_num_threads; // use 0 if you want onnxruntime to choose a value for you

// number of threads used to parallelize execution across ops
int inter_op_num_threads; // use 0 if you want onnxruntime to choose a value for you
} ThreadingOptions;

struct OrtApi;
typedef struct OrtApi OrtApi;

Expand Down Expand Up @@ -747,6 +755,22 @@ struct OrtApi {
OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION;

ORT_CLASS_RELEASE(ModelMetadata);

/*
* Creates an environment with global threadpools that will be shared across sessions.
* Use this in conjunction with DisablePerSessionThreads API or else by default the session will use
* its own thread pools.
*/
OrtStatus*(ORT_API_CALL* CreateEnvWithGlobalThreadPools)(OrtLoggingLevel default_logging_level, _In_ const char* logid,
_In_ ThreadingOptions t_options, _Outptr_ OrtEnv** out)
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

// TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function?

/*
* Calling this API will make the session use the global threadpools shared across sessions.
*/
OrtStatus*(ORT_API_CALL* DisablePerSessionThreads)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION;
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
};

/*
Expand Down
7 changes: 5 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct Base {
~Base() { OrtRelease(p_); }

operator T*() { return p_; }
operator const T *() const { return p_; }
operator const T*() const { return p_; }

T* release() {
T* p = p_;
Expand Down Expand Up @@ -123,6 +123,7 @@ struct ModelMetadata;
struct Env : Base<OrtEnv> {
Env(std::nullptr_t) {}
Env(OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
Env(ThreadingOptions tp_options, OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
Env(OrtLoggingLevel default_logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}

Expand Down Expand Up @@ -185,6 +186,8 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& SetLogId(const char* logid);

SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);

SessionOptions& DisablePerSessionThreads();
};

struct ModelMetadata : Base<OrtModelMetadata> {
Expand Down Expand Up @@ -289,7 +292,7 @@ struct AllocatorWithDefaultOptions {
AllocatorWithDefaultOptions();

operator OrtAllocator*() { return p_; }
operator const OrtAllocator *() const { return p_; }
operator const OrtAllocator*() const { return p_; }

void* Alloc(size_t size);
void Free(void* p);
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLog
ThrowOnError(Global<void>::api_.CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_));
}

inline Env::Env(ThreadingOptions tp_options, OrtLoggingLevel default_warning_level, const char* logid) {
ThrowOnError(Global<void>::api_.CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_));
}

inline Env& Env::EnableTelemetryEvents() {
ThrowOnError(Global<void>::api_.EnableTelemetryEvents(p_));
return *this;
Expand Down Expand Up @@ -601,4 +605,8 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
return out;
}

inline SessionOptions& SessionOptions::DisablePerSessionThreads() {
ThrowOnError(Global<void>::api_.DisablePerSessionThreads(p_));
return *this;
}
} // namespace Ort
4 changes: 4 additions & 0 deletions onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,9 @@ struct SessionOptions {
// For models with free input dimensions (most commonly batch size), specifies a set of values to override those
// free dimensions with, keyed by dimension denotation.
std::vector<FreeDimensionOverride> free_dimension_overrides;

// By default the session uses it's own set of threadpools, unless this is set to false.
// Use this in conjunction with the CreateEnvWithGlobalThreadPools API.
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
bool use_per_session_threads = true;
};
} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/core/session/abi_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,8 @@ ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverride, _Inout_ OrtSessionOptions
options->value.free_dimension_overrides.push_back(onnxruntime::FreeDimensionOverride{symbolic_dim, dim_override});
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::DisablePerSessionThreads, _In_ OrtSessionOptions* options) {
options->value.use_per_session_threads = false;
return nullptr;
}
22 changes: 19 additions & 3 deletions onnxruntime/core/session/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#endif

#include "core/platform/env.h"
#include "core/util/thread_utils.h"

#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
#include "core/platform/tracing.h"
Expand All @@ -29,15 +30,30 @@ using namespace ONNX_NAMESPACE;

std::once_flag schemaRegistrationOnceFlag;

Status Environment::Create(std::unique_ptr<Environment>& environment) {
Status Environment::Create(std::unique_ptr<logging::LoggingManager> logging_manager,
std::unique_ptr<Environment>& environment,
const ThreadingOptions* tp_options,
bool create_global_thread_pool) {
environment = std::unique_ptr<Environment>(new Environment());
auto status = environment->Initialize();
auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pool);
return status;
}

Status Environment::Initialize() {
Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
const ThreadingOptions* tp_options,
bool create_global_thread_pool) {
auto status = Status::OK();

logging_manager_ = std::move(logging_manager);

// create thread pools
if (create_global_thread_pool) {
intra_op_thread_pool_ = concurrency::CreateThreadPool("env_global_intra_op_thread_pool",
tp_options->intra_op_num_threads);
inter_op_thread_pool_ = concurrency::CreateThreadPool("env_global_inter_op_thread_pool",
tp_options->inter_op_num_threads);
}

try {
// Register Microsoft domain with min/max op_set version as 1/1.
std::call_once(schemaRegistrationOnceFlag, []() {
Expand Down
66 changes: 40 additions & 26 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,31 +161,45 @@ static Status FinalizeSessionOptions(const SessionOptions& user_provided_session
}

void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
logging::LoggingManager* logging_manager) {
const Environment& session_env) {
auto status = FinalizeSessionOptions(session_options, model_proto_, model_loaded_, session_options_);
ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ",
status.ErrorMessage());

// Update the number of steps for the graph transformer manager using the "finalized" session options
graph_transformation_mgr_.SetSteps(session_options_.max_num_graph_transformation_steps);
use_per_session_threads_ = session_options.use_per_session_threads;

logging_manager_ = logging_manager;
if (use_per_session_threads_) {
thread_pool_ = concurrency::CreateThreadPool("intra_op_thread_pool",
session_options_.intra_op_num_threads);

thread_pool_ = concurrency::CreateThreadPool("intra_op_thread_pool",
session_options_.intra_op_num_threads);

inter_op_thread_pool_ = session_options_.execution_mode == ExecutionMode::ORT_PARALLEL
? concurrency::CreateThreadPool("inter_op_thread_pool",
session_options_.inter_op_num_threads)
: nullptr;
inter_op_thread_pool_ = session_options_.execution_mode == ExecutionMode::ORT_PARALLEL
? concurrency::CreateThreadPool("inter_op_thread_pool",
session_options_.inter_op_num_threads)
: nullptr;
} else {
intra_op_thread_pool_from_env_ = session_env.GetIntraOpThreadPool();
inter_op_thread_pool_from_env_ = session_env.GetInterOpThreadPool();

ORT_ENFORCE(intra_op_thread_pool_from_env_,
"Since use_per_session_threads is false, this must be non-nullptr"
" You probably didn't create the env using the CreateEnvWithGlobalThreadPools API");
ORT_ENFORCE(inter_op_thread_pool_from_env_,
"Since use_per_session_threads is false, this must be non-nullptr"
" You probably didn't create the env using the CreateEnvWithGlobalThreadPools API");
ORT_ENFORCE(thread_pool_ == nullptr, "Since use_per_session_threads is false per session threadpools should be nullptr");
ORT_ENFORCE(inter_op_thread_pool_ == nullptr, "Since use_per_session_threads is false per session threadpools should be nullptr");
}

session_state_ = onnxruntime::make_unique<SessionState>(execution_providers_,
session_options_.enable_mem_pattern &&
session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL,
thread_pool_.get(),
inter_op_thread_pool_.get());
use_per_session_threads_ ? thread_pool_.get() : intra_op_thread_pool_from_env_,
use_per_session_threads_ ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_);

InitLogger(logging_manager);
logging_manager_ = session_env.GetLoggingManager();
InitLogger(logging_manager_);

session_state_->SetDataTransferMgr(&data_transfer_mgr_);
session_profiler_.Initialize(session_logger_);
Expand All @@ -200,16 +214,16 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
}

InferenceSession::InferenceSession(const SessionOptions& session_options,
logging::LoggingManager* logging_manager)
const Environment& session_env)
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
insert_cast_transformer_("CastFloat16Transformer") {
// Initialize assets of this session instance
ConstructorCommon(session_options, logging_manager);
ConstructorCommon(session_options, session_env);
}

InferenceSession::InferenceSession(const SessionOptions& session_options,
const std::string& model_uri,
logging::LoggingManager* logging_manager)
const Environment& session_env,
const std::string& model_uri)
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
insert_cast_transformer_("CastFloat16Transformer") {
model_location_ = ToWideString(model_uri);
Expand All @@ -218,13 +232,13 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
status.ErrorMessage());
model_loaded_ = true;
// Finalize session options and initialize assets of this session instance
ConstructorCommon(session_options, logging_manager);
ConstructorCommon(session_options, session_env);
}

#ifdef _WIN32
InferenceSession::InferenceSession(const SessionOptions& session_options,
const std::wstring& model_uri,
logging::LoggingManager* logging_manager)
const Environment& session_env,
const std::wstring& model_uri)
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
insert_cast_transformer_("CastFloat16Transformer") {
model_location_ = ToWideString(model_uri);
Expand All @@ -233,34 +247,34 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
status.ErrorMessage());
model_loaded_ = true;
// Finalize session options and initialize assets of this session instance
ConstructorCommon(session_options, logging_manager);
ConstructorCommon(session_options, session_env);
}
#endif

InferenceSession::InferenceSession(const SessionOptions& session_options,
std::istream& model_istream,
logging::LoggingManager* logging_manager)
const Environment& session_env,
std::istream& model_istream)
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
insert_cast_transformer_("CastFloat16Transformer") {
google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream);
const bool result = model_proto_.ParseFromZeroCopyStream(&zero_copy_input) && model_istream.eof();
ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session");
model_loaded_ = true;
// Finalize session options and initialize assets of this session instance
ConstructorCommon(session_options, logging_manager);
ConstructorCommon(session_options, session_env);
}

InferenceSession::InferenceSession(const SessionOptions& session_options,
const Environment& session_env,
const void* model_data,
int model_data_len,
logging::LoggingManager* logging_manager)
int model_data_len)
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
insert_cast_transformer_("CastFloat16Transformer") {
const bool result = model_proto_.ParseFromArray(model_data, model_data_len);
ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session");
model_loaded_ = true;
// Finalize session options and initialize assets of this session instance
ConstructorCommon(session_options, logging_manager);
ConstructorCommon(session_options, session_env);
}

InferenceSession::~InferenceSession() {
Expand Down
Loading