From 41701969a9e73fe85bbcbf99265cb84ecf512f4d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 13 Jun 2018 22:34:00 +0800 Subject: [PATCH 01/67] [wip] ckpt m2 develop --- .../fluid/operators/detail/request_handler.h | 1 + .../operators/detail/request_handler_impl.h | 10 ++++ paddle/fluid/operators/listen_and_serv_op.cc | 8 +++ paddle/fluid/operators/listen_and_serv_op.h | 2 + paddle/fluid/operators/save_op.cc | 50 +++++++++++++++---- .../fluid/transpiler/distribute_transpiler.py | 20 ++++++++ 6 files changed, 81 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index a2d08747d5922..cb480accb4ea2 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -36,6 +36,7 @@ namespace detail { constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestPrefetch[] = "RequestPrefetch"; +constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h index 3f77c09a9598b..643eae4d31438 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -66,6 +66,16 @@ class RequestPrefetchHandler final : public RequestHandler { const std::string& out_var_name = "") override; }; +class RequestCheckpointHandler final : public RequestHandler { + public: + explicit RequestCheckpointHandler(bool sync_mode) + : RequestHandler(sync_mode) {} + virtual ~RequestCheckpointHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar, + const std::string& out_var_name = "") override; +}; + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 4d12278799f66..0804a266d0f1e 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -253,11 +253,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); request_prefetch_handler_.reset( new detail::RequestPrefetchHandler(sync_mode)); + request_checkpoint_handler_.reset( + new detail::RequestCheckpointHandler(sync_mode)); rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestPrefetch, request_prefetch_handler_.get()); + rpc_service_->RegisterRPC(detail::kRequestCheckpoint, + request_checkpoint_handler_.get()); auto *optimize_block = Attr(kOptimizeBlock); auto *program = optimize_block->Program(); @@ -300,6 +304,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, f(request_send_handler_.get()); f(request_get_handler_.get()); f(request_prefetch_handler_.get()); + f(request_checkpoint_handler_.get()); // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); @@ -344,6 +349,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({}); AddAttr("Fanin", "How many clients send to this server.") .SetDefault(1); + AddAttr(kCheckpointBlockId, + "BolckID to run save checkpoint on pserer.") + .SetDefault(-1); } }; diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 46c3a19e20b3f..b00ad195e9e16 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -32,6 +32,7 @@ namespace operators { constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; +constexpr char kCheckpointBlockId[] = "checkpint_block_id"; void RunServer(std::shared_ptr service); @@ -66,6 +67,7 @@ class ListenAndServOp : public framework::OperatorBase { mutable std::shared_ptr request_send_handler_; mutable std::shared_ptr request_get_handler_; mutable std::shared_ptr request_prefetch_handler_; + mutable std::shared_ptr request_checkpoint_handler_; mutable std::shared_ptr server_thread_; }; diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index e6d27e2dedd76..410796eeb6cd1 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -78,26 +79,37 @@ class SaveOp : public framework::OperatorBase { MkDirRecursively(DirName(filename).c_str()); - // FIXME(yuyang18): We save variable to local file now, but we should change - // it to save an output stream. - std::ofstream fout(filename); - PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", - filename); - auto iname = Input("X"); auto *var = scope.FindVar(iname); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", iname); - PADDLE_ENFORCE(var->IsType(), - "SaveOp only support LoDTensor, %s has wrong type", iname); + if (var->IsType()) { + SaveLodTensor(filename, place, var); + } else if (var->IsType()) { + SaveSelectedRows(filename, place, var); + } else { + PADDLE_ENFORCE( + false, + "SaveOp only support LoDTensor and SelectedRows, %s has wrong type", + iname); + } + } + SaveLodTensor(const string &filename, const platform::Place &place, + Variable *var) { auto &tensor = var->Get(); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + auto in_dtype = framework::ToDataType(tensor.type()); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; @@ -112,17 +124,35 @@ class SaveOp : public framework::OperatorBase { } else { framework::SerializeToStream(fout, tensor, dev_ctx); } + fout.close() + } + + SaveSelectedRows(const string &filename, const platform::Place &place, + Variable *var) { + auto &selectedRows = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + framework::SerializeToStream(fout, selectedRows, dev_ctx); + fout.close() } }; class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(Tensor ) Input tensor to be saved"); + AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved"); AddComment(R"DOC( Save operator -This operator will serialize and write a tensor variable to file on disk. +This operator will serialize and write a tensor/selected rows variable to file on disk. )DOC"); AddAttr("overwrite", "(boolean, default true)" diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 2480d4e76a1b5..caad745b1fb73 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -522,6 +522,8 @@ def __op_have_grad_input__(op): pserver_index, pserver_program, pre_block_idx, grad_to_block_id) prefetch_var_name_to_block_id = self._create_prefetch_block( pserver_index, pserver_program, table_opt_block) + checkpoint_block_id = self._create_checkpoint_save_block( + pserver_program, table_opt_block.idx) # NOTE: if has_distributed_lookup_table is False, then prefetch_block will # not be executed, so it's safe to use optimize_block to hold the place @@ -540,6 +542,7 @@ def __op_have_grad_input__(op): if len(prefetch_var_name_to_block_id) > 0: attrs['prefetch_var_name_to_block_id'] \ = prefetch_var_name_to_block_id + attrs['checkpint_block_id'] = checkpoint_block_id # step5 append the listen_and_serv op pserver_program.global_block().append_op( @@ -824,6 +827,23 @@ def _create_table_optimize_block(self, pserver_index, pserver_program, return table_opt_block + def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): + """ + create a new block to handle save checkpoint. + """ + import os + + checkpoint_save_block = pserver_program.create_block(pre_block_idx) + checkpoint_save_block.append_op( + type='save', + inputs={'X': [self.table_name]}, + outputs={}, + attrs={ + 'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name) + }) + + return checkpoint_save_block.idx + def _create_vars_from_blocklist(self, program, block_list, From a8959162749257cb52449a8effda19bd0c191205 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 14 Jun 2018 11:33:39 +0800 Subject: [PATCH 02/67] [wip] add load lookup table in io and trianer --- python/paddle/fluid/io.py | 13 ++++++++++- python/paddle/fluid/trainer.py | 42 ++++++++++++++++------------------ 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 6323c9899e008..0fb88de0bbbf6 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -25,7 +25,8 @@ 'load_persistables', 'save_inference_model', 'load_inference_model', 'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'clean_checkpoint', 'load_persist_vars_without_grad', - 'save_persist_vars_without_grad', 'get_latest_checkpoint_serial' + 'load_lookup_table_vars', 'save_persist_vars_without_grad', + 'get_latest_checkpoint_serial' ] @@ -459,7 +460,9 @@ def get_parameter_value_by_name(name, executor, program=None): SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" MODEL_DIR = "__model__" +LOOKUP_TABLE_DIR = "__lookup_table__" TRAINER_PREFIX = "trainer" +PSERVER_PREFIX = "pserver" CHECKPOINT_SEPARATOR = "_" @@ -567,6 +570,14 @@ def load_persist_vars_without_grad(executor, filename=None) +def load_lookup_table_vars(executor, dirname, pserver_id, table_name): + lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + table_file = table_name + CHECKPOINT_SEPARATOR + PSERVER_PREFIX + CHECKPOINT_SEPARATOR + str( + pserver_id) + + load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file) + + def save_persist_vars_without_grad(executor, dirname, program): """ save_persist_vars_without_grad will save variables to a directory by an executor, diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index efc28d899304b..2cb908f799bf8 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -62,27 +62,20 @@ def __init__(self, max_num_checkpoints=3, epoch_interval=1, step_interval=10): - if checkpoint_dir is None: - self.checkpoint_dir = os.getcwd() - else: - self.checkpoint_dir = checkpoint_dir - - self.max_num_checkpoints = max_num_checkpoints - - if epoch_interval < 1: - self.epoch_interval = 1 - else: - self.epoch_interval = epoch_interval - if step_interval < 1: - self.step_interval = 10 - else: - self.step_interval = step_interval + assert epoch_interval >= 1 + assert step_interval >= 1 + self.checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else os.getcwd( + ) + self.max_num_checkpoints = max_num_checkpoints + self.epoch_interval = epoch_interval + self.step_interval = step_interval self.epoch_id = 0 self.step_id = 0 self.load_serial = None self.is_pserver = False + self.has_lookup_table = False def check_and_get_place(place): @@ -181,13 +174,18 @@ def __init__(self, self.checkpoint_cfg.load_serial, self.startup_program) - if not self.checkpoint_cfg.is_pserver: - epoch_id, step_id = io.load_trainer_args( - self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial, self.trainer_id, - self._get_checkpoint_load_args()) - self.checkpoint_cfg.epoch_id = int(epoch_id) - self.checkpoint_cfg.step_id = int(step_id) + if not self.checkpoint_cfg.is_pserver: + epoch_id, step_id = io.load_trainer_args( + self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial, self.trainer_id, + self._get_checkpoint_load_args()) + self.checkpoint_cfg.epoch_id = int(epoch_id) + self.checkpoint_cfg.step_id = int(step_id) + else: + if self.checkpoint_cfg.has_lookup_table: + io.load_lookup_table_vars( + exe, self.checkpoint_cfg.checkpoint_dir, 0, + "table_name") if param_path and os.path.isdir(param_path): # load params from param_path into scope From 8a178165a6ae4d19f226e2e13d291d96be61798e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 14 Jun 2018 20:33:51 +0800 Subject: [PATCH 03/67] add lookuo table in python --- python/paddle/fluid/io.py | 30 +++++++++++++++++++++++++++++- python/paddle/fluid/trainer.py | 3 ++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 0fb88de0bbbf6..6a0e422cb3761 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -500,6 +500,7 @@ def save_checkpoint(executor, if trainer_id == 0: save_persist_vars_without_grad(executor, cur_dir, main_program) + save_pserver_vars_by_notify(executor, cur_dir, "") _scroll_delete(checkpoint_dir, max_num_checkpoints) @@ -530,7 +531,8 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): def clean_checkpoint(checkpoint_dir, delete_dir=False): """ - clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. + clean the checkpoint dir, when the train exits normally, + the trainer will call clean_checkpoint to delete checkpoint directory saved before. delete_dir only works when the directory is empty, otherwise, OSError is raised. :param checkpoint_dir @@ -598,6 +600,23 @@ def save_persist_vars_without_grad(executor, dirname, program): _write_success(cur_dir) +def save_pserver_vars_by_notify(executor, dirname, epmap): + """ + """ + cur_dir = _get_lookuptable_dir(dirname) + + checkpoint_notify_program = Program() + checkpoint_notify_block = checkpoint_notify_program.global_block() + + attrs = {} + attrs['epmap'] = None + attrs['dir'] = cur_dir + + checkpoint_notify_block.append_op( + type='checkpointnotify', inputs={}, output={}, attrs=attrs) + executor.run(checkpoint_notify_program) + + def save_trainer_args(dirname, trainer_id, trainer_args): assert isinstance(trainer_args, dict) @@ -680,6 +699,15 @@ def _get_model_dir(dirname): return model_dir +def _get_lookuptable_dir(dirname): + lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + + if not os.path.isdir(lookuptable_dir): + os.makedirs(lookuptable_dir) + + return lookuptable_dir + + def _get_trainer_dir(dirname, trainer_id): trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) trainer_dir = os.path.join(dirname, trainer_folder) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 2cb908f799bf8..f77c0f65dcb97 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -446,7 +446,8 @@ def _get_checkpoint_save_args(self, epoch_id, step_id): def _save_checkpoint(self, epoch_id, step_id): assert self.checkpoint_cfg - if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0: + if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ + and step_id % self.checkpoint_cfg.step_interval == 0: exe = executor.Executor(self.place) io.save_checkpoint( executor=exe, From 12de20f5f75f2943f24651dd01bfa32f7f491a4e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 14 Jun 2018 20:35:15 +0800 Subject: [PATCH 04/67] add checkpoint_notify_op for trainer to notify pserver, update listen_and_serv_op --- .../fluid/operators/checkpoint_notify_op.cc | 81 +++++++++++++++++++ paddle/fluid/operators/listen_and_serv_op.cc | 13 ++- 2 files changed, 91 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/checkpoint_notify_op.cc diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc new file mode 100644 index 0000000000000..1b922e08907dc --- /dev/null +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/macros.h" +#include "paddle/fluid/operators/send_recv_util.h" + +namespace paddle { +namespace operators { + +class CheckpointNotifyOp : public framework::OperatorBase { + public: + CheckpointNotifyOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + std::vector epmap = Attr>("epmap"); + std::string dir = Attr("dir"); + + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); + VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " + << outs[i] << " back"; + rpc_client->AsyncCheckpointNotify(epmap[i], dir); + rpc_client->Wait(); + } +}; + +class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddAttr>( + "epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input variables for mapping") + .SetDefault({"127.0.0.1:6164"}); + AddAttr( + "dir", "(string, default '') indicate the folder checkpoint will use"); + AddComment(R"DOC( +Prefetch operator + +This operator will send Ids variables to listen_and_serve op at +the parameter server and fetch result back. +)DOC"); + } +}; + +class CheckpointNotifyOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(checkpointnotify, ops::CheckpointNotifyOp, + paddle::framework::EmptyGradOpMaker, + ops::CheckpointNotifyOpMaker, + ops::CheckpointNotifyOpShapeInference); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 0804a266d0f1e..088366dac7b28 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -221,6 +221,7 @@ static void FillRequestCtx( std::unordered_map> *prefetch_ctx, + std::shared_ptr checkpoint_ctx, detail::RPCServer *rpc_server) { h->SetScope(scope); h->SetDevCtx(dev_ctx); @@ -228,6 +229,7 @@ static void FillRequestCtx( h->SetProgram(program); h->SetPrefetchPreparedCtx(prefetch_ctx); h->SetRPCServer(rpc_server); + h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); } void ListenAndServOp::RunImpl(const framework::Scope &scope, @@ -297,9 +299,14 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; } - auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, - &dev_ctx, &executor, program, - &prefetch_var_name_to_prepared_ctx, rpc_service_.get()); + int checkpoint_point_block_id = Attr(kCheckpointBlockId); + std::shared_ptr ckpt_pre_context = + executor.Prepare(*program, checkpoint_point_block_id); + + auto f = + std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, + &executor, program, &prefetch_var_name_to_prepared_ctx, + &ckpt_pre_context, rpc_service_.get()); f(request_send_handler_.get()); f(request_get_handler_.get()); From b089b8098872669e2f7ec1125b37e37a033b8b8e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 14 Jun 2018 20:35:58 +0800 Subject: [PATCH 05/67] update rpc to add checkpoint notify --- paddle/fluid/operators/detail/grpc_client.cc | 16 ++++++++++++++++ paddle/fluid/operators/detail/grpc_client.h | 18 ++++++++++++++++++ paddle/fluid/operators/detail/grpc_service.h | 3 +++ .../fluid/operators/detail/request_handler.h | 10 ++++++++++ .../operators/detail/request_handler_impl.cc | 6 ++++++ paddle/fluid/operators/detail/rpc_client.h | 4 ++++ paddle/fluid/operators/detail/send_recv.proto | 7 +++++++ 7 files changed, 64 insertions(+) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 02ffe3651e1de..8898438675687 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -229,6 +229,22 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { req_count_++; } +void GRPCClient::AsyncCheckpointNotify(const std::string& ep, + const std::string& dir, + int64_t time_out) { + const auto ch = GetChannel(ep); + CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); + s.prepare(time_out); + + sendrecv::CheckpointMessage req; + req.set_notify_type(CHECKPOINT_SAVE_MESSAGE); + req.set_checkpoint_dir(dir); + + auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; +} + void GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return req_count_ == 0; }); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 44000c028b499..bc3deff47cec1 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -165,6 +165,20 @@ class FetchBarrierProcessor : public BaseProcessor { std::unique_ptr stub_; }; +class CheckpointNotifyProcessor : public BaseProcessor { + public: + explicit CheckpointNotifyProcessor(std::shared_ptr ch) + : BaseProcessor(ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + } + + virtual ~CheckpointNotifyProcessor() {} + + virtual void Process() {} + sendrecv::VoidMessage reply_; + std::unique_ptr stub_; +} + class GRPCClient : public RPCClient { public: GRPCClient() {} @@ -193,6 +207,10 @@ class GRPCClient : public RPCClient { const std::string& ep, int64_t time_out = RPCClient::rpc_time_out) override; + void AsyncCheckpointNotify( + const std::string& ep, const std::string& dir, + int64_t time_out = RPCClient::rpc_time_out) override; + void Wait() override; void SendComplete() override; diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h index e0505c2b9d090..69200a01d3c8c 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -79,6 +79,7 @@ enum class GrpcMethod { kSendVariable, kGetVariable, kPrefetchVariable, + kCheckpointNotify, }; static const int kGrpcNumMethods = @@ -92,6 +93,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { return "/sendrecv.SendRecvService/GetVariable"; case GrpcMethod::kPrefetchVariable: return "/sendrecv.SendRecvService/PrefetchVariable"; + case GrpcMethod::kCheckpointNotify: + return "/sendrecv.SendRecvService/CheckpointNotify"; } // Shouldn't be reached. diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index cb480accb4ea2..fd33521fd1488 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -43,6 +43,9 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" +#define CHECKPOINT_SAVE_MESSAGE "SAVE" +#define CHECKPOINT_LOAD_MESSAGE "LOAD" + class RPCServer; class RequestHandler { @@ -70,6 +73,11 @@ class RequestHandler { prefetch_var_name_to_prepared_ctx_ = g; } + void SetCheckpointNotifyPreparedCtx( + std::shared_ptr g) { + checkpoint_prepared_ctx_ = g; + } + // Used for async. void SetGradToPreparedCtx( std::unordered_map< @@ -116,6 +124,8 @@ class RequestHandler { std::unordered_map>* prefetch_var_name_to_prepared_ctx_; + // used for checkpoint notify + std::shared_ptr checkpoint_prepared_ctx_; // Used for async. std::unordered_map Date: Fri, 15 Jun 2018 14:28:24 +0800 Subject: [PATCH 06/67] bug fix --- paddle/fluid/operators/detail/grpc_client.cc | 2 +- paddle/fluid/operators/detail/request_handler_impl.cc | 5 ++++- paddle/fluid/operators/save_op.cc | 6 +++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 8898438675687..1dff3bfa3cbb9 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -240,7 +240,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, req.set_notify_type(CHECKPOINT_SAVE_MESSAGE); req.set_checkpoint_dir(dir); - auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq); + auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; } diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index ffad66700e589..de6ce72d4dcd3 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -123,7 +123,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, framework::Variable** outvar, - const std::string& out_var_name) {} + const std::string& out_var_name) { + executor_->RunPreparedContext(checkpoint_prepared_ctx_); + return true; +} } // namespace detail } // namespace operators diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 410796eeb6cd1..3d114538eb881 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -96,7 +96,7 @@ class SaveOp : public framework::OperatorBase { } } - SaveLodTensor(const string &filename, const platform::Place &place, + SaveLodTensor(const std::string &filename, const platform::Place &place, Variable *var) { auto &tensor = var->Get(); @@ -127,7 +127,7 @@ class SaveOp : public framework::OperatorBase { fout.close() } - SaveSelectedRows(const string &filename, const platform::Place &place, + SaveSelectedRows(const std::string &filename, const platform::Place &place, Variable *var) { auto &selectedRows = var->Get(); @@ -141,7 +141,7 @@ class SaveOp : public framework::OperatorBase { PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", filename); framework::SerializeToStream(fout, selectedRows, dev_ctx); - fout.close() + fout.close(); } }; From 1cb0ab36f08b4746fce77a4ff3444507226d3e52 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 14:34:13 +0800 Subject: [PATCH 07/67] bug fix --- paddle/fluid/operators/detail/grpc_client.cc | 2 +- paddle/fluid/operators/detail/grpc_client.h | 2 +- paddle/fluid/operators/detail/request_handler_impl.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 1dff3bfa3cbb9..9a25ec8fdb459 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -234,7 +234,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); - s.prepare(time_out); + s->Prepare(time_out); sendrecv::CheckpointMessage req; req.set_notify_type(CHECKPOINT_SAVE_MESSAGE); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index bc3deff47cec1..0c54ec0efefcc 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -177,7 +177,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { virtual void Process() {} sendrecv::VoidMessage reply_; std::unique_ptr stub_; -} +}; class GRPCClient : public RPCClient { public: diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index de6ce72d4dcd3..ba7d027637a24 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -124,7 +124,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const std::string& out_var_name) { - executor_->RunPreparedContext(checkpoint_prepared_ctx_); + executor_->RunPreparedContext(checkpoint_prepared_ctx_, scope); return true; } From fb27c9a5a34469cfd93fe1974c8ea43e444a4591 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 14:37:16 +0800 Subject: [PATCH 08/67] bug fix --- paddle/fluid/operators/detail/request_handler_impl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index ba7d027637a24..41b22e2143f00 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -124,7 +124,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const std::string& out_var_name) { - executor_->RunPreparedContext(checkpoint_prepared_ctx_, scope); + executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; } From fe76244f0ee363c194265ebac7abbcc9cf5e5e68 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 15:13:02 +0800 Subject: [PATCH 09/67] bug fix --- paddle/fluid/operators/save_op.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 3d114538eb881..3277d09ab20ae 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -70,7 +71,6 @@ class SaveOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); - auto save_as_fp16 = Attr("save_as_fp16"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -97,7 +97,7 @@ class SaveOp : public framework::OperatorBase { } SaveLodTensor(const std::string &filename, const platform::Place &place, - Variable *var) { + framework::Variable *var) { auto &tensor = var->Get(); // get device context from pool @@ -110,6 +110,7 @@ class SaveOp : public framework::OperatorBase { PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", filename); + auto save_as_fp16 = Attr("save_as_fp16"); auto in_dtype = framework::ToDataType(tensor.type()); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; @@ -124,11 +125,11 @@ class SaveOp : public framework::OperatorBase { } else { framework::SerializeToStream(fout, tensor, dev_ctx); } - fout.close() + fout.close(); } SaveSelectedRows(const std::string &filename, const platform::Place &place, - Variable *var) { + framework::Variable *var) { auto &selectedRows = var->Get(); // get device context from pool From 98c30c7cbea4f4ed00ed22fdc8fff06268e74629 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 15:39:03 +0800 Subject: [PATCH 10/67] bug fix --- paddle/fluid/operators/save_op.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 3277d09ab20ae..b54bd7db36745 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -96,8 +96,8 @@ class SaveOp : public framework::OperatorBase { } } - SaveLodTensor(const std::string &filename, const platform::Place &place, - framework::Variable *var) { + void SaveLodTensor(const std::string &filename, const platform::Place &place, + framework::Variable *var) const { auto &tensor = var->Get(); // get device context from pool @@ -128,8 +128,9 @@ class SaveOp : public framework::OperatorBase { fout.close(); } - SaveSelectedRows(const std::string &filename, const platform::Place &place, - framework::Variable *var) { + void SaveSelectedRows(const std::string &filename, + const platform::Place &place, + framework::Variable *var) const { auto &selectedRows = var->Get(); // get device context from pool From f224948f310963d91d823679cfd6e16d0186ae00 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 16:38:44 +0800 Subject: [PATCH 11/67] bug fix --- paddle/fluid/operators/listen_and_serv_op.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 088366dac7b28..f235c86ad5699 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -300,8 +300,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, } int checkpoint_point_block_id = Attr(kCheckpointBlockId); + auto *ctx = new ExecutorPrepareContext(*program, checkpoint_point_block_id); + std::shared_ptr ckpt_pre_context = - executor.Prepare(*program, checkpoint_point_block_id); + std::shared_ptr(ctx); auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, From 8d46d1ddf2dac143bfb009da2205ea68215d5cd8 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 17:08:26 +0800 Subject: [PATCH 12/67] bug fix --- paddle/fluid/operators/checkpoint_notify_op.cc | 7 ++++--- paddle/fluid/operators/listen_and_serv_op.cc | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index 1b922e08907dc..c229cbf4984d9 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -39,9 +39,10 @@ class CheckpointNotifyOp : public framework::OperatorBase { detail::RPCClient* rpc_client = detail::RPCClient::GetInstance(); - VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " - << outs[i] << " back"; - rpc_client->AsyncCheckpointNotify(epmap[i], dir); + for (size_t i = 0; i < epmap.size(); i++) { + VLOG(3) << "sending to " << epmap[i] << " to checkpoint notify ... "; + rpc_client->AsyncCheckpointNotify(epmap[i], dir); + } rpc_client->Wait(); } }; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index f235c86ad5699..780d47f385a15 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -300,10 +300,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, } int checkpoint_point_block_id = Attr(kCheckpointBlockId); - auto *ctx = new ExecutorPrepareContext(*program, checkpoint_point_block_id); + auto ctx = executor.Prepare(*program, checkpoint_point_block_id); std::shared_ptr ckpt_pre_context = - std::shared_ptr(ctx); + std::move(ctx); auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, From 860360d96d0e5f0606a6f528047158f87b7e2e17 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 17:29:31 +0800 Subject: [PATCH 13/67] bug fix --- paddle/fluid/operators/listen_and_serv_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 780d47f385a15..13dfe45bb292e 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -308,7 +308,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor, program, &prefetch_var_name_to_prepared_ctx, - &ckpt_pre_context, rpc_service_.get()); + ckpt_pre_context, rpc_service_.get()); f(request_send_handler_.get()); f(request_get_handler_.get()); From 1c2e9bdd493686721cfec61d52e481b2a4380d52 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 18:13:43 +0800 Subject: [PATCH 14/67] fix cmakelist --- paddle/fluid/operators/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d6a36eff09c7f..8c08ae3430674 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -197,6 +197,8 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + op_library(checkpoint_notify_op DEPS ${DISTRIBUTE_DEPS}) + set_source_files_properties(checkpoint_notify_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(recv_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS}) @@ -223,7 +225,7 @@ if(WITH_DISTRIBUTE) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) endif() else() - set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op) + set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op) endif() op_library(cross_entropy_op DEPS cross_entropy) From 985026ce42f538f61dad9286c9dfb86929f5115a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 18:37:14 +0800 Subject: [PATCH 15/67] add checkpoint_notify in python --- paddle/fluid/operators/checkpoint_notify_op.cc | 2 +- python/paddle/fluid/framework.py | 2 +- python/paddle/fluid/io.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index c229cbf4984d9..026ad722c275f 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -76,7 +76,7 @@ class CheckpointNotifyOpShapeInference : public framework::InferShapeBase { namespace ops = paddle::operators; -REGISTER_OPERATOR(checkpointnotify, ops::CheckpointNotifyOp, +REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp, paddle::framework::EmptyGradOpMaker, ops::CheckpointNotifyOpMaker, ops::CheckpointNotifyOpShapeInference); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index bbd35aaecba27..edc7ba69dd684 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -382,7 +382,7 @@ class Operator(object): 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'ncclInit', 'channel_create', 'channel_close', 'channel_send', - 'channel_recv', 'select' + 'channel_recv', 'select', 'checkpoint_notify' } def __init__(self, diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 6a0e422cb3761..253fd5651c67d 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -613,7 +613,7 @@ def save_pserver_vars_by_notify(executor, dirname, epmap): attrs['dir'] = cur_dir checkpoint_notify_block.append_op( - type='checkpointnotify', inputs={}, output={}, attrs=attrs) + type='checkpoint_notify', inputs={}, output={}, attrs=attrs) executor.run(checkpoint_notify_program) From 925e2324b3b7920d66217ec2c870010f24c0967a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 10:50:12 +0800 Subject: [PATCH 16/67] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 33 ++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 2d34f85838c34..3e5625fa28cea 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -185,6 +185,37 @@ class RequestPrefetch final : public RequestBase { framework::Scope* local_scope_; }; +class RequestCheckpointNotify final : public RequestBase { + public: + explicit RequestCheckpointNotify(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), + responder_(&ctx_), + local_scope_(nullptr) { + request_.reset(new VariableResponse(request_handler->scope(), + request_handler->dev_ctx(), true)); + int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id))); + } + + virtual ~RequestCheckpointNotify() {} + + std::string GetReqName() override { return request_->Varname(); } + + void Process() override { + auto scope = request_->GetMutableLocalScope(); + std::string nullptr_str = nullptr; + framework::Variable* invar = nullptr; + framework::Variable* outvar = nullptr; + + request_handler_->Handle(nullptr_str, scope, invar, &outvar, nullptr_str); + Finish(reply_, &responder_); + } +} + void AsyncGRPCServer::WaitServerReady() { VLOG(3) << "AsyncGRPCServer is wait server ready"; std::unique_lock lock(this->mutex_ready_); @@ -288,6 +319,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, b = new RequestGet(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestPrefetch) { b = new RequestPrefetch(&service_, cq.get(), handler, req_id); + } else if (rpc_name == kRequestCheckpoint) { + b = new RequestCheckpoin } else { PADDLE_ENFORCE(false, "not supported rpc"); } From a9ac2007f2248ba8c08e8f42acab9d49f7d13dff Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 10:58:58 +0800 Subject: [PATCH 17/67] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 3e5625fa28cea..75f0a1f789311 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -320,7 +320,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, } else if (rpc_name == kRequestPrefetch) { b = new RequestPrefetch(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestCheckpoint) { - b = new RequestCheckpoin + b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); } else { PADDLE_ENFORCE(false, "not supported rpc"); } From 36d17d11a41bf08a1b93491183ae249d48e7ffe7 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 11:13:32 +0800 Subject: [PATCH 18/67] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 75f0a1f789311..21bd232260f2e 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -203,7 +203,7 @@ class RequestCheckpointNotify final : public RequestBase { virtual ~RequestCheckpointNotify() {} - std::string GetReqName() override { return request_->Varname(); } + std::string GetReqName() override { return "checkpoint_notify"; } void Process() override { auto scope = request_->GetMutableLocalScope(); @@ -214,6 +214,11 @@ class RequestCheckpointNotify final : public RequestBase { request_handler_->Handle(nullptr_str, scope, invar, &outvar, nullptr_str); Finish(reply_, &responder_); } + + protected: + sendrecv::CheckpointMessage request_; + sendrecv::VoidMessage reply_; + ServerAsyncResponseWriter responder_; } void AsyncGRPCServer::WaitServerReady() { From 74384b750eb35648bdd317b4f153de60ea21179e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 11:18:47 +0800 Subject: [PATCH 19/67] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 21bd232260f2e..4079df7ab1099 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -216,10 +216,10 @@ class RequestCheckpointNotify final : public RequestBase { } protected: - sendrecv::CheckpointMessage request_; + std::shared_ptr request_; sendrecv::VoidMessage reply_; ServerAsyncResponseWriter responder_; -} +}; void AsyncGRPCServer::WaitServerReady() { VLOG(3) << "AsyncGRPCServer is wait server ready"; From 050b66e27c8240ace296bcec2592d6908c8271ed Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 11:26:39 +0800 Subject: [PATCH 20/67] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 4079df7ab1099..238457b3e15f9 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -190,12 +190,10 @@ class RequestCheckpointNotify final : public RequestBase { explicit RequestCheckpointNotify(GrpcService::AsyncService* service, ::grpc::ServerCompletionQueue* cq, RequestHandler* request_handler, int req_id) - : RequestBase(service, cq, request_handler, req_id), - responder_(&ctx_), - local_scope_(nullptr) { + : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { request_.reset(new VariableResponse(request_handler->scope(), request_handler->dev_ctx(), true)); - int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); + int method_id = static_cast(detail::GrpcMethod::kCheckpointNotify); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, reinterpret_cast(static_cast(req_id))); From 54013a93b15cde1863f7fc2a5ea89f6b8655aae9 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 12:52:55 +0800 Subject: [PATCH 21/67] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 5 +++++ paddle/fluid/operators/detail/request_handler.h | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 238457b3e15f9..61afffeebada3 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -247,6 +247,9 @@ void AsyncGRPCServer::StartServer() { std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this, std::placeholders::_1, std::placeholders::_2); + LOG(INFO) << "Server StartServer on " + << "TryToRegisterNewOne bind finished"; + for (auto& t : rpc_call_map_) { auto& rpc_name = t.first; auto& cq = rpc_cq_[rpc_name]; @@ -255,6 +258,8 @@ void AsyncGRPCServer::StartServer() { reqs.reserve(kRequestBufSize); + LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << "I: " << i; + for (int i = 0; i < kRequestBufSize; i++) { TryToRegisterNewOne(rpc_name, i); } diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index fd33521fd1488..387a6b119030c 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -43,8 +43,8 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" -#define CHECKPOINT_SAVE_MESSAGE "SAVE" -#define CHECKPOINT_LOAD_MESSAGE "LOAD" +#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" +#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" class RPCServer; From 15532c74b19f87b7ea3df16969d76299b366801c Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 13:04:22 +0800 Subject: [PATCH 22/67] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 61afffeebada3..5adb516292331 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -247,9 +247,6 @@ void AsyncGRPCServer::StartServer() { std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this, std::placeholders::_1, std::placeholders::_2); - LOG(INFO) << "Server StartServer on " - << "TryToRegisterNewOne bind finished"; - for (auto& t : rpc_call_map_) { auto& rpc_name = t.first; auto& cq = rpc_cq_[rpc_name]; @@ -258,7 +255,7 @@ void AsyncGRPCServer::StartServer() { reqs.reserve(kRequestBufSize); - LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << "I: " << i; + LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i; for (int i = 0; i < kRequestBufSize; i++) { TryToRegisterNewOne(rpc_name, i); @@ -313,8 +310,11 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, return; } - VLOG(4) << "register send rpc_name:" << rpc_name - << ", handler:" << rpc_call_map_[kRequestSend]; + LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name + << " REQ ID: " << req_id; + + // VLOG(4) << "register send rpc_name:" << rpc_name + // << ", handler:" << rpc_call_map_[kRequestSend]; auto& reqs = rpc_reqs_[rpc_name]; auto& handler = rpc_call_map_[rpc_name]; @@ -328,6 +328,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, } else if (rpc_name == kRequestPrefetch) { b = new RequestPrefetch(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestCheckpoint) { + LOG(INFO) << "TryToRegisterNewOne on RPC kRequestCheckpoint"; b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); } else { PADDLE_ENFORCE(false, "not supported rpc"); From bbb349fbf075eb67536f1c488cbc395f5fb04d46 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 14:40:15 +0800 Subject: [PATCH 23/67] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 8 ++------ paddle/fluid/operators/detail/grpc_service.h | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 5adb516292331..9a8c4196720d0 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -255,9 +255,9 @@ void AsyncGRPCServer::StartServer() { reqs.reserve(kRequestBufSize); - LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i; - for (int i = 0; i < kRequestBufSize; i++) { + LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name + << " I: " << i; TryToRegisterNewOne(rpc_name, i); } @@ -313,9 +313,6 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " REQ ID: " << req_id; - // VLOG(4) << "register send rpc_name:" << rpc_name - // << ", handler:" << rpc_call_map_[kRequestSend]; - auto& reqs = rpc_reqs_[rpc_name]; auto& handler = rpc_call_map_[rpc_name]; auto& cq = rpc_cq_[rpc_name]; @@ -328,7 +325,6 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, } else if (rpc_name == kRequestPrefetch) { b = new RequestPrefetch(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestCheckpoint) { - LOG(INFO) << "TryToRegisterNewOne on RPC kRequestCheckpoint"; b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); } else { PADDLE_ENFORCE(false, "not supported rpc"); diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h index 69200a01d3c8c..cb745e125a9a5 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -83,7 +83,7 @@ enum class GrpcMethod { }; static const int kGrpcNumMethods = - static_cast(GrpcMethod::kPrefetchVariable) + 1; + static_cast(GrpcMethod::kCheckpointNotify) + 1; inline const char* GrpcMethodName(GrpcMethod id) { switch (id) { From 527b86b7d07d8d403423caeece302416ee44d2de Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 20:14:31 +0800 Subject: [PATCH 24/67] bug fix --- paddle/fluid/operators/listen_and_serv_op.cc | 6 ++++-- paddle/fluid/operators/listen_and_serv_op.h | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 13dfe45bb292e..698ff22997891 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -99,7 +99,8 @@ static int64_t GetTimestamp() { void ListenAndServOp::RunSyncLoop( framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope, - const std::vector &prefetch_block_id_list) const { + const std::vector &prefetch_block_id_list, + const int checkpoint_point_block_id) const { size_t num_blocks = program->Size(); PADDLE_ENFORCE_GE(num_blocks, 2, "server program should have at least 2 blocks"); @@ -107,7 +108,8 @@ void ListenAndServOp::RunSyncLoop( std::vector optimize_block_id_list; for (int blkid = 1; blkid < num_blocks; ++blkid) { if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(), - blkid) == prefetch_block_id_list.end()) { + blkid) == prefetch_block_id_list.end() && + blkid != checkpoint_point_block_id) { optimize_block_id_list.push_back(blkid); } } diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index b00ad195e9e16..ca2dafb737a82 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -48,7 +48,8 @@ class ListenAndServOp : public framework::OperatorBase { void RunSyncLoop(framework::Executor* executor, framework::ProgramDesc* program, framework::Scope* recv_scope, - const std::vector& prefetch_block_id_list) const; + const std::vector& prefetch_block_id_list, + const int checkpoint_point_block_id) const; void RunAsyncLoop(framework::Executor* executor, framework::ProgramDesc* program) const; From 85215df087d1d6f8f9af19f71feb4140c72765fe Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 22:08:58 +0800 Subject: [PATCH 25/67] move checkpoint message to variable message --- paddle/fluid/operators/detail/grpc_client.cc | 14 +++++++++++--- paddle/fluid/operators/detail/send_recv.proto | 7 ++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 9a25ec8fdb459..17476ab513b55 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -233,12 +233,20 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out) { const auto ch = GetChannel(ep); + CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); s->Prepare(time_out); + s->response_call_back_ = nullptr; - sendrecv::CheckpointMessage req; - req.set_notify_type(CHECKPOINT_SAVE_MESSAGE); - req.set_checkpoint_dir(dir); + sendrecv::VariableMessage req; + req.set_varname(CHECKPOINT_SAVE_MESSAGE); + req.out_varname(dir); + + auto call = s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/CheckpointNotify", req, + &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index cc6529cea7451..f5800cdb7f7e1 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -26,7 +26,7 @@ service SendRecvService { // pre-fetch variable by given variable name and Ids rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} - rpc CheckpointNotify(CheckpointMessage) returns (VoidMessage) {} + rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} } // VariableMessage is serialized paddle variable message. @@ -83,6 +83,7 @@ message VariableMessage { message VoidMessage {} message CheckpointMessage { - string notify_type = 1; - string checkpoint_dir = 2; + string varname = 1; + string notify_type = 2; + string checkpoint_dir = 3; } From 8af8da4fe4fa59e35b0d033286df574df7117067 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 22:14:39 +0800 Subject: [PATCH 26/67] move checkpoint message to variable message --- paddle/fluid/operators/detail/grpc_server.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 9a8c4196720d0..b7f4032c5af69 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -201,15 +201,19 @@ class RequestCheckpointNotify final : public RequestBase { virtual ~RequestCheckpointNotify() {} - std::string GetReqName() override { return "checkpoint_notify"; } + std::string GetReqName() override { return request_->Varname(); } void Process() override { auto scope = request_->GetMutableLocalScope(); - std::string nullptr_str = nullptr; + + std::string checkpoint_notify = request_->Varname(); + std::string checkpoint_dir = request_->Varname(); + framework::Variable* invar = nullptr; framework::Variable* outvar = nullptr; - request_handler_->Handle(nullptr_str, scope, invar, &outvar, nullptr_str); + request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, + checkpoint_dir); Finish(reply_, &responder_); } From 5553adf85ddb1e3839c30f7dd7d89909eabfb8ca Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 22:45:04 +0800 Subject: [PATCH 27/67] move checkpoint message to variable message --- paddle/fluid/operators/detail/grpc_client.cc | 9 +-------- paddle/fluid/operators/listen_and_serv_op.cc | 3 ++- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 17476ab513b55..7a63e39d5aae6 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -236,17 +236,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); s->Prepare(time_out); - s->response_call_back_ = nullptr; sendrecv::VariableMessage req; req.set_varname(CHECKPOINT_SAVE_MESSAGE); - req.out_varname(dir); - - auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/CheckpointNotify", req, - &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req.set_out_varname(dir); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 698ff22997891..7294acc3e357d 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -329,7 +329,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, // Write to a file of server selected port for python use. SavePort(); if (sync_mode) { - RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list); + RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, + checkpoint_point_block_id); } else { RunAsyncLoop(&executor, program); } From 752eb08b4b40b7fa44c21f1760ba71a790186b67 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 22:52:59 +0800 Subject: [PATCH 28/67] move checkpoint message to variable message --- paddle/fluid/operators/detail/grpc_server.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index b7f4032c5af69..2f58b7d15e9fc 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -208,10 +208,12 @@ class RequestCheckpointNotify final : public RequestBase { std::string checkpoint_notify = request_->Varname(); std::string checkpoint_dir = request_->Varname(); - framework::Variable* invar = nullptr; framework::Variable* outvar = nullptr; + VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify + << ", dir: " << checkpoint_dir; + request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, checkpoint_dir); Finish(reply_, &responder_); From ae12281d9b91b4d13bf0979d92cc1b3587c4fd1b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 02:12:27 +0800 Subject: [PATCH 29/67] checkpoint notify --- paddle/fluid/operators/checkpoint_notify_op.cc | 9 +++++++-- paddle/fluid/operators/detail/grpc_server.cc | 5 ++++- .../operators/detail/request_handler_impl.cc | 7 +++++++ paddle/fluid/operators/save_op.cc | 12 ++++++++++-- python/paddle/fluid/io.py | 15 ++++++++++----- .../fluid/transpiler/distribute_transpiler.py | 4 +++- 6 files changed, 41 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index 026ad722c275f..3e5019dd4b167 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/string/printf.h" namespace paddle { namespace operators { @@ -36,12 +37,14 @@ class CheckpointNotifyOp : public framework::OperatorBase { const platform::Place& place) const override { std::vector epmap = Attr>("epmap"); std::string dir = Attr("dir"); + std::string lookup_table_name = Attr("lookup_table"); detail::RPCClient* rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < epmap.size(); i++) { - VLOG(3) << "sending to " << epmap[i] << " to checkpoint notify ... "; - rpc_client->AsyncCheckpointNotify(epmap[i], dir); + VLOG(3) << "sending " << dir <<" to " << epmap[i] << " to checkpoint notify ... "; + auto serial_looku_table = string::Sprintf("%s/%s.%d", dir, lookup_table_name, i); + rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table); } rpc_client->Wait(); } @@ -57,6 +60,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({"127.0.0.1:6164"}); AddAttr( "dir", "(string, default '') indicate the folder checkpoint will use"); + AddAttr( + "lookup_table", "(string, default '') the lookup table name"); AddComment(R"DOC( Prefetch operator diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index ed3e60ec4504f..9f4971dc12cfb 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -208,11 +208,14 @@ class RequestCheckpointNotify final : public RequestBase { auto scope = request_->GetMutableLocalScope(); std::string checkpoint_notify = request_->Varname(); - std::string checkpoint_dir = request_->Varname(); + std::string checkpoint_dir = request_->OutVarname(); framework::Variable* invar = nullptr; framework::Variable* outvar = nullptr; + VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify + << ", dir: " << checkpoint_dir; + request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, checkpoint_dir); Finish(reply_, &responder_); diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 41b22e2143f00..487397312217b 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/rpc_server.h" +#include "paddle/fluid/string/printf.h" namespace paddle { namespace operators { @@ -124,6 +125,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const std::string& out_var_name) { + + auto lt_varname = string::Sprintf("%s.path", varname); + auto *lt_var = scope->FindVar(lt_varname)->GetMutable(); + lt_var->clear(); + lt_var->append(out_var_name); + VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; } diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index b54bd7db36745..005e03e69d2b8 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -87,7 +87,7 @@ class SaveOp : public framework::OperatorBase { if (var->IsType()) { SaveLodTensor(filename, place, var); } else if (var->IsType()) { - SaveSelectedRows(filename, place, var); + SaveSelectedRows(scope, place, var); } else { PADDLE_ENFORCE( false, @@ -128,9 +128,17 @@ class SaveOp : public framework::OperatorBase { fout.close(); } - void SaveSelectedRows(const std::string &filename, + void SaveSelectedRows(const framework::Scope &scope, const platform::Place &place, framework::Variable *var) const { + + auto lt_varname = string::Sprintf("%s.path", Input("X")); + auto *lt_var = scope.FindVar(lt_varname)->GetMutable(); + PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable %s for SaveSelectedRows", + lt_varname); + std::string filename = lt_var->data(); + VLOG(4) << "SaveSelectedRows get File name: " << filename; + auto &selectedRows = var->Get(); // get device context from pool diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 253fd5651c67d..ce82b6b904b0a 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -471,7 +471,10 @@ def save_checkpoint(executor, trainer_id, trainer_args=None, main_program=None, - max_num_checkpoints=3): + max_num_checkpoints=3, + lookup_table=None, + ps_endpoint_list=None + ): """ Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy @@ -500,7 +503,7 @@ def save_checkpoint(executor, if trainer_id == 0: save_persist_vars_without_grad(executor, cur_dir, main_program) - save_pserver_vars_by_notify(executor, cur_dir, "") + save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table) _scroll_delete(checkpoint_dir, max_num_checkpoints) @@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program): _write_success(cur_dir) -def save_pserver_vars_by_notify(executor, dirname, epmap): +def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list): """ """ cur_dir = _get_lookuptable_dir(dirname) @@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap): checkpoint_notify_block = checkpoint_notify_program.global_block() attrs = {} - attrs['epmap'] = None + attrs['epmap'] = ps_endpoint_list attrs['dir'] = cur_dir + attrs['lookup_table'] = lookup_table checkpoint_notify_block.append_op( - type='checkpoint_notify', inputs={}, output={}, attrs=attrs) + type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) executor.run(checkpoint_notify_program) @@ -783,3 +787,4 @@ def has_success(checkpoint_dir, cur_dir): if success_num > current_dir: current_dir = success_num return current_dir + diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 55a439660f19d..d5ce6e2704a92 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -838,13 +838,15 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): """ import os + pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW) + checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block.append_op( type='save', inputs={'X': [self.table_name]}, outputs={}, attrs={ - 'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name) + 'file_path': self.table_name) }) return checkpoint_save_block.idx From af0a6a149f7e77ffa3b3768f27dd4cc0615cab90 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 02:56:37 +0800 Subject: [PATCH 30/67] checkpoint notify --- .../operators/detail/request_handler_impl.cc | 5 ++-- paddle/fluid/operators/save_op.cc | 29 +++++++++++++++++-- .../fluid/transpiler/distribute_transpiler.py | 2 +- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 487397312217b..87fa5842c4e9a 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -126,11 +126,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable** outvar, const std::string& out_var_name) { - auto lt_varname = string::Sprintf("%s.path", varname); - auto *lt_var = scope->FindVar(lt_varname)->GetMutable(); + auto *lt_var = scope->FindVar("loopup_table_path")->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); - VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name; + VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; } diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 005e03e69d2b8..13798c88b1856 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -182,9 +182,32 @@ This operator will serialize and write a tensor/selected rows variable to file o } }; -} // namespace operators -} // namespace paddle +class SaveOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output("loopup_table_path").front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class SaveOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; +} +} + +// namespace operators +// namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker); +REGISTER_OPERATOR(save, ops::SaveOp, + paddle::framework::EmptyGradOpMaker, + ops::SaveOpProtoMaker, + ops::SaveOpVarTypeInference, + ops::SaveOpShapeInference); + diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d5ce6e2704a92..f9c39262ce32e 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -838,7 +838,7 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): """ import os - pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW) + pserver_program.global_block().create_var(name="loopup_table_path", persistable=True, type=core.VarDesc.VarType.RAW) checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block.append_op( From 549f0aa0d3ee482afdac53f72cc532f5f42e0382 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 03:16:38 +0800 Subject: [PATCH 31/67] load op add seletedRows --- paddle/fluid/operators/load_op.cc | 39 +++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 8f4b5049271c9..dc5457dba8794 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -44,7 +44,24 @@ class LoadOp : public framework::OperatorBase { PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", out_var_name); - auto *tensor = out_var->GetMutable(); + } + } + + void LoadLodTensor(const std::string &filename, const platform::Place &place, + framework::Variable *var) const { + auto &tensor = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + + auto *tensor = out_var->GetMutable(); DeserializeFromStream(fin, tensor, *dev_ctx); @@ -67,7 +84,25 @@ class LoadOp : public framework::OperatorBase { tensor = out_var->GetMutable(); tensor->set_lod(fp16_tensor.lod()); tensor->ShareDataWith(fp16_tensor); - } + } + + void LoadSelectedRows(const std::string &filename, + const framework::Scope &scope, + const platform::Place &place, + framework::Variable *var) const { + + auto &selectedRows = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to write", + filename); + framework::DeserializeFromStream(fin, selectedRows, dev_ctx); } }; From a501766ab16362a0cc35d6ad75e68c35859df166 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 03:22:03 +0800 Subject: [PATCH 32/67] load op add seletedRows --- paddle/fluid/operators/load_op.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index dc5457dba8794..7308330e74e67 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -44,6 +44,16 @@ class LoadOp : public framework::OperatorBase { PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", out_var_name); + if (out_var->IsType()) { + SaveLodTensor(filename, place, out_var); + } else if (out_var->IsType()) { + SaveSelectedRows(filename, scope, place, out_var); + } else { + PADDLE_ENFORCE( + false, + "Load only support LoDTensor and SelectedRows, %s has wrong type", + iname); + } } } @@ -91,7 +101,7 @@ class LoadOp : public framework::OperatorBase { const platform::Place &place, framework::Variable *var) const { - auto &selectedRows = var->Get(); + auto *selectedRows = var->GetMutable(); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); From ca27f78e299a86fc1aca2c087270a6133eb1a79e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 08:16:40 +0800 Subject: [PATCH 33/67] load op add seletedRows --- paddle/fluid/operators/load_op.cc | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 7308330e74e67..dd24dacf42e61 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -1,3 +1,4 @@ + /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -45,22 +46,19 @@ class LoadOp : public framework::OperatorBase { out_var_name); if (out_var->IsType()) { - SaveLodTensor(filename, place, out_var); + LoadLodTensor(filename, place, out_var); } else if (out_var->IsType()) { - SaveSelectedRows(filename, scope, place, out_var); + LoadSelectedRows(filename, scope, place, out_var); } else { PADDLE_ENFORCE( false, "Load only support LoDTensor and SelectedRows, %s has wrong type", - iname); + out_var_name); } } - } void LoadLodTensor(const std::string &filename, const platform::Place &place, framework::Variable *var) const { - auto &tensor = var->Get(); - // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); @@ -68,10 +66,10 @@ class LoadOp : public framework::OperatorBase { // FIXME(yuyang18): We save variable to local file now, but we should change // it to save an output stream. std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to read", filename); - auto *tensor = out_var->GetMutable(); + auto *tensor = var->GetMutable(); DeserializeFromStream(fin, tensor, *dev_ctx); @@ -90,10 +88,11 @@ class LoadOp : public framework::OperatorBase { &fp16_tensor); // reset output tensor - out_var->Clear(); - tensor = out_var->GetMutable(); + var->Clear(); + tensor = var->GetMutable(); tensor->set_lod(fp16_tensor.lod()); tensor->ShareDataWith(fp16_tensor); + } } void LoadSelectedRows(const std::string &filename, @@ -110,7 +109,7 @@ class LoadOp : public framework::OperatorBase { // FIXME(yuyang18): We save variable to local file now, but we should change // it to save an output stream. std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to write", + PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to read", filename); framework::DeserializeFromStream(fin, selectedRows, dev_ctx); } From ee64f577d4dabcf886e30f7dd13b839d6cfc4097 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 08:18:32 +0800 Subject: [PATCH 34/67] load op add seletedRows --- paddle/fluid/operators/load_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index dd24dacf42e61..e75fc4d674171 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -71,7 +71,7 @@ class LoadOp : public framework::OperatorBase { auto *tensor = var->GetMutable(); - DeserializeFromStream(fin, tensor, *dev_ctx); + DeserializeFromStream(fin, tensor, dev_ctx); auto load_as_fp16 = Attr("load_as_fp16"); auto in_dtype = framework::ToDataType(tensor->type()); From 1296d96e2e1d143bf002732b4bb138d93a1187cd Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 09:11:04 +0800 Subject: [PATCH 35/67] add raw clone --- python/paddle/fluid/framework.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index e27444fb10bb5..78c4aea92587d 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1014,6 +1014,9 @@ def clone_variable(self, var): if var.type == core.VarDesc.VarType.STEP_SCOPES: ret_var = self.create_var( name=var.name, persistable=var.persistable, type=var.type) + elif var.type == core.VarDesc.VarType.RAW: + ret_var = self.create_var( + name=var.name, persistable=var.persistable, type=var.type) elif var.type == core.VarDesc.VarType.SELECTED_ROWS: ret_var = self.create_var( name=var.name, From 620698e7e6f37188ba5bbd6851933a558c97f10b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 09:41:15 +0800 Subject: [PATCH 36/67] bug fux --- paddle/fluid/operators/save_op.cc | 23 ++++++++++--------- python/paddle/fluid/io.py | 2 +- .../fluid/transpiler/distribute_transpiler.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 13798c88b1856..7a0b566ea87c6 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -69,15 +69,6 @@ class SaveOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto filename = Attr("file_path"); - auto overwrite = Attr("overwrite"); - - if (FileExists(filename) && !overwrite) { - PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", - filename, overwrite); - } - - MkDirRecursively(DirName(filename).c_str()); auto iname = Input("X"); auto *var = scope.FindVar(iname); @@ -85,7 +76,7 @@ class SaveOp : public framework::OperatorBase { iname); if (var->IsType()) { - SaveLodTensor(filename, place, var); + SaveLodTensor(place, var); } else if (var->IsType()) { SaveSelectedRows(scope, place, var); } else { @@ -96,8 +87,18 @@ class SaveOp : public framework::OperatorBase { } } - void SaveLodTensor(const std::string &filename, const platform::Place &place, + void SaveLodTensor( const platform::Place &place, framework::Variable *var) const { + auto filename = Attr("file_path"); + auto overwrite = Attr("overwrite"); + + if (FileExists(filename) && !overwrite) { + PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", + filename, overwrite); + } + + MkDirRecursively(DirName(filename).c_str()); + auto &tensor = var->Get(); // get device context from pool diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index ce82b6b904b0a..ffe0021e96c3c 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -503,7 +503,7 @@ def save_checkpoint(executor, if trainer_id == 0: save_persist_vars_without_grad(executor, cur_dir, main_program) - save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table) + save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list) _scroll_delete(checkpoint_dir, max_num_checkpoints) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index f9c39262ce32e..a1617600d6260 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -846,7 +846,7 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): inputs={'X': [self.table_name]}, outputs={}, attrs={ - 'file_path': self.table_name) + 'file_path': self.table_name }) return checkpoint_save_block.idx From 459690ae3ba23a2edb79119a0cc12d70086f3068 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 10:04:48 +0800 Subject: [PATCH 37/67] bug fux --- paddle/fluid/operators/save_op.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 7a0b566ea87c6..d43216749cd33 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -132,11 +132,8 @@ class SaveOp : public framework::OperatorBase { void SaveSelectedRows(const framework::Scope &scope, const platform::Place &place, framework::Variable *var) const { - - auto lt_varname = string::Sprintf("%s.path", Input("X")); - auto *lt_var = scope.FindVar(lt_varname)->GetMutable(); - PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable %s for SaveSelectedRows", - lt_varname); + auto *lt_var = scope.FindVar("loopup_table_path")->GetMutable(); + PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable loopup_table_path for SaveSelectedRows"); std::string filename = lt_var->data(); VLOG(4) << "SaveSelectedRows get File name: " << filename; From 5250ca8c879cb2027818375ddd3f65d83b5d1dcb Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 10:09:05 +0800 Subject: [PATCH 38/67] bug fux --- paddle/fluid/operators/checkpoint_notify_op.cc | 2 +- python/paddle/fluid/io.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index 3e5019dd4b167..72976e22cacb5 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -43,7 +43,7 @@ class CheckpointNotifyOp : public framework::OperatorBase { detail::RPCClient::GetInstance(); for (size_t i = 0; i < epmap.size(); i++) { VLOG(3) << "sending " << dir <<" to " << epmap[i] << " to checkpoint notify ... "; - auto serial_looku_table = string::Sprintf("%s/%s.%d", dir, lookup_table_name, i); + auto serial_looku_table = string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table); } rpc_client->Wait(); diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index ffe0021e96c3c..629ded7f7a6e2 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -462,7 +462,6 @@ def get_parameter_value_by_name(name, executor, program=None): MODEL_DIR = "__model__" LOOKUP_TABLE_DIR = "__lookup_table__" TRAINER_PREFIX = "trainer" -PSERVER_PREFIX = "pserver" CHECKPOINT_SEPARATOR = "_" @@ -577,8 +576,7 @@ def load_persist_vars_without_grad(executor, def load_lookup_table_vars(executor, dirname, pserver_id, table_name): lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) - table_file = table_name + CHECKPOINT_SEPARATOR + PSERVER_PREFIX + CHECKPOINT_SEPARATOR + str( - pserver_id) + table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file) From bccf8df51bd7b2ff8f40540e409a620ad62c27de Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 11:20:05 +0800 Subject: [PATCH 39/67] bug fix --- python/paddle/fluid/io.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 629ded7f7a6e2..ac91c367962d0 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -574,11 +574,28 @@ def load_persist_vars_without_grad(executor, filename=None) -def load_lookup_table_vars(executor, dirname, pserver_id, table_name): +def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): + + for var in program.list_vars(): + if var.name == table_name: + lookup_table_var = var + break + + assert lookup_table_var is not None + lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) - table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) + table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) + + load_prog = Program() + load_block = load_prog.global_block() + + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [lookup_table_var]}, + attrs={'file_path': os.path.join(lookup_table_dir, table_file)}) - load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file) + executor.run(load_prog) def save_persist_vars_without_grad(executor, dirname, program): From 7efd73ac53839ced86bdc5b4ea10061b4df730af Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 13:38:09 +0800 Subject: [PATCH 40/67] code clean --- paddle/fluid/operators/detail/send_recv.proto | 6 ------ python/paddle/fluid/trainer.py | 19 ++++++++++++------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index f5800cdb7f7e1..e0902320cff00 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -81,9 +81,3 @@ message VariableMessage { } message VoidMessage {} - -message CheckpointMessage { - string varname = 1; - string notify_type = 2; - string checkpoint_dir = 3; -} diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index f77c0f65dcb97..6fc456f475562 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -74,8 +74,8 @@ def __init__(self, self.epoch_id = 0 self.step_id = 0 self.load_serial = None - self.is_pserver = False - self.has_lookup_table = False + self.pserver_id = -1, + self.lookup_table_name = None def check_and_get_place(place): @@ -174,7 +174,7 @@ def __init__(self, self.checkpoint_cfg.load_serial, self.startup_program) - if not self.checkpoint_cfg.is_pserver: + if self.checkpoint_cfg.pserver_id != -1: epoch_id, step_id = io.load_trainer_args( self.checkpoint_cfg.checkpoint_dir, self.checkpoint_cfg.load_serial, self.trainer_id, @@ -182,10 +182,12 @@ def __init__(self, self.checkpoint_cfg.epoch_id = int(epoch_id) self.checkpoint_cfg.step_id = int(step_id) else: - if self.checkpoint_cfg.has_lookup_table: + if self.checkpoint_cfg.lookup_table_name: io.load_lookup_table_vars( - exe, self.checkpoint_cfg.checkpoint_dir, 0, - "table_name") + exe, self.checkpoint_cfg.checkpoint_dir, + self.startup_program, + self.checkpoint_cfg.pserver_id, + self.checkpoint_cfg.lookup_table_name) if param_path and os.path.isdir(param_path): # load params from param_path into scope @@ -255,7 +257,10 @@ def _dist_transpile_if_necessary(self, optimize_ops, params_grads): self.trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": if self.checkpoint_cfg: - self.is_pserver = True + pserver_id = eplist.index(current_endpoint) + self.checkpoint_cfg.pserver_id = pserver_id + if t.has_distributed_lookup_table: + self.checkpoint_cfg.lookup_table_name = t.table_name self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, From 1571c25ae927f035afc67f7fdb9fb10fd09e90b0 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 14:26:59 +0800 Subject: [PATCH 41/67] code style fix --- python/paddle/fluid/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 6fc456f475562..b4cb019aea347 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -66,8 +66,8 @@ def __init__(self, assert epoch_interval >= 1 assert step_interval >= 1 - self.checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else os.getcwd( - ) + self.checkpoint_dir = checkpoint_dir \ + if checkpoint_dir is not None else os.getcwd() self.max_num_checkpoints = max_num_checkpoints self.epoch_interval = epoch_interval self.step_interval = step_interval From d93dc81c4eeaa070586ed25055933a4e6bda57e4 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 15:14:10 +0800 Subject: [PATCH 42/67] add handle when checkpoint_notify_id = -1 --- .../operators/detail/request_handler_impl.cc | 8 ++++++-- .../operators/detail/request_handler_impl.h | 9 +++++++-- paddle/fluid/operators/listen_and_serv_op.cc | 18 ++++++++++-------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 87fa5842c4e9a..859f6a75781cd 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -125,11 +125,15 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const std::string& out_var_name) { + PADDLE_ENFORCE( + checkpoint_notify_id != -1, + "when checkpoint_notify_id = -1, there should be no RPC invoke."); - auto *lt_var = scope->FindVar("loopup_table_path")->GetMutable(); + auto* lt_var = scope->FindVar("loopup_table_path")->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); - VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " << out_var_name; + VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " + << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; } diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h index 643eae4d31438..b7cebf1a61940 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -68,12 +68,17 @@ class RequestPrefetchHandler final : public RequestHandler { class RequestCheckpointHandler final : public RequestHandler { public: - explicit RequestCheckpointHandler(bool sync_mode) - : RequestHandler(sync_mode) {} + explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id) + : RequestHandler(sync_mode) { + this.checkpoint_notify_id = checkpoint_notify_id; + } virtual ~RequestCheckpointHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar, const std::string& out_var_name = "") override; + + private: + int checkpoint_notify_id; }; } // namespace detail diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 78b8c96f4f746..477cb90efb668 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -247,9 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); + int checkpoint_point_block_id = Attr(kCheckpointBlockId); LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in - << ", end_point:" << endpoint; + << ", end_point:" << endpoint + << ", CheckpointNotify Id: " << checkpoint_notify_id; rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); @@ -258,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_prefetch_handler_.reset( new detail::RequestPrefetchHandler(sync_mode)); request_checkpoint_handler_.reset( - new detail::RequestCheckpointHandler(sync_mode)); + new detail::RequestCheckpointHandler(sync_mode, checkpoint_notify_id)); rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); @@ -267,6 +269,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_->RegisterRPC(detail::kRequestCheckpoint, request_checkpoint_handler_.get()); + std::shared_ptr ckpt_pre_context = nullptr; + if (checkpoint_notify_id != -1) { + auto ctx = executor.Prepare(*program, checkpoint_point_block_id); + ckpt_pre_context = std::move(ctx); + } + auto *optimize_block = Attr(kOptimizeBlock); auto *program = optimize_block->Program(); framework::Executor executor(dev_place); @@ -301,12 +309,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; } - int checkpoint_point_block_id = Attr(kCheckpointBlockId); - auto ctx = executor.Prepare(*program, checkpoint_point_block_id); - - std::shared_ptr ckpt_pre_context = - std::move(ctx); - auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor, program, &prefetch_var_name_to_prepared_ctx, From 8c0e1d5cba715cae9d9a47c788f3a75da6efba2c Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 15:33:17 +0800 Subject: [PATCH 43/67] unittest case fix --- paddle/fluid/operators/save_load_op_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/operators/save_load_op_test.cc b/paddle/fluid/operators/save_load_op_test.cc index c4fcc61af4b75..ccaea0eef2906 100644 --- a/paddle/fluid/operators/save_load_op_test.cc +++ b/paddle/fluid/operators/save_load_op_test.cc @@ -139,6 +139,7 @@ TEST(LoadFP16Op, CPU) { save_op->Run(scope, place); auto load_var = scope.Var("out_var"); + load_var->GetMutable(); auto load_op = paddle::framework::OpRegistry::CreateOp( "load", {}, {{"Out", {"out_var"}}}, attrs); load_op->Run(scope, place); From 49c2d0c5fb216909cf3101629558c99092895e6d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 16:30:02 +0800 Subject: [PATCH 44/67] bug fix --- paddle/fluid/operators/detail/request_handler_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h index b7cebf1a61940..689c6893cf86b 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -70,7 +70,7 @@ class RequestCheckpointHandler final : public RequestHandler { public: explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id) : RequestHandler(sync_mode) { - this.checkpoint_notify_id = checkpoint_notify_id; + this->checkpoint_notify_id = checkpoint_notify_id; } virtual ~RequestCheckpointHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, From 16ecead837c940d52109769c60791af874eee51a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 17:53:38 +0800 Subject: [PATCH 45/67] load op optimize --- paddle/fluid/operators/load_op.cc | 37 +++++++++---------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index e75fc4d674171..6be8fdb0de46b 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -1,4 +1,5 @@ - +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,6 +36,8 @@ class LoadOp : public framework::OperatorBase { auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); platform::RecordEvent record_event(Type(), dev_ctx); + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. auto filename = Attr("file_path"); std::ifstream fin(filename); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", @@ -46,31 +49,23 @@ class LoadOp : public framework::OperatorBase { out_var_name); if (out_var->IsType()) { - LoadLodTensor(filename, place, out_var); + LoadLodTensor(fin, place, out_var); } else if (out_var->IsType()) { - LoadSelectedRows(filename, scope, place, out_var); + LoadSelectedRows(fin, place, out_var); } else { PADDLE_ENFORCE( false, "Load only support LoDTensor and SelectedRows, %s has wrong type", out_var_name); } - } + } - void LoadLodTensor(const std::string &filename, const platform::Place &place, + void LoadLodTensor(std::istream &fin, const platform::Place &place, framework::Variable *var) const { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - - // FIXME(yuyang18): We save variable to local file now, but we should change - // it to save an output stream. - std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to read", - filename); - - auto *tensor = var->GetMutable(); - + auto *tensor = var->GetMutable(); DeserializeFromStream(fin, tensor, dev_ctx); auto load_as_fp16 = Attr("load_as_fp16"); @@ -92,25 +87,15 @@ class LoadOp : public framework::OperatorBase { tensor = var->GetMutable(); tensor->set_lod(fp16_tensor.lod()); tensor->ShareDataWith(fp16_tensor); - } + } } - void LoadSelectedRows(const std::string &filename, - const framework::Scope &scope, - const platform::Place &place, + void LoadSelectedRows(std::istream &fin, const platform::Place &place, framework::Variable *var) const { - auto *selectedRows = var->GetMutable(); - // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - - // FIXME(yuyang18): We save variable to local file now, but we should change - // it to save an output stream. - std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to read", - filename); framework::DeserializeFromStream(fin, selectedRows, dev_ctx); } }; From 6abf07693ae721ca8c6f01fe1269a38b4c4106dd Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 17:56:57 +0800 Subject: [PATCH 46/67] checkpoint_notify_id rename --- paddle/fluid/operators/listen_and_serv_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 477cb90efb668..b0eec2eb443e2 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -247,7 +247,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); - int checkpoint_point_block_id = Attr(kCheckpointBlockId); + int checkpoint_notify_id = Attr(kCheckpointBlockId); LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in << ", end_point:" << endpoint @@ -271,7 +271,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, std::shared_ptr ckpt_pre_context = nullptr; if (checkpoint_notify_id != -1) { - auto ctx = executor.Prepare(*program, checkpoint_point_block_id); + auto ctx = executor.Prepare(*program, checkpoint_notify_id); ckpt_pre_context = std::move(ctx); } From 28482f81a8ad8d7f5e11b987337b5ba164eb856e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 19:49:02 +0800 Subject: [PATCH 47/67] bug fix --- paddle/fluid/operators/listen_and_serv_op.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index b0eec2eb443e2..463677c75e2b9 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -269,16 +269,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_->RegisterRPC(detail::kRequestCheckpoint, request_checkpoint_handler_.get()); + auto *optimize_block = Attr(kOptimizeBlock); + auto *program = optimize_block->Program(); + framework::Executor executor(dev_place); + std::shared_ptr ckpt_pre_context = nullptr; if (checkpoint_notify_id != -1) { auto ctx = executor.Prepare(*program, checkpoint_notify_id); ckpt_pre_context = std::move(ctx); } - auto *optimize_block = Attr(kOptimizeBlock); - auto *program = optimize_block->Program(); - framework::Executor executor(dev_place); - // prepare for prefetch std::vector prefetch_block_id_list; std::unordered_map block_id_to_prefetch_var_name; From 06f6c21303c527af8ffca7d3f83c1fb2240a55a8 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 19:55:07 +0800 Subject: [PATCH 48/67] bug fix --- paddle/fluid/operators/listen_and_serv_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 463677c75e2b9..3d67b2d2ea2f2 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -332,7 +332,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, SavePort(); if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, - checkpoint_point_block_id); + checkpoint_notify_id); } else { RunAsyncLoop(&executor, program); } From 5600b135120659448a3fc95d54fe22989eaadf25 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 21:09:30 +0800 Subject: [PATCH 49/67] bug fix --- python/paddle/fluid/io.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index ac91c367962d0..96311e5ef8b9c 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -472,8 +472,7 @@ def save_checkpoint(executor, main_program=None, max_num_checkpoints=3, lookup_table=None, - ps_endpoint_list=None - ): + ps_endpoint_list=None): """ Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy @@ -495,14 +494,18 @@ def save_checkpoint(executor, if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) + is_chief = trainer_id == 0 + serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial) save_trainer_args(cur_dir, trainer_id, trainer_args) - if trainer_id == 0: + if is_chief: save_persist_vars_without_grad(executor, cur_dir, main_program) - save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list) + if is_chief and lookup_table and ps_endpoint_list: + save_pserver_vars_by_notify(executor, cur_dir, lookup_table, + ps_endpoint_list) _scroll_delete(checkpoint_dir, max_num_checkpoints) @@ -618,7 +621,8 @@ def save_persist_vars_without_grad(executor, dirname, program): _write_success(cur_dir) -def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list): +def save_pserver_vars_by_notify(executor, dirname, lookup_table, + ps_endpoint_list): """ """ cur_dir = _get_lookuptable_dir(dirname) @@ -802,4 +806,3 @@ def has_success(checkpoint_dir, cur_dir): if success_num > current_dir: current_dir = success_num return current_dir - From 32fa832b4b765a963a3b84ac90b62803a454a321 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 22:09:49 +0800 Subject: [PATCH 50/67] code style --- paddle/fluid/operators/checkpoint_notify_op.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index 72976e22cacb5..7b4c607c33e34 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -42,8 +42,9 @@ class CheckpointNotifyOp : public framework::OperatorBase { detail::RPCClient* rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < epmap.size(); i++) { - VLOG(3) << "sending " << dir <<" to " << epmap[i] << " to checkpoint notify ... "; - auto serial_looku_table = string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); + VLOG(3) << "checkpoint notify sending " << dir << " to " << epmap[i]; + auto serial_looku_table = + string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table); } rpc_client->Wait(); @@ -60,8 +61,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({"127.0.0.1:6164"}); AddAttr( "dir", "(string, default '') indicate the folder checkpoint will use"); - AddAttr( - "lookup_table", "(string, default '') the lookup table name"); + AddAttr("lookup_table", + "(string, default '') the lookup table name"); AddComment(R"DOC( Prefetch operator From 8af4d4c7a08dda435176ea995bf42f60b2e58562 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 23:09:48 +0800 Subject: [PATCH 51/67] code style --- paddle/fluid/operators/checkpoint_notify_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index 7b4c607c33e34..31b725ec184ce 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -56,7 +56,7 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { void Make() { AddAttr>( "epmap", - "(string vector, default 127.0.0.1:6164)" + "(string vector, default 127.0.0.1:6164)" "Server endpoints in the order of input variables for mapping") .SetDefault({"127.0.0.1:6164"}); AddAttr( From db6126ca9938ad86e92b661e08c3035abbd83a78 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 20 Jun 2018 10:24:35 +0800 Subject: [PATCH 52/67] code style --- paddle/fluid/operators/load_op.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 6be8fdb0de46b..764e3428ec4a2 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -1,5 +1,3 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); From 91eae9cc916f23e7e67bf7d84d8be1025aa73ed9 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 20 Jun 2018 11:59:40 +0800 Subject: [PATCH 53/67] code style --- paddle/fluid/operators/save_op.cc | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index d43216749cd33..941bca1047760 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -69,7 +69,6 @@ class SaveOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto iname = Input("X"); auto *var = scope.FindVar(iname); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", @@ -87,7 +86,7 @@ class SaveOp : public framework::OperatorBase { } } - void SaveLodTensor( const platform::Place &place, + void SaveLodTensor(const platform::Place &place, framework::Variable *var) const { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); @@ -132,8 +131,11 @@ class SaveOp : public framework::OperatorBase { void SaveSelectedRows(const framework::Scope &scope, const platform::Place &place, framework::Variable *var) const { - auto *lt_var = scope.FindVar("loopup_table_path")->GetMutable(); - PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable loopup_table_path for SaveSelectedRows"); + auto *lt_var = + scope.FindVar("loopup_table_path")->GetMutable(); + PADDLE_ENFORCE( + lt_var != nullptr, + "Can not find variable loopup_table_path for SaveSelectedRows"); std::string filename = lt_var->data(); VLOG(4) << "SaveSelectedRows get File name: " << filename; @@ -195,17 +197,11 @@ class SaveOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override {} }; -} -} - -// namespace operators -// namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(save, ops::SaveOp, - paddle::framework::EmptyGradOpMaker, - ops::SaveOpProtoMaker, - ops::SaveOpVarTypeInference, +REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker, + ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference, ops::SaveOpShapeInference); - From c073bb3b2c1381d2967785b956381be231fa6583 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 20 Jun 2018 12:51:44 +0800 Subject: [PATCH 54/67] code style --- python/paddle/fluid/framework.py | 7 +++---- python/paddle/fluid/transpiler/distribute_transpiler.py | 9 +++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c389c4aeff61d..ca118452711bd 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -382,8 +382,7 @@ class Operator(object): 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'ncclInit', 'channel_create', 'channel_close', 'channel_send', - 'channel_recv', 'select', 'checkpoint_notify' - , 'gen_nccl_id' + 'channel_recv', 'select', 'checkpoint_notify', 'gen_nccl_id' } def __init__(self, @@ -1022,7 +1021,7 @@ def clone_variable(self, var): name=var.name, persistable=var.persistable, type=var.type) elif var.type == core.VarDesc.VarType.RAW: ret_var = self.create_var( - name=var.name, persistable=var.persistable, type=var.type) + name=var.name, persistable=var.persistable, type=var.type) elif var.type == core.VarDesc.VarType.SELECTED_ROWS: ret_var = self.create_var( name=var.name, @@ -1465,7 +1464,7 @@ def get_var(name, program=None): Args: name(str): name of the variable program(Program|None): program object. - If None, default_global_program() will be used. + If None, default_global_program() will be used. Returns: Variable diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index b9c67dbf95bdb..5bbeeeaed6666 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -865,16 +865,17 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): """ import os - pserver_program.global_block().create_var(name="loopup_table_path", persistable=True, type=core.VarDesc.VarType.RAW) + pserver_program.global_block().create_var( + name="loopup_table_path", + persistable=True, + type=core.VarDesc.VarType.RAW) checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block.append_op( type='save', inputs={'X': [self.table_name]}, outputs={}, - attrs={ - 'file_path': self.table_name - }) + attrs={'file_path': self.table_name}) return checkpoint_save_block.idx From 05bd9db84bfb6b0a2beea4c4c79306c5eb127ff7 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 20 Jun 2018 17:25:16 +0800 Subject: [PATCH 55/67] add comments in io.py --- python/paddle/fluid/io.py | 94 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e59ac11fd4118..32f53ebe388b9 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -840,6 +840,12 @@ def save_checkpoint(executor, max_num_checkpoints(int): The max number of total number of existing checkpoints. Default: 3 + lookup_table(string|None): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + ps_endpoint_list(list|None): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by + distribute arguments. Returns: None @@ -856,15 +862,21 @@ def save_checkpoint(executor, prog = fluid.default_main_program() trainer_args = {"epoch_id": 200, "step_id": 20} # just an example + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + fluid.io.save_checkpoint(executor=exe, checkpoint_dir=path, trainer_id=0, trainer_args=trainer_args, main_program=prog, - max_num_checkpoints=3) + max_num_checkpoints=3, + lookup_table=table_name, + ps_endpoint_list = ps_endpoints) """ if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") + assert checkpoint_dir if trainer_args: assert isinstance(trainer_args, dict) @@ -881,6 +893,7 @@ def save_checkpoint(executor, if is_chief: save_persist_vars_without_grad(executor, cur_dir, main_program) + if is_chief and lookup_table and ps_endpoint_list: save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list) @@ -1020,6 +1033,31 @@ def load_persist_vars_without_grad(executor, def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): + """ + The parameter server will load lookup table's local file in + selectedrows variable. + + Args: + executor(Executor): The executor to run for loading persistable variables + dirname(str): The directory path + main_program(Program): Find the variable named table_name in main_program + pserver_id(int): the serial number in pserver_endpoints list + table_name(str): lookup table name + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + dirname = "./checkpoints/checkpoint_9/__model__" + prog = fluid.default_main_program() + pserver_id = 1 + table_name = "share_w" + fluid.io.load_lookup_table_vars(executor=exe, + dirname=dirname, program=prog, pserver_id=pserver_id, + table_name=table_name) + """ for var in program.list_vars(): if var.name == table_name: @@ -1092,6 +1130,35 @@ def save_persist_vars_without_grad(executor, dirname, program): def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list): """ + This function will send checkpoint notify message from Trainer 0 + to all the pservers. + The checkpoint notify message contains lookup table name, + the absolute path on pserver to save lookup_table. + + Args: + executor(Executor): The executor to run for send checkpoint notify. + dirname(str): The folder where to save checkpoints. + lookup_table(string): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + ps_endpoint_list(list): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by + distribute arguments. + Return: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + prog = fluid.default_main_program() + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + + fluid.io.save_pserver_vars_by_notify(executor=exe, + dirname=param_path, lookup_table=table_name, + ps_endpoint_list=ps_endpoints) """ cur_dir = _get_lookuptable_dir(dirname) @@ -1121,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): + """ + trainer will load some args from it's independent directory, + such as epoch_id and step_id. + + Args: + checkpoint_dir(str): The folder where all checkpoints are. + serial(int): The serial of checkpoint you would like to load. + trainer_id(int): current trainer id. + trainer_args(list): list about load trainer args + Return: + None + + Examples: + .. code-block:: python + + param_path = "./checkpoint/" + serial = 7 + trainer_id = 2 + trainer_args = ["epoch_id", "step_id"] + + fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial, + trainer_id=trainer_id, trainer_args=trainer_args) + """ assert isinstance(trainer_args, list) cur_dir = _get_serial_dir(checkpoint_dir, serial) @@ -1141,7 +1231,7 @@ def _is_checkpoint_var(var): the checkpoint will not save or load all the variables. var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. - : param var + : param var(Variable) """ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ From 97648442cd3124b49f4453bf86a8f7f715f9b734 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 21 Jun 2018 11:37:04 +0800 Subject: [PATCH 56/67] merge develop --- paddle/fluid/operators/distributed/grpc_server.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc index b0f2704250424..218a1f85625e8 100644 --- a/paddle/fluid/operators/distributed/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -195,7 +195,8 @@ class RequestCheckpointNotify final : public RequestBase { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { request_.reset(new VariableResponse(request_handler->scope(), request_handler->dev_ctx(), true)); - int method_id = static_cast(detail::GrpcMethod::kCheckpointNotify); + int method_id = + static_cast(distributed::GrpcMethod::kCheckpointNotify); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, reinterpret_cast(static_cast(req_id))); From 620999c917fa6e948d238a80ea9c774e72b28dbb Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 21 Jun 2018 15:44:39 +0800 Subject: [PATCH 57/67] save checkpoint bug fix --- paddle/fluid/operators/checkpoint_notify_op.cc | 4 ++-- paddle/fluid/operators/listen_and_serv_op.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index 31b725ec184ce..e7a65b76a496c 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -39,8 +39,8 @@ class CheckpointNotifyOp : public framework::OperatorBase { std::string dir = Attr("dir"); std::string lookup_table_name = Attr("lookup_table"); - detail::RPCClient* rpc_client = - detail::RPCClient::GetInstance(); + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(); for (size_t i = 0; i < epmap.size(); i++) { VLOG(3) << "checkpoint notify sending " << dir << " to " << epmap[i]; auto serial_looku_table = diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 420c4e9e4ed2e..df9cdae97df91 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -268,7 +268,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_get_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestPrefetch, request_prefetch_handler_.get()); - rpc_service_->RegisterRPC(detail::kRequestCheckpoint, + rpc_service_->RegisterRPC(distributed::kRequestCheckpoint, request_checkpoint_handler_.get()); auto *optimize_block = Attr(kOptimizeBlock); From 8e01f3b94873a72ddde223fee1f8f69c6fb8881e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 21 Jun 2018 21:50:23 +0800 Subject: [PATCH 58/67] bug fix --- python/paddle/fluid/io.py | 44 ++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 32f53ebe388b9..8cc25e8623717 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import errno import time import shutil @@ -881,11 +882,9 @@ def save_checkpoint(executor, if trainer_args: assert isinstance(trainer_args, dict) - if not os.path.isdir(checkpoint_dir): - os.makedirs(checkpoint_dir) - is_chief = trainer_id == 0 + _make_chekcpoint_dirs(checkpoint_dir) serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial) @@ -1251,6 +1250,20 @@ def _is_checkpoint_var(var): return var.persistable +def _make_chekcpoint_dirs(dirs): + assert dirs is not None + + if os.path.isfile(dirs): + raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs) + + if not os.path.isdir(dirs): + try: + os.makedirs(dirs) + except OSError as err: + if err.errno != errno.EEXIST: + raise err + + def _get_dir_serial(dirname): _, serial = dirname.split(CHECKPOINT_SEPARATOR) @@ -1264,38 +1277,27 @@ def _get_dir_serial(dirname): def _get_serial_dir(dirname, serial): serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_dir = os.path.join(dirname, serial_folder) - - if not os.path.isdir(serial_dir): - os.makedirs(serial_dir) + _make_chekcpoint_dirs(serial_dir) return serial_dir def _get_model_dir(dirname): model_dir = os.path.join(dirname, MODEL_DIR) - - if not os.path.isdir(model_dir): - os.makedirs(model_dir) - + _make_chekcpoint_dirs(model_dir) return model_dir def _get_lookuptable_dir(dirname): lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) - - if not os.path.isdir(lookuptable_dir): - os.makedirs(lookuptable_dir) - + _make_chekcpoint_dirs(lookuptable_dir) return lookuptable_dir def _get_trainer_dir(dirname, trainer_id): trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) trainer_dir = os.path.join(dirname, trainer_folder) - - if not os.path.isdir(trainer_dir): - os.makedirs(trainer_dir) - + _make_chekcpoint_dirs(trainer_dir) return trainer_dir @@ -1314,7 +1316,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3): serials = serials[max_num_checkpoints:] for serial in serials: cur_dir = _get_serial_dir(dirname, serial) - shutil.rmtree(cur_dir) + try: + shutil.rmtree(cur_dir) + except OSError as err: + if err.errno != errno.ENOENT: + raise err def _write_success(dirname): From 2229db523ba6391e7d4b0831bee68fa5f38a3584 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 22 Jun 2018 13:44:26 +0800 Subject: [PATCH 59/67] pserver_id init value to None --- python/paddle/fluid/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index ad3872ab0d6ff..f191ef7df5caa 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -131,7 +131,7 @@ def __init__(self, self.epoch_id = 0 self.step_id = 0 self.load_serial = None - self.pserver_id = -1, + self.pserver_id = None self.lookup_table_name = None @@ -283,7 +283,7 @@ def __init__(self, self.checkpoint_cfg.load_serial, self.startup_program) - if self.checkpoint_cfg.pserver_id != -1: + if not self.checkpoint_cfg.pserver_id: epoch_id, step_id = io.load_trainer_args( self.checkpoint_cfg.checkpoint_dir, self.checkpoint_cfg.load_serial, self.trainer_id, From e684575f662c13fd0f8c732671c77420c2aedefe Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 22 Jun 2018 14:55:16 +0800 Subject: [PATCH 60/67] checkpoint feature optimized --- paddle/fluid/operators/checkpoint_notify_op.cc | 13 +++++++------ paddle/fluid/operators/detail/macros.h | 4 ++++ paddle/fluid/operators/distributed/grpc_server.cc | 11 ++++------- .../operators/distributed/request_handler_impl.cc | 5 +++-- paddle/fluid/operators/listen_and_serv_op.cc | 12 ++++++------ paddle/fluid/operators/load_op.cc | 6 ++++-- paddle/fluid/operators/save_op.cc | 10 +++++----- python/paddle/fluid/io.py | 3 ++- .../fluid/transpiler/distribute_transpiler.py | 7 +++++-- 9 files changed, 40 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index e7a65b76a496c..7fc5b5e62214c 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -42,10 +42,11 @@ class CheckpointNotifyOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); for (size_t i = 0; i < epmap.size(); i++) { - VLOG(3) << "checkpoint notify sending " << dir << " to " << epmap[i]; - auto serial_looku_table = + auto lookup_table_save_dir = string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); - rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table); + rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir); + VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name + << " and dir:" << dir << " to " << epmap[i]; } rpc_client->Wait(); } @@ -64,10 +65,10 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("lookup_table", "(string, default '') the lookup table name"); AddComment(R"DOC( -Prefetch operator +CheckpointNotify operator -This operator will send Ids variables to listen_and_serve op at -the parameter server and fetch result back. +This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at +the parameter server. )DOC"); } }; diff --git a/paddle/fluid/operators/detail/macros.h b/paddle/fluid/operators/detail/macros.h index b9e385994efce..6e9f7beb93b95 100644 --- a/paddle/fluid/operators/detail/macros.h +++ b/paddle/fluid/operators/detail/macros.h @@ -25,3 +25,7 @@ #define RPCSERVER_T distributed::AsyncBRPCServer #define RPCCLIENT_T distributed::BRPCClient #endif + +// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables +// to directory specified. +constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc index 218a1f85625e8..363614df4f92c 100644 --- a/paddle/fluid/operators/distributed/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -194,7 +194,7 @@ class RequestCheckpointNotify final : public RequestBase { RequestHandler* request_handler, int req_id) : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { request_.reset(new VariableResponse(request_handler->scope(), - request_handler->dev_ctx(), true)); + request_handler->dev_ctx())); int method_id = static_cast(distributed::GrpcMethod::kCheckpointNotify); service_->RequestAsyncUnary( @@ -212,13 +212,10 @@ class RequestCheckpointNotify final : public RequestBase { std::string checkpoint_notify = request_->Varname(); std::string checkpoint_dir = request_->OutVarname(); - framework::Variable* invar = nullptr; - framework::Variable* outvar = nullptr; - VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify << ", dir: " << checkpoint_dir; - request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, + request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, checkpoint_dir); Finish(reply_, &responder_); } @@ -320,8 +317,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, return; } - LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name - << " REQ ID: " << req_id; + VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name + << " REQ ID: " << req_id; auto& reqs = rpc_reqs_[rpc_name]; auto& handler = rpc_call_map_[rpc_name]; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index b6e4e156080a6..cd8059a96d385 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/string/printf.h" @@ -129,10 +130,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, checkpoint_notify_id != -1, "when checkpoint_notify_id = -1, there should be no RPC invoke."); - auto* lt_var = scope->FindVar("loopup_table_path")->GetMutable(); + auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); - VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " + VLOG(4) << "RequestCheckpointHandler update var lookup_table_path to: " << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index df9cdae97df91..87a501eaa25a4 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -247,11 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); - int checkpoint_notify_id = Attr(kCheckpointBlockId); + int checkpoint_notify_block_id = Attr(kCheckpointBlockId); LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in << ", end_point:" << endpoint - << ", CheckpointNotify Id: " << checkpoint_notify_id; + << ", CheckpointNotify Id: " << checkpoint_notify_block_id; rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); @@ -260,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_prefetch_handler_.reset( new distributed::RequestPrefetchHandler(sync_mode)); request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( - sync_mode, checkpoint_notify_id)); + sync_mode, checkpoint_notify_block_id)); rpc_service_->RegisterRPC(distributed::kRequestSend, request_send_handler_.get()); @@ -276,8 +276,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, framework::Executor executor(dev_place); std::shared_ptr ckpt_pre_context = nullptr; - if (checkpoint_notify_id != -1) { - auto ctx = executor.Prepare(*program, checkpoint_notify_id); + if (checkpoint_notify_block_id != -1) { + auto ctx = executor.Prepare(*program, checkpoint_notify_block_id); ckpt_pre_context = std::move(ctx); } @@ -334,7 +334,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, SavePort(); if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, - checkpoint_notify_id); + checkpoint_notify_block_id); } else { RunAsyncLoop(&executor, program); } diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 764e3428ec4a2..ac35cf0b89bfa 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -101,7 +101,7 @@ class LoadOp : public framework::OperatorBase { class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddOutput("Out", "The tensor need to be loaded"); + AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded"); AddAttr( "load_as_fp16", "If true, the tensor will be first loaded and then " @@ -112,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { R"(Variable will be loaded from "file_path")") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); - AddComment("Load operator will load a tensor variable from disk file."); + AddComment( + "Load operator will load a LoDTensor / SelectedRows variable from disk " + "file."); } }; } // namespace operators diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 941bca1047760..bf8553ed55744 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -131,11 +132,10 @@ class SaveOp : public framework::OperatorBase { void SaveSelectedRows(const framework::Scope &scope, const platform::Place &place, framework::Variable *var) const { - auto *lt_var = - scope.FindVar("loopup_table_path")->GetMutable(); + auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable(); PADDLE_ENFORCE( lt_var != nullptr, - "Can not find variable loopup_table_path for SaveSelectedRows"); + "Can not find variable lookup_table_path for SaveSelectedRows"); std::string filename = lt_var->data(); VLOG(4) << "SaveSelectedRows get File name: " << filename; @@ -162,7 +162,7 @@ class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Save operator -This operator will serialize and write a tensor/selected rows variable to file on disk. +This operator will serialize and write LoDTensor / SelectedRows variable to file on disk. )DOC"); AddAttr("overwrite", "(boolean, default true)" @@ -186,7 +186,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference { public: void operator()(const framework::OpDesc &op_desc, framework::BlockDesc *block) const override { - auto out_var_name = op_desc.Output("loopup_table_path").front(); + auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front(); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); auto var_type = framework::proto::VarType::RAW; out_var.SetType(var_type); diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 8cc25e8623717..d7b42ef351560 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1042,6 +1042,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): main_program(Program): Find the variable named table_name in main_program pserver_id(int): the serial number in pserver_endpoints list table_name(str): lookup table name + Returns: None @@ -1188,7 +1189,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): """ - trainer will load some args from it's independent directory, + trainer will load some args from it's independent directory, such as epoch_id and step_id. Args: diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index a3f0a4ffe28f1..d9578af2a93d5 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -914,7 +914,7 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): import os pserver_program.global_block().create_var( - name="loopup_table_path", + name="lookup_table_path", persistable=True, type=core.VarDesc.VarType.RAW) @@ -923,7 +923,10 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): type='save', inputs={'X': [self.table_name]}, outputs={}, - attrs={'file_path': self.table_name}) + attrs={ + 'file_path': + "this 'file_path' do not be used in save lookup table variable" + }) return checkpoint_save_block.idx From 7fae9e0a7bb3663dc75f2f2dd9881fb8b675af9d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 22 Jun 2018 15:08:07 +0800 Subject: [PATCH 61/67] checkpoint feature optimized --- paddle/fluid/operators/detail/macros.h | 4 ---- paddle/fluid/operators/distributed/request_handler_impl.cc | 5 ++++- paddle/fluid/operators/save_op.cc | 5 ++++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/detail/macros.h b/paddle/fluid/operators/detail/macros.h index 6e9f7beb93b95..b9e385994efce 100644 --- a/paddle/fluid/operators/detail/macros.h +++ b/paddle/fluid/operators/detail/macros.h @@ -25,7 +25,3 @@ #define RPCSERVER_T distributed::AsyncBRPCServer #define RPCCLIENT_T distributed::BRPCClient #endif - -// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables -// to directory specified. -constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index cd8059a96d385..b0d42be388f74 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -20,7 +20,6 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/string/printf.h" @@ -29,6 +28,10 @@ namespace paddle { namespace operators { namespace distributed { +// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables +// to directory specified. +constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; + bool RequestSendHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index bf8553ed55744..493bb2ec89021 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -24,12 +24,15 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { +// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables +// to directory specified. +constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; + // TODO(yuyang18): If the functions below are needed by other files, move them // to paddle::filesystem namespace. constexpr char kSEP = '/'; From 4388ce112e6920a406387157cb6a3ab5c17e8f72 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 22 Jun 2018 15:19:05 +0800 Subject: [PATCH 62/67] checkpoint notify op optimized --- paddle/fluid/operators/checkpoint_notify_op.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index 7fc5b5e62214c..c4219a429a53e 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -55,10 +55,9 @@ class CheckpointNotifyOp : public framework::OperatorBase { class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddAttr>( - "epmap", - "(string vector, default 127.0.0.1:6164)" - "Server endpoints in the order of input variables for mapping") + AddAttr>("epmap", + "(string vector, default 127.0.0.1:6164)" + "Parameter Server endpoints in the order") .SetDefault({"127.0.0.1:6164"}); AddAttr( "dir", "(string, default '') indicate the folder checkpoint will use"); From b519bf05d059b5aea9bde2a6ddca7fe8170b3bf9 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 25 Jun 2018 10:02:13 +0800 Subject: [PATCH 63/67] log level optimize --- paddle/fluid/operators/distributed/grpc_server.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc index 363614df4f92c..9289139b133d8 100644 --- a/paddle/fluid/operators/distributed/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -263,8 +263,7 @@ void AsyncGRPCServer::StartServer() { reqs.reserve(kRequestBufSize); for (int i = 0; i < kRequestBufSize; i++) { - LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name - << " I: " << i; + VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i; TryToRegisterNewOne(rpc_name, i); } @@ -351,7 +350,7 @@ void AsyncGRPCServer::HandleRequest( while (true) { VLOG(3) << "HandleRequest " << rpc_name << " wait next"; if (!cq->Next(&tag, &ok)) { - LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!"; + VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!"; break; } From dc847f129eac535f5738c107f8723fb1b3a461de Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 25 Jun 2018 20:35:45 +0800 Subject: [PATCH 64/67] bug fix and code optimize --- .../distributed/request_handler_impl.cc | 4 ++-- paddle/fluid/operators/listen_and_serv_op.cc | 19 ++++++++++--------- paddle/fluid/operators/save_op.cc | 4 ++-- .../fluid/transpiler/distribute_transpiler.py | 8 +++----- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index b0d42be388f74..163154c678f65 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -30,7 +30,7 @@ namespace distributed { // define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables // to directory specified. -constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; +constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; bool RequestSendHandler::Handle(const std::string& varname, framework::Scope* scope, @@ -136,7 +136,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); - VLOG(4) << "RequestCheckpointHandler update var lookup_table_path to: " + VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 6663459939789..66c9ed6ddf7df 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -206,7 +206,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, VLOG(3) << "RunAsyncLoop into while"; while (true) { if (rpc_service_->IsExit()) { - LOG(INFO) << "get exit!rpc_processor break!"; + VLOG(4) << "get exit!rpc_processor break!"; break; } @@ -245,11 +245,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); - int checkpoint_notify_block_id = Attr(kCheckpointBlockId); + int checkpoint_block_id = Attr(kCheckpointBlockId); - LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in - << ", end_point:" << endpoint - << ", CheckpointNotify Id: " << checkpoint_notify_block_id; + VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in + << ", end_point:" << endpoint + << ", checkpoint_block_id: " << checkpoint_block_id; rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); @@ -258,7 +258,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_prefetch_handler_.reset( new distributed::RequestPrefetchHandler(sync_mode)); request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( - sync_mode, checkpoint_notify_block_id)); + sync_mode, checkpoint_block_id)); rpc_service_->RegisterRPC(distributed::kRequestSend, request_send_handler_.get()); @@ -277,8 +277,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, framework::Executor executor(dev_place); std::shared_ptr ckpt_pre_context = nullptr; - if (checkpoint_notify_block_id != -1) { - auto ctx = executor.Prepare(*program, checkpoint_notify_block_id); + if (checkpoint_block_id != -1) { + auto ctx = executor.Prepare(*program, checkpoint_block_id); + // see: https://stackoverflow.com/a/14856553 ckpt_pre_context = std::move(ctx); } @@ -335,7 +336,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, SavePort(); if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, - checkpoint_notify_block_id); + checkpoint_block_id); } else { RunAsyncLoop(&executor, program); } diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 493bb2ec89021..201a51130d6b6 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -31,7 +31,7 @@ namespace operators { // define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables // to directory specified. -constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; +constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; // TODO(yuyang18): If the functions below are needed by other files, move them // to paddle::filesystem namespace. @@ -138,7 +138,7 @@ class SaveOp : public framework::OperatorBase { auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable(); PADDLE_ENFORCE( lt_var != nullptr, - "Can not find variable lookup_table_path for SaveSelectedRows"); + "Can not find variable kLookupTablePath for SaveSelectedRows"); std::string filename = lt_var->data(); VLOG(4) << "SaveSelectedRows get File name: " << filename; diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d74fdbf17f326..b15ef4cb4c709 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -920,19 +920,17 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): import os pserver_program.global_block().create_var( - name="lookup_table_path", + name="kLookupTablePath", persistable=True, type=core.VarDesc.VarType.RAW) checkpoint_save_block = pserver_program.create_block(pre_block_idx) + # this 'file_path' do not be used in save lookup table variable checkpoint_save_block.append_op( type='save', inputs={'X': [self.table_name]}, outputs={}, - attrs={ - 'file_path': - "this 'file_path' do not be used in save lookup table variable" - }) + attrs={'file_path': ""}) return checkpoint_save_block.idx From 33ff69b6218b18383da591d95b35b2105b34d56d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 25 Jun 2018 20:41:36 +0800 Subject: [PATCH 65/67] file path can not be empty --- python/paddle/fluid/transpiler/distribute_transpiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index b15ef4cb4c709..ad9631c72985c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -930,7 +930,7 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): type='save', inputs={'X': [self.table_name]}, outputs={}, - attrs={'file_path': ""}) + attrs={'file_path': "none"}) return checkpoint_save_block.idx From 88cb5d79f23906217c8f58b2dcc2771a1ac06ea1 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 26 Jun 2018 10:57:08 +0800 Subject: [PATCH 66/67] add doc --- python/paddle/fluid/io.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index d7b42ef351560..d94564e11f982 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1252,6 +1252,9 @@ def _is_checkpoint_var(var): def _make_chekcpoint_dirs(dirs): + """ + _make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it. + """ assert dirs is not None if os.path.isfile(dirs): From f57978e6b5f4af4a8d9a8d41d18e369ddbb89892 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 26 Jun 2018 12:51:40 +0800 Subject: [PATCH 67/67] renae --- paddle/fluid/operators/distributed/grpc_client.h | 5 ++--- paddle/fluid/operators/distributed/rpc_client.h | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 102b4e5bcf72c..eab9fc7e866ef 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -211,9 +211,8 @@ class GRPCClient : public RPCClient { void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) override; - void AsyncCheckpointNotify( - const std::string& ep, const std::string& dir, - int64_t time_out = RPCClient::rpc_time_out) override; + void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, + int64_t time_out = FLAGS_grpc_deadline) override; void Wait() override; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 84bef0ab2341a..4ce01287aa847 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -56,9 +56,9 @@ class RPCClient { virtual void AsyncSendFetchBarrier( const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0; - virtual void AsyncCheckpointNotify(const std::string& ep, - const std::string& dir, - int64_t time_out = rpc_time_out) = 0; + virtual void AsyncCheckpointNotify( + const std::string& ep, const std::string& dir, + int64_t time_out = FLAGS_grpc_deadline) = 0; // SendComplete tells all the server that current trainer have no more data // to train, so that the pserver can reduce it's barrier count, and continue