Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Aug 25, 2021
2 parents 06e2a07 + 587235c commit a88ed02
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 86 deletions.
60 changes: 41 additions & 19 deletions paddle/fluid/distributed/fleet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,15 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
}

// check zcb
VLOG(3) << "pull sparse check zcb\n";
for (int i = 0; i < fea_keys.size(); ++i) {
VLOG(3) << "key " << fea_keys[i] << ": ";
for (int j = 0; j < fea_dim; ++j) {
VLOG(3) << pull_result_ptr[i][j] << " ";
/*
std::cout << "pull sparse check zcb\n";
for (int i = 0; i < fea_keys.size(); ++ i) {
std::cout << "key " << fea_keys[i] << ": ";
for (int j = 0; j < fea_dim; ++ j) {
std::cout << pull_result_ptr[i][j] << " ";
}
VLOG(3) << "\n";
}
std::cout << "\n";
}*/
}

void FleetWrapper::PullDenseVarsAsync(
Expand Down Expand Up @@ -380,12 +381,31 @@ void FleetWrapper::PushDenseVarsAsync(
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* push_sparse_status, float scale_datanorm,
int batch_size) {
auto* communicator = Communicator::GetInstance();
PADDLE_ENFORCE_EQ(
communicator->Check(table_id), true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
communicator->Send(var_names, scope);
auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id "
<< table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] "
<< g[tensor->numel() - 1];
}

auto* communicator =
dynamic_cast<AsyncCommunicator*>(Communicator::GetInstance());
// PADDLE_ENFORCE_EQ(
// communicator->Check(table_id), true,
// platform::errors::InvalidArgument(
// "can not find table: %s, please check your config", table_id));
// communicator->Send(var_names, scope);
auto push_status = communicator->_worker_ptr->push_dense(
regions.data(), regions.size(), table_id);

communicator->PushDensePostProcessing();
}

void FleetWrapper::PushSparseVarsAsync(
Expand Down Expand Up @@ -519,18 +539,20 @@ void FleetWrapper::PushSparseFromTensorAsync(
}
}
}
VLOG(0) << "output_len: " << output_len << " g.size(): " << g.size();
VLOG(1) << "output_len: " << output_len << " g.size(): " << g.size();
CHECK(output_len == g.size());

std::vector<float*> push_g_vec(input_idx, nullptr);
VLOG(3) << "zcb debug push sparse\n";
/*
std::cout << "zcb debug push sparse\n";
for (auto i = 0u; i < push_keys.size(); ++i) {
push_g_vec[i] = push_values.at(i).data();
VLOG(3) << "key: " << push_keys[i] << " ";
for (int j = 0; j < fea_dim + 3; ++j) VLOG(3) << push_g_vec[i][j] << " ";
VLOG(3) << "\n";
}
std::cout << "key: " << push_keys[i] << " ";
for (int j = 0; j < fea_dim + 3; ++ j)
std::cout << push_g_vec[i][j] << " ";
std::cout << "\n";
}*/
auto* communicator = Communicator::GetInstance();
PADDLE_ENFORCE_EQ(
communicator->Check(table_id), true,
Expand Down
117 changes: 76 additions & 41 deletions paddle/fluid/distributed/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1433,9 +1433,9 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
std::make_shared<CostTimer>("pslib_downpour_client_push_dense_parse");
int push_dense_async_num = _push_dense_task_queue_map[table_id]->size();
while (push_dense_async_num > FLAGS_pslib_max_async_call_num) {
// LOG(INFO) << "push_dense Waiting for async_call_num comsume, task_num:"
// << push_dense_async_num << ", max_task_limit:" <<
// FLAGS_pslib_max_async_call_num;
LOG(INFO) << "push_dense Waiting for async_call_num comsume, task_num:"
<< push_dense_async_num
<< ", max_task_limit:" << FLAGS_pslib_max_async_call_num;
usleep(5000); // 5ms
push_dense_async_num = _push_dense_task_queue_map[table_id]->size();
}
Expand Down Expand Up @@ -1532,53 +1532,88 @@ void BrpcPsClient::push_dense_task_consume() {
for (int i = 0; i < merge_count; ++i) {
merge_status[i].wait();
}

VLOG(1) << "BrpcPsClient::push_dense_task_consume before merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
<< total_send_data[0] << " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1];
if (scale_gradient && merge_count > 1) {
Eigen::Map<Eigen::MatrixXf> mat(total_send_data, 1,
total_send_data_size);
mat *= (1.0 / merge_count);
mat *= (1.0 / (merge_count + 1));
}
}

push_dense_raw_gradient(task, total_send_data, total_send_data_size,
closure);
}
auto wait_ms =
FLAGS_pslib_async_push_dense_interval_ms - (timeline.ElapsedMS());
if (wait_ms > 0) {
usleep(wait_ms * 1000);
VLOG(1) << "BrpcPsClient::push_dense_task_consume after merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
<< " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1] << " merge_count "
<< merge_count;

DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
}
push_dense_raw_gradient(task, total_send_data, total_send_data_size,
closure);
}
if (scale_gradient && merge_count > 1) {
Eigen::Map<Eigen::MatrixXf> mat(total_send_data, 1,
total_send_data_size);
mat *= (1.0 / merge_count);
}
}

push_dense_raw_gradient(task, total_send_data, total_send_data_size,
closure);
}
auto wait_ms =
FLAGS_pslib_async_push_dense_interval_ms - (timeline.ElapsedMS());
if (wait_ms > 0) {
usleep(wait_ms * 1000);
}
}
}
}

void BrpcPsClient::push_dense_raw_gradient(
std::shared_ptr<DenseAsyncTask> &task, float *total_send_data,
size_t total_send_data_size, DownpourBrpcClosure *closure) {
auto *accessor = table_accessor(task->table_id());
size_t request_call_num = _server_channels.size();
//将数据拷贝到请求buffer区
auto timer =
std::make_shared<CostTimer>("pslib_downpour_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(task->table_id());
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard, num_per_shard * sizeof(float));
closure->cntl(i)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
void BrpcPsClient::push_dense_raw_gradient(
std::shared_ptr<DenseAsyncTask> & task, float *total_send_data,
size_t total_send_data_size, DownpourBrpcClosure *closure) {
auto *accessor = table_accessor(task->table_id());
size_t request_call_num = _server_channels.size();
//将数据拷贝到请求buffer区
auto timer =
std::make_shared<CostTimer>("pslib_downpour_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(task->table_id());
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard,
num_per_shard * sizeof(float));
closure->cntl(i)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
}
}

} // namespace distributed
} // namespace paddle
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ class BrpcPsClient : public PSClient {
public:
BrpcPsClient() {}
virtual ~BrpcPsClient() {
VLOG(3) << "debug zcb: client deconstructor begin\n";
std::cout << "debug zcb: client deconstructor begin\n";
// finalize_worker();
_running = false;
try {
_async_push_dense_thread.join();
_async_push_sparse_thread.join();
VLOG(3) << "debug zcb: client deconstructor done\n";
//_async_push_dense_thread.join();
//_async_push_sparse_thread.join();
std::cout << "debug zcb: client deconstructor done\n";
} catch (...) {
}
}
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/distributed/service/brpc_ps_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
return;
}

std::cout << "zcb debug service cmd_id--: " << request->cmd_id() << "\n";
// std::cout << "zcb debug service cmd_id: " << request->cmd_id() << "\n";
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->table(request->table_id());
Expand Down Expand Up @@ -203,6 +203,9 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_dense(res_data->data(), num);
VLOG(1) << "BrpcPsService::pull_dense num " << num << " data[0] "
<< res_data->data()[0] << " data[-2] " << res_data->data()[num - 2]
<< " data[-1] " << res_data->data()[num - 1];

cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
Expand Down Expand Up @@ -232,6 +235,9 @@ int32_t BrpcPsService::push_dense_param(Table *table,
uint32_t num = *(const uint32_t *)data;

const float *values = (const float *)(data + sizeof(uint32_t));
VLOG(1) << "BrpcPsService::push_dense_param num " << num << " data[0] "
<< values[0] << " data[-2] " << values[num - 2] << " data[-1] "
<< values[num - 1];
if (table->push_dense_param(values, num) != 0) {
set_response_code(response, -1, "push_dense_param failed");
}
Expand All @@ -257,6 +263,8 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
uint32_t num = *(const uint32_t *)(request.data().data());
const float *values =
(const float *)(request.data().data() + sizeof(uint32_t));
VLOG(1) << "BrpcPsService::push_dense num " << num << " data[0] " << values[0]
<< " data[-2] " << values[num - 2] << " data[-1] " << values[num - 1];
if (table->push_dense(values, num) != 0) {
set_response_code(response, -1, "push_dense failed");
}
Expand Down Expand Up @@ -405,7 +413,6 @@ int32_t BrpcPsService::push_sparse(Table *table,
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
std::cout << "debug zcb, server::push_sparse\n";
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/distributed/service/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());

// TODO: zcb del this later
float *temp_recv_data = tensor->mutable_data<float>(platform::CPUPlace());
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_recv_data[0]
<< " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1];
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor =
Expand Down Expand Up @@ -524,6 +530,13 @@ void AsyncCommunicator::SendByCommunicator() {
return;
}

void AsyncCommunicator::PushDensePostProcessing() {
if (independent_recv_) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
return;
}

void AsyncCommunicator::MainThread() {
VLOG(3) << "AsyncCommunicator MainThread start and wait";

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/service/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ class AsyncCommunicator : public Communicator {
void InitEnvs() {
independent_recv_ = static_cast<bool>(
std::stoi(envs.at("communicator_independent_recv_thread")));
std::cout << "debug zcb: communicator_independent_recv_thread " << independent_recv_ << "\n";
min_send_grad_num_before_recv_ =
std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
Expand Down Expand Up @@ -398,6 +399,8 @@ class AsyncCommunicator : public Communicator {

virtual void BarrierWeakUp() {}

void PushDensePostProcessing();

protected:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
Expand Down
26 changes: 14 additions & 12 deletions paddle/fluid/distributed/table/ctr_sparse_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,14 @@ int32_t CtrSparseTable::push_sparse(const uint64_t* keys, const float* values,
float* value_data = const_cast<float*>(feature_value->data());
size_t value_size = feature_value->size();

VLOG(3) << "push sparse, key: " << key << " value: ";
for (int i = 0; i < value_size; ++i)
VLOG(3) << value_data[i] << " ";
VLOG(3) << "\n";
VLOG(3) << "update_data: ";
for (int i = 0; i < update_value_col; ++i)
VLOG(3) << update_data[i] << " ";
VLOG(3) << "\n";
// VLOG(3) << "push sparse, key: " << key << " value: ";
// for (int i = 0; i < value_size; ++i)
// VLOG(3) << value_data[i] << " ";
// VLOG(3) << "\n";
// VLOG(3) << "update_data: ";
// for (int i = 0; i < update_value_col; ++i)
// VLOG(3) << update_data[i] << " ";
// VLOG(3) << "\n";

if (value_size == value_col) { //已拓展到最大size, 则就地update
_value_accesor->update(&value_data, &update_data, 1);
Expand All @@ -486,10 +486,12 @@ int32_t CtrSparseTable::push_sparse(const uint64_t* keys, const float* values,
}
memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
}
VLOG(3) << "after update key:" << key << "\n";
for (int i = 0; i < feature_value->size(); ++i)
VLOG(3) << value_data[i] << " ";
VLOG(3) << "\n";
/*
std::cout << "after update key:" << key << "\n";
for(int i = 0; i < feature_value->size(); ++ i)
std::cout << value_data[i] << " ";
std::cout << "\n";
*/
}
return 0;
});
Expand Down
Loading

0 comments on commit a88ed02

Please sign in to comment.