From 044e133f7a4640e5b8ef20cdaa18e156b608cb0d Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 4 Jun 2018 10:11:24 +0000 Subject: [PATCH 1/9] add rpc_client.h --- paddle/fluid/operators/detail/grpc_client.cc | 58 +++++++++--------- paddle/fluid/operators/detail/grpc_client.h | 38 ++++++------ .../operators/detail/grpc_server_test.cc | 6 +- paddle/fluid/operators/detail/rpc_client.h | 60 +++++++++++++++++++ paddle/fluid/operators/fetch_barrier_op.cc | 2 +- paddle/fluid/operators/recv_op.cc | 2 +- paddle/fluid/operators/send_barrier_op.cc | 2 +- paddle/fluid/operators/send_op.cc | 2 +- paddle/fluid/operators/send_vars_op.cc | 2 +- paddle/fluid/operators/test_send_nccl_id.cc | 10 ++-- 10 files changed, 123 insertions(+), 59 deletions(-) create mode 100644 paddle/fluid/operators/detail/rpc_client.h diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index da9ca1a0c1d55..cf037a47376ed 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -25,26 +25,26 @@ namespace paddle { namespace operators { namespace detail { -std::once_flag RPCClient::init_flag_; +std::once_flag GRPCClient::init_flag_; -std::unique_ptr RPCClient::rpc_client_(nullptr); +std::unique_ptr GRPCClient::rpc_client_(nullptr); -RPCClient* RPCClient::GetInstance() { - std::call_once(init_flag_, &RPCClient::Init); +GRPCClient* GRPCClient::GetInstance() { + std::call_once(init_flag_, &GRPCClient::Init); return rpc_client_.get(); } -void RPCClient::Init() { +void GRPCClient::Init() { if (rpc_client_.get() == nullptr) { - rpc_client_.reset(new RPCClient()); + rpc_client_.reset(new GRPCClient()); } } -bool RPCClient::AsyncSendVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out) { +bool GRPCClient::AsyncSendVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; @@ -94,11 +94,11 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { result->Swap(&tmp); } -bool RPCClient::AsyncGetVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out) { +bool GRPCClient::AsyncGetVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; @@ -136,12 +136,12 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::AsyncPrefetchVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out) { +bool GRPCClient::AsyncPrefetchVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string in_var_name_val = in_var_name; @@ -179,7 +179,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, return true; } -void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { +void GRPCClient::AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); @@ -192,7 +193,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { req_count_++; } -void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { +void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); s->Prepare(time_out); @@ -204,8 +206,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { req_count_++; } -bool RPCClient::Wait() { - VLOG(3) << "RPCClient begin Wait()" +bool GRPCClient::Wait() { + VLOG(3) << "GRPCClient begin Wait()" << " req_count_:" << req_count_; if (req_count_ <= 0) { return true; @@ -239,7 +241,7 @@ bool RPCClient::Wait() { return true; } -bool RPCClient::Proceed() { +bool GRPCClient::Proceed() { void* tag = NULL; bool ok = false; @@ -265,7 +267,7 @@ bool RPCClient::Proceed() { delete c; return true; } -std::shared_ptr RPCClient::GetChannel(const std::string& ep) { +std::shared_ptr GRPCClient::GetChannel(const std::string& ep) { // TODO(Yancey1989): make grpc client completely thread-safe std::unique_lock lock(mutex_); auto it = channels_.find(ep); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 449d5105afb8c..e9a2e47819100 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -35,6 +35,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN @@ -161,38 +162,39 @@ class FetchBarrierProcessor : public BaseProcessor { std::unique_ptr stub_; }; -class RPCClient { +class GRPCClient : public RPCClient { public: - RPCClient() {} + GRPCClient() {} - static RPCClient* GetInstance(); + static GRPCClient* GetInstance(); bool AsyncSendVariable(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, - int64_t time_out = 600 * 1000); + int64_t time_out = RPCClient::rpc_time_out) override; bool AsyncGetVariable(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, - int64_t time_out = 600 * 1000); + int64_t time_out = RPCClient::rpc_time_out) override; - bool AsyncPrefetchVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out = 600 * 1000); + bool AsyncPrefetchVariable( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = RPCClient::rpc_time_out) override; - void AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out = 600 * 1000); + void AsyncSendBatchBarrier( + const std::string& ep, + int64_t time_out = RPCClient::rpc_time_out) override; - void AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out = 600 * 1000); + void AsyncSendFetchBarrier( + const std::string& ep, + int64_t time_out = RPCClient::rpc_time_out) override; - bool Wait(); + bool Wait() override; private: bool Proceed(); @@ -205,9 +207,9 @@ class RPCClient { std::map> channels_; std::atomic req_count_{0}; std::mutex mutex_; - static std::unique_ptr rpc_client_; + static std::unique_ptr rpc_client_; static std::once_flag init_flag_; - DISABLE_COPY_AND_ASSIGN(RPCClient); + DISABLE_COPY_AND_ASSIGN(GRPCClient); }; } // namespace detail diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index f97f638701cfb..b57deffc41515 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -127,7 +127,7 @@ TEST(PREFETCH, CPU) { std::thread server_thread(StartServer); g_rpc_service->WaitServerReady(); - detail::RPCClient client; + std::unique_ptr client(new detail::GRPCClient); int port = g_rpc_service->GetSelectedPort(); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); @@ -141,8 +141,8 @@ TEST(PREFETCH, CPU) { std::string in_var_name("ids"); std::string out_var_name("out"); - client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); - client.Wait(); + client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); + client->Wait(); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); auto ptr = value.mutable_data(place); diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h new file mode 100644 index 0000000000000..482e63c55d8b9 --- /dev/null +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -0,0 +1,60 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace operators { +namespace detail { + +class RPCClient { + public: + virtual bool AsyncSendVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = rpc_time_out) = 0; + + virtual bool AsyncGetVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = rpc_time_out) = 0; + + virtual bool AsyncPrefetchVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = rpc_time_out) = 0; + + virtual void AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = rpc_time_out) = 0; + + virtual void AsyncSendFetchBarrier(const std::string& ep, + int64_t time_out = rpc_time_out) = 0; + + virtual bool Wait() = 0; + + static const int64_t rpc_time_out = 600 * 1000; +}; +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 79ec02f520941..fec9b9257f319 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -43,7 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + auto rpc_client = detail::GRPCClient::GetInstance(); PADDLE_ENFORCE(rpc_client->Wait()); diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index d8ddb7b448910..44f042946fa48 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + auto rpc_client = detail::GRPCClient::GetInstance(); for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index bcd8e81609a37..97955def7b1d4 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -44,7 +44,7 @@ class SendBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + auto rpc_client = detail::GRPCClient::GetInstance(); VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index a5150f242ca3b..f39c7b1c1dad6 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -49,7 +49,7 @@ class SendOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + auto rpc_client = detail::GRPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index fe839dab69246..3db8d9297ca26 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -45,7 +45,7 @@ class SendVarsOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + auto rpc_client = detail::GRPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc index a845ba2eb038f..97fb4d3214926 100644 --- a/paddle/fluid/operators/test_send_nccl_id.cc +++ b/paddle/fluid/operators/test_send_nccl_id.cc @@ -88,12 +88,12 @@ TEST(SendNcclId, GrpcServer) { int port = g_rpc_service->GetSelectedPort(); std::string ep = string::Sprintf("127.0.0.1:%d", port); - detail::RPCClient client; + std::unique_ptr client(new detail::GRPCClient); LOG(INFO) << "connect to server" << ep; - client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); - client.Wait(); - client.AsyncSendBatchBarrier(ep); - client.Wait(); + client->AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); + client->Wait(); + client->AsyncSendBatchBarrier(ep); + client->Wait(); server_thread.join(); g_rpc_service.reset(nullptr); From 07872134ce207cca5d3c01bda3ffe5a6f404dd11 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 4 Jun 2018 10:56:14 +0000 Subject: [PATCH 2/9] fix compile error --- paddle/fluid/operators/gen_nccl_id_op.cc | 2 +- paddle/fluid/operators/prefetch_op.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc index 4bce2d322d825..30d935d14688e 100644 --- a/paddle/fluid/operators/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -61,7 +61,7 @@ class GenNCCLIdOp : public framework::OperatorBase { std::vector endpoint_list = Attr>("endpoint_list"); - detail::RPCClient client; + detail::GRPCClient client; for (auto& ep : endpoint_list) { VLOG(3) << "sending nccl id to " << ep; client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME); diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index e0a9b24ac8978..a90c60ef76ec2 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -41,7 +41,7 @@ class PrefetchOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); - auto rpc_client = detail::RPCClient::GetInstance(); + auto rpc_client = detail::GRPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { From 8d4050aca747f16bc15828887d12468a3092c2fe Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 4 Jun 2018 12:24:41 +0000 Subject: [PATCH 3/9] follow comments --- paddle/fluid/operators/detail/grpc_client.cc | 8 ++++---- paddle/fluid/operators/detail/grpc_client.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index cf037a47376ed..e39c66b3e1734 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -27,16 +27,16 @@ namespace detail { std::once_flag GRPCClient::init_flag_; -std::unique_ptr GRPCClient::rpc_client_(nullptr); +std::unique_ptr GRPCClient::grpc_client_(nullptr); GRPCClient* GRPCClient::GetInstance() { std::call_once(init_flag_, &GRPCClient::Init); - return rpc_client_.get(); + return grpc_client_.get(); } void GRPCClient::Init() { - if (rpc_client_.get() == nullptr) { - rpc_client_.reset(new GRPCClient()); + if (grpc_client_.get() == nullptr) { + grpc_client_.reset(new GRPCClient()); } } diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index e9a2e47819100..8506478dc4eac 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -207,7 +207,7 @@ class GRPCClient : public RPCClient { std::map> channels_; std::atomic req_count_{0}; std::mutex mutex_; - static std::unique_ptr rpc_client_; + static std::unique_ptr grpc_client_; static std::once_flag init_flag_; DISABLE_COPY_AND_ASSIGN(GRPCClient); }; From 8432c209a9cc52aa2a2210ff7e61cde00de20c54 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Tue, 5 Jun 2018 02:13:14 +0000 Subject: [PATCH 4/9] follow comments --- paddle/fluid/operators/detail/grpc_client.cc | 30 +++++++++--------- paddle/fluid/operators/detail/grpc_client.h | 23 ++++++-------- paddle/fluid/operators/detail/rpc_client.h | 32 ++++++++++---------- 3 files changed, 40 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index e39c66b3e1734..fd077057a392e 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -40,11 +40,10 @@ void GRPCClient::Init() { } } -bool GRPCClient::AsyncSendVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out) { +bool GRPCClient::AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; @@ -94,11 +93,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { result->Swap(&tmp); } -bool GRPCClient::AsyncGetVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out) { +bool GRPCClient::AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; @@ -136,12 +134,12 @@ bool GRPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool GRPCClient::AsyncPrefetchVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out) { +bool GRPCClient::AsyncPrefetchVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string in_var_name_val = in_var_name; diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 8506478dc4eac..7f6f2f9f90871 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -168,24 +168,21 @@ class GRPCClient : public RPCClient { static GRPCClient* GetInstance(); - bool AsyncSendVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = RPCClient::rpc_time_out) override; + bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + int64_t time_out = RPCClient::rpc_time_out) override; - bool AsyncGetVariable(const std::string& ep, + bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + int64_t time_out = RPCClient::rpc_time_out) override; + + bool AsyncPrefetchVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, - const std::string& var_name, + const std::string& in_var_name, + const std::string& out_var_name, int64_t time_out = RPCClient::rpc_time_out) override; - bool AsyncPrefetchVariable( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out = RPCClient::rpc_time_out) override; - void AsyncSendBatchBarrier( const std::string& ep, int64_t time_out = RPCClient::rpc_time_out) override; diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h index 482e63c55d8b9..747369cf09031 100644 --- a/paddle/fluid/operators/detail/rpc_client.h +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -26,25 +26,25 @@ namespace detail { class RPCClient { public: - virtual bool AsyncSendVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = rpc_time_out) = 0; - - virtual bool AsyncGetVariable(const std::string& ep, + virtual bool AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = rpc_time_out) = 0; + + virtual bool AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = rpc_time_out) = 0; + + virtual bool AsyncPrefetchVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, - const std::string& var_name, + const std::string& in_var_name, + const std::string& out_var_name, int64_t time_out = rpc_time_out) = 0; - virtual bool AsyncPrefetchVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out = rpc_time_out) = 0; - virtual void AsyncSendBatchBarrier(const std::string& ep, int64_t time_out = rpc_time_out) = 0; @@ -53,7 +53,7 @@ class RPCClient { virtual bool Wait() = 0; - static const int64_t rpc_time_out = 600 * 1000; + static constexpr int64_t rpc_time_out = 600 * 1000; }; } // namespace detail } // namespace operators From 60f0ed78dce0cee69458508670bb9abf402ee312 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Tue, 5 Jun 2018 03:25:16 +0000 Subject: [PATCH 5/9] add getinstance template --- paddle/fluid/operators/detail/CMakeLists.txt | 2 +- paddle/fluid/operators/detail/grpc_client.cc | 15 --------------- paddle/fluid/operators/detail/grpc_client.h | 6 ------ .../fluid/operators/detail/grpc_server_test.cc | 6 ++++-- paddle/fluid/operators/detail/rpc_client.h | 18 ++++++++++++++++++ paddle/fluid/operators/fetch_barrier_op.cc | 4 +++- paddle/fluid/operators/gen_nccl_id_op.cc | 7 ++++--- paddle/fluid/operators/prefetch_op.cc | 6 +++--- paddle/fluid/operators/recv_op.cc | 5 +++-- paddle/fluid/operators/send_barrier_op.cc | 3 ++- paddle/fluid/operators/send_op.cc | 7 ++++--- paddle/fluid/operators/send_vars_op.cc | 5 +++-- paddle/fluid/operators/test_send_nccl_id.cc | 5 +++-- 13 files changed, 48 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index cf20530513cf6..c29dc5d7e077a 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_DISTRIBUTE) grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc - request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor + request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows memory) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index fd077057a392e..0b5b815277967 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -25,21 +25,6 @@ namespace paddle { namespace operators { namespace detail { -std::once_flag GRPCClient::init_flag_; - -std::unique_ptr GRPCClient::grpc_client_(nullptr); - -GRPCClient* GRPCClient::GetInstance() { - std::call_once(init_flag_, &GRPCClient::Init); - return grpc_client_.get(); -} - -void GRPCClient::Init() { - if (grpc_client_.get() == nullptr) { - grpc_client_.reset(new GRPCClient()); - } -} - bool GRPCClient::AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 7f6f2f9f90871..d148ad711f373 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -166,8 +166,6 @@ class GRPCClient : public RPCClient { public: GRPCClient() {} - static GRPCClient* GetInstance(); - bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, int64_t time_out = RPCClient::rpc_time_out) override; @@ -196,16 +194,12 @@ class GRPCClient : public RPCClient { private: bool Proceed(); std::shared_ptr GetChannel(const std::string& ep); - // Init is called by GetInstance. - static void Init(); private: grpc::CompletionQueue cq_; std::map> channels_; std::atomic req_count_{0}; std::mutex mutex_; - static std::unique_ptr grpc_client_; - static std::once_flag init_flag_; DISABLE_COPY_AND_ASSIGN(GRPCClient); }; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index b57deffc41515..0b1689091aec9 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_registry.h" @@ -127,7 +128,8 @@ TEST(PREFETCH, CPU) { std::thread server_thread(StartServer); g_rpc_service->WaitServerReady(); - std::unique_ptr client(new detail::GRPCClient); + detail::RPCClient* client = + detail::RPCClient::GetInstance(); int port = g_rpc_service->GetSelectedPort(); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); @@ -141,7 +143,7 @@ TEST(PREFETCH, CPU) { std::string in_var_name("ids"); std::string out_var_name("out"); - client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); + client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name); client->Wait(); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h index 747369cf09031..9cf706bb536df 100644 --- a/paddle/fluid/operators/detail/rpc_client.h +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -54,6 +54,24 @@ class RPCClient { virtual bool Wait() = 0; static constexpr int64_t rpc_time_out = 600 * 1000; + + template + static RPCClient* GetInstance() { + std::call_once(init_flag_, &RPCClient::Init); + return rpc_client_.get(); + } + + // Init is called by GetInstance. + template + static void Init() { + if (rpc_client_.get() == nullptr) { + rpc_client_.reset(new T()); + } + } + + private: + static std::once_flag init_flag_; + static std::unique_ptr rpc_client_; }; } // namespace detail } // namespace operators diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index fec9b9257f319..b2a71cf261631 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::GRPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); PADDLE_ENFORCE(rpc_client->Wait()); diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc index 30d935d14688e..547de4fa49dc1 100644 --- a/paddle/fluid/operators/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase { std::vector endpoint_list = Attr>("endpoint_list"); - detail::GRPCClient client; + detail::RPCClient* client = + detail::RPCClient::GetInstance(); for (auto& ep : endpoint_list) { VLOG(3) << "sending nccl id to " << ep; - client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME); + client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME); } - client.Wait(); + client->Wait(); VLOG(3) << "sending completed..."; } diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index a90c60ef76ec2..27a65fbb19cc7 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); - auto rpc_client = detail::GRPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " << outs[i] << " back"; - rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i], - outs[i]); + rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 44f042946fa48..7e34291a0fcc5 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::GRPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; - rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); } if (sync_mode) { PADDLE_ENFORCE(rpc_client->Wait()); diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 97955def7b1d4..a1497bb365921 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -44,7 +44,8 @@ class SendBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::GRPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index f39c7b1c1dad6..e326375f65bb3 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -49,12 +49,13 @@ class SendOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::GRPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); + rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } @@ -72,7 +73,7 @@ class SendOp : public framework::OperatorBase { if (outs.size() > 0) { for (size_t i = 0; i < outs.size(); i++) { VLOG(2) << "getting " << outs[i] << " from " << epmap[i]; - rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); } PADDLE_ENFORCE(rpc_client->Wait()); // tell pservers that current trainer have called fetch diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index 3db8d9297ca26..564e40461f8f8 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -45,14 +45,15 @@ class SendVarsOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::GRPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; // TODO(Yancey1989): we need to use an IO threadpool which has // a larger number of threads than the computing threadpool. - rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); + rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc index 97fb4d3214926..90ff7b9571a31 100644 --- a/paddle/fluid/operators/test_send_nccl_id.cc +++ b/paddle/fluid/operators/test_send_nccl_id.cc @@ -88,9 +88,10 @@ TEST(SendNcclId, GrpcServer) { int port = g_rpc_service->GetSelectedPort(); std::string ep = string::Sprintf("127.0.0.1:%d", port); - std::unique_ptr client(new detail::GRPCClient); + detail::RPCClient* client = + detail::RPCClient::GetInstance(); LOG(INFO) << "connect to server" << ep; - client->AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); + client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME); client->Wait(); client->AsyncSendBatchBarrier(ep); client->Wait(); From a1db10bf962ad2ad7e12290795756388d38d8bd8 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Tue, 5 Jun 2018 03:26:23 +0000 Subject: [PATCH 6/9] add rpc_client.cc --- paddle/fluid/operators/detail/rpc_client.cc | 26 +++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 paddle/fluid/operators/detail/rpc_client.cc diff --git a/paddle/fluid/operators/detail/rpc_client.cc b/paddle/fluid/operators/detail/rpc_client.cc new file mode 100644 index 0000000000000..9a791403e3d6b --- /dev/null +++ b/paddle/fluid/operators/detail/rpc_client.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/detail/rpc_client.h" + +namespace paddle { +namespace operators { +namespace detail { + +std::once_flag RPCClient::init_flag_; +std::unique_ptr RPCClient::rpc_client_(nullptr); + +} // namespace detail +} // namespace operators +} // namespace paddle From 0e35c704ac1ef4f30a9105ed47380dd32a110a5e Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 6 Jun 2018 01:25:28 +0000 Subject: [PATCH 7/9] fix bug --- paddle/fluid/operators/detail/grpc_client.cc | 22 +++++++++++++++++++- paddle/fluid/operators/detail/grpc_client.h | 3 +++ paddle/fluid/operators/detail/rpc_client.h | 4 ++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 3073eee1ce6ff..056bbbd6790c6 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -25,6 +25,26 @@ namespace paddle { namespace operators { namespace detail { +void GRPCClient::Init() { rpc_client_->InitEventLoop(); } + +void GRPCClient::InitEventLoop() { + // start the client process thread + // TODO(wuyi): can make this in a threadpool + client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this))); +} + +GRPCClient::~GRPCClient() { + Wait(); + cq_.Shutdown(); + { + std::lock_guard guard(chan_mutex_); + for (auto& it : channels_) { + it.second.reset(); + } + } + client_thread_->join(); +} + bool GRPCClient::AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, @@ -194,7 +214,7 @@ void GRPCClient::Wait() { sync_cond_.wait(lk, [this] { return req_count_ == 0; }); } -void RPCClient::Proceed() { +void GRPCClient::Proceed() { void* tag = nullptr; bool ok = false; diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 4df65d4dfcef3..9677caf801052 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -210,6 +210,9 @@ class GRPCClient : public RPCClient { std::condition_variable sync_cond_; std::atomic req_count_{0}; std::mutex mutex_; + + // mutex for GetChannel thread safety + std::mutex chan_mutex_; DISABLE_COPY_AND_ASSIGN(GRPCClient); }; diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h index 9cf706bb536df..03249453cfa89 100644 --- a/paddle/fluid/operators/detail/rpc_client.h +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -66,9 +66,13 @@ class RPCClient { static void Init() { if (rpc_client_.get() == nullptr) { rpc_client_.reset(new T()); + rpc_client_->Init(); } } + protected: + virtual void Init() = 0; + private: static std::once_flag init_flag_; static std::unique_ptr rpc_client_; From 06f72958ded8c75616b2665c6126dbe2c0e5d7b9 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 6 Jun 2018 01:43:39 +0000 Subject: [PATCH 8/9] fix return value --- paddle/fluid/operators/detail/grpc_client.cc | 4 ++-- paddle/fluid/operators/detail/grpc_client.h | 11 ++++++++--- paddle/fluid/operators/detail/rpc_client.h | 6 +++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 056bbbd6790c6..fae39418b4166 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -25,12 +25,12 @@ namespace paddle { namespace operators { namespace detail { -void GRPCClient::Init() { rpc_client_->InitEventLoop(); } +void GRPCClient::InitImpl() { InitEventLoop(); } void GRPCClient::InitEventLoop() { // start the client process thread // TODO(wuyi): can make this in a threadpool - client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this))); + client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); } GRPCClient::~GRPCClient() { diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 9677caf801052..8db73f875e3e2 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -168,6 +168,7 @@ class FetchBarrierProcessor : public BaseProcessor { class GRPCClient : public RPCClient { public: GRPCClient() {} + virtual ~GRPCClient(); bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, @@ -192,12 +193,17 @@ class GRPCClient : public RPCClient { const std::string& ep, int64_t time_out = RPCClient::rpc_time_out) override; - bool Wait() override; + void Wait() override; + + protected: + void InitImpl() override; + + private: // InitEventLoop should only be called by Init() void InitEventLoop(); - private: void Proceed(); + std::shared_ptr GetChannel(const std::string& ep); private: @@ -209,7 +215,6 @@ class GRPCClient : public RPCClient { std::mutex sync_mutex_; std::condition_variable sync_cond_; std::atomic req_count_{0}; - std::mutex mutex_; // mutex for GetChannel thread safety std::mutex chan_mutex_; diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h index 03249453cfa89..7e4f9a0bb81c1 100644 --- a/paddle/fluid/operators/detail/rpc_client.h +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -51,7 +51,7 @@ class RPCClient { virtual void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = rpc_time_out) = 0; - virtual bool Wait() = 0; + virtual void Wait() = 0; static constexpr int64_t rpc_time_out = 600 * 1000; @@ -66,12 +66,12 @@ class RPCClient { static void Init() { if (rpc_client_.get() == nullptr) { rpc_client_.reset(new T()); - rpc_client_->Init(); + rpc_client_->InitImpl(); } } protected: - virtual void Init() = 0; + virtual void InitImpl() = 0; private: static std::once_flag init_flag_; From 806026345895ca347e94a8385bda111c89fdfe26 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 6 Jun 2018 07:22:56 +0000 Subject: [PATCH 9/9] follow comments --- paddle/fluid/operators/detail/rpc_client.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h index 7e4f9a0bb81c1..8e3717f076db6 100644 --- a/paddle/fluid/operators/detail/rpc_client.h +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -53,7 +53,7 @@ class RPCClient { virtual void Wait() = 0; - static constexpr int64_t rpc_time_out = 600 * 1000; + static constexpr int64_t rpc_time_out = 120 * 1000; template static RPCClient* GetInstance() { @@ -71,7 +71,7 @@ class RPCClient { } protected: - virtual void InitImpl() = 0; + virtual void InitImpl() {} private: static std::once_flag init_flag_;