Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#14 from WeiyueSu/FeatureNode
Browse files Browse the repository at this point in the history
get_node_feat return  py:bytes
  • Loading branch information
seemingwang authored Mar 25, 2021
2 parents bb48ece + 6f4223c commit 578e305
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
9 changes: 3 additions & 6 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,14 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {

auto node = shards[index].add_feature_node(id);

//auto mutable_feature = node->get_mutable_feature();

//mutable_feature.clear();
//mutable_feature.resize(this->feat_name.size());
node->set_feature_size(feat_name.size());

for (size_t slice = 2; slice < values.size(); slice++) {
auto feat = this->parse_feature(values[slice]);
if(feat.first > 0) {
//mutable_feature[feat.first] = feat.second;
if (feat.first >= 0) {
node->set_feature(feat.first, feat.second);
} else{
VLOG(4) << "Node feature: " << values[slice] << " not in feature_map.";
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,22 @@ void RunBrpcPushSparse() {
std::cout << "get_node_feat: " << node_feat[1][0] << std::endl;
std::cout << "get_node_feat: " << node_feat[1][1] << std::endl;

// Test string
node_ids.clear();
node_ids.push_back(37);
node_ids.push_back(96);
//std::vector<std::string> feature_names;
feature_names.clear();
feature_names.push_back(std::string("a"));
feature_names.push_back(std::string("b"));
node_feat = client1.get_node_feat(std::string("user"), node_ids, feature_names);
ASSERT_EQ(node_feat.size(), 2);
ASSERT_EQ(node_feat[0].size(), 2);
std::cout << "get_node_feat: " << node_feat[0][0].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[0][1].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[1][0].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[1][1].size() << std::endl;

std::remove(edge_file_name);
std::remove(node_file_name);
LOG(INFO) << "Run stop_server";
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,17 @@ void BindGraphPyClient(py::module* m) {
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("get_node_feat", &GraphPyClient::get_node_feat)
.def("get_node_feat", [](GraphPyClient& self, std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names){
auto feats = self.get_node_feat(node_type, node_ids, feature_names);
std::vector<std::vector<py::bytes> > bytes_feats(feats.size());
for (int i = 0; i < feats.size(); ++i ){
for (int j = 0; j < feats[i].size(); ++j ){
bytes_feats[i].push_back(py::bytes(feats[i][j]));
}
}
return bytes_feats;
})
.def("bind_local_server", &GraphPyClient::bind_local_server);
}

Expand Down

0 comments on commit 578e305

Please sign in to comment.