Skip to content

Commit

Permalink
fix task stuck in barrier (PaddlePaddle#189)
Browse files Browse the repository at this point in the history
Co-authored-by: yangjunchao <[email protected]>
  • Loading branch information
2 people authored and zmxdream committed Dec 24, 2022
1 parent 91d951f commit 718db77
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class HogwildWorker : public CPUWorkerBase {
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
std::map<std::string, int> stat_var_name_map_;
static std::atomic<uint64_t> worker_num_stat_;
static std::atomic<bool> quit_flag_;
};

class DownpourWorker : public HogwildWorker {
Expand Down
18 changes: 9 additions & 9 deletions paddle/fluid/framework/hogwild_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ DECLARE_bool(enable_exit_when_partial_worker);
namespace paddle {
namespace framework {

std::atomic<uint64_t> HogwildWorker::worker_num_stat_(0);
std::atomic<bool> HogwildWorker::quit_flag_(false);
Barrier g_barrier;

void HogwildWorker::Initialize(const TrainerDesc &desc) {
Expand Down Expand Up @@ -148,7 +148,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int cur_batch;
int batch_cnt = 0;
if (thread_id_ == 0) {
worker_num_stat_.store(0);
quit_flag_.store(false);
}
g_barrier.wait();
bool train_mode = device_reader_->IsTrainMode();
Expand All @@ -160,11 +160,11 @@ void HogwildWorker::TrainFilesWithProfiler() {
while (1) {
cur_batch = device_reader_->Next();
if (FLAGS_enable_exit_when_partial_worker && train_mode) {
if (cur_batch > 0) {
worker_num_stat_.fetch_add(1, std::memory_order_relaxed);
if (cur_batch <= 0) {
quit_flag_.store(true, std::memory_order_relaxed);
}
g_barrier.wait();
if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) {
if (quit_flag_.load(std::memory_order_relaxed) == true) {
break;
}
}
Expand Down Expand Up @@ -265,7 +265,7 @@ void HogwildWorker::TrainFiles() {
int cur_batch;
int batch_cnt = 0;
if (thread_id_ == 0) {
worker_num_stat_.store(0);
quit_flag_.store(false);
}
g_barrier.wait();

Expand All @@ -280,11 +280,11 @@ void HogwildWorker::TrainFiles() {
while (1) {
cur_batch = device_reader_->Next();
if (FLAGS_enable_exit_when_partial_worker && train_mode) {
if (cur_batch > 0) {
worker_num_stat_.fetch_add(1, std::memory_order_relaxed);
if (cur_batch <= 0) {
quit_flag_.store(true, std::memory_order_relaxed);
}
g_barrier.wait();
if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) {
if (quit_flag_.load(std::memory_order_relaxed) == true) {
break;
}
}
Expand Down

0 comments on commit 718db77

Please sign in to comment.