Skip to content

Commit

Permalink
add dump_walk_path (PaddlePaddle#193)
Browse files Browse the repository at this point in the history
* add dump_walk_path; test=develop

* add dump_walk_path; test=develop

* add dump_walk_path; test=develop
  • Loading branch information
danleifeng authored Dec 30, 2022
1 parent 60b6484 commit 2541de4
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 32 deletions.
10 changes: 10 additions & 0 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2714,6 +2714,16 @@ void SlotRecordInMemoryDataFeed::DoWalkandSage() {
#endif
}

void SlotRecordInMemoryDataFeed::DumpWalkPath(std::string dump_path,
size_t dump_rate) {
VLOG(3) << "INTO SlotRecordInMemoryDataFeed::DumpWalkPath";
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
std::string path =
string::format_string("%s/part-%03d", dump_path.c_str(), thread_id_);
gpu_graph_data_generator_.DumpWalkPath(path, dump_rate);
#endif
}

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
int offset_cols_size = (ins_num + 1);
Expand Down
109 changes: 78 additions & 31 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/phi/kernels/gpu/graph_reindex_funcs.h"
#include "paddle/phi/kernels/graph_reindex_kernel.h"

Expand Down Expand Up @@ -458,7 +459,7 @@ __global__ void GraphFillIdKernel(uint64_t *id_tensor,
int step,
int len,
int col_num,
uint8_t* excluded_train_pair,
uint8_t *excluded_train_pair,
int excluded_train_pair_len) {
__shared__ uint64_t local_key[CUDA_NUM_THREADS * 2];
__shared__ int local_num;
Expand All @@ -477,8 +478,8 @@ __global__ void GraphFillIdKernel(uint64_t *id_tensor,
int src = row[idx] * col_num + central_word;
if (walk[src] != 0 && walk[src + step] != 0) {
for (int i = 0; i < excluded_train_pair_len; i += 2) {
if (walk_ntype[src] == excluded_train_pair[i]
&& walk_ntype[src + step] == excluded_train_pair[i + 1]) {
if (walk_ntype[src] == excluded_train_pair[i] &&
walk_ntype[src + step] == excluded_train_pair[i + 1]) {
// filter this pair
need_filter = true;
break;
Expand Down Expand Up @@ -733,7 +734,8 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) {
uint8_t *excluded_train_pair = NULL;
if (excluded_train_pair_len_ > 0) {
walk_ntype = reinterpret_cast<uint8_t *>(d_walk_ntype_->ptr());
excluded_train_pair = reinterpret_cast<uint8_t *>(d_excluded_train_pair_->ptr());
excluded_train_pair =
reinterpret_cast<uint8_t *>(d_excluded_train_pair_->ptr());
}
uint64_t *ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
int *random_row = reinterpret_cast<int *>(d_random_row_->ptr());
Expand Down Expand Up @@ -768,7 +770,8 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) {
VLOG(2) << "h_pair_num = " << h_pair_num
<< ", ins_buf_pair_len = " << ins_buf_pair_len_;
for (int xx = 0; xx < ins_buf_pair_len_; xx++) {
VLOG(2) << "h_ins_buf: " << h_ins_buf[xx * 2] << ", " << h_ins_buf[xx * 2 + 1];
VLOG(2) << "h_ins_buf: " << h_ins_buf[xx * 2] << ", "
<< h_ins_buf[xx * 2 + 1];
}
}
return ins_buf_pair_len_;
Expand Down Expand Up @@ -1855,16 +1858,17 @@ std::shared_ptr<phi::Allocation> GraphDataGenerator::GenerateSampleGraph(
}

std::shared_ptr<phi::Allocation> GraphDataGenerator::GetNodeDegree(
uint64_t* node_ids, int len) {
uint64_t *node_ids, int len) {
auto node_degree = memory::AllocShared(
place_,
len * edge_to_id_len_ * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
auto edge_to_id = gpu_graph_ptr->edge_to_id;
for (auto& iter : edge_to_id) {
for (auto &iter : edge_to_id) {
int edge_idx = iter.second;
gpu_graph_ptr->get_node_degree(gpuid_, edge_idx, node_ids, len, node_degree);
gpu_graph_ptr->get_node_degree(
gpuid_, edge_idx, node_ids, len, node_degree);
}
return node_degree;
}
Expand Down Expand Up @@ -1959,10 +1963,11 @@ void GraphDataGenerator::DoWalkandSage() {
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
auto final_sage_nodes = GenerateSampleGraph(
ins_cursor, total_instance, &uniq_instance, inverse);
uint64_t* final_sage_nodes_ptr =
uint64_t *final_sage_nodes_ptr =
reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
if (get_degree_) {
auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, uniq_instance);
auto node_degrees =
GetNodeDegree(final_sage_nodes_ptr, uniq_instance);
node_degree_vec_.emplace_back(node_degrees);
}
cudaStreamSynchronize(sample_stream_);
Expand Down Expand Up @@ -2016,10 +2021,11 @@ void GraphDataGenerator::DoWalkandSage() {
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
auto final_sage_nodes = GenerateSampleGraph(
node_buf_ptr_, total_instance, &uniq_instance, inverse);
uint64_t* final_sage_nodes_ptr =
uint64_t *final_sage_nodes_ptr =
reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
if (get_degree_) {
auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, uniq_instance);
auto node_degrees =
GetNodeDegree(final_sage_nodes_ptr, uniq_instance);
node_degree_vec_.emplace_back(node_degrees);
}
cudaStreamSynchronize(sample_stream_);
Expand Down Expand Up @@ -2084,7 +2090,8 @@ int GraphDataGenerator::FillInferBuf() {
}
if (!infer_node_type_index_set_.empty()) {
while (infer_cursor < h_device_keys_len_.size()) {
if (infer_node_type_index_set_.find(infer_cursor) == infer_node_type_index_set_.end()) {
if (infer_node_type_index_set_.find(infer_cursor) ==
infer_node_type_index_set_.end()) {
VLOG(2) << "Skip cursor[" << infer_cursor << "]";
infer_cursor++;
continue;
Expand Down Expand Up @@ -2723,21 +2730,21 @@ void GraphDataGenerator::AllocResource(int thread_id,
excluded_train_pair_len_ = gpu_graph_ptr->excluded_train_pair_.size();
if (excluded_train_pair_len_ > 0) {
d_excluded_train_pair_ = memory::AllocShared(
place_,
excluded_train_pair_len_ * sizeof(uint8_t),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
CUDA_CHECK(cudaMemcpyAsync(
d_excluded_train_pair_->ptr(), gpu_graph_ptr->excluded_train_pair_.data(),
excluded_train_pair_len_ * sizeof(uint8_t),
cudaMemcpyHostToDevice,
sample_stream_));
place_,
excluded_train_pair_len_ * sizeof(uint8_t),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
CUDA_CHECK(cudaMemcpyAsync(d_excluded_train_pair_->ptr(),
gpu_graph_ptr->excluded_train_pair_.data(),
excluded_train_pair_len_ * sizeof(uint8_t),
cudaMemcpyHostToDevice,
sample_stream_));

d_walk_ntype_ = memory::AllocShared(
place_,
buf_size_ * sizeof(uint8_t),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
place_,
buf_size_ * sizeof(uint8_t),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
cudaMemsetAsync(
d_walk_ntype_->ptr(), 0, buf_size_ * sizeof(uint8_t), sample_stream_);
d_walk_ntype_->ptr(), 0, buf_size_ * sizeof(uint8_t), sample_stream_);
}

d_sample_keys_ = memory::AllocShared(
Expand Down Expand Up @@ -2835,11 +2842,12 @@ void GraphDataGenerator::AllocResource(int thread_id,
(batch_size_ * 2 * 2) * sizeof(uint32_t),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
}

// parse infer_node_type
auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index();
if (!gpu_graph_training_) {
auto node_types = paddle::string::split_string<std::string>(infer_node_type_, ";");
auto node_types =
paddle::string::split_string<std::string>(infer_node_type_, ";");
auto node_to_id = gpu_graph_ptr->node_to_id;
for (auto &type : node_types) {
auto iter = node_to_id.find(type);
Expand All @@ -2849,11 +2857,13 @@ void GraphDataGenerator::AllocResource(int thread_id,
platform::errors::NotFound("(%s) is not found in node_to_id.", type));
int node_type = iter->second;
int type_index = type_to_index[node_type];
VLOG(2) << "add node[" << type << "] into infer_node_type, type_index(cursor)["
<< type_index << "]";
VLOG(2) << "add node[" << type
<< "] into infer_node_type, type_index(cursor)[" << type_index
<< "]";
infer_node_type_index_set_.insert(type_index);
}
VLOG(2) << "infer_node_type_index_set_num: " << infer_node_type_index_set_.size();
VLOG(2) << "infer_node_type_index_set_num: "
<< infer_node_type_index_set_.size();
}

cudaStreamSynchronize(sample_stream_);
Expand Down Expand Up @@ -2910,7 +2920,8 @@ void GraphDataGenerator::SetConfig(
std::string str_samples = graph_config.samples();
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
debug_gpu_memory_info("init_conf start");
gpu_graph_ptr->init_conf(first_node_type, meta_path, graph_config.excluded_train_pair());
gpu_graph_ptr->init_conf(
first_node_type, meta_path, graph_config.excluded_train_pair());
debug_gpu_memory_info("init_conf end");

auto edge_to_id = gpu_graph_ptr->edge_to_id;
Expand All @@ -2928,6 +2939,42 @@ void GraphDataGenerator::SetConfig(
}
};

void GraphDataGenerator::DumpWalkPath(std::string dump_path, size_t dump_rate) {
#ifdef _LINUX
PADDLE_ENFORCE_LT(
dump_rate,
10000000,
platform::errors::InvalidArgument(
"dump_rate can't be large than 10000000. Please check the dump "
"rate[1, 10000000]"));
PADDLE_ENFORCE_GT(dump_rate,
1,
platform::errors::InvalidArgument(
"dump_rate can't be less than 1. Please check "
"the dump rate[1, 10000000]"));
int err_no = 0;
std::shared_ptr<FILE> fp = fs_open_append_write(dump_path, &err_no, "");
uint64_t *h_walk = new uint64_t[buf_size_];
uint64_t *walk = reinterpret_cast<uint64_t *>(d_walk_->ptr());
cudaMemcpy(
h_walk, walk, buf_size_ * sizeof(uint64_t), cudaMemcpyDeviceToHost);
VLOG(1) << "DumpWalkPath all buf_size_:" << buf_size_;
std::string ss = "";
size_t write_count = 0;
for (int xx = 0; xx < buf_size_ / dump_rate; xx += walk_len_) {
ss = "";
for (int yy = 0; yy < walk_len_; yy++) {
ss += std::to_string(h_walk[xx + yy]) + "-";
}
write_count = fwrite_unlocked(ss.data(), 1, ss.length(), fp.get());
if (write_count != ss.length()) {
VLOG(1) << "dump walk path" << ss << " failed";
}
write_count = fwrite_unlocked("\n", 1, 1, fp.get());
}
#endif
}

} // namespace framework
} // namespace paddle
#endif
9 changes: 8 additions & 1 deletion paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ class GraphDataGenerator {
void FillOneStep(uint64_t* start_ids,
int etype_id,
uint64_t* walk,
uint8_t *walk_ntype,
uint8_t* walk_ntype,
int len,
NeighborSampleResult& sample_res,
int cur_degree,
Expand All @@ -940,6 +940,7 @@ class GraphDataGenerator {
void ResetPathNum() { total_row_ = 0; }
void ResetEpochFinish() { epoch_finish_ = false; }
void ClearSampleState();
void DumpWalkPath(std::string dump_path, size_t dump_rate);
void SetDeviceKeys(std::vector<uint64_t>* device_keys, int type) {
// type_to_index_[type] = h_device_keys_.size();
// h_device_keys_.push_back(device_keys);
Expand Down Expand Up @@ -1222,6 +1223,11 @@ class DataFeed {
}
virtual const paddle::platform::Place& GetPlace() const { return place_; }

virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) {
PADDLE_THROW(platform::errors::Unimplemented(
"This function(DumpWalkPath) is not implemented."));
}

protected:
// The following three functions are used to check if it is executed in this
// order:
Expand Down Expand Up @@ -1828,6 +1834,7 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
const UsedSlotGpuType* used_slots);
#endif
virtual void DoWalkandSage();
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate);

float sample_rate_ = 1.0f;
int use_slot_size_ = 0;
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,26 @@ void DatasetImpl<T>::LocalShuffle() {
<< timeline.ElapsedSec() << " seconds";
}

template <typename T>
void DatasetImpl<T>::DumpWalkPath(std::string dump_path, size_t dump_rate) {
VLOG(3) << "DatasetImpl<T>::DumpWalkPath() begin";
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
std::vector<std::thread> dump_threads;
if (gpu_graph_mode_) {
for (int64_t i = 0; i < thread_num_; ++i) {
dump_threads.push_back(
std::thread(&paddle::framework::DataFeed::DumpWalkPath,
readers_[i].get(),
dump_path,
dump_rate));
}
for (std::thread& t : dump_threads) {
t.join();
}
}
#endif
}

// do tdm sample
void MultiSlotDataset::TDMSample(const std::string tree_name,
const std::string tree_path,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class Dataset {
virtual void SetPassId(uint32_t pass_id) = 0;
virtual uint32_t GetPassID() = 0;

virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) = 0;

protected:
virtual int ReceiveFromClient(int msg_type,
int client_id,
Expand Down Expand Up @@ -268,6 +270,7 @@ class DatasetImpl : public Dataset {
virtual void SetFleetSendSleepSeconds(int seconds);
virtual std::vector<std::string> GetSlots();
virtual bool GetEpochFinish();
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate);

std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
return multi_output_channel_;
Expand Down
34 changes: 34 additions & 0 deletions paddle/fluid/framework/io/fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ std::shared_ptr<FILE> localfs_open_write(std::string path,
return fs_open_internal(path, is_pipe, "w", localfs_buffer_size());
}

std::shared_ptr<FILE> localfs_open_append_write(std::string path,
const std::string& converter) {
shell_execute(
string::format_string("mkdir -p $(dirname \"%s\")", path.c_str()));

bool is_pipe = false;

if (fs_end_with_internal(path, ".gz")) {
fs_add_write_converter_internal(path, is_pipe, "gzip");
}

fs_add_write_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "a", localfs_buffer_size());
}

int64_t localfs_file_size(const std::string& path) {
struct stat buf;
if (0 != stat(path.c_str(), &buf)) {
Expand Down Expand Up @@ -432,6 +447,25 @@ std::shared_ptr<FILE> fs_open_write(const std::string& path,
return {};
}

std::shared_ptr<FILE> fs_open_append_write(const std::string& path,
int* err_no,
const std::string& converter) {
switch (fs_select_internal(path)) {
case 0:
return localfs_open_append_write(path, converter);

case 1:
return hdfs_open_write(path, err_no, converter);

default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport file system. Now only supports local file system and "
"HDFS."));
}

return {};
}

std::shared_ptr<FILE> fs_open(const std::string& path,
const std::string& mode,
int* err_no,
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/io/fs.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ extern std::shared_ptr<FILE> fs_open_write(const std::string& path,
int* err_no,
const std::string& converter);

extern std::shared_ptr<FILE> fs_open_append_write(const std::string& path,
int* err_no,
const std::string& converter);

extern std::shared_ptr<FILE> fs_open(const std::string& path,
const std::string& mode,
int* err_no,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/data_set_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_pass_id",
&framework::Dataset::SetPassId,
py::call_guard<py::gil_scoped_release>())
.def("dump_walk_path",
&framework::Dataset::DumpWalkPath,
py::call_guard<py::gil_scoped_release>());

py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
Expand Down
Loading

0 comments on commit 2541de4

Please sign in to comment.