Skip to content

Commit

Permalink
[psgpu]fix pipe bug:save and pull overlap; test=develop (#37233)
Browse files Browse the repository at this point in the history
  • Loading branch information
danleifeng authored Nov 16, 2021
1 parent f29a3c6 commit 62ec644
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
59 changes: 30 additions & 29 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,8 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {

void PSGPUWrapper::start_build_thread() {
running_ = true;
VLOG(3) << "start build CPU&GPU ps thread.";
VLOG(3) << "start build CPU ps thread.";
pre_build_threads_ = std::thread([this] { pre_build_thread(); });
build_threads_ = std::thread([this] { build_thread(); });
}

void PSGPUWrapper::pre_build_thread() {
Expand All @@ -515,30 +514,28 @@ void PSGPUWrapper::pre_build_thread() {
VLOG(3) << "build cpu thread end";
}

void PSGPUWrapper::build_thread() {
// build: build_pull + build_gputask
while (running_) {
std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!gpu_free_channel_->Get(gpu_task)) {
continue;
}
if (!buildcpu_ready_channel_->Get(gpu_task)) {
continue;
}
VLOG(3) << "thread BuildGPUTask start.";
platform::Timer timer;
timer.Start();
BuildPull(gpu_task);
timer.Pause();
timer.Start();
BuildGPUTask(gpu_task);
timer.Pause();
VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec()
<< "s";

train_ready_channel_->Put(gpu_task);
void PSGPUWrapper::build_task() {
// build_task: build_pull + build_gputask
std::shared_ptr<HeterContext> gpu_task = nullptr;
// train end, gpu free
if (!gpu_free_channel_->Get(gpu_task)) {
return;
}
// ins and pre_build end
if (!buildcpu_ready_channel_->Get(gpu_task)) {
return;
}
VLOG(3) << "build gpu thread end";

VLOG(1) << "BuildPull start.";
platform::Timer timer;
timer.Start();
BuildPull(gpu_task);
BuildGPUTask(gpu_task);
timer.Pause();
VLOG(1) << "BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec()
<< "s";

current_task_ = gpu_task;
}

void PSGPUWrapper::BeginPass() {
Expand All @@ -548,11 +545,15 @@ void PSGPUWrapper::BeginPass() {
PADDLE_THROW(
platform::errors::Fatal("[BeginPass] current task is not ended."));
}
// load+build done
if (!train_ready_channel_->Get(current_task_)) {
PADDLE_THROW(platform::errors::Fatal("train_ready_channel_ failed."));
}

build_task();
timer.Pause();

if (current_task_ == nullptr) {
PADDLE_THROW(platform::errors::Fatal(
"[BeginPass] after build_task, current task is not null."));
}

VLOG(1) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s";
}

Expand Down
9 changes: 1 addition & 8 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class PSGPUWrapper {
void EndPass();
void start_build_thread();
void pre_build_thread();
void build_thread();
void build_task();

void Finalize() {
VLOG(3) << "PSGPUWrapper Begin Finalize.";
Expand All @@ -101,7 +101,6 @@ class PSGPUWrapper {
data_ready_channel_->Close();
buildcpu_ready_channel_->Close();
gpu_free_channel_->Close();
train_ready_channel_->Close();
running_ = false;
VLOG(3) << "begin stop pre_build_threads_";
pre_build_threads_.join();
Expand Down Expand Up @@ -169,8 +168,6 @@ class PSGPUWrapper {
buildcpu_ready_channel_->SetCapacity(3);
gpu_free_channel_->Open();
gpu_free_channel_->SetCapacity(1);
train_ready_channel_->Open();
train_ready_channel_->SetCapacity(1);

current_task_ = nullptr;
gpu_free_channel_->Put(current_task_);
Expand Down Expand Up @@ -306,10 +303,6 @@ class PSGPUWrapper {
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
gpu_free_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
train_ready_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<HeterContext> current_task_ = nullptr;
std::thread pre_build_threads_;
std::thread build_threads_;
Expand Down

0 comments on commit 62ec644

Please sign in to comment.