Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Disco] Support setting workers' CPU affinity #16807

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/runtime/disco/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t
TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device {
return DiscoWorker::ThreadLocal()->default_device;
});
TVM_REGISTER_GLOBAL("runtime.disco.bind_worker_to_cpu_core").set_body_typed([](IntTuple cpu_ids) {
int worker_id = WorkerId();
ICHECK_LT(worker_id, static_cast<int>(cpu_ids.size()));
const PackedFunc* f_set_thread_affinity =
Registry::Get("tvm.runtime.threading.set_current_thread_affinity");
ICHECK_NOTNULL(f_set_thread_affinity);
(*f_set_thread_affinity)(IntTuple{cpu_ids[worker_id]});
});

} // namespace runtime
} // namespace tvm
75 changes: 43 additions & 32 deletions src/runtime/threading_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Native threading backend
*/
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>

#if defined(__linux__) || defined(__ANDROID__)
Expand Down Expand Up @@ -106,6 +107,39 @@ class QuRTThread {
void* stack_ = nullptr;
};
#endif // __hexagon__

// This is a common function used to set thread affinity.
void SetThreadAffinity(std::thread::native_handle_type thread,
const std::vector<unsigned int>& ids) {
#if defined(__linux__) || defined(__ANDROID__)
if (pthread_equal(thread, CURRENT_THREAD_HANDLE)) {
thread = pthread_self();
}
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
for (auto id : ids) {
CPU_SET(id, &cpuset);
}
#if defined(__ANDROID__)
#if __ANDROID_API__ >= 21
pid_t tid = pthread_gettid_np(thread);
#else
typedef struct {
void* next;
void* pred;
pid_t tid;
} pthread_internal;
pid_t tid = reinterpret_cast<pthread_internal*>(thread)->tid;
#endif
if (sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset) != 0) {
LOG(WARNING) << "sched_setaffinity failed";
}
#else
pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
#endif
#endif
}

thread_local int max_concurrency = 0;
class ThreadGroup::Impl {
public:
Expand Down Expand Up @@ -158,37 +192,6 @@ class ThreadGroup::Impl {
}

private:
void SetThreadAffinity(std::thread::native_handle_type thread,
const std::vector<unsigned int>& ids) {
#if defined(__linux__) || defined(__ANDROID__)
if (pthread_equal(thread, CURRENT_THREAD_HANDLE)) {
thread = pthread_self();
}
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
for (auto id : ids) {
CPU_SET(id, &cpuset);
}
#if defined(__ANDROID__)
#if __ANDROID_API__ >= 21
pid_t tid = pthread_gettid_np(thread);
#else
typedef struct {
void* next;
void* pred;
pid_t tid;
} pthread_internal;
pid_t tid = reinterpret_cast<pthread_internal*>(thread)->tid;
#endif
if (sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset) != 0) {
LOG(WARNING) << "sched_setaffinity failed";
}
#else
pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
#endif
#endif
}

// bind worker threads to disjoint cores
// if worker 0 is offloaded to main, i.e. exclude_worker0 is true,
// the main thread is bound to core 0.
Expand Down Expand Up @@ -326,7 +329,7 @@ class ThreadGroup::Impl {
const std::pair<unsigned int, int64_t>& b) {
return a.second == b.second ? a.first < b.first : a.second > b.second;
};
std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq);
std::stable_sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq);
int64_t big_freq = max_freqs.begin()->second;
int64_t little_freq = max_freqs.rbegin()->second;
for (auto it = max_freqs.begin(); it != max_freqs.end(); it++) {
Expand Down Expand Up @@ -431,6 +434,14 @@ int MaxConcurrency() {
return std::max(max_concurrency, 1);
}

// This global function can be used by disco runtime to bind processes
// to CPUs.
TVM_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity")
.set_body_typed([](IntTuple cpu_ids) {
SetThreadAffinity(CURRENT_THREAD_HANDLE,
std::vector<unsigned int>{cpu_ids.begin(), cpu_ids.end()});
});

} // namespace threading
} // namespace runtime
} // namespace tvm
Loading