Skip to content

Commit

Permalink
refactor GraphTable (PaddlePaddle#201)
Browse files Browse the repository at this point in the history
Co-authored-by: root <[email protected]>
  • Loading branch information
huwei02 and root authored Jan 13, 2023
1 parent ca67b11 commit cf23031
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 184 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/ps/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
GraphTableType type_id = *(GraphTableType*)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
((GraphTable *)table)->clear_nodes(type_id, idx_);
return 0;
Expand Down Expand Up @@ -375,7 +375,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
response, -1, "pull_graph_list request requires at least 5 arguments");
return 0;
}
int type_id = *(int *)(request.params(0).c_str());
GraphTableType type_id = *(GraphTableType*)(request.params(0).c_str());
int idx = *(int *)(request.params(1).c_str());
int start = *(int *)(request.params(2).c_str());
int size = *(int *)(request.params(3).c_str());
Expand Down Expand Up @@ -425,7 +425,7 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
GraphTableType type_id = *(GraphTableType*)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
size_t size = *(uint64_t *)(request.params(2).c_str());
// size_t size = *(int64_t *)(request.params(0).c_str());
Expand Down
65 changes: 32 additions & 33 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
paddle::framework::GpuPsFeaInfo x;
std::vector<uint64_t> feature_ids;
for (size_t j = 0; j < bags[i].size(); j++) {
// TODO use FEATURE_TABLE instead
Node *v = find_node(1, bags[i][j]);
Node *v = find_node(GraphTableType::FEATURE_TABLE, bags[i][j]);
node_id = bags[i][j];
if (v == NULL) {
x.feature_size = 0;
Expand Down Expand Up @@ -193,7 +192,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
for (size_t j = 0; j < bags[i].size(); j++) {
auto node_id = bags[i][j];
node_array[i][j] = node_id;
Node *v = find_node(0, idx, node_id);
Node *v = find_node(GraphTableType::EDGE_TABLE, idx, node_id);
if (v != nullptr) {
info_array[i][j].neighbor_offset = edge_array[i].size();
info_array[i][j].neighbor_size = v->get_neighbor_size();
Expand Down Expand Up @@ -1472,14 +1471,14 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype2files,
}

int32_t GraphTable::get_nodes_ids_by_ranges(
int type_id,
GraphTableType table_type,
int idx,
std::vector<std::pair<int, int>> ranges,
std::vector<uint64_t> &res) {
std::mutex mutex;
int start = 0, end, index = 0, total_size = 0;
res.clear();
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks;
for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) {
end = total_size + shards[i]->get_size();
Expand Down Expand Up @@ -1829,14 +1828,14 @@ int32_t GraphTable::load_edges(const std::string &path,
return 0;
}

Node *GraphTable::find_node(int type_id, uint64_t id) {
Node *GraphTable::find_node(GraphTableType table_type, uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
Node *node = nullptr;
size_t index = shard_id - shard_start;
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards : feature_shards;
for (auto &search_shard : search_shards) {
PADDLE_ENFORCE_NOT_NULL(search_shard[index],
paddle::platform::errors::InvalidArgument(
Expand All @@ -1849,13 +1848,13 @@ Node *GraphTable::find_node(int type_id, uint64_t id) {
return node;
}

Node *GraphTable::find_node(int type_id, int idx, uint64_t id) {
Node *GraphTable::find_node(GraphTableType table_type, int idx, uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
size_t index = shard_id - shard_start;
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx] : feature_shards[idx];
PADDLE_ENFORCE_NOT_NULL(search_shards[index],
paddle::platform::errors::InvalidArgument(
"search_shard[%d] should not be null.", index));
Expand All @@ -1871,21 +1870,21 @@ uint32_t GraphTable::get_thread_pool_index_by_shard_index(
return shard_index % shard_num_per_server % task_pool_size_;
}

int32_t GraphTable::clear_nodes(int type_id, int idx) {
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
int32_t GraphTable::clear_nodes(GraphTableType table_type, int idx) {
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx] : feature_shards[idx];
for (size_t i = 0; i < search_shards.size(); i++) {
search_shards[i]->clear();
}
return 0;
}

int32_t GraphTable::random_sample_nodes(int type_id,
int32_t GraphTable::random_sample_nodes(GraphTableType table_type,
int idx,
int sample_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
int total_size = 0;
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx] : feature_shards[idx];
for (int i = 0; i < (int)shards.size(); i++) {
total_size += shards[i]->get_size();
}
Expand Down Expand Up @@ -1940,7 +1939,7 @@ int32_t GraphTable::random_sample_nodes(int type_id,
}
for (auto &pair : first_half) second_half.push_back(pair);
std::vector<uint64_t> res;
get_nodes_ids_by_ranges(type_id, idx, second_half, res);
get_nodes_ids_by_ranges(table_type, idx, second_half, res);
actual_size = res.size() * sizeof(uint64_t);
buffer.reset(new char[actual_size]);
char *pointer = buffer.get();
Expand Down Expand Up @@ -1989,7 +1988,7 @@ int32_t GraphTable::random_sample_neighbors(
index++;
} else {
node_id = id_list[i][k].node_key;
Node *node = find_node(0, idx, node_id);
Node *node = find_node(GraphTableType::EDGE_TABLE, idx, node_id);
int idy = seq_id[i][k];
int &actual_size = actual_sizes[idy];
if (node == nullptr) {
Expand Down Expand Up @@ -2060,7 +2059,7 @@ int32_t GraphTable::get_node_feat(int idx,
uint64_t node_id = node_ids[idy];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, idy, node_id]() -> int {
Node *node = find_node(1, idx, node_id);
Node *node = find_node(GraphTableType::FEATURE_TABLE, idx, node_id);

if (node == nullptr) {
return 0;
Expand Down Expand Up @@ -2259,11 +2258,11 @@ class MergeShardVector {
std::vector<std::vector<uint64_t>> *_shard_keys;
};

int GraphTable::get_all_id(int type_id,
int GraphTable::get_all_id(GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards : feature_shards;
std::vector<std::future<size_t>> tasks;
for (size_t idx = 0; idx < search_shards.size(); idx++) {
for (size_t j = 0; j < search_shards[idx].size(); j++) {
Expand All @@ -2285,9 +2284,9 @@ int GraphTable::get_all_id(int type_id,
}

int GraphTable::get_all_neighbor_id(
int type_id, int slice_num, std::vector<std::vector<uint64_t>> *output) {
GraphTableType table_type, int slice_num, std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards : feature_shards;
std::vector<std::future<size_t>> tasks;
for (size_t idx = 0; idx < search_shards.size(); idx++) {
for (size_t j = 0; j < search_shards[idx].size(); j++) {
Expand All @@ -2308,12 +2307,12 @@ int GraphTable::get_all_neighbor_id(
return 0;
}

int GraphTable::get_all_id(int type_id,
int GraphTable::get_all_id(GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks;
VLOG(3) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
for (size_t i = 0; i < search_shards.size(); i++) {
Expand All @@ -2334,12 +2333,12 @@ int GraphTable::get_all_id(int type_id,
}

int GraphTable::get_all_neighbor_id(
int type_id,
GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks;
VLOG(3) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
for (size_t i = 0; i < search_shards.size(); i++) {
Expand All @@ -2361,12 +2360,12 @@ int GraphTable::get_all_neighbor_id(
}

int GraphTable::get_all_feature_ids(
int type_id,
GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks;
for (size_t i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
Expand All @@ -2388,14 +2387,14 @@ int GraphTable::get_all_feature_ids(
int GraphTable::get_node_embedding_ids(
int slice_num, std::vector<std::vector<uint64_t>> *output) {
if (is_load_reverse_edge and !FLAGS_graph_get_neighbor_id) {
return get_all_id(0, slice_num, output);
return get_all_id(GraphTableType::EDGE_TABLE, slice_num, output);
} else {
get_all_id(0, slice_num, output);
return get_all_neighbor_id(0, slice_num, output);
get_all_id(GraphTableType::EDGE_TABLE, slice_num, output);
return get_all_neighbor_id(GraphTableType::EDGE_TABLE, slice_num, output);
}
}

int32_t GraphTable::pull_graph_list(int type_id,
int32_t GraphTable::pull_graph_list(GraphTableType table_type,
int idx,
int start,
int total_size,
Expand All @@ -2405,7 +2404,7 @@ int32_t GraphTable::pull_graph_list(int type_id,
int step) {
if (start < 0) start = 0;
int size = 0, cur_size;
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<Node *>>> tasks;
for (size_t i = 0; i < search_shards.size() && total_size > 0; i++) {
cur_size = search_shards[i]->get_size();
Expand Down Expand Up @@ -2634,7 +2633,7 @@ void GraphTable::build_graph_type_keys() {
for (auto &it : this->feature_to_id) {
auto node_idx = it.second;
std::vector<std::vector<uint64_t>> keys;
this->get_all_id(1, node_idx, 1, &keys);
this->get_all_id(GraphTableType::FEATURE_TABLE, node_idx, 1, &keys);
type_to_index_[node_idx] = cnt;
graph_type_keys_[cnt++] = std::move(keys[0]);
}
Expand All @@ -2645,7 +2644,7 @@ void GraphTable::build_graph_type_keys() {
for (auto &it : this->feature_to_id) {
auto node_idx = it.second;
std::vector<std::vector<uint64_t>> keys;
this->get_all_feature_ids(1, node_idx, 1, &keys);
this->get_all_feature_ids(GraphTableType::FEATURE_TABLE, node_idx, 1, &keys);
graph_total_keys_.insert(
graph_total_keys_.end(), keys[0].begin(), keys[0].end());
}
Expand Down
24 changes: 13 additions & 11 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ class GraphSampler {
#endif
*/

enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };

class GraphTable : public Table {
public:
GraphTable() {
Expand Down Expand Up @@ -524,7 +526,7 @@ class GraphTable : public Table {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}

virtual int32_t pull_graph_list(int type_id,
virtual int32_t pull_graph_list(GraphTableType table_type,
int idx,
int start,
int size,
Expand All @@ -541,14 +543,14 @@ class GraphTable : public Table {
std::vector<int> &actual_sizes,
bool need_weight);

int32_t random_sample_nodes(int type_id,
int32_t random_sample_nodes(GraphTableType table_type,
int idx,
int sample_size,
std::unique_ptr<char[]> &buffers,
int &actual_sizes);

virtual int32_t get_nodes_ids_by_ranges(
int type_id,
GraphTableType table_type,
int idx,
std::vector<std::pair<int, int>> ranges,
std::vector<uint64_t> &res);
Expand Down Expand Up @@ -581,21 +583,21 @@ class GraphTable : public Table {
int32_t load_edges(const std::string &path,
bool reverse,
const std::string &edge_type);
int get_all_id(int type,
int get_all_id(GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_neighbor_id(int type,
int get_all_neighbor_id(GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_id(int type,
int get_all_id(GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_neighbor_id(int type_id,
int get_all_neighbor_id(GraphTableType table_type,
int id,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_feature_ids(int type,
int get_all_feature_ids(GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
Expand All @@ -617,13 +619,13 @@ class GraphTable : public Table {
int32_t remove_graph_node(int idx, std::vector<uint64_t> &id_list);

int32_t get_server_index_by_id(uint64_t id);
Node *find_node(int type_id, int idx, uint64_t id);
Node *find_node(int type_id, uint64_t id);
Node *find_node(GraphTableType table_type, int idx, uint64_t id);
Node *find_node(GraphTableType table_type, uint64_t id);

virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }

virtual int32_t clear_nodes(int type, int idx);
virtual int32_t clear_nodes(GraphTableType table_type, int idx);
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Shrink(const std::string &param) { return 0; }
Expand Down
15 changes: 11 additions & 4 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,24 @@ DECLARE_double(gpugraph_hbm_table_load_factor);

namespace paddle {
namespace framework {
enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };

typedef paddle::distributed::GraphTableType GraphTableType;

class GpuPsGraphTable
: public HeterComm<uint64_t, uint64_t, int, CommonFeatureValueAccessor> {
public:
int get_table_offset(int gpu_id, GraphTableType type, int idx) const {
inline int get_table_offset(int gpu_id, GraphTableType type, int idx) const {
int type_id = type;
return gpu_id * (graph_table_num_ + feature_table_num_) +
type_id * graph_table_num_ + idx;
}
inline int get_graph_list_offset(int gpu_id, int edge_idx) const {
return gpu_id * graph_table_num_ + edge_idx;
}
inline int get_graph_fea_list_offset(int gpu_id) const {
return gpu_id * feature_table_num_;
}

GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource,
int graph_table_num)
: HeterComm<uint64_t, uint64_t, int, CommonFeatureValueAccessor>(
Expand Down Expand Up @@ -83,8 +92,6 @@ class GpuPsGraphTable
void clear_feature_info(int index);
void build_graph_from_cpu(const std::vector<GpuPsCommGraph> &cpu_node_list,
int idx);
void build_graph_fea_from_cpu(
const std::vector<GpuPsCommGraphFea> &cpu_node_list, int idx);
NodeQueryResult graph_node_sample(int gpu_id, int sample_size);
NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q,
bool cpu_switch,
Expand Down
Loading

0 comments on commit cf23031

Please sign in to comment.