Skip to content

Commit

Permalink
fix sage train hang && add dump neighbors (PaddlePaddle#306)
Browse files Browse the repository at this point in the history
* fix sage train hang

* add zero_key for deepwalk mode

* add dump_neighbors

* fix flag

* change vlog

* delete unused log

* delete unused code
  • Loading branch information
DesmonDay authored and danleifeng committed Sep 12, 2023
1 parent 8a88322 commit ff96cb3
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 47 deletions.
9 changes: 9 additions & 0 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2816,6 +2816,15 @@ void SlotRecordInMemoryDataFeed::DumpWalkPath(std::string dump_path,
#endif
}

void SlotRecordInMemoryDataFeed::DumpSampleNeighbors(std::string dump_path) {
VLOG(1) << "INTO SlotRecordInMemoryDataFeed::DumpSampleNeighbors";
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
std::string path =
string::format_string("%s/part-%03d", dump_path.c_str(), thread_id_);
gpu_graph_data_generator_.DumpSampleNeighbors(path);
#endif
}

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num,
MiniBatchGpuPack* pack) {
Expand Down
190 changes: 147 additions & 43 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,8 @@ int GraphDataGenerator::GenerateBatch() {
VLOG(1)
<< "reset buf state to make batch num equal in multi node";
} else {
VLOG(0) << "total row in buf state is 0";
VLOG(1) << "total row in buf state is 0";
// Fill 0 ins kernel
GraphZeroIdKernel<<<GET_BLOCKS(fill_zero_num), CUDA_NUM_THREADS, 0, train_stream_>>>(
reinterpret_cast<uint64_t *>(d_ins_buf_[tensor_pair_idx]->ptr()),
fill_zero_num);
Expand All @@ -1055,7 +1056,7 @@ int GraphDataGenerator::GenerateBatch() {
if (conf_.is_multi_node && total_row_[0] == 0) {
total_instance = fill_zero_num;
ins_buf_pair_len_[0] = fill_zero_num;
VLOG(2) << "gpu id: " << conf_.gpuid << "set total ins num: " << fill_zero_num;
VLOG(1) << "gpu id: " << conf_.gpuid << "set total ins num: " << fill_zero_num;
}

total_instance *= 2;
Expand Down Expand Up @@ -2614,6 +2615,40 @@ int multi_node_sync_sample(int flag,
return ret;
}

int get_multi_node_global_flag(int local_flag,
const ncclRedOp_t &op,
const paddle::platform::Place &place,
cudaStream_t stream) {
auto send_buff = memory::Alloc(
place,
2 * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int *send_buff_ptr = reinterpret_cast<int *>(send_buff->ptr());
cudaMemcpyAsync(send_buff_ptr,
&local_flag,
sizeof(int),
cudaMemcpyHostToDevice,
stream);
cudaStreamSynchronize(stream);
auto comm =
platform::NCCLCommContext::Instance().Get(0, place.GetDeviceId());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&send_buff_ptr[0],
&send_buff_ptr[1],
1,
ncclInt,
op,
comm->comm(),
stream));
int global_flag = 0;
cudaMemcpyAsync(&global_flag,
&send_buff_ptr[1],
sizeof(int),
cudaMemcpyDeviceToHost,
stream);
cudaStreamSynchronize(stream);
return global_flag;
}

int FillWalkBuf(const std::vector<uint64_t> &h_device_keys_len,
const std::vector<std::shared_ptr<phi::Allocation>>
&d_device_keys, // input
Expand Down Expand Up @@ -2977,6 +3012,7 @@ int FillWalkBuf(const std::vector<uint64_t> &h_device_keys_len,
<< ", row:" << *jump_rows_ptr << ", total_step:" << step
<< ", device_key_size:" << device_key_size;
}

platform::CUDADeviceGuard guard2(conf_.gpuid);
buf_state->Reset(*total_row_ptr);
paddle::memory::ThrustAllocator<cudaStream_t> allocator(place, stream);
Expand Down Expand Up @@ -3305,9 +3341,23 @@ void GraphDataGenerator::DoWalkandSage() {
platform::CUDADeviceGuard guard(conf_.gpuid);
sage_batch_num_ = 0;
if (conf_.gpu_graph_training) {
bool train_flag = DoWalkForTrain();
if (train_flag && conf_.sage_mode) {
DoSageForTrain();
int local_train_flag = DoWalkForTrain();
if (!conf_.is_multi_node) {
if (local_train_flag && conf_.sage_mode) {
DoSageForTrain();
}
} else {
if (conf_.sage_mode) {
global_train_flag_ = get_multi_node_global_flag(local_train_flag, ncclProd,
place_, sample_stream_);
VLOG(1) << "gpu_id: " << conf_.gpuid
<< ", local_train_flag: " << local_train_flag
<< ", global_train_flag: " << global_train_flag_;
if (global_train_flag_) {
// When global_train_flag is true, we need to go ahead in multi-node scenario.
DoSageForTrain();
}
}
}
} else {
bool infer_flag = DoWalkForInfer();
Expand Down Expand Up @@ -3391,6 +3441,7 @@ void GraphDataGenerator::DoSageForTrain() {
int sage_pass_end = 0;
uint64_t *ins_buf, *ins_cursor;
while (is_sage_pass_continue) {
int fill_zero_num = 10;
for (int tensor_pair_idx = 0;
tensor_pair_idx < conf_.tensor_pair_num && is_sage_pass_continue;
++tensor_pair_idx) {
Expand All @@ -3411,14 +3462,22 @@ void GraphDataGenerator::DoSageForTrain() {
reinterpret_cast<int *>(d_pair_num_[tensor_pair_idx]->ptr()),
&ins_buf_pair_len_[tensor_pair_idx],
sample_stream_);

if (res == -1) {
if (ins_buf_pair_len_[tensor_pair_idx] == 0) {
if (conf_.is_multi_node) {
sage_pass_end = 1;
if (total_row_[tensor_pair_idx] != 0) {
buf_state_[tensor_pair_idx].Reset(total_row_[tensor_pair_idx]);
VLOG(1) << "reset buf state to make batch num equal in "
VLOG(1) << conf_.gpuid << ": reset buf state to make batch num equal in "
"multi node";
} else {
VLOG(1) << conf_.gpuid << ": total row in buf state is 0";
GraphZeroIdKernel<<<GET_BLOCKS(fill_zero_num), CUDA_NUM_THREADS, 0, train_stream_>>>(
reinterpret_cast<uint64_t *>(d_ins_buf_[tensor_pair_idx]->ptr()),
fill_zero_num);
VLOG(1) << conf_.gpuid << ": end set seq ins";
break;
}
} else {
is_sage_pass_continue = false;
Expand All @@ -3433,15 +3492,22 @@ void GraphDataGenerator::DoSageForTrain() {
// check whether reach sage pass end
if (conf_.is_multi_node) {
int res = multi_node_sync_sample(
sage_pass_end, ncclProd, place_, &multi_node_sync_stat_);
sage_pass_end, ncclMax, place_, &multi_node_sync_stat_);
VLOG(1) << conf_.gpuid << " get global sage_pass_end: " << res;
if (res) {
VLOG(1) << conf_.gpuid << ": reach sage pass end";
is_sage_pass_continue = false;
break;
}
}

total_instance = ins_buf_pair_len_[tensor_pair_idx] < conf_.batch_size ?
ins_buf_pair_len_[tensor_pair_idx] : conf_.batch_size;
if (conf_.is_multi_node && total_row_[0] == 0) {
total_instance = fill_zero_num;
ins_buf_pair_len_[0] = fill_zero_num;
VLOG(1) << "gpu id: " << conf_.gpuid << " set total ins num: " << fill_zero_num;
}
total_instance *= 2;

if (total_instance == 0) {
Expand Down Expand Up @@ -3506,7 +3572,9 @@ void GraphDataGenerator::DoSageForTrain() {
if (is_sage_pass_continue) {
sage_batch_num_ += 1;
}
} // end while (is_sage_pass_continue)
} // end while (is_sage_pass_continue)
VLOG(1) << "gpuid: " << conf_.gpuid
<< " train_sage_batch_num: " << sage_batch_num_;
}

void GraphDataGenerator::DoSageForInfer() {
Expand Down Expand Up @@ -3580,48 +3648,16 @@ void GraphDataGenerator::DoSageForInfer() {
sage_batch_num_ += 1;
} // end while (total_instance != 0)
} // end for (int tensor_pair_idx = 0; tensor_pair_idx < conf_.tensor_pair_num;

sage_batch_num_ /= conf_.tensor_pair_num;
VLOG(1) << "gpuid: " << conf_.gpuid
<< " infer_sage_batch_num: " << sage_batch_num_;
}

void GraphDataGenerator::clear_gpu_mem() {
platform::CUDADeviceGuard guard(conf_.gpuid);
delete table_;
}

int dynamic_adjust_total_row_for_infer(int local_reach_end,
const paddle::platform::Place &place,
cudaStream_t stream) {
auto send_buff = memory::Alloc(
place,
2 * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int *send_buff_ptr = reinterpret_cast<int *>(send_buff->ptr());
cudaMemcpyAsync(send_buff_ptr,
&local_reach_end,
sizeof(int),
cudaMemcpyHostToDevice,
stream);
cudaStreamSynchronize(stream);
auto comm =
platform::NCCLCommContext::Instance().Get(0, place.GetDeviceId());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&send_buff_ptr[0],
&send_buff_ptr[1],
1,
ncclInt,
ncclProd,
comm->comm(),
stream));
int global_reach_end = 0;
cudaMemcpyAsync(&global_reach_end,
&send_buff_ptr[1],
sizeof(int),
cudaMemcpyDeviceToHost,
stream);
cudaStreamSynchronize(stream);
return global_reach_end;
}

bool FillInferBuf(const std::vector<uint64_t> &h_device_keys_len, // input
const std::vector<std::shared_ptr<phi::Allocation>> &d_device_keys,
const GraphDataGeneratorConfig &conf,
Expand Down Expand Up @@ -3688,7 +3724,8 @@ bool FillInferBuf(const std::vector<uint64_t> &h_device_keys_len, // input
if (conf.is_multi_node) {
int local_reach_end = global_infer_node_type_start[infer_cursor] + conf.buf_size >=
device_key_size;
int global_reach_end = dynamic_adjust_total_row_for_infer(local_reach_end, place, stream);
int global_reach_end = get_multi_node_global_flag(local_reach_end, ncclProd,
place, stream);
int remain = device_key_size - global_infer_node_type_start[infer_cursor];
if (global_reach_end) {
*total_row_ptr = remain;
Expand Down Expand Up @@ -4200,6 +4237,73 @@ void GraphDataGenerator::DumpWalkPath(std::string dump_path, size_t dump_rate) {
}
write_count = fwrite_unlocked("\n", 1, 1, fp.get());
}
delete[] h_walk;
#endif
}

void GraphDataGenerator::DumpSampleNeighbors(std::string dump_path) {
#ifdef _LINUX
int err_no = 0;
int part_num = rand() % 100; // set 100 part files
std::string part_path =
string::format_string("%s-%03d", dump_path.c_str(), part_num);
std::shared_ptr<FILE> fp = fs_open_append_write(part_path, &err_no, "");
for (int i = 0; i < sage_batch_num_; i++) {
int uniq_instance = uniq_instance_vec_[i];
uint64_t *h_id_tensor = new uint64_t[uniq_instance];
cudaMemcpy(h_id_tensor,
final_sage_nodes_vec_[i]->ptr(),
sizeof(uint64_t) * uniq_instance,
cudaMemcpyDeviceToHost);
std::string ss = "id:";
for (int xx = 0; xx < uniq_instance; xx++) {
ss += std::to_string(h_id_tensor[xx]) + ",";
}
ss += "\t";
int len_samples = conf_.samples.size();
std::vector<std::vector<int>> edges_split_num_for_graph =
edges_split_num_vec_[i];
std::vector<std::shared_ptr<phi::Allocation>> graph_edges =
graph_edges_vec_[i];
int graph_edges_index = 0;
for (int j = 0; j < len_samples; j++) {
ss += std::to_string(j) + ":[";
std::vector<int> edges_split_num = edges_split_num_for_graph[j];
int neighbor_len = edges_split_num[conf_.edge_to_id_len + 2];
int64_t *h_edge_src_tensor = new int64_t[neighbor_len];
int64_t *h_edge_dst_tensor = new int64_t[neighbor_len];
cudaMemcpy(h_edge_src_tensor,
graph_edges[graph_edges_index++]->ptr(),
sizeof(int64_t) * neighbor_len,
cudaMemcpyDeviceToHost);
cudaMemcpy(h_edge_dst_tensor,
graph_edges[graph_edges_index++]->ptr(),
sizeof(int64_t) * neighbor_len,
cudaMemcpyDeviceToHost);
ss += "src:";
for (int yy = 0; yy < neighbor_len; yy++) {
ss += std::to_string(h_edge_src_tensor[yy]) + ",";
}
ss += "\tdst:";
for (int yy = 0; yy < neighbor_len; yy++) {
ss += std::to_string(h_edge_dst_tensor[yy]) + ",";
}
ss += "\tsplit:";
for (int yy = 0; yy < conf_.edge_to_id_len; yy++) {
ss += std::to_string(edges_split_num[yy]) + ",";
}
ss += "]\t";

delete[] h_edge_src_tensor;
delete[] h_edge_dst_tensor;
}
size_t write_count = fwrite_unlocked(ss.data(), 1, ss.length(), fp.get());
if (write_count != ss.length()) {
VLOG(1) << "dump sample neighbors: " << ss << " failed!";
}
write_count = fwrite_unlocked("\n", 1, 1, fp.get());
delete[] h_id_tensor;
}
#endif
}

Expand Down
Loading

0 comments on commit ff96cb3

Please sign in to comment.