Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint M2: lookup table checkpoint #11490

Merged
merged 75 commits into from
Jun 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
4170196
[wip] ckpt m2 develop
seiriosPlus Jun 13, 2018
a895916
[wip] add load lookup table in io and trianer
seiriosPlus Jun 14, 2018
8a17816
add lookuo table in python
seiriosPlus Jun 14, 2018
12de20f
add checkpoint_notify_op for trainer to notify pserver, update listen…
seiriosPlus Jun 14, 2018
b089b80
update rpc to add checkpoint notify
seiriosPlus Jun 14, 2018
bb17604
bug fix
seiriosPlus Jun 15, 2018
1cb0ab3
bug fix
seiriosPlus Jun 15, 2018
fb27c9a
bug fix
seiriosPlus Jun 15, 2018
fe76244
bug fix
seiriosPlus Jun 15, 2018
98c30c7
bug fix
seiriosPlus Jun 15, 2018
f224948
bug fix
seiriosPlus Jun 15, 2018
8d46d1d
bug fix
seiriosPlus Jun 15, 2018
860360d
bug fix
seiriosPlus Jun 15, 2018
1c2e9bd
fix cmakelist
seiriosPlus Jun 15, 2018
985026c
add checkpoint_notify in python
seiriosPlus Jun 15, 2018
925e232
add RequestCheckpointNotify in grpc
seiriosPlus Jun 18, 2018
a9ac200
add RequestCheckpointNotify in grpc
seiriosPlus Jun 18, 2018
36d17d1
add RequestCheckpointNotify in grpc
seiriosPlus Jun 18, 2018
74384b7
add RequestCheckpointNotify in grpc
seiriosPlus Jun 18, 2018
050b66e
add RequestCheckpointNotify in grpc
seiriosPlus Jun 18, 2018
54013a9
add RequestCheckpointNotify in grpc
seiriosPlus Jun 18, 2018
15532c7
add RequestCheckpointNotify in grpc
seiriosPlus Jun 18, 2018
bbb349f
add RequestCheckpointNotify in grpc
seiriosPlus Jun 18, 2018
527b86b
bug fix
seiriosPlus Jun 18, 2018
85215df
move checkpoint message to variable message
seiriosPlus Jun 18, 2018
8af8da4
move checkpoint message to variable message
seiriosPlus Jun 18, 2018
5553adf
move checkpoint message to variable message
seiriosPlus Jun 18, 2018
752eb08
move checkpoint message to variable message
seiriosPlus Jun 18, 2018
3088084
merge develop
seiriosPlus Jun 18, 2018
ae12281
checkpoint notify
seiriosPlus Jun 18, 2018
af0a6a1
checkpoint notify
seiriosPlus Jun 18, 2018
bb10c37
merge
seiriosPlus Jun 18, 2018
549f0aa
load op add seletedRows
seiriosPlus Jun 18, 2018
a501766
load op add seletedRows
seiriosPlus Jun 18, 2018
ca27f78
load op add seletedRows
seiriosPlus Jun 19, 2018
ee64f57
load op add seletedRows
seiriosPlus Jun 19, 2018
1296d96
add raw clone
seiriosPlus Jun 19, 2018
620698e
bug fux
seiriosPlus Jun 19, 2018
459690a
bug fux
seiriosPlus Jun 19, 2018
5250ca8
bug fux
seiriosPlus Jun 19, 2018
bccf8df
bug fix
seiriosPlus Jun 19, 2018
7efd73a
code clean
seiriosPlus Jun 19, 2018
1571c25
code style fix
seiriosPlus Jun 19, 2018
d93dc81
add handle when checkpoint_notify_id = -1
seiriosPlus Jun 19, 2018
8c0e1d5
unittest case fix
seiriosPlus Jun 19, 2018
49c2d0c
bug fix
seiriosPlus Jun 19, 2018
16ecead
load op optimize
seiriosPlus Jun 19, 2018
6abf076
checkpoint_notify_id rename
seiriosPlus Jun 19, 2018
28482f8
bug fix
seiriosPlus Jun 19, 2018
06f6c21
bug fix
seiriosPlus Jun 19, 2018
5600b13
bug fix
seiriosPlus Jun 19, 2018
32fa832
code style
seiriosPlus Jun 19, 2018
8af4d4c
code style
seiriosPlus Jun 19, 2018
db6126c
code style
seiriosPlus Jun 20, 2018
5a4a24c
Merge branch 'develop' into ckpt_m2
seiriosPlus Jun 20, 2018
91eae9c
code style
seiriosPlus Jun 20, 2018
298588f
Merge branch 'ckpt_m2' of github.com:seiriosPlus/Paddle into ckpt_m2
seiriosPlus Jun 20, 2018
c073bb3
code style
seiriosPlus Jun 20, 2018
05bd9db
add comments in io.py
seiriosPlus Jun 20, 2018
e589005
merge
seiriosPlus Jun 21, 2018
9764844
merge develop
seiriosPlus Jun 21, 2018
620999c
save checkpoint bug fix
seiriosPlus Jun 21, 2018
8e01f3b
bug fix
seiriosPlus Jun 21, 2018
2229db5
pserver_id init value to None
seiriosPlus Jun 22, 2018
e684575
checkpoint feature optimized
seiriosPlus Jun 22, 2018
7fae9e0
checkpoint feature optimized
seiriosPlus Jun 22, 2018
4388ce1
checkpoint notify op optimized
seiriosPlus Jun 22, 2018
b519bf0
log level optimize
seiriosPlus Jun 25, 2018
fb7e479
merger paddle develop
seiriosPlus Jun 25, 2018
dc847f1
bug fix and code optimize
seiriosPlus Jun 25, 2018
33ff69b
file path can not be empty
seiriosPlus Jun 25, 2018
88cb5d7
add doc
seiriosPlus Jun 26, 2018
b6e6355
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into ckpt_m2
seiriosPlus Jun 26, 2018
fa3d470
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into ckpt_m2
seiriosPlus Jun 26, 2018
f57978e
renae
seiriosPlus Jun 26, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ if(WITH_DISTRIBUTE)
endif()

set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
foreach(dist_op "prefetch_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
foreach(dist_op "prefetch_op" "checkpoint_notify_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endforeach()
Expand All @@ -216,7 +216,7 @@ if(WITH_DISTRIBUTE)
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
else()
set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif()

op_library(cross_entropy_op DEPS cross_entropy)
Expand Down
88 changes: 88 additions & 0 deletions paddle/fluid/operators/checkpoint_notify_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <future> // NOLINT
#include <ostream>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace operators {

class CheckpointNotifyOp : public framework::OperatorBase {
public:
CheckpointNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");

distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir);
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i];
}
rpc_client->Wait();
}
};

class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>(
"dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name");
AddComment(R"DOC(
CheckpointNotify operator

This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server.
)DOC");
}
};

class CheckpointNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp,
paddle::framework::EmptyGradOpMaker,
ops::CheckpointNotifyOpMaker,
ops::CheckpointNotifyOpShapeInference);
17 changes: 17 additions & 0 deletions paddle/fluid/operators/distributed/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,23 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
req_count_++;
}

void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
const auto ch = GetChannel(ep);

CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s->Prepare(time_out);

sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);

auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}

void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/operators/distributed/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};

class CheckpointNotifyProcessor : public BaseProcessor {
public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}

virtual ~CheckpointNotifyProcessor() {}

virtual void Process() {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};

class GRPCClient : public RPCClient {
public:
GRPCClient() {}
Expand All @@ -197,6 +211,9 @@ class GRPCClient : public RPCClient {
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override;

void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_grpc_deadline) override;

void Wait() override;

void SendComplete() override;
Expand Down
48 changes: 45 additions & 3 deletions paddle/fluid/operators/distributed/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,45 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* local_scope_;
};

class RequestCheckpointNotify final : public RequestBase {
public:
explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx()));
int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}

virtual ~RequestCheckpointNotify() {}

std::string GetReqName() override { return request_->Varname(); }

void Process() override {
auto scope = request_->GetMutableLocalScope();

std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();

VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;

request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir);
Finish(reply_, &responder_);
}

protected:
std::shared_ptr<VariableResponse> request_;
sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
Expand Down Expand Up @@ -237,6 +276,7 @@ void AsyncGRPCServer::StartServer() {
reqs.reserve(kRequestBufSize);

for (int i = 0; i < kRequestBufSize; i++) {
VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
TryToRegisterNewOne(rpc_name, i);
}

Expand Down Expand Up @@ -289,8 +329,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
return;
}

VLOG(4) << "register send rpc_name:" << rpc_name
<< ", handler:" << rpc_call_map_[kRequestSend];
VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
<< " REQ ID: " << req_id;

auto& reqs = rpc_reqs_[rpc_name];
auto& handler = rpc_call_map_[rpc_name];
Expand All @@ -303,6 +343,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not supported rpc");
}
Expand All @@ -321,7 +363,7 @@ void AsyncGRPCServer::HandleRequest(
while (true) {
VLOG(4) << "HandleRequest " << rpc_name << " wait next";
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!";
VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!";
break;
}

Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/distributed/grpc_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ enum class GrpcMethod {
kSendVariable,
kGetVariable,
kPrefetchVariable,
kCheckpointNotify,
};

static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
static_cast<int>(GrpcMethod::kCheckpointNotify) + 1;

inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
Expand All @@ -93,6 +94,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify";
}

// Shouldn't be reached.
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/distributed/request_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@ namespace distributed {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"

#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

class RPCServer;

class RequestHandler {
Expand Down Expand Up @@ -69,6 +73,11 @@ class RequestHandler {
prefetch_var_name_to_prepared_ctx_ = g;
}

void SetCheckpointNotifyPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
checkpoint_prepared_ctx_ = g;
}

// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
Expand Down Expand Up @@ -115,6 +124,8 @@ class RequestHandler {
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// used for checkpoint notify
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;

// Used for async.
std::unordered_map<std::string,
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/operators/distributed/request_handler_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace operators {
namespace distributed {

// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";

bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
Expand Down Expand Up @@ -119,6 +124,24 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true;
}

bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const std::string& out_var_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");

auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true;
}

} // namespace distributed
} // namespace operators
} // namespace paddle
15 changes: 15 additions & 0 deletions paddle/fluid/operators/distributed/request_handler_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override;
};

class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id)
: RequestHandler(sync_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;

private:
int checkpoint_notify_id;
};

} // namespace distributed
} // namespace operators
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/operators/distributed/rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class RPCClient {
virtual void AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;

virtual void AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_grpc_deadline) = 0;

// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/distributed/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ service SendRecvService {
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}

rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
}

// VariableMessage is serialized paddle variable message.
Expand Down
Loading