From 91d951f9fd3a6c0dc98a580485f4ec512a0848f8 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Sun, 11 Dec 2022 17:41:06 +0800 Subject: [PATCH] support return of degree (#188) --- paddle/fluid/framework/data_feed.cu | 40 ++++ paddle/fluid/framework/data_feed.h | 4 + paddle/fluid/framework/data_feed.proto | 1 + .../fleet/heter_ps/graph_gpu_ps_table.h | 7 + .../fleet/heter_ps/graph_gpu_ps_table_inl.cu | 180 ++++++++++++++++++ .../fleet/heter_ps/graph_gpu_wrapper.cu | 7 + .../fleet/heter_ps/graph_gpu_wrapper.h | 2 + python/paddle/fluid/dataset.py | 2 + 8 files changed, 243 insertions(+) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 35f0d0efde982..78b4862bd2e95 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -616,6 +616,10 @@ int GraphDataGenerator::FillGraphIdShowClkTensor(int uniq_instance, int index_offset = 3 + slot_num_ * 2 + 5 * samples_.size(); index_tensor_ptr_ = feed_vec_[index_offset]->mutable_data( {total_instance}, this->place_); + if (get_degree_) { + degree_tensor_ptr_ = feed_vec_[index_offset + 1]->mutable_data( + {uniq_instance * edge_to_id_len_}, this->place_); + } int len_samples = samples_.size(); int *num_nodes_tensor_ptr_[len_samples]; @@ -682,6 +686,13 @@ int GraphDataGenerator::FillGraphIdShowClkTensor(int uniq_instance, sizeof(int) * total_instance, cudaMemcpyDeviceToDevice, train_stream_); + if (get_degree_) { + cudaMemcpyAsync(degree_tensor_ptr_, + node_degree_vec_[index]->ptr(), + sizeof(int) * uniq_instance * edge_to_id_len_, + cudaMemcpyDeviceToDevice, + train_stream_); + } GraphFillCVMKernel<< GraphDataGenerator::GenerateSampleGraph( return final_nodes_vec[len_samples - 1]; } +std::shared_ptr GraphDataGenerator::GetNodeDegree( + uint64_t* node_ids, int len) { + auto node_degree = memory::AllocShared( + place_, + len * edge_to_id_len_ * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + auto edge_to_id = gpu_graph_ptr->edge_to_id; + for (auto& iter : edge_to_id) { + int edge_idx = iter.second; + gpu_graph_ptr->get_node_degree(gpuid_, edge_idx, node_ids, len, node_degree); + } + return node_degree; +} + uint64_t GraphDataGenerator::CopyUniqueNodes() { if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { uint64_t h_uniq_node_num = 0; @@ -1933,6 +1959,13 @@ void GraphDataGenerator::DoWalkandSage() { phi::Stream(reinterpret_cast(sample_stream_))); auto final_sage_nodes = GenerateSampleGraph( ins_cursor, total_instance, &uniq_instance, inverse); + uint64_t* final_sage_nodes_ptr = + reinterpret_cast(final_sage_nodes->ptr()); + if (get_degree_) { + auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, uniq_instance); + node_degree_vec_.emplace_back(node_degrees); + } + cudaStreamSynchronize(sample_stream_); if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { uint64_t *final_sage_nodes_ptr = reinterpret_cast(final_sage_nodes->ptr()); @@ -1983,6 +2016,12 @@ void GraphDataGenerator::DoWalkandSage() { phi::Stream(reinterpret_cast(sample_stream_))); auto final_sage_nodes = GenerateSampleGraph( node_buf_ptr_, total_instance, &uniq_instance, inverse); + uint64_t* final_sage_nodes_ptr = + reinterpret_cast(final_sage_nodes->ptr()); + if (get_degree_) { + auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, uniq_instance); + node_degree_vec_.emplace_back(node_degrees); + } cudaStreamSynchronize(sample_stream_); if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { uint64_t *final_sage_nodes_ptr = @@ -2857,6 +2896,7 @@ void GraphDataGenerator::SetConfig( once_sample_startid_len_ * walk_len_ * walk_degree_ * repeat_time_; train_table_cap_ = graph_config.train_table_cap(); infer_table_cap_ = graph_config.infer_table_cap(); + get_degree_ = graph_config.get_degree(); epoch_finish_ = false; VLOG(1) << "Confirm GraphConfig, walk_degree : " << walk_degree_ << ", walk_len : " << walk_len_ << ", window : " << window_ diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 89accff4f19e3..0a316a9313aa3 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -968,6 +968,7 @@ class GraphDataGenerator { int len, int* uniq_len, std::shared_ptr& inverse); // NOLINT + std::shared_ptr GetNodeDegree(uint64_t* node_ids, int len); int InsertTable(const uint64_t* d_keys, uint64_t len, std::shared_ptr d_uniq_node_num); @@ -990,6 +991,7 @@ class GraphDataGenerator { int* index_tensor_ptr_; int64_t* show_tensor_ptr_; int64_t* clk_tensor_ptr_; + int* degree_tensor_ptr_; cudaStream_t train_stream_; cudaStream_t sample_stream_; @@ -1037,6 +1039,7 @@ class GraphDataGenerator { // sage mode batch data std::vector> inverse_vec_; std::vector> final_sage_nodes_vec_; + std::vector> node_degree_vec_; std::vector uniq_instance_vec_; std::vector total_instance_vec_; std::vector>> graph_edges_vec_; @@ -1074,6 +1077,7 @@ class GraphDataGenerator { size_t infer_node_end_; std::set infer_node_type_index_set_; std::string infer_node_type_; + bool get_degree_; }; class DataFeed { diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index e20f1a0f1602d..7f81711b7c8e5 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -44,6 +44,7 @@ message GraphConfig { optional int64 infer_table_cap = 14 [ default = 80000 ]; optional string excluded_train_pair = 15; optional string infer_node_type = 16; + optional bool get_degree = 17 [ default = false ]; } message DataFeedDesc { diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index 50895c2645853..c90ddf6f04aa1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -109,6 +109,8 @@ class GpuPsGraphTable std::vector> edge_type_graphs); std::vector> get_edge_type_graph( int gpu_id, int edge_type_len); + void get_node_degree(int gpu_id, int edge_idx, uint64_t* key, int len, + std::shared_ptr node_degree); int get_feature_of_nodes(int gpu_id, uint64_t *d_walk, uint64_t *d_offset, @@ -146,6 +148,11 @@ class GpuPsGraphTable uint32_t *actual_feature_size, uint64_t *feature_list, uint8_t *slot_list); + void move_degree_to_source_gpu(int gpu_id, + int gpu_num, + int *h_left, + int *h_right, + int *node_degree); void move_result_to_source_gpu_all_edge_type(int gpu_id, int gpu_num, int sample_size, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 9df347fd070e0..bacdffe080e13 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -155,6 +155,15 @@ __global__ void get_features_kernel(GpuPsCommGraphFea graph, } } +__global__ void get_node_degree_kernel(GpuPsNodeInfo* node_info_list, + int* node_degree, + int n) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + node_degree[i] = node_info_list[i].neighbor_size; + } +} + template __global__ void neighbor_sample_kernel_walking(GpuPsCommGraph graph, GpuPsNodeInfo* node_info_list, @@ -455,6 +464,44 @@ void GpuPsGraphTable::move_result_to_source_gpu(int start_index, } } +void GpuPsGraphTable::move_degree_to_source_gpu(int start_index, + int gpu_num, + int* h_left, + int* h_right, + int* node_degree) { + int shard_len[gpu_num]; + for (int i = 0; i < gpu_num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + shard_len[i] = h_right[i] - h_left[i] + 1; + int cur_step = (int)path_[start_index][i].nodes_.size() - 1; + for (int j = cur_step; j > 0; j--) { + CUDA_CHECK( + cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, + path_[start_index][i].nodes_[j].val_storage, + path_[start_index][i].nodes_[j - 1].val_bytes_len, + cudaMemcpyDefault, + path_[start_index][i].nodes_[j - 1].out_stream)); + } + auto& node = path_[start_index][i].nodes_.front(); + CUDA_CHECK(cudaMemcpyAsync( + reinterpret_cast(node_degree + h_left[i]), + node.val_storage + sizeof(int64_t) * shard_len[i], + sizeof(int) * shard_len[i], + cudaMemcpyDefault, + node.out_stream)); + } + + for (int i = 0; i < gpu_num; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + auto& node = path_[start_index][i].nodes_.front(); + CUDA_CHECK(cudaStreamSynchronize(node.out_stream)); + } +} + void GpuPsGraphTable::move_result_to_source_gpu_all_edge_type( int start_index, int gpu_num, @@ -570,6 +617,16 @@ __global__ void fill_dvalues(uint64_t* d_shard_vals, } } +__global__ void fill_dvalues(int* d_shard_degree, + int* d_degree, + int* idx, + int len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_degree[idx[i]] = d_shard_degree[i]; + } +} + __global__ void fill_dvalues_with_edge_type(uint64_t* d_shard_vals, uint64_t* d_vals, int* d_shard_actual_sample_size, @@ -1538,6 +1595,129 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( return result; } +void GpuPsGraphTable::get_node_degree( + int gpu_id, int edge_idx, uint64_t* key, int len, + std::shared_ptr node_degree) { + int* node_degree_ptr = + reinterpret_cast(node_degree->ptr()) + edge_idx * len; + int total_gpu = resource_->total_device(); + platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); + platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); + auto stream = resource_->local_stream(gpu_id, 0); + int grid_size = (len - 1) / block_size_ + 1; + int h_left[total_gpu]; // NOLINT + int h_right[total_gpu]; // NOLINT + auto d_left = + memory::Alloc(place, + total_gpu * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + auto d_right = + memory::Alloc(place, + total_gpu * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + int* d_left_ptr = reinterpret_cast(d_left->ptr()); + int* d_right_ptr = reinterpret_cast(d_right->ptr()); + CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream)); + CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream)); + auto d_idx = + memory::Alloc(place, + len * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); + auto d_shard_keys = + memory::Alloc(place, + len * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); + uint64_t* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_degree = + memory::Alloc(place, + len * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + int* d_shard_degree_ptr = reinterpret_cast(d_shard_degree->ptr()); + split_input_to_shard( + (uint64_t*)(key), d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); + heter_comm_kernel_->fill_shard_key( + d_shard_keys_ptr, key, d_idx_ptr, len, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaMemcpyAsync(h_left, + d_left_ptr, + total_gpu * sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaMemcpyAsync(h_right, + d_right_ptr, + total_gpu * sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + device_mutex_[gpu_id]->lock(); + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + if (shard_len == 0) { + continue; + } + create_storage(gpu_id, + i, + shard_len * sizeof(uint64_t), + shard_len * sizeof(uint64_t) + + sizeof(int) * shard_len + shard_len % 2); + } + walk_to_dest( + gpu_id, total_gpu, h_left, h_right, (uint64_t*)(d_shard_keys_ptr), NULL); + for (int i = 0; i < total_gpu; ++i) { + if (h_left[i] == -1) { + continue; + } + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + auto& node = path_[gpu_id][i].nodes_.back(); + CUDA_CHECK(cudaMemsetAsync(node.val_storage, + 0, + shard_len * sizeof(uint64_t), + node.in_stream)); + CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, edge_idx); + tables_[table_offset]->get(reinterpret_cast(node.key_storage), + reinterpret_cast(node.val_storage), + (size_t)(h_right[i] - h_left[i] + 1), + resource_->remote_stream(i, gpu_id)); + GpuPsNodeInfo* node_info_list = + reinterpret_cast(node.val_storage); + int* node_degree_array = (int*)(node_info_list + shard_len); + int grid_size_ = (shard_len - 1) / block_size_ + 1; + get_node_degree_kernel<<< + grid_size_, block_size_, 0, resource_->remote_stream(i, gpu_id)>>>( + node_info_list, + node_degree_array, + shard_len); + } + for (int i = 0; i < total_gpu; ++i) { + if (h_left[i] == -1) { + continue; + } + CUDA_CHECK(cudaStreamSynchronize(resource_->remote_stream(i, gpu_id))); + } + move_degree_to_source_gpu(gpu_id, + total_gpu, + h_left, + h_right, + d_shard_degree_ptr); + fill_dvalues<<>>( + d_shard_degree_ptr, + node_degree_ptr, + d_idx_ptr, + len); + CUDA_CHECK(cudaStreamSynchronize(stream)); + for (int i = 0; i < total_gpu; i++) { + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + if (shard_len == 0) { + continue; + } + destroy_storage(gpu_id, i); + } + device_mutex_[gpu_id]->unlock(); +} + NodeQueryResult GpuPsGraphTable::graph_node_sample(int gpu_id, int sample_size) { return NodeQueryResult(); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 97fd5ca84495c..83d8d511dfd6d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -630,6 +630,13 @@ GraphGpuWrapper::get_edge_type_graph(int gpu_id, int edge_type_len) { ->get_edge_type_graph(gpu_id, edge_type_len); } +void GraphGpuWrapper::get_node_degree( + int gpu_id, int edge_idx, uint64_t* key, int len, + std::shared_ptr node_degree) { + return ((GpuPsGraphTable *)graph_table) + ->get_node_degree(gpu_id, edge_idx, key, len, node_degree); +} + int GraphGpuWrapper::get_feature_info_of_nodes( int gpu_id, uint64_t *d_nodes, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index a3b6225b87f99..a9d70a0b646e0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -121,6 +121,8 @@ class GraphGpuWrapper { int sample_size, int len, std::vector> edge_type_graphs); + void get_node_degree(int gpu_id, int edge_idx, uint64_t* key, int len, + std::shared_ptr node_degree); gpuStream_t get_local_stream(int gpuid); std::vector graph_neighbor_sample( int gpu_id, diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 3095fc591a1bd..764a225da5d30 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -1127,6 +1127,8 @@ def set_graph_config(self, config): "excluded_train_pair", "") self.proto_desc.graph_config.infer_node_type = config.get( "infer_node_type", "") + self.proto_desc.graph_config.get_degree = config.get( + "get_degree", False) self.dataset.set_gpu_graph_mode(True) def set_pass_id(self, pass_id):