diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 05961df9d585..906cea1e323e 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -129,6 +129,14 @@ TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device { return DiscoWorker::ThreadLocal()->default_device; }); +TVM_REGISTER_GLOBAL("runtime.disco.bind_worker_to_cpu_core").set_body_typed([](IntTuple cpu_ids) { + int worker_id = WorkerId(); + ICHECK_LT(worker_id, static_cast(cpu_ids.size())); + const PackedFunc* f_set_thread_affinity = + Registry::Get("tvm.runtime.threading.set_current_thread_affinity"); + ICHECK_NOTNULL(f_set_thread_affinity); + (*f_set_thread_affinity)(IntTuple{cpu_ids[worker_id]}); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index b6e12a25cca8..177ecf511070 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -22,6 +22,7 @@ * \brief Native threading backend */ #include +#include #include #if defined(__linux__) || defined(__ANDROID__) @@ -106,6 +107,39 @@ class QuRTThread { void* stack_ = nullptr; }; #endif // __hexagon__ + +// This is a common function used to set thread affinity. +void SetThreadAffinity(std::thread::native_handle_type thread, + const std::vector& ids) { +#if defined(__linux__) || defined(__ANDROID__) + if (pthread_equal(thread, CURRENT_THREAD_HANDLE)) { + thread = pthread_self(); + } + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + for (auto id : ids) { + CPU_SET(id, &cpuset); + } +#if defined(__ANDROID__) +#if __ANDROID_API__ >= 21 + pid_t tid = pthread_gettid_np(thread); +#else + typedef struct { + void* next; + void* pred; + pid_t tid; + } pthread_internal; + pid_t tid = reinterpret_cast(thread)->tid; +#endif + if (sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset) != 0) { + LOG(WARNING) << "sched_setaffinity failed"; + } +#else + pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); +#endif +#endif +} + thread_local int max_concurrency = 0; class ThreadGroup::Impl { public: @@ -158,37 +192,6 @@ class ThreadGroup::Impl { } private: - void SetThreadAffinity(std::thread::native_handle_type thread, - const std::vector& ids) { -#if defined(__linux__) || defined(__ANDROID__) - if (pthread_equal(thread, CURRENT_THREAD_HANDLE)) { - thread = pthread_self(); - } - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - for (auto id : ids) { - CPU_SET(id, &cpuset); - } -#if defined(__ANDROID__) -#if __ANDROID_API__ >= 21 - pid_t tid = pthread_gettid_np(thread); -#else - typedef struct { - void* next; - void* pred; - pid_t tid; - } pthread_internal; - pid_t tid = reinterpret_cast(thread)->tid; -#endif - if (sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset) != 0) { - LOG(WARNING) << "sched_setaffinity failed"; - } -#else - pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); -#endif -#endif - } - // bind worker threads to disjoint cores // if worker 0 is offloaded to main, i.e. exclude_worker0 is true, // the main thread is bound to core 0. @@ -326,7 +329,7 @@ class ThreadGroup::Impl { const std::pair& b) { return a.second == b.second ? a.first < b.first : a.second > b.second; }; - std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq); + std::stable_sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq); int64_t big_freq = max_freqs.begin()->second; int64_t little_freq = max_freqs.rbegin()->second; for (auto it = max_freqs.begin(); it != max_freqs.end(); it++) { @@ -431,6 +434,14 @@ int MaxConcurrency() { return std::max(max_concurrency, 1); } +// This global function can be used by disco runtime to bind processes +// to CPUs. +TVM_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity") + .set_body_typed([](IntTuple cpu_ids) { + SetThreadAffinity(CURRENT_THREAD_HANDLE, + std::vector{cpu_ids.begin(), cpu_ids.end()}); + }); + } // namespace threading } // namespace runtime } // namespace tvm