Skip to content

Commit

Permalink
Merge pull request #11321 from Yancey1989/polish_sparse_update
Browse files Browse the repository at this point in the history
polish sparse update logic
  • Loading branch information
typhoonzero authored Jun 10, 2018
2 parents eced973 + 5696494 commit 7bcc980
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 21 deletions.
12 changes: 10 additions & 2 deletions paddle/fluid/operators/detail/request_handler_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,21 @@ bool RequestSendHandler::Handle(const std::string& varname,
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
rpc_server_->RecordSparseVar(invar);
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
}
}

return true;
}

void RequestSendHandler::ResetSparseVarRecorder() {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}

bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/detail/request_handler_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class RequestSendHandler final : public RequestHandler {
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
void ResetSparseVarRecorder();

private:
std::mutex mutex_sparse_vars_;
std::vector<framework::Variable*> sparse_vars_;
};

class RequestGetHandler final : public RequestHandler {
Expand Down
13 changes: 0 additions & 13 deletions paddle/fluid/operators/detail/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@ void RPCServer::ResetBarrierCounter() {
t.second = 0;
}
}
void RPCServer::RecordSparseVar(framework::Variable* sparse_var) {
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
sparse_vars_.push_back(sparse_var);
}

void RPCServer::ResetSparseVarsRecorder() {
VLOG(3) << "RPCServer reset sparse vars recorder.";
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}

void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) {
Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/operators/detail/rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class RPCServer {
void IncreaseBatchBarrier(const std::string rpc_name);

void ResetBarrierCounter();
void RecordSparseVar(framework::Variable* sparse_var);
void ResetSparseVarsRecorder();

protected:
virtual void ShutDownImpl() = 0;
Expand All @@ -77,9 +75,6 @@ class RPCServer {
std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_;

std::vector<framework::Variable*> sparse_vars_;
std::mutex mutex_sparse_var_recorder_;

protected:
std::string bind_address_;
std::atomic<int> exit_flag_;
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter();
rpc_service_->ResetSparseVarsRecorder();
// reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<detail::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder();
} // while(true)
}

Expand Down

0 comments on commit 7bcc980

Please sign in to comment.