diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 308f8a23df29ed..c108822040e3fb 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -2714,6 +2714,16 @@ void SlotRecordInMemoryDataFeed::DoWalkandSage() { #endif } +void SlotRecordInMemoryDataFeed::DumpWalkPath(std::string dump_path, + size_t dump_rate) { + VLOG(3) << "INTO SlotRecordInMemoryDataFeed::DumpWalkPath"; +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) + std::string path = + string::format_string("%s/part-%03d", dump_path.c_str(), thread_id_); + gpu_graph_data_generator_.DumpWalkPath(path, dump_rate); +#endif +} + #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) { int offset_cols_size = (ins_num + 1); diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index b4ab4300ae9282..7451a038493c21 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -28,6 +28,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" #include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" +#include "paddle/fluid/framework/io/fs.h" #include "paddle/phi/kernels/gpu/graph_reindex_funcs.h" #include "paddle/phi/kernels/graph_reindex_kernel.h" @@ -458,7 +459,7 @@ __global__ void GraphFillIdKernel(uint64_t *id_tensor, int step, int len, int col_num, - uint8_t* excluded_train_pair, + uint8_t *excluded_train_pair, int excluded_train_pair_len) { __shared__ uint64_t local_key[CUDA_NUM_THREADS * 2]; __shared__ int local_num; @@ -477,8 +478,8 @@ __global__ void GraphFillIdKernel(uint64_t *id_tensor, int src = row[idx] * col_num + central_word; if (walk[src] != 0 && walk[src + step] != 0) { for (int i = 0; i < excluded_train_pair_len; i += 2) { - if (walk_ntype[src] == excluded_train_pair[i] - && walk_ntype[src + step] == excluded_train_pair[i + 1]) { + if (walk_ntype[src] == excluded_train_pair[i] && + walk_ntype[src + step] == excluded_train_pair[i + 1]) { // filter this pair need_filter = true; break; @@ -733,7 +734,8 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) { uint8_t *excluded_train_pair = NULL; if (excluded_train_pair_len_ > 0) { walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); - excluded_train_pair = reinterpret_cast(d_excluded_train_pair_->ptr()); + excluded_train_pair = + reinterpret_cast(d_excluded_train_pair_->ptr()); } uint64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); int *random_row = reinterpret_cast(d_random_row_->ptr()); @@ -768,7 +770,8 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) { VLOG(2) << "h_pair_num = " << h_pair_num << ", ins_buf_pair_len = " << ins_buf_pair_len_; for (int xx = 0; xx < ins_buf_pair_len_; xx++) { - VLOG(2) << "h_ins_buf: " << h_ins_buf[xx * 2] << ", " << h_ins_buf[xx * 2 + 1]; + VLOG(2) << "h_ins_buf: " << h_ins_buf[xx * 2] << ", " + << h_ins_buf[xx * 2 + 1]; } } return ins_buf_pair_len_; @@ -1855,16 +1858,17 @@ std::shared_ptr GraphDataGenerator::GenerateSampleGraph( } std::shared_ptr GraphDataGenerator::GetNodeDegree( - uint64_t* node_ids, int len) { + uint64_t *node_ids, int len) { auto node_degree = memory::AllocShared( place_, len * edge_to_id_len_ * sizeof(int), phi::Stream(reinterpret_cast(sample_stream_))); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); auto edge_to_id = gpu_graph_ptr->edge_to_id; - for (auto& iter : edge_to_id) { + for (auto &iter : edge_to_id) { int edge_idx = iter.second; - gpu_graph_ptr->get_node_degree(gpuid_, edge_idx, node_ids, len, node_degree); + gpu_graph_ptr->get_node_degree( + gpuid_, edge_idx, node_ids, len, node_degree); } return node_degree; } @@ -1959,10 +1963,11 @@ void GraphDataGenerator::DoWalkandSage() { phi::Stream(reinterpret_cast(sample_stream_))); auto final_sage_nodes = GenerateSampleGraph( ins_cursor, total_instance, &uniq_instance, inverse); - uint64_t* final_sage_nodes_ptr = + uint64_t *final_sage_nodes_ptr = reinterpret_cast(final_sage_nodes->ptr()); if (get_degree_) { - auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, uniq_instance); + auto node_degrees = + GetNodeDegree(final_sage_nodes_ptr, uniq_instance); node_degree_vec_.emplace_back(node_degrees); } cudaStreamSynchronize(sample_stream_); @@ -2016,10 +2021,11 @@ void GraphDataGenerator::DoWalkandSage() { phi::Stream(reinterpret_cast(sample_stream_))); auto final_sage_nodes = GenerateSampleGraph( node_buf_ptr_, total_instance, &uniq_instance, inverse); - uint64_t* final_sage_nodes_ptr = + uint64_t *final_sage_nodes_ptr = reinterpret_cast(final_sage_nodes->ptr()); if (get_degree_) { - auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, uniq_instance); + auto node_degrees = + GetNodeDegree(final_sage_nodes_ptr, uniq_instance); node_degree_vec_.emplace_back(node_degrees); } cudaStreamSynchronize(sample_stream_); @@ -2084,7 +2090,8 @@ int GraphDataGenerator::FillInferBuf() { } if (!infer_node_type_index_set_.empty()) { while (infer_cursor < h_device_keys_len_.size()) { - if (infer_node_type_index_set_.find(infer_cursor) == infer_node_type_index_set_.end()) { + if (infer_node_type_index_set_.find(infer_cursor) == + infer_node_type_index_set_.end()) { VLOG(2) << "Skip cursor[" << infer_cursor << "]"; infer_cursor++; continue; @@ -2723,21 +2730,21 @@ void GraphDataGenerator::AllocResource(int thread_id, excluded_train_pair_len_ = gpu_graph_ptr->excluded_train_pair_.size(); if (excluded_train_pair_len_ > 0) { d_excluded_train_pair_ = memory::AllocShared( - place_, - excluded_train_pair_len_ * sizeof(uint8_t), - phi::Stream(reinterpret_cast(sample_stream_))); - CUDA_CHECK(cudaMemcpyAsync( - d_excluded_train_pair_->ptr(), gpu_graph_ptr->excluded_train_pair_.data(), - excluded_train_pair_len_ * sizeof(uint8_t), - cudaMemcpyHostToDevice, - sample_stream_)); + place_, + excluded_train_pair_len_ * sizeof(uint8_t), + phi::Stream(reinterpret_cast(sample_stream_))); + CUDA_CHECK(cudaMemcpyAsync(d_excluded_train_pair_->ptr(), + gpu_graph_ptr->excluded_train_pair_.data(), + excluded_train_pair_len_ * sizeof(uint8_t), + cudaMemcpyHostToDevice, + sample_stream_)); d_walk_ntype_ = memory::AllocShared( - place_, - buf_size_ * sizeof(uint8_t), - phi::Stream(reinterpret_cast(sample_stream_))); + place_, + buf_size_ * sizeof(uint8_t), + phi::Stream(reinterpret_cast(sample_stream_))); cudaMemsetAsync( - d_walk_ntype_->ptr(), 0, buf_size_ * sizeof(uint8_t), sample_stream_); + d_walk_ntype_->ptr(), 0, buf_size_ * sizeof(uint8_t), sample_stream_); } d_sample_keys_ = memory::AllocShared( @@ -2835,11 +2842,12 @@ void GraphDataGenerator::AllocResource(int thread_id, (batch_size_ * 2 * 2) * sizeof(uint32_t), phi::Stream(reinterpret_cast(sample_stream_))); } - + // parse infer_node_type auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index(); if (!gpu_graph_training_) { - auto node_types = paddle::string::split_string(infer_node_type_, ";"); + auto node_types = + paddle::string::split_string(infer_node_type_, ";"); auto node_to_id = gpu_graph_ptr->node_to_id; for (auto &type : node_types) { auto iter = node_to_id.find(type); @@ -2849,11 +2857,13 @@ void GraphDataGenerator::AllocResource(int thread_id, platform::errors::NotFound("(%s) is not found in node_to_id.", type)); int node_type = iter->second; int type_index = type_to_index[node_type]; - VLOG(2) << "add node[" << type << "] into infer_node_type, type_index(cursor)[" - << type_index << "]"; + VLOG(2) << "add node[" << type + << "] into infer_node_type, type_index(cursor)[" << type_index + << "]"; infer_node_type_index_set_.insert(type_index); } - VLOG(2) << "infer_node_type_index_set_num: " << infer_node_type_index_set_.size(); + VLOG(2) << "infer_node_type_index_set_num: " + << infer_node_type_index_set_.size(); } cudaStreamSynchronize(sample_stream_); @@ -2910,7 +2920,8 @@ void GraphDataGenerator::SetConfig( std::string str_samples = graph_config.samples(); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); debug_gpu_memory_info("init_conf start"); - gpu_graph_ptr->init_conf(first_node_type, meta_path, graph_config.excluded_train_pair()); + gpu_graph_ptr->init_conf( + first_node_type, meta_path, graph_config.excluded_train_pair()); debug_gpu_memory_info("init_conf end"); auto edge_to_id = gpu_graph_ptr->edge_to_id; @@ -2928,6 +2939,42 @@ void GraphDataGenerator::SetConfig( } }; +void GraphDataGenerator::DumpWalkPath(std::string dump_path, size_t dump_rate) { +#ifdef _LINUX + PADDLE_ENFORCE_LT( + dump_rate, + 10000000, + platform::errors::InvalidArgument( + "dump_rate can't be large than 10000000. Please check the dump " + "rate[1, 10000000]")); + PADDLE_ENFORCE_GT(dump_rate, + 1, + platform::errors::InvalidArgument( + "dump_rate can't be less than 1. Please check " + "the dump rate[1, 10000000]")); + int err_no = 0; + std::shared_ptr fp = fs_open_append_write(dump_path, &err_no, ""); + uint64_t *h_walk = new uint64_t[buf_size_]; + uint64_t *walk = reinterpret_cast(d_walk_->ptr()); + cudaMemcpy( + h_walk, walk, buf_size_ * sizeof(uint64_t), cudaMemcpyDeviceToHost); + VLOG(1) << "DumpWalkPath all buf_size_:" << buf_size_; + std::string ss = ""; + size_t write_count = 0; + for (int xx = 0; xx < buf_size_ / dump_rate; xx += walk_len_) { + ss = ""; + for (int yy = 0; yy < walk_len_; yy++) { + ss += std::to_string(h_walk[xx + yy]) + "-"; + } + write_count = fwrite_unlocked(ss.data(), 1, ss.length(), fp.get()); + if (write_count != ss.length()) { + VLOG(1) << "dump walk path" << ss << " failed"; + } + write_count = fwrite_unlocked("\n", 1, 1, fp.get()); + } +#endif +} + } // namespace framework } // namespace paddle #endif diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 0a42f8e6fcf29c..0fd6e8953d31e7 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -916,7 +916,7 @@ class GraphDataGenerator { void FillOneStep(uint64_t* start_ids, int etype_id, uint64_t* walk, - uint8_t *walk_ntype, + uint8_t* walk_ntype, int len, NeighborSampleResult& sample_res, int cur_degree, @@ -940,6 +940,7 @@ class GraphDataGenerator { void ResetPathNum() { total_row_ = 0; } void ResetEpochFinish() { epoch_finish_ = false; } void ClearSampleState(); + void DumpWalkPath(std::string dump_path, size_t dump_rate); void SetDeviceKeys(std::vector* device_keys, int type) { // type_to_index_[type] = h_device_keys_.size(); // h_device_keys_.push_back(device_keys); @@ -1222,6 +1223,11 @@ class DataFeed { } virtual const paddle::platform::Place& GetPlace() const { return place_; } + virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) { + PADDLE_THROW(platform::errors::Unimplemented( + "This function(DumpWalkPath) is not implemented.")); + } + protected: // The following three functions are used to check if it is executed in this // order: @@ -1828,6 +1834,7 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { const UsedSlotGpuType* used_slots); #endif virtual void DoWalkandSage(); + virtual void DumpWalkPath(std::string dump_path, size_t dump_rate); float sample_rate_ = 1.0f; int use_slot_size_ = 0; diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index d2fff7c5cdf11e..afe686ea48dd67 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -657,6 +657,26 @@ void DatasetImpl::LocalShuffle() { << timeline.ElapsedSec() << " seconds"; } +template +void DatasetImpl::DumpWalkPath(std::string dump_path, size_t dump_rate) { + VLOG(3) << "DatasetImpl::DumpWalkPath() begin"; +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) + std::vector dump_threads; + if (gpu_graph_mode_) { + for (int64_t i = 0; i < thread_num_; ++i) { + dump_threads.push_back( + std::thread(&paddle::framework::DataFeed::DumpWalkPath, + readers_[i].get(), + dump_path, + dump_rate)); + } + for (std::thread& t : dump_threads) { + t.join(); + } + } +#endif +} + // do tdm sample void MultiSlotDataset::TDMSample(const std::string tree_name, const std::string tree_path, diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 9e1998a35fd649..6d9a1d20f64059 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -174,6 +174,8 @@ class Dataset { virtual void SetPassId(uint32_t pass_id) = 0; virtual uint32_t GetPassID() = 0; + virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) = 0; + protected: virtual int ReceiveFromClient(int msg_type, int client_id, @@ -268,6 +270,7 @@ class DatasetImpl : public Dataset { virtual void SetFleetSendSleepSeconds(int seconds); virtual std::vector GetSlots(); virtual bool GetEpochFinish(); + virtual void DumpWalkPath(std::string dump_path, size_t dump_rate); std::vector>& GetMultiOutputChannel() { return multi_output_channel_; diff --git a/paddle/fluid/framework/io/fs.cc b/paddle/fluid/framework/io/fs.cc index 285ce2ddb2791f..aa909136f40dc3 100644 --- a/paddle/fluid/framework/io/fs.cc +++ b/paddle/fluid/framework/io/fs.cc @@ -131,6 +131,21 @@ std::shared_ptr localfs_open_write(std::string path, return fs_open_internal(path, is_pipe, "w", localfs_buffer_size()); } +std::shared_ptr localfs_open_append_write(std::string path, + const std::string& converter) { + shell_execute( + string::format_string("mkdir -p $(dirname \"%s\")", path.c_str())); + + bool is_pipe = false; + + if (fs_end_with_internal(path, ".gz")) { + fs_add_write_converter_internal(path, is_pipe, "gzip"); + } + + fs_add_write_converter_internal(path, is_pipe, converter); + return fs_open_internal(path, is_pipe, "a", localfs_buffer_size()); +} + int64_t localfs_file_size(const std::string& path) { struct stat buf; if (0 != stat(path.c_str(), &buf)) { @@ -432,6 +447,25 @@ std::shared_ptr fs_open_write(const std::string& path, return {}; } +std::shared_ptr fs_open_append_write(const std::string& path, + int* err_no, + const std::string& converter) { + switch (fs_select_internal(path)) { + case 0: + return localfs_open_append_write(path, converter); + + case 1: + return hdfs_open_write(path, err_no, converter); + + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport file system. Now only supports local file system and " + "HDFS.")); + } + + return {}; +} + std::shared_ptr fs_open(const std::string& path, const std::string& mode, int* err_no, diff --git a/paddle/fluid/framework/io/fs.h b/paddle/fluid/framework/io/fs.h index 0ebc7fea089fbe..842f816d857923 100644 --- a/paddle/fluid/framework/io/fs.h +++ b/paddle/fluid/framework/io/fs.h @@ -103,6 +103,10 @@ extern std::shared_ptr fs_open_write(const std::string& path, int* err_no, const std::string& converter); +extern std::shared_ptr fs_open_append_write(const std::string& path, + int* err_no, + const std::string& converter); + extern std::shared_ptr fs_open(const std::string& path, const std::string& mode, int* err_no, diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index bc60d536d19cee..07872c744f5e22 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -374,6 +374,9 @@ void BindDataset(py::module *m) { py::call_guard()) .def("set_pass_id", &framework::Dataset::SetPassId, + py::call_guard()) + .def("dump_walk_path", + &framework::Dataset::DumpWalkPath, py::call_guard()); py::class_(*m, "IterableDatasetWrapper") diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index bf1e7eb1eecefd..eae9b8e3784b4c 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -1111,6 +1111,12 @@ def get_pass_id(self): """ return self.pass_id + def dump_walk_path(self, path, dump_rate=1000): + """ + dump_walk_path + """ + self.dataset.dump_walk_path(path, dump_rate) + class QueueDataset(DatasetBase): """