Skip to content

Commit

Permalink
merge from gpugraph (#75)
Browse files Browse the repository at this point in the history
* fix hpi (PaddlePaddle#204)

* make infer & train same logic (PaddlePaddle#196)

* make infer & train same logic

* make infer & train same logic

* fix nccl sync (PaddlePaddle#205)

---------

Co-authored-by: Huang Zhengjie <[email protected]>
  • Loading branch information
qingshui and Yelrose authored Feb 2, 2023
1 parent 89f76da commit ab8d3f2
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 3 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2108,9 +2108,9 @@ int GraphDataGenerator::FillInferBuf() {

size_t device_key_size = h_device_keys_len_[infer_cursor];
total_row_ =
(global_infer_node_type_start[infer_cursor] + infer_table_cap_ <=
(global_infer_node_type_start[infer_cursor] + buf_size_ <=
device_key_size)
? infer_table_cap_
? buf_size_
: device_key_size - global_infer_node_type_start[infer_cursor];

uint64_t *d_type_keys =
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) {
dump_fields_path_ = desc.dump_fields_path();
need_dump_field_ = false;
need_dump_param_ = false;
dump_fields_mode_ = desc.dump_fields_mode();

if (dump_fields_path_ == "") {
VLOG(2) << "dump_fields_path_ is empty";
return;
Expand Down Expand Up @@ -58,7 +60,15 @@ void TrainerBase::DumpWork(int tid) {
int err_no = 0;
// GetDumpPath is implemented in each Trainer
std::string path = GetDumpPath(tid);
std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_);
std::shared_ptr<FILE> fp;
if (dump_fields_mode_ == "a") {
VLOG(3) << "dump field mode append";
fp = fs_open_append_write(path, &err_no, dump_converter_);
}
else {
VLOG(3) << "dump field mode overwrite";
fp = fs_open_write(path, &err_no, dump_converter_);
}
while (1) {
std::string out_str;
if (!queue_->Get(out_str)) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TrainerBase {
std::string dump_converter_;
std::vector<std::string> dump_param_;
std::vector<std::string> dump_fields_;
std::string dump_fields_mode_;
int dump_thread_num_;
std::vector<std::thread> dump_thread_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/trainer_desc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ message TrainerDesc {
// add for gpu
optional string fleet_desc = 37;
optional bool is_dump_in_simple_mode = 38 [ default = false ];
optional string dump_fields_mode = 39 [ default = "w" ];
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103;
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/trainer_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def _set_dump_param(self, dump_param):
for param in dump_param:
self.proto_desc.dump_param.append(param)

def _set_dump_fields_mode(self, mode):
self.proto_desc.dump_fields_mode = mode

def _set_worker_places(self, worker_places):
for place in worker_places:
self.proto_desc.worker_places.append(place)
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _create_trainer(self, opt_info=None):
if opt_info.get("dump_fields_path") is not None and len(
opt_info.get("dump_fields_path")) != 0:
trainer._set_dump_fields_path(opt_info["dump_fields_path"])
if opt_info.get("dump_fields_mode") is not None:
trainer._set_dump_fields_mode(opt_info["dump_fields_mode"])
if opt_info.get(
"user_define_dump_filename") is not None and len(
opt_info.get("user_define_dump_filename")) != 0:
Expand Down

0 comments on commit ab8d3f2

Please sign in to comment.