Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine listen and serve op #10080

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/reduce_op_handle_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct TestReduceOpHandle {
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
if (!use_gpu_) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
in_var_handle->generated_op_ = nullptr;
Expand Down
5 changes: 1 addition & 4 deletions paddle/fluid/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ class AsyncGRPCServer final {

void SetProgram(framework::ProgramDesc *program) { program_ = program; }

void SetPrefetchBlkdId(int blkid) { prefetch_blk_id_ = blkid; }

void SetExecutor(framework::Executor *executor) { executor_ = executor; }

void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
prefetch_ctx_ = prepared;
}

int GetSelectedPort() { return selected_port_; }
int GetSelectedPort() const { return selected_port_; }

const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }

Expand Down Expand Up @@ -114,7 +112,6 @@ class AsyncGRPCServer final {
std::unique_ptr<std::thread> t_get_;
std::unique_ptr<std::thread> t_prefetch_;

int prefetch_blk_id_;
framework::ExecutorPrepareContext *prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
Expand Down
114 changes: 55 additions & 59 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG(4) << "RunServer thread end";
}

static void CreateTensorFromMessageType(framework::Variable *var,
sendrecv::VarType var_type) {
if (var_type == sendrecv::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
var->GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]",
var_type);
}
}

static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
Expand All @@ -62,6 +48,13 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}

static void SavePort(std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
std::ofstream port_file;
port_file.open("/tmp/paddle.selected_port");
port_file << rpc_service->GetSelectedPort();
port_file.close();
}

ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
Expand All @@ -77,59 +70,26 @@ void ListenAndServOp::Stop() {
server_thread_->join();
}

void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();

if (!rpc_service_) {
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
}

auto ins = Inputs("X");
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
auto fan_in = Attr<int>("Fanin");
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();

size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");

framework::Executor executor(dev_place);
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
block_list.push_back(blkid);
}
block_list.push_back(blkid);
}
auto optimize_prepared = executor.Prepare(*program, block_list);
auto optimize_prepared = executor->Prepare(*program, block_list);
// Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));

rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program);
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
sleep(5);
// Write to a file of server selected port for python use.
std::ofstream port_file;
port_file.open("/tmp/paddle.selected_port");
port_file << rpc_service_->GetSelectedPort();
port_file.close();

bool exit_flag = false;
// Record received sparse variables, so that
// we could reset those after execute optimize program
Expand Down Expand Up @@ -170,7 +130,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
break;
}

// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.

// The optimize blocks which have the same parent ID would run parallel
Expand All @@ -182,16 +142,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
}
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";

// Reset the received sparse variables, the sum operator would not
Expand All @@ -209,6 +169,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
} // while(true)
}

void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();

PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));

auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
framework::Executor executor(dev_place);

// prepare rpc_service
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(program);
rpc_service_->SetExecutor(&executor);

// prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();

// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
sleep(5);
// Write to a file of server selected port for python use.
SavePort(rpc_service_);
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
}

class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Expand Down
17 changes: 11 additions & 6 deletions paddle/fluid/operators/listen_and_serv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);

class ListenAndServOp : public framework::OperatorBase {
public:
ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs);
ListenAndServOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);

int GetSelectedPort() const;

void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;

void Stop() override;

void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override;
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;

protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/send_recv_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void StartServerNet(bool is_sparse) {
const auto &root_block = program.Block(0);
auto *optimize_block = program.AppendBlock(root_block);
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensers, must be of same shape.
// X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);

f::AttributeMap attrs;
Expand Down