Skip to content

Commit

Permalink
Merge pull request #18 from qingshui/paddlebox
Browse files Browse the repository at this point in the history
fix dataset read ins pipe
  • Loading branch information
qingshui authored Nov 9, 2021
2 parents 638d3f1 + 4e577e7 commit a8439c3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
29 changes: 17 additions & 12 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,14 @@ void PadBoxSlotDataset::CheckThreadPool(void) {
void PadBoxSlotDataset::PreLoadIntoMemory() {
CheckThreadPool();
LoadIndexIntoMemory();
// dualbox global data shuffle
if (!FLAGS_padbox_dataset_disable_shuffle && mpi_size_ > 1) {
finished_counter_ = mpi_size_;
mpi_flags_.assign(mpi_size_, 1);
VLOG(3) << "RegisterClientToClientMsgHandler";
data_consumer_ = reinterpret_cast<void*>(new PadBoxSlotDataConsumer(this));
VLOG(3) << "RegisterClientToClientMsgHandler done";
}

read_ins_ref_ = thread_num_;
for (int64_t i = 0; i < thread_num_; ++i) {
Expand All @@ -1539,14 +1547,9 @@ void PadBoxSlotDataset::PreLoadIntoMemory() {
}
}));
}

// dualbox global data shuffle
if (!FLAGS_padbox_dataset_disable_shuffle && mpi_size_ > 1) {
finished_counter_ = mpi_size_;
mpi_flags_.assign(mpi_size_, 1);
VLOG(3) << "RegisterClientToClientMsgHandler";
data_consumer_ = reinterpret_cast<void*>(new PadBoxSlotDataConsumer(this));
VLOG(3) << "RegisterClientToClientMsgHandler done";

ShuffleData(shuffle_thread_num_);
MergeInsKeys(shuffle_channel_);
} else {
Expand Down Expand Up @@ -1577,6 +1580,14 @@ void PadBoxSlotDataset::LoadIntoMemory() {

platform::Timer timeline;
timeline.Start();
// dualbox global data shuffle
if (!FLAGS_padbox_dataset_disable_shuffle && mpi_size_ > 1) {
finished_counter_ = mpi_size_;
mpi_flags_.assign(mpi_size_, 1);
VLOG(3) << "RegisterClientToClientMsgHandler";
data_consumer_ = reinterpret_cast<void*>(new PadBoxSlotDataConsumer(this));
VLOG(3) << "RegisterClientToClientMsgHandler done";
}

read_ins_ref_ = thread_num_;
for (int64_t i = 0; i < thread_num_; ++i) {
Expand All @@ -1590,12 +1601,6 @@ void PadBoxSlotDataset::LoadIntoMemory() {

// dualbox global data shuffle
if (!FLAGS_padbox_dataset_disable_shuffle && mpi_size_ > 1) {
finished_counter_ = mpi_size_;
mpi_flags_.assign(mpi_size_, 1);
VLOG(3) << "RegisterClientToClientMsgHandler";
data_consumer_ = reinterpret_cast<void*>(new PadBoxSlotDataConsumer(this));
VLOG(3) << "RegisterClientToClientMsgHandler done";

ShuffleData(shuffle_thread_num_);
MergeInsKeys(shuffle_channel_);
} else {
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/operators/pull_box_sparse_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"

DECLARE_bool(enable_pull_box_padding_zero);

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -123,8 +125,10 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
const auto *slot = inputs[i];
auto *output = outputs[i];
if (slot->numel() == 0) {
// only support GPU
PaddingZeros<T>(ctx, output, batch_size, hidden_size);
if (FLAGS_enable_pull_box_padding_zero) {
// only support GPU
PaddingZeros<T>(ctx, output, batch_size, hidden_size);
}
continue;
}
output->mutable_data<T>(ctx.GetPlace());
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,5 @@ DEFINE_bool(enable_pullpush_dedup_keys, false,
"enable pull push dedup keys, default false");
DEFINE_bool(enable_shuffle_by_searchid, false,
"enable dualbox shuffle by searchid, default false");
DEFINE_bool(enable_pull_box_padding_zero, true,
"enable pull box padding zero, default true");
1 change: 1 addition & 0 deletions python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __bootstrap__():
'enable_slotrecord_reset_shrink',
'enable_pullpush_dedup_keys',
'enable_shuffle_by_searchid',
'enable_pull_box_padding_zero',
]
core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
Expand Down

0 comments on commit a8439c3

Please sign in to comment.