Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch sample k #2

Merged
merged 17 commits into from
Mar 16, 2021
106 changes: 77 additions & 29 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,91 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
return id % shard_num / shard_per_server;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::sample(
uint32_t table_id, uint64_t node_id, int sample_size,
std::vector<std::pair<uint64_t, float>> &res) {
int server_index = get_server_index_by_id(node_id);
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
std::future<int32_t> GraphBrpcClient::batch_sample(uint32_t table_id,
std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float> > > &res) {

std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
int server_index = get_server_index_by_id(node_ids[query_idx]);
if(server2request[server_index] == -1){
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
//res.push_back(std::vector<GraphNode>());
res.push_back(std::vector<std::pair<uint64_t, float>>());
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t> > node_id_buckets(request_call_num);
std::vector<std::vector<int> > query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}

DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE) != 0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int offset = 0;
while (offset < bytes_size) {
res.push_back({*(uint64_t *)(buffer + offset),
*(float *)(buffer + offset + GraphNode::id_size)});
offset += GraphNode::id_size + GraphNode::weight_size;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
++fail_num;
} else {
VLOG(0) << "check sample response: "
<< " " << closure->check_response(request_idx, PS_GRAPH_SAMPLE);
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);

size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;

int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx){
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back({*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start + GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
}
offset += actual_size;
}
}
if (fail_num == request_call_num){
ret = -1;
}
}
closure->set_promise_value(ret);
});

auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&node_id, sizeof(uint64_t));
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);

for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
// std::string type_str = GraphNode::node_type_to_string(type);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)node_id_buckets[request_idx].data(), sizeof(uint64_t)*node_num);
closure->request(request_idx)->add_params((char *)&sample_size, sizeof(int));
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx),
closure);
}

return fut;
}
Expand Down Expand Up @@ -124,4 +172,4 @@ int32_t GraphBrpcClient::initialize() {
return 0;
}
}
}
}
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class GraphBrpcClient : public BrpcPsClient {
public:
GraphBrpcClient() {}
virtual ~GraphBrpcClient() {}
virtual std::future<int32_t> sample(
uint32_t table_id, uint64_t node_id, int sample_size,
std::vector<std::pair<uint64_t, float>> &res);
virtual std::future<int32_t> batch_sample(uint32_t table_id, std::vector<uint64_t> node_ids,
int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size,
Expand Down
27 changes: 22 additions & 5 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,29 @@ int32_t GraphBrpcService::graph_random_sample(Table *table,
"graph_random_sample request requires at least 2 arguments");
return 0;
}
uint64_t node_id = *(uint64_t *)(request.params(0).c_str());
size_t num_nodes = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
char *buffer;
int actual_size;
table->random_sample(node_id, sample_size, buffer, actual_size);
cntl->response_attachment().append(buffer, actual_size);

std::vector<std::future<int>*> tasks;
std::vector<char*> buffers(num_nodes);
std::vector<int> actual_sizes(num_nodes);

for (size_t idx = 0; idx < num_nodes; ++idx){
//std::future<int> task = table->random_sample(node_data[idx], sample_size,
//buffers[idx], actual_sizes[idx]);
table->random_sample(node_data[idx], sample_size,
buffers[idx], actual_sizes[idx]);
//tasks.push_back(&task);
}
//for (size_t idx = 0; idx < num_nodes; ++idx){
//tasks[idx]->get();
//}
cntl->response_attachment().append(&num_nodes, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(), sizeof(int)*num_nodes);
for (size_t idx = 0; idx < num_nodes; ++idx){
cntl->response_attachment().append(buffers[idx], actual_sizes[idx]);
}
return 0;
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
status.wait();
}
}
std::vector<std::pair<uint64_t, float>> GraphPyClient::sample_k(
std::string name, uint64_t node_id, int sample_size) {
std::vector<std::pair<uint64_t, float>> v;
std::vector<std::vector<std::pair<uint64_t, float> > > GraphPyClient::batch_sample_k(
std::string name, std::vector<uint64_t> node_ids, int sample_size) {
std::vector<std::vector<std::pair<uint64_t, float> > > v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = worker_ptr->sample(table_id, node_id, sample_size, v);
auto status = worker_ptr->batch_sample(table_id, node_ids, sample_size, v);
status.wait();
}
return v;
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ class GraphPyClient : public GraphPyService {
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::vector<std::pair<uint64_t, float>> sample_k(std::string name,
uint64_t node_id,
int sample_size);
std::vector<std::vector<std::pair<uint64_t, float> > > batch_sample_k(
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::vector<GraphNode> pull_graph_list(std::string name, int server_index,
int start, int size);
::paddle::distributed::PSParameter GetWorkerProto();
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ class PSClient {
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> sample(
uint32_t table_id, uint64_t node_id, int sample_size,
std::vector<std::pair<uint64_t, float>> &res) {
virtual std::future<int32_t> batch_sample(uint32_t table_id, std::vector<uint64_t> node_ids,
int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ GraphNode *GraphTable::find_node(uint64_t id) {
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num_per_table % task_pool_size_;
}
//std::future<int> GraphTable::random_sample(uint64_t node_id, int sample_size,
//char *&buffer, int &actual_size) {
int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
char *&buffer, int &actual_size) {
return _shards_task_pool[get_thread_pool_index(node_id)]
Expand All @@ -226,6 +228,7 @@ int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
memcpy(buffer + offset, &weight, GraphNode::weight_size);
offset += GraphNode::weight_size;
}
return 0;
})
.get();
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class GraphTable : public SparseTable {
virtual ~GraphTable() {}
virtual int32_t pull_graph_list(int start, int size, char *&buffer,
int &actual_size);
//virtual std::future<int> random_sample(uint64_t node_id, int sampe_size, char *&buffer,
//int &actual_size);
virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer,
int &actual_size);
virtual int32_t initialize();
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/table/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class Table {
int &actual_size) {
return 0;
}
//virtual std::future<int> random_sample(uint64_t node_id, int sampe_size, char *&buffer,
//int &actual_size) {
//return std::future<int>();
//}
virtual int32_t pour() { return 0; }

virtual void clear() = 0;
Expand Down
60 changes: 21 additions & 39 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,26 @@ void RunBrpcPushSparse() {

/*-----------------------Test Server Init----------------------------------*/
auto pull_status =
worker_ptr_->load(0, std::string(file_name), std::string(""));
worker_ptr_->load(0, std::string(file_name), std::string("edge"));

pull_status.wait();
std::vector<std::pair<uint64_t, float>> v;
pull_status = worker_ptr_->sample(0, 37, 4, v);
std::vector<std::vector<std::pair<uint64_t, float> > > vs;
//std::vector<std::pair<uint64_t, float>> v;
//pull_status = worker_ptr_->sample(0, 37, 4, v);
pull_status = worker_ptr_->batch_sample(0, std::vector<uint64_t>(1, 37), 4, vs);
pull_status.wait();
ASSERT_EQ(v.size(), 3);
v.clear();
pull_status = worker_ptr_->sample(0, 96, 4, v);
ASSERT_EQ(vs[0].size(), 3);
vs.clear();
//pull_status = worker_ptr_->sample(0, 96, 4, v);
pull_status = worker_ptr_->batch_sample(0, std::vector<uint64_t>(1, 96), 4, vs);
pull_status.wait();
std::unordered_set<int> s = {111, 48, 247};
ASSERT_EQ(3, v.size());
for (auto g : v) {
ASSERT_EQ(3, vs[0].size());
for (auto g : vs[0]) {
// std::cout << g.first << std::endl;
ASSERT_EQ(true, s.find(g.first) != s.end());
}
v.clear();
vs.clear();
std::vector<distributed::GraphNode> nodes;
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, nodes);
pull_status.wait();
Expand Down Expand Up @@ -276,38 +279,17 @@ void RunBrpcPushSparse() {
nodes = client2.pull_graph_list(std::string("user2item"), 0, 1, 4);
ASSERT_EQ(nodes[0].get_id(), 59);
nodes.clear();
v = client1.sample_k(std::string("user2item"), 96, 4);
ASSERT_EQ(v.size(), 3);
std::cout << "sample result" << std::endl;
for (auto p : v) {
vs = client1.batch_sample_k(std::string("user2item"), std::vector<uint64_t>(1, 96), 4);
ASSERT_EQ(vs[0].size(), 3);
std::cout << "batch sample result" << std::endl;
for (auto p : vs[0]) {
std::cout << p.first << " " << p.second << std::endl;
}
/*
from paddle.fluid.core import GraphPyService
ips_str = "127.0.0.1:4211;127.0.0.1:4212"
server1 = GraphPyServer()
server2 = GraphPyServer()
client1 = GraphPyClient()
client2 = GraphPyClient()
edge_types = ["user2item"]
server1.set_up(ips_str,127,edge_types,0);
server2.set_up(ips_str,127,edge_types,1);
client1.set_up(ips_str,127,edge_types,0);
client2.set_up(ips_str,127,edge_types,1);
server1.start_server();
server2.start_server();
client1.start_client();
client2.start_client();
client1.load_edge_file(user2item", "input.txt", 0);
list = client2.pull_graph_list("user2item",0,1,4)
for x in list:
print(x.get_id())

list = client1.sample_k("user2item",96, 4);
for x in list:
print(x.get_id())
*/

std::vector<uint64_t> node_ids;
node_ids.push_back(96);
node_ids.push_back(37);
vs = client1.batch_sample_k(std::string("user2item"), node_ids, 4);
ASSERT_EQ(vs.size(), 2);
// to test in python,try this:
// from paddle.fluid.core import GraphPyService
// ips_str = "127.0.0.1:4211;127.0.0.1:4212"
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ void BindGraphPyClient(py::module* m) {
.def("load_node_file", &GraphPyClient::load_node_file)
.def("set_up", &GraphPyClient::set_up)
.def("pull_graph_list", &GraphPyClient::pull_graph_list)
.def("sample_k", &GraphPyClient::sample_k)
.def("start_client", &GraphPyClient::start_client);
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_k", &GraphPyClient::batch_sample_k);
}

} // end namespace pybind
Expand Down