Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add federated learning parameter server(fl-ps) mode #42682

Merged
merged 41 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d9bb853
back fl
ziyoujiyi Mar 25, 2022
6073452
delete ssl cert
ziyoujiyi Mar 25, 2022
66fa8c8
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 25, 2022
4bb3d3f
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 25, 2022
7a02e84
.
ziyoujiyi Mar 25, 2022
883b55a
make warning
ziyoujiyi Mar 26, 2022
f917402
.
ziyoujiyi Mar 26, 2022
fa4ab2e
unittest paral degree
ziyoujiyi Mar 28, 2022
a129afc
solve unittest
ziyoujiyi Mar 28, 2022
a54e061
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 29, 2022
ed7e38f
heter & multi cloud commm ready
ziyoujiyi Mar 29, 2022
3e86455
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 29, 2022
b5a34fc
.
ziyoujiyi Mar 29, 2022
0e4b998
Merge branch 'develop' of https://github.com/ziyoujiyi/Paddle into de…
ziyoujiyi Mar 29, 2022
eeec283
.
ziyoujiyi Mar 29, 2022
d293d97
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 29, 2022
c1759b5
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 30, 2022
d9aa775
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 31, 2022
7105730
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 2, 2022
73ea318
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 11, 2022
7dc2091
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 19, 2022
2019a5f
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 24, 2022
f22bbcd
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 26, 2022
5019c73
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi May 9, 2022
9b92deb
fl-ps v1.0
ziyoujiyi May 9, 2022
31f330c
merge dev
ziyoujiyi May 9, 2022
f2fa8ee
.
ziyoujiyi May 9, 2022
6c76994
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi May 11, 2022
7aadb99
support N + N mode
ziyoujiyi May 11, 2022
001c11c
Merge branch 'develop' of https://github.com/ziyoujiyi/Paddle into fl_ps
ziyoujiyi May 11, 2022
5f7b4fd
.
ziyoujiyi May 11, 2022
a6f7f29
.
ziyoujiyi May 11, 2022
cbbd5e9
.
ziyoujiyi May 12, 2022
2873622
.
ziyoujiyi May 13, 2022
16ad3c1
delete print
ziyoujiyi May 24, 2022
9a89ba3
.
ziyoujiyi May 25, 2022
2469beb
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi May 25, 2022
acc3898
merge dev
ziyoujiyi May 25, 2022
3c5374d
.
ziyoujiyi May 30, 2022
07bf8ab
.
ziyoujiyi May 30, 2022
25f38c1
.
ziyoujiyi May 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ option(WITH_POCKETFFT "Compile with pocketfft support" ON)
option(WITH_RECORD_BUILDTIME "Compile PaddlePaddle with record all targets build time" OFF)
option(WITH_CUSTOM_DEVICE "Compile with custom device support" OFF)
option(WITH_ARM_BRPC "Supprot Brpc in Arm" OFF)
option(WITH_FLPS "FL PS mode" OFF)

if(WITH_RECORD_BUILDTIME)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh ${CMAKE_CURRENT_BINARY_DIR}")
Expand Down
4 changes: 4 additions & 0 deletions cmake/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ if(WITH_ARM_BRPC)
add_definitions(-DPADDLE_WITH_ARM_BRPC)
endif()

if(WITH_FLPS)
add_definitions(-DPADDLE_WITH_FLPS)
endif()

if(WITH_GLOO)
add_definitions(-DPADDLE_WITH_GLOO)
endif()
Expand Down
Empty file modified paddle/fluid/distributed/ps/service/brpc_ps_server.cc
100644 → 100755
Empty file.
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/ps/service/heter_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ void HeterClient::SendAndRecvAsync(
message_name, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);

int micro_id = GetMicroId(ctx, p_scope);
int micro_id = GetMicroId(ctx, p_scope); // global
auto minibatch_id = micro_id / 10;
VLOG(4) << "micro_id: " << micro_id;
// select channel according to micro id
if (mode == "forward") {
int num = minibatch_id % xpu_channels_.size();
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/ps/service/heter_client.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ class HeterClient {

// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoint,
const std::vector<std::string>& previous_endpoint,
const std::vector<std::string>& endpoints,
const std::vector<std::string>& previous_endpoints,
const int& trainer_id) {
if (NULL == s_instance_) {
s_instance_.reset(new HeterClient());
s_instance_->SetXpuList(endpoint);
s_instance_->SetPreviousXpuList(previous_endpoint);
s_instance_->SetXpuList(endpoints);
s_instance_->SetPreviousXpuList(previous_endpoints);
s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection();
}
Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/distributed/ps/service/heter_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ void HeterServer::StartHeterInterService(bool neeed_encrypt) {
VLOG(4) << "switch inter server server start success! listen on "
<< endpoint_inter_;
}

{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
Expand All @@ -115,9 +114,6 @@ void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
while (!this->ready_) {
sleep(1);
}
}

int SendAndRecvVariableHandler::SaveInSwitchWithShard(
Expand Down
13 changes: 9 additions & 4 deletions paddle/fluid/distributed/ps/service/heter_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ class ServiceHandlerBase {

using SharedMiniScope =
std::shared_ptr<std::unordered_map<int, ::paddle::framework::Scope*>>;

using SharedMicroScope = std::shared_ptr<std::unordered_map<
int, std::shared_ptr<std::vector<::paddle::framework::Scope*>>>>;

using SharedTaskQueue = std::shared_ptr<
std::unordered_map<int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
Expand Down Expand Up @@ -226,6 +228,7 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
auto* tensor = var->GetMutable<framework::LoDTensor>();
auto data = reinterpret_cast<const float*>(tensor->data());
auto micro_id = static_cast<int>(data[0]);
VLOG(4) << "micro_id in heter server: " << micro_id;
int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10;

Expand Down Expand Up @@ -261,6 +264,9 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, micro_scope);
// blocking queue handles multi thread
VLOG(4) << "Handle in HeterServer: " << message_name << ", "
<< microbatch_index;
VLOG(4) << "task_queue_ size: " << task_queue_->size();
(*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index));

Expand All @@ -274,6 +280,7 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name, response_var_names, empty_var_names, *dev_ctx_,
&local_scope, response, &response_io_buffer);
VLOG(4) << "Handle over";
return 0;
}

Expand Down Expand Up @@ -612,11 +619,9 @@ class HeterServer {

// HeterWrapper singleton
static std::shared_ptr<HeterServer> GetInstance() {
std::unique_lock<std::mutex> lock(mtx_);
if (s_instance_ == nullptr) {
std::unique_lock<std::mutex> lock(mtx_);
if (NULL == s_instance_) {
s_instance_.reset(new HeterServer());
}
s_instance_.reset(new HeterServer());
}
return s_instance_;
}
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/framework/data_feed.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ bool DataFeed::PickOneFile(std::string* filename) {
file_idx_, platform::errors::PreconditionNotMet(
"You should call SetFileListIndex before PickOneFile"));
std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
VLOG(4) << "filelist_ size: " << filelist_.size();
if (*file_idx_ == filelist_.size()) {
VLOG(3) << "DataFeed::PickOneFile no more file to pick";
return false;
Expand Down Expand Up @@ -284,6 +285,7 @@ void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {

template <typename T>
bool PrivateQueueDataFeed<T>::Start() {
VLOG(4) << "entering PrivateQueueDataFeed<T>::Start()";
CheckSetFileList();
read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this);
read_thread_.detach();
Expand All @@ -295,6 +297,7 @@ bool PrivateQueueDataFeed<T>::Start() {
template <typename T>
void PrivateQueueDataFeed<T>::ReadThread() {
#ifdef _LINUX
VLOG(4) << "entering PrivateQueueDataFeed<T>::ReadThread()";
std::string filename;
while (PickOneFile(&filename)) {
int err_no = 0;
Expand Down Expand Up @@ -356,6 +359,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
template <typename T>
bool InMemoryDataFeed<T>::Start() {
#ifdef _LINUX
VLOG(4) << "entering InMemoryDataFeed<T>::Start()";
this->CheckSetFileList();
if (output_channel_->Size() == 0 && input_channel_->Size() != 0) {
std::vector<T> data;
Expand Down Expand Up @@ -664,6 +668,7 @@ void MultiSlotDataFeed::Init(

void MultiSlotDataFeed::ReadThread() {
#ifdef _LINUX
VLOG(4) << "entering MultiSlotDataFeed::ReadThread()";
std::string filename;
while (PickOneFile(&filename)) {
int err_no = 0;
Expand Down Expand Up @@ -831,7 +836,6 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
} else {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);

const char* str = reader.get();
std::string line = std::string(str);

Expand Down Expand Up @@ -971,10 +975,13 @@ void MultiSlotDataFeed::PutToFeedVec(
if (feed_vec_[i] == nullptr) {
continue;
}
VLOG(4) << "MultiSlotDataFeed::PutToFeedVec i: " << i;
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());

VLOG(4) << "total_instance: " << total_instance;
// platform::CPUPlace()
VLOG(4) << "this->place_: " << this->place_;
if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData();
float* tensor_ptr =
Expand Down Expand Up @@ -2573,6 +2580,7 @@ void SlotRecordInMemoryDataFeed::ExpandSlotRecord(SlotRecord* rec) {
}

bool SlotRecordInMemoryDataFeed::Start() {
VLOG(4) << "entering SlotRecordInMemoryDataFeed::Start";
#ifdef _LINUX
this->CheckSetFileList();
if (input_channel_->Size() != 0) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ message DistributedStrategy {
optional bool adam_d2sum = 36 [ default = false ];
optional bool auto_search = 37 [ default = false ];
optional bool heter_ccl_mode = 38 [ default = false ];
optional bool is_fl_ps_mode = 39 [ default = false ];

optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
Expand Down
61 changes: 43 additions & 18 deletions paddle/fluid/framework/heter_pipeline_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ using TaskQueue =
std::pair<std::string, int>>>>;

void HeterPipelineTrainer::ResetDataset(Dataset* dataset) {
#ifndef PADDLE_WITH_FLPS
if (pipeline_stage_ == 0) {
#endif
SetDataset(dataset);
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
Expand All @@ -51,40 +53,39 @@ void HeterPipelineTrainer::ResetDataset(Dataset* dataset) {
this_worker->SetDataFeed(readers[cnt]);
this_worker->SetReaderPlace(place_);
}
#ifndef PADDLE_WITH_FLPS
}
#endif
}

void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
trainer_desc_ = trainer_desc;
thread_num_ = trainer_desc.thread_num();
ParseDumpConfig(trainer_desc);
SetDebug(trainer_desc.debug());
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
// change thread num to readers num
thread_num_ = readers.size();
VLOG(3) << "worker thread num: " << thread_num_;
VLOG(3) << "worker(readers) thread num: " << thread_num_;
const auto& heter_section_params = trainer_desc.heter_section_param();
num_pipeline_stages_ = heter_section_params.num_pipeline_stages();
pipeline_stage_ = heter_section_params.pipeline_stage();
num_microbatches_ = heter_section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc;
trainer_id_ = trainer_desc.trainer_id();
for (int i = 0; i < num_pipeline_stages_; ++i) {
auto trainer_num = trainer_desc.trainers(i);
trainers_.push_back(trainer_num);
}
int cpu_trainer_num = trainers_[0];
// int cur_stage_trainer_num = trainers_[pipeline_stage_];
// int global_thread_num = cpu_trainer_num * thread_num_;
// int previous_trainers = 0;
// for (int i = 0; i < pipeline_stage_; i++) previous_trainers +=
// trainers_[i];
// int stage_trainer_id =
// trainer_id_ - previous_trainers; // trainer id in current stage

VLOG(4) << "trainer_id_: " << trainer_id_;
VLOG(4) << "cpu_trainer_num: " << cpu_trainer_num
<< " xpu_trainer_num: " << trainers_[1];
#ifdef PADDLE_WITH_FLPS
thread_num_ = 1;
#endif
if (pipeline_stage_ == 0) { // for cpu trainer
int cnt = -1;
int real_thread_id = trainer_id_;
Expand All @@ -103,25 +104,33 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDeviceIndex(real_thread_id);
real_thread_id += cpu_trainer_num;
// if (pipeline_stage_ == 0) {
this_worker->SetDataFeed(readers[cnt]);
//}
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
}
} else { // for heter_trainer
// heter trainer with thread_id == -1 is not for
// real training
} else {
// for heter_trainer
// heter trainer with thread_id == -1 is not for real training, just for run
// listen op
workers_[-1] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[-1]);
#ifdef PADDLE_WITH_FLPS
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDataFeed(readers[0]);
#endif
this_worker->SetDeviceIndex(-1);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetDeviceIndex(-1);
}
}

Expand Down Expand Up @@ -159,14 +168,19 @@ void HeterPipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
for (auto& worker_pair : workers_) {
auto worker_index = worker_pair.first;
auto device_worker = worker_pair.second;
VLOG(0) << "workers index in InitTrainerEnv: " << worker_index;
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
device_worker);
this_worker->SetPlace(place);
this_worker->Initialize(trainer_desc_);
#ifdef PADDLE_WITH_FLPS
this_worker->SetReaderPlace(place);
#else
if (pipeline_stage_ == 0) {
this_worker->SetReaderPlace(place);
}
#endif
this_worker->SetRootScope(root_scope_);
// generate mini_batch scope for every worker
auto* minibatch_scope = &root_scope_->NewScope();
Expand All @@ -175,13 +189,15 @@ void HeterPipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
// after set micro num & mini batch scope
this_worker->CreateMicrobatchScopes();
(*micro_scopes_)[worker_index] = this_worker->GetMicrobatchScopes();
VLOG(4) << "worker_index: " << worker_index;
(*task_queue_)[worker_index] = this_worker->GetThreadQueue();
}
}

void HeterPipelineTrainer::Run() {
VLOG(3) << "Going to run HeterPipelineTrainer::Run()";
if (listen_ptr_ == nullptr) {
VLOG(3) << "listen_ptr_ is null";
for (auto& worker_pair : workers_) {
auto& device_worker = worker_pair.second;
auto worker_0 =
Expand All @@ -196,10 +212,14 @@ void HeterPipelineTrainer::Run() {
heter_server->WaitServerReady();
heter_server->SetMiniBatchScopes(mini_scopes_);
heter_server->SetMicroBatchScopes(micro_scopes_);
VLOG(4) << "heter_server SetTaskQueue";
heter_server->SetTaskQueue(task_queue_);

// main training logic
VLOG(3) << "pipeline_stage_ is " << pipeline_stage_;
if (pipeline_stage_ == 0) { // for cpu trainer
for (auto& worker_pair : workers_) {
VLOG(4) << "cpu worker index : " << worker_pair.first;
auto device_worker = worker_pair.second;
if (!debug_) {
threads_.push_back(
Expand All @@ -212,6 +232,7 @@ void HeterPipelineTrainer::Run() {
} else { // for heter worker
// start thread_worker with thread_id = -1
for (auto& worker_pair : workers_) {
VLOG(4) << "xpu worker index : " << worker_pair.first;
auto device_worker = worker_pair.second;
if (!debug_) {
threads_.push_back(
Expand Down Expand Up @@ -252,6 +273,10 @@ void HeterPipelineTrainer::Run() {
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetPlace(place_);
#ifdef PADDLE_WITH_FLPS
this_worker->SetDataFeed(workers_[-1]->device_reader_);
this_worker->SetReaderPlace(place_);
#endif
this_worker->Initialize(trainer_desc_);
this_worker->SetRootScope(root_scope_);

Expand Down Expand Up @@ -308,5 +333,5 @@ Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) {
}

} // end namespace framework
} // end namespace paddle
} // namespace paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不应该去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#endif
Loading