Skip to content

Commit

Permalink
support return of degree (PaddlePaddle#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored and zmxdream committed Dec 24, 2022
1 parent 20852e4 commit 91d951f
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 0 deletions.
40 changes: 40 additions & 0 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ int GraphDataGenerator::FillGraphIdShowClkTensor(int uniq_instance,
int index_offset = 3 + slot_num_ * 2 + 5 * samples_.size();
index_tensor_ptr_ = feed_vec_[index_offset]->mutable_data<int>(
{total_instance}, this->place_);
if (get_degree_) {
degree_tensor_ptr_ = feed_vec_[index_offset + 1]->mutable_data<int>(
{uniq_instance * edge_to_id_len_}, this->place_);
}

int len_samples = samples_.size();
int *num_nodes_tensor_ptr_[len_samples];
Expand Down Expand Up @@ -682,6 +686,13 @@ int GraphDataGenerator::FillGraphIdShowClkTensor(int uniq_instance,
sizeof(int) * total_instance,
cudaMemcpyDeviceToDevice,
train_stream_);
if (get_degree_) {
cudaMemcpyAsync(degree_tensor_ptr_,
node_degree_vec_[index]->ptr(),
sizeof(int) * uniq_instance * edge_to_id_len_,
cudaMemcpyDeviceToDevice,
train_stream_);
}
GraphFillCVMKernel<<<GET_BLOCKS(uniq_instance),
CUDA_NUM_THREADS,
0,
Expand Down Expand Up @@ -1843,6 +1854,21 @@ std::shared_ptr<phi::Allocation> GraphDataGenerator::GenerateSampleGraph(
return final_nodes_vec[len_samples - 1];
}

std::shared_ptr<phi::Allocation> GraphDataGenerator::GetNodeDegree(
uint64_t* node_ids, int len) {
auto node_degree = memory::AllocShared(
place_,
len * edge_to_id_len_ * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
auto edge_to_id = gpu_graph_ptr->edge_to_id;
for (auto& iter : edge_to_id) {
int edge_idx = iter.second;
gpu_graph_ptr->get_node_degree(gpuid_, edge_idx, node_ids, len, node_degree);
}
return node_degree;
}

uint64_t GraphDataGenerator::CopyUniqueNodes() {
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
uint64_t h_uniq_node_num = 0;
Expand Down Expand Up @@ -1933,6 +1959,13 @@ void GraphDataGenerator::DoWalkandSage() {
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
auto final_sage_nodes = GenerateSampleGraph(
ins_cursor, total_instance, &uniq_instance, inverse);
uint64_t* final_sage_nodes_ptr =
reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
if (get_degree_) {
auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, uniq_instance);
node_degree_vec_.emplace_back(node_degrees);
}
cudaStreamSynchronize(sample_stream_);
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
uint64_t *final_sage_nodes_ptr =
reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
Expand Down Expand Up @@ -1983,6 +2016,12 @@ void GraphDataGenerator::DoWalkandSage() {
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
auto final_sage_nodes = GenerateSampleGraph(
node_buf_ptr_, total_instance, &uniq_instance, inverse);
uint64_t* final_sage_nodes_ptr =
reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
if (get_degree_) {
auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, uniq_instance);
node_degree_vec_.emplace_back(node_degrees);
}
cudaStreamSynchronize(sample_stream_);
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
uint64_t *final_sage_nodes_ptr =
Expand Down Expand Up @@ -2857,6 +2896,7 @@ void GraphDataGenerator::SetConfig(
once_sample_startid_len_ * walk_len_ * walk_degree_ * repeat_time_;
train_table_cap_ = graph_config.train_table_cap();
infer_table_cap_ = graph_config.infer_table_cap();
get_degree_ = graph_config.get_degree();
epoch_finish_ = false;
VLOG(1) << "Confirm GraphConfig, walk_degree : " << walk_degree_
<< ", walk_len : " << walk_len_ << ", window : " << window_
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ class GraphDataGenerator {
int len,
int* uniq_len,
std::shared_ptr<phi::Allocation>& inverse); // NOLINT
std::shared_ptr<phi::Allocation> GetNodeDegree(uint64_t* node_ids, int len);
int InsertTable(const uint64_t* d_keys,
uint64_t len,
std::shared_ptr<phi::Allocation> d_uniq_node_num);
Expand All @@ -990,6 +991,7 @@ class GraphDataGenerator {
int* index_tensor_ptr_;
int64_t* show_tensor_ptr_;
int64_t* clk_tensor_ptr_;
int* degree_tensor_ptr_;

cudaStream_t train_stream_;
cudaStream_t sample_stream_;
Expand Down Expand Up @@ -1037,6 +1039,7 @@ class GraphDataGenerator {
// sage mode batch data
std::vector<std::shared_ptr<phi::Allocation>> inverse_vec_;
std::vector<std::shared_ptr<phi::Allocation>> final_sage_nodes_vec_;
std::vector<std::shared_ptr<phi::Allocation>> node_degree_vec_;
std::vector<int> uniq_instance_vec_;
std::vector<int> total_instance_vec_;
std::vector<std::vector<std::shared_ptr<phi::Allocation>>> graph_edges_vec_;
Expand Down Expand Up @@ -1074,6 +1077,7 @@ class GraphDataGenerator {
size_t infer_node_end_;
std::set<int> infer_node_type_index_set_;
std::string infer_node_type_;
bool get_degree_;
};

class DataFeed {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/data_feed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ message GraphConfig {
optional int64 infer_table_cap = 14 [ default = 80000 ];
optional string excluded_train_pair = 15;
optional string infer_node_type = 16;
optional bool get_degree = 17 [ default = false ];
}

message DataFeedDesc {
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class GpuPsGraphTable
std::vector<std::shared_ptr<phi::Allocation>> edge_type_graphs);
std::vector<std::shared_ptr<phi::Allocation>> get_edge_type_graph(
int gpu_id, int edge_type_len);
void get_node_degree(int gpu_id, int edge_idx, uint64_t* key, int len,
std::shared_ptr<phi::Allocation> node_degree);
int get_feature_of_nodes(int gpu_id,
uint64_t *d_walk,
uint64_t *d_offset,
Expand Down Expand Up @@ -146,6 +148,11 @@ class GpuPsGraphTable
uint32_t *actual_feature_size,
uint64_t *feature_list,
uint8_t *slot_list);
void move_degree_to_source_gpu(int gpu_id,
int gpu_num,
int *h_left,
int *h_right,
int *node_degree);
void move_result_to_source_gpu_all_edge_type(int gpu_id,
int gpu_num,
int sample_size,
Expand Down
180 changes: 180 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ __global__ void get_features_kernel(GpuPsCommGraphFea graph,
}
}

__global__ void get_node_degree_kernel(GpuPsNodeInfo* node_info_list,
int* node_degree,
int n) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
node_degree[i] = node_info_list[i].neighbor_size;
}
}

template <int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample_kernel_walking(GpuPsCommGraph graph,
GpuPsNodeInfo* node_info_list,
Expand Down Expand Up @@ -455,6 +464,44 @@ void GpuPsGraphTable::move_result_to_source_gpu(int start_index,
}
}

void GpuPsGraphTable::move_degree_to_source_gpu(int start_index,
int gpu_num,
int* h_left,
int* h_right,
int* node_degree) {
int shard_len[gpu_num];
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
shard_len[i] = h_right[i] - h_left[i] + 1;
int cur_step = (int)path_[start_index][i].nodes_.size() - 1;
for (int j = cur_step; j > 0; j--) {
CUDA_CHECK(
cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage,
path_[start_index][i].nodes_[j].val_storage,
path_[start_index][i].nodes_[j - 1].val_bytes_len,
cudaMemcpyDefault,
path_[start_index][i].nodes_[j - 1].out_stream));
}
auto& node = path_[start_index][i].nodes_.front();
CUDA_CHECK(cudaMemcpyAsync(
reinterpret_cast<char*>(node_degree + h_left[i]),
node.val_storage + sizeof(int64_t) * shard_len[i],
sizeof(int) * shard_len[i],
cudaMemcpyDefault,
node.out_stream));
}

for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto& node = path_[start_index][i].nodes_.front();
CUDA_CHECK(cudaStreamSynchronize(node.out_stream));
}
}

void GpuPsGraphTable::move_result_to_source_gpu_all_edge_type(
int start_index,
int gpu_num,
Expand Down Expand Up @@ -570,6 +617,16 @@ __global__ void fill_dvalues(uint64_t* d_shard_vals,
}
}

__global__ void fill_dvalues(int* d_shard_degree,
int* d_degree,
int* idx,
int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_degree[idx[i]] = d_shard_degree[i];
}
}

__global__ void fill_dvalues_with_edge_type(uint64_t* d_shard_vals,
uint64_t* d_vals,
int* d_shard_actual_sample_size,
Expand Down Expand Up @@ -1538,6 +1595,129 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type(
return result;
}
void GpuPsGraphTable::get_node_degree(
int gpu_id, int edge_idx, uint64_t* key, int len,
std::shared_ptr<phi::Allocation> node_degree) {
int* node_degree_ptr =
reinterpret_cast<int *>(node_degree->ptr()) + edge_idx * len;
int total_gpu = resource_->total_device();
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
auto stream = resource_->local_stream(gpu_id, 0);
int grid_size = (len - 1) / block_size_ + 1;
int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; // NOLINT
auto d_left =
memory::Alloc(place,
total_gpu * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
auto d_right =
memory::Alloc(place,
total_gpu * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream));
CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream));
auto d_idx =
memory::Alloc(place,
len * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
auto d_shard_keys =
memory::Alloc(place,
len * sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
uint64_t* d_shard_keys_ptr = reinterpret_cast<uint64_t*>(d_shard_keys->ptr());
auto d_shard_degree =
memory::Alloc(place,
len * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_shard_degree_ptr = reinterpret_cast<int *>(d_shard_degree->ptr());
split_input_to_shard(
(uint64_t*)(key), d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
heter_comm_kernel_->fill_shard_key(
d_shard_keys_ptr, key, d_idx_ptr, len, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaMemcpyAsync(h_left,
d_left_ptr,
total_gpu * sizeof(int),
cudaMemcpyDeviceToHost,
stream));
CUDA_CHECK(cudaMemcpyAsync(h_right,
d_right_ptr,
total_gpu * sizeof(int),
cudaMemcpyDeviceToHost,
stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
device_mutex_[gpu_id]->lock();
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
create_storage(gpu_id,
i,
shard_len * sizeof(uint64_t),
shard_len * sizeof(uint64_t) +
sizeof(int) * shard_len + shard_len % 2);
}
walk_to_dest(
gpu_id, total_gpu, h_left, h_right, (uint64_t*)(d_shard_keys_ptr), NULL);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
auto& node = path_[gpu_id][i].nodes_.back();
CUDA_CHECK(cudaMemsetAsync(node.val_storage,
0,
shard_len * sizeof(uint64_t),
node.in_stream));
CUDA_CHECK(cudaStreamSynchronize(node.in_stream));
platform::CUDADeviceGuard guard(resource_->dev_id(i));
int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, edge_idx);
tables_[table_offset]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<uint64_t*>(node.val_storage),
(size_t)(h_right[i] - h_left[i] + 1),
resource_->remote_stream(i, gpu_id));
GpuPsNodeInfo* node_info_list =
reinterpret_cast<GpuPsNodeInfo*>(node.val_storage);
int* node_degree_array = (int*)(node_info_list + shard_len);
int grid_size_ = (shard_len - 1) / block_size_ + 1;
get_node_degree_kernel<<<
grid_size_, block_size_, 0, resource_->remote_stream(i, gpu_id)>>>(
node_info_list,
node_degree_array,
shard_len);
}
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
CUDA_CHECK(cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)));
}
move_degree_to_source_gpu(gpu_id,
total_gpu,
h_left,
h_right,
d_shard_degree_ptr);
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
d_shard_degree_ptr,
node_degree_ptr,
d_idx_ptr,
len);
CUDA_CHECK(cudaStreamSynchronize(stream));
for (int i = 0; i < total_gpu; i++) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
destroy_storage(gpu_id, i);
}
device_mutex_[gpu_id]->unlock();
}
NodeQueryResult GpuPsGraphTable::graph_node_sample(int gpu_id,
int sample_size) {
return NodeQueryResult();
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,13 @@ GraphGpuWrapper::get_edge_type_graph(int gpu_id, int edge_type_len) {
->get_edge_type_graph(gpu_id, edge_type_len);
}

void GraphGpuWrapper::get_node_degree(
int gpu_id, int edge_idx, uint64_t* key, int len,
std::shared_ptr<phi::Allocation> node_degree) {
return ((GpuPsGraphTable *)graph_table)
->get_node_degree(gpu_id, edge_idx, key, len, node_degree);
}

int GraphGpuWrapper::get_feature_info_of_nodes(
int gpu_id,
uint64_t *d_nodes,
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class GraphGpuWrapper {
int sample_size,
int len,
std::vector<std::shared_ptr<phi::Allocation>> edge_type_graphs);
void get_node_degree(int gpu_id, int edge_idx, uint64_t* key, int len,
std::shared_ptr<phi::Allocation> node_degree);
gpuStream_t get_local_stream(int gpuid);
std::vector<uint64_t> graph_neighbor_sample(
int gpu_id,
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,8 @@ def set_graph_config(self, config):
"excluded_train_pair", "")
self.proto_desc.graph_config.infer_node_type = config.get(
"infer_node_type", "")
self.proto_desc.graph_config.get_degree = config.get(
"get_degree", False)
self.dataset.set_gpu_graph_mode(True)

def set_pass_id(self, pass_id):
Expand Down

0 comments on commit 91d951f

Please sign in to comment.