Skip to content

Commit

Permalink
【Pglbox】merge gpugraph to develop (#50091)
Browse files Browse the repository at this point in the history
* add dump_walk_path  (#193)

* add dump_walk_path; test=develop

* add dump_walk_path; test=develop

* add dump_walk_path; test=develop

* Add multiple CPU communication, parameter query and merging functions, support batch alignment between multiple cards (#194)

* compatible with edge_type of src2dst and src2etype2dst (#195)

* do not merge_feature_shard when using metapath_split_opt (#198)

* support only load reverse_edge (#199)

* refactor GraphTable (#201)

* fix

* fix

* fix code style

* fix code style

* fix test_dataset

* fix hogwild worker

* fix code style

* fix code style

* fix code style

* fix code style

* fix code style.

* fix code style.

---------

Co-authored-by: danleifeng <[email protected]>
Co-authored-by: qingshui <[email protected]>
Co-authored-by: Webbley <[email protected]>
Co-authored-by: huwei02 <[email protected]>
  • Loading branch information
5 people authored Feb 6, 2023
1 parent 5a13280 commit caf2008
Show file tree
Hide file tree
Showing 45 changed files with 2,705 additions and 892 deletions.
49 changes: 49 additions & 0 deletions paddle/fluid/distributed/ps/service/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ brpc_library(

get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)

proto_library(simple_brpc_proto SRCS simple_brpc.proto)
set_source_files_properties(
simple_rpc/rpc_server.cc simple_rpc/baidu_rpc_server.cc
PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
communicator/communicator.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
Expand All @@ -60,6 +64,8 @@ set_source_files_properties(
brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_local_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_graph_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

set_source_files_properties(
brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand All @@ -85,11 +91,17 @@ set_source_files_properties(
set_source_files_properties(
ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})

cc_library(
brpc_utils
SRCS brpc_utils.cc
DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})

cc_library(
simple_rpc
SRCS simple_rpc/rpc_server.cc simple_rpc/baidu_rpc_server.cc
DEPS simple_brpc_proto ${RPC_DEPS})

cc_library(
ps_service
SRCS graph_brpc_server.cc
Expand All @@ -98,6 +110,7 @@ cc_library(
graph_brpc_client.cc
brpc_ps_client.cc
ps_local_client.cc
ps_graph_client.cc
coordinator_client.cc
ps_client.cc
communicator/communicator.cc
Expand All @@ -107,11 +120,42 @@ cc_library(
table
brpc_utils
simple_threadpool
simple_rpc
scope
math_function
selected_rows_functor
ps_gpu_wrapper
${RPC_DEPS})

#cc_library(
# downpour_server
# SRCS graph_brpc_server.cc brpc_ps_server.cc
# DEPS eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})

#cc_library(
# downpour_client
# SRCS graph_brpc_client.cc brpc_ps_client.cc ps_local_client.cc
# ps_graph_client.cc coordinator_client.cc
# DEPS eigen3 table brpc_utils simple_threadpool ps_gpu_wrapper simple_rpc ${RPC_DEPS})

#cc_library(
# client
# SRCS ps_client.cc
# DEPS downpour_client ${RPC_DEPS})
#cc_library(
# server
# SRCS server.cc
# DEPS downpour_server ${RPC_DEPS})

#cc_library(
# communicator
# SRCS communicator/communicator.cc
# DEPS scope client table math_function selected_rows_functor ${RPC_DEPS})
#cc_library(
# ps_service
# SRCS ps_service/service.cc
# DEPS communicator client server ${RPC_DEPS})

cc_library(
heter_client
SRCS heter_client.cc
Expand All @@ -120,3 +164,8 @@ cc_library(
heter_server
SRCS heter_server.cc
DEPS heter_client brpc_utils ${COMMON_DEPS} ${RPC_DEPS})

#cc_library(
# graph_py_service
# SRCS ps_service/graph_py_service.cc
# DEPS ps_service)
22 changes: 11 additions & 11 deletions paddle/fluid/distributed/ps/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = std::stoi(request.params(0).c_str());
int idx_ = std::stoi(request.params(1).c_str());
(reinterpret_cast<GraphTable *>(table))->clear_nodes(type_id, idx_);
GraphTableType type_id = *(GraphTableType *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
((GraphTable *)table)->clear_nodes(type_id, idx_);
return 0;
}

Expand Down Expand Up @@ -380,11 +380,11 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
response, -1, "pull_graph_list request requires at least 5 arguments");
return 0;
}
int type_id = std::stoi(request.params(0).c_str());
int idx = std::stoi(request.params(1).c_str());
int start = std::stoi(request.params(2).c_str());
int size = std::stoi(request.params(3).c_str());
int step = std::stoi(request.params(4).c_str());
GraphTableType type_id = *(GraphTableType *)(request.params(0).c_str());
int idx = *(int *)(request.params(1).c_str());
int start = *(int *)(request.params(2).c_str());
int size = *(int *)(request.params(3).c_str());
int step = *(int *)(request.params(4).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
(reinterpret_cast<GraphTable *>(table))
Expand Down Expand Up @@ -432,9 +432,9 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = std::stoi(request.params(0).c_str());
int idx_ = std::stoi(request.params(1).c_str());
size_t size = std::stoull(request.params(2).c_str());
GraphTableType type_id = *(GraphTableType *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
size_t size = *(uint64_t *)(request.params(2).c_str());
// size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
Expand Down
23 changes: 21 additions & 2 deletions paddle/fluid/distributed/ps/service/ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#include "paddle/fluid/distributed/ps/service/ps_graph_client.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
REGISTER_PSCORE_CLASS(PSClient, CoordinatorClient);
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
REGISTER_PSCORE_CLASS(PSClient, PsGraphClient);
#endif

int32_t PSClient::Configure( // called in FleetWrapper::InitWorker
const PSParameter &config,
Expand Down Expand Up @@ -77,8 +84,20 @@ PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
}

const auto &service_param = config.downpour_server_param().service_param();
PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
const auto &client_name = service_param.client_class();

PSClient *client = NULL;
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (client_name == "PsLocalClient" && gloo->Size() > 1) {
client = CREATE_PSCORE_CLASS(PSClient, "PsGraphClient");
LOG(WARNING) << "change PsLocalClient to PsGraphClient";
} else {
client = CREATE_PSCORE_CLASS(PSClient, client_name);
}
#else
client = CREATE_PSCORE_CLASS(PSClient, client_name);
#endif
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/service/sparse_shard_value.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
Expand Down Expand Up @@ -72,7 +73,7 @@ class PSClient {
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, // NOLINT
size_t client_id) final;
size_t client_id);

virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
Expand Down Expand Up @@ -153,7 +154,8 @@ class PSClient {
size_t table_id,
const uint64_t *keys,
size_t num,
uint16_t pass_id) {
uint16_t pass_id,
const uint16_t &dim_id = 0) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
Expand Down Expand Up @@ -329,6 +331,12 @@ class PSClient {
promise.set_value(-1);
return fut;
}
// add
virtual std::shared_ptr<SparseShardValues> TakePassSparseReferedValues(
const size_t &table_id, const uint16_t &pass_id, const uint16_t &dim_id) {
VLOG(0) << "Did not implement";
return nullptr;
}

protected:
virtual int32_t Initialize() = 0;
Expand Down
Loading

0 comments on commit caf2008

Please sign in to comment.