Skip to content

Commit

Permalink
Add API to allow configuration of the global thread pools. (#5199)
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavsharma authored and edgchen1 committed Sep 24, 2020
1 parent c52d900 commit 3dbbdcc
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 3 deletions.
17 changes: 17 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,23 @@ struct OrtApi {
*/
ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Outptr_ uint64_t* out);

/**
* Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools.
* A value of 0 means ORT will pick the default.
* A value of 1 means the invoking thread will be used; no threads will be created in the thread pool.
*/
ORT_API2_STATUS(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads);
ORT_API2_STATUS(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads);

/**
* Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools.
* Allow spinning of thread pools when their queues are empty. This API will set the value for both
* inter_op and intra_op threadpools.
* \param allow_spinning valid values are 1 and 0.
* 1: threadpool will spin to wait for queue to become non-empty, 0: it won't spin.
* Prefer a value of 0 if your CPU usage is very high.
*/
ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning);
};

/*
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1731,7 +1731,7 @@ ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, const int64_t* l
_Outptr_ void** out) {
TENSOR_READWRITE_API_BEGIN

if(tensor->IsDataTypeString()) {
if (tensor->IsDataTypeString()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "this API does not support strings");
}

Expand Down Expand Up @@ -2011,6 +2011,9 @@ static constexpr OrtApi ort_api_1_to_5 = {
&OrtApis::CreateAndRegisterAllocator,
&OrtApis::SetLanguageProjection,
&OrtApis::SessionGetProfilingStartTimeNs,
&OrtApis::SetGlobalIntraOpNumThreads,
&OrtApis::SetGlobalInterOpNumThreads,
&OrtApis::SetGlobalSpinControl,
};

// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,7 @@ ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const

ORT_API_STATUS_IMPL(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection);
ORT_API_STATUS_IMPL(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Outptr_ uint64_t* out);
ORT_API_STATUS_IMPL(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads);
ORT_API_STATUS_IMPL(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads);
ORT_API_STATUS_IMPL(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning);
} // namespace OrtApis
28 changes: 28 additions & 0 deletions onnxruntime/core/util/thread_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <Windows.h>
#endif
#include <thread>
#include "core/session/ort_apis.h"

namespace onnxruntime {
namespace concurrency {
Expand Down Expand Up @@ -61,4 +62,31 @@ ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out)
ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions* p) {
delete p;
}

ORT_API_STATUS_IMPL(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads) {
if (!tp_options) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
}
tp_options->intra_op_thread_pool_params.thread_pool_size = intra_op_num_threads;
return nullptr;
}
ORT_API_STATUS_IMPL(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads) {
if (!tp_options) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
}
tp_options->inter_op_thread_pool_params.thread_pool_size = inter_op_num_threads;
return nullptr;
}

ORT_API_STATUS_IMPL(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning) {
if (!tp_options) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
}
if (!(allow_spinning == 1 || allow_spinning == 0)) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received invalid value for allow_spinning. Valid values are 0 or 1");
}
tp_options->intra_op_thread_pool_params.allow_spinning = allow_spinning;
tp_options->inter_op_thread_pool_params.allow_spinning = allow_spinning;
return nullptr;
}
} // namespace OrtApis
22 changes: 20 additions & 2 deletions onnxruntime/test/global_thread_pools/test_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,35 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "gtest/gtest.h"
#include "test/test_environment.h"
#include <thread>

std::unique_ptr<Ort::Env> ort_env;

#define ORT_RETURN_IF_NON_NULL_STATUS(arg) \
if (arg) { \
return -1; \
}

int main(int argc, char** argv) {
int status = 0;
ORT_TRY {
::testing::InitGoogleTest(&argc, argv);
const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
OrtThreadingOptions* tp_options;
OrtStatus* st = g_ort->CreateThreadingOptions(&tp_options);
if (st != nullptr) return -1;
std::unique_ptr<OrtStatus, decltype(OrtApi::ReleaseStatus)> st_ptr(nullptr, g_ort->ReleaseStatus);

st_ptr.reset(g_ort->CreateThreadingOptions(&tp_options));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);

st_ptr.reset(g_ort->SetGlobalSpinControl(tp_options, 0));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);

st_ptr.reset(g_ort->SetGlobalIntraOpNumThreads(tp_options, std::thread::hardware_concurrency()));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);

st_ptr.reset(g_ort->SetGlobalInterOpNumThreads(tp_options, std::thread::hardware_concurrency()));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);

ort_env.reset(new Ort::Env(tp_options, ORT_LOGGING_LEVEL_VERBOSE, "Default")); // this is the only change from test/providers/test_main.cc
g_ort->ReleaseThreadingOptions(tp_options);
status = RUN_ALL_TESTS();
Expand Down

0 comments on commit 3dbbdcc

Please sign in to comment.