Skip to content

Commit

Permalink
Merge pull request #16 from qingshui/paddlebox
Browse files Browse the repository at this point in the history
1. 增加pull push sparse去重复, 2. 优化样本读取shuffle打包解包性能,3. 增加检测inf nan打印op参数输入输出
  • Loading branch information
qingshui authored Aug 25, 2021
2 parents f7d08b4 + 9db0933 commit edf364c
Show file tree
Hide file tree
Showing 17 changed files with 1,245 additions and 340 deletions.
2 changes: 1 addition & 1 deletion cmake/external/box_ps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL))
SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE)
SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE)
#SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE)
SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.16" CACHE STRING "" FORCE)
SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.17" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}")
SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps")
Expand Down
51 changes: 25 additions & 26 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -1783,40 +1783,39 @@ class InputIndexDataFeed : public DataFeed {
template <class AR, class T>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
const SlotValues<T>& r) {
ar << r.slot_values;

uint16_t slot_num = (uint16_t)r.slot_offsets.size();
ar << slot_num;
if (slot_num > 0 && !r.slot_values.empty()) {
// remove first 0 and end data len
for (uint16_t i = 1; i < slot_num - 1; ++i) {
ar << r.slot_offsets[i];
uint16_t value_len = static_cast<uint16_t>(r.slot_values.size());
ar << value_len;
if (value_len > 0) {
ar.Write(&r.slot_values[0], value_len * sizeof(T));

uint16_t slot_num = static_cast<uint16_t>(r.slot_offsets.size());
ar << slot_num;
if (slot_num > 2) {
// remove first 0 and end data len
ar.Write(&r.slot_offsets[1], (slot_num - 2) * sizeof(uint32_t));
}
}
return ar;
}
template <class AR, class T>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
SlotValues<T>& r) {
ar >> r.slot_values;

uint16_t slot_num = 0;
ar >> slot_num;
if (slot_num > 0) {
size_t value_len = r.slot_values.size();
if (value_len > 0) {
r.slot_offsets.resize(slot_num);
// fill first 0
r.slot_offsets[0] = 0;
for (uint16_t i = 1; i < slot_num - 1; ++i) {
ar >> r.slot_offsets[i];
}
// fill end data len
r.slot_offsets[slot_num - 1] = value_len;
} else {
// empty values set zero
r.slot_offsets.assign(slot_num, 0);
uint16_t value_len = 0;
ar >> value_len;
if (value_len > 0) {
r.slot_values.resize(value_len);
ar.Read(&r.slot_values[0], value_len * sizeof(T));

uint16_t slot_num = 0;
ar >> slot_num;
r.slot_offsets.resize(slot_num);
// fill first 0
r.slot_offsets[0] = 0;
if (slot_num > 2) {
ar.Read(&r.slot_offsets[1], (slot_num - 2) * sizeof(uint32_t));
}
// fill end data len
r.slot_offsets[slot_num - 1] = value_len;
}
return ar;
}
Expand Down
14 changes: 12 additions & 2 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,8 @@ class PadBoxSlotDataConsumer : public boxps::DataConsumer {
CHECK_GE(_service_id, 0);
}
virtual ~PadBoxSlotDataConsumer() {
CHECK_GE(BoxWrapper::data_shuffle_->register_handler(this), 0);
// CHECK_GE(BoxWrapper::data_shuffle_->register_handler(this), 0);
BoxWrapper::data_shuffle_->unregister_consumer(_service_id);
}
virtual void on_receive(const int client_id, const char* buff, int len) {
_dataset->ReceiveSuffleData(client_id, buff, len);
Expand All @@ -1413,6 +1414,9 @@ class PadBoxSlotDataConsumer : public boxps::DataConsumer {
BoxWrapper::data_shuffle_->send_message_callback(client_id, buf, len,
callback);
}
void wait_message_done(void) {
BoxWrapper::data_shuffle_->wait_done(_service_id);
}

private:
PadBoxSlotDataset* _dataset;
Expand Down Expand Up @@ -1823,6 +1827,7 @@ void PadBoxSlotDataset::ShuffleData(int thread_num) {
<< ", span: " << span;
// only one thread send finish notify
if (--shuffle_counter_ == 0) {
timer.Start();
// send closed
wg.add(mpi_size_);
for (int i = 0; i < mpi_size_; ++i) {
Expand All @@ -1833,10 +1838,15 @@ void PadBoxSlotDataset::ShuffleData(int thread_num) {
handler->send_message_callback(i, NULL, 0, &wg);
}
wg.wait();
// wait message done
handler->wait_message_done();
timer.Pause();

// end shuffle thread
LOG(WARNING) << "passid = " << pass_id_
<< ", end shuffle span max:" << max_shuffle_span_
<< ", min:" << min_shuffle_span_;
<< ", min:" << min_shuffle_span_
<< ", wait:" << timer.ElapsedSec();
// local closed channel
if (--finished_counter_ == 0) {
while (receiver_cnt_ > 0) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "paddle/fluid/framework/threadpool.h"
DECLARE_int32(padbox_dataset_shuffle_thread_num);
DECLARE_int32(padbox_dataset_merge_thread_num);
DECLARE_int32(padbox_max_shuffle_wait_count);
namespace boxps {
class PSAgentBase;
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/details/nan_inf_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ void CheckVarHasNanOrInf(const std::string& op_type,
void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::Place& place);
bool CheckOpHasNanOrInfRet(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::Place& place);
} // namespace details
} // namespace framework
} // namespace paddle
64 changes: 64 additions & 0 deletions paddle/fluid/framework/details/nan_inf_utils_detail.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,70 @@ void CheckOpHasNanOrInf(const framework::OperatorBase& op,
}
}

bool CheckVarHasNanOrInfRet(const std::string& op_type,
const framework::Scope& scope,
const std::string& var_name,
const platform::Place& place) {
auto* var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("In op=%s, can't find var:%s", op_type,
var_name));
const Tensor* tensor{nullptr};
if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<framework::SelectedRows>()) {
tensor = &var->Get<framework::SelectedRows>().value();
} else {
return false;
}

if (tensor->memory_size() == 0) {
return false;
}
VLOG(10) << "begin check " << op_type << " var_name:" << var_name
<< ", place:" << tensor->place() << ", numel:" << tensor->numel();

if (!platform::is_gpu_place(tensor->place())) {
return false;
}
return CudaTensorCheckNanInf(op_type, var_name, *tensor);
}
bool CheckOpHasNanOrInfRet(const framework::OperatorBase& op,
const framework::Scope& exec_scope,
const platform::Place& place) {
std::call_once(white_list_init_flag, InitWhiteListFormEnv);

if (IsSkipOp(op)) return false;

if (op_var_nan_inf_white_list().count(op.Type()) == 0) {
// NOTE. vname may destruct in the end of this func.
for (auto& vname : op.OutputVars(true)) {
auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue;
if (CheckVarHasNanOrInfRet(op.Type(), exec_scope, vname, place)) {
return true;
}
}
} else {
for (auto& vname : op.OutputVars(true)) {
bool need_check = true;
for (auto& white_vname : op_var_nan_inf_white_list().at(op.Type())) {
if (vname.find(white_vname) != std::string::npos) {
need_check = false;
break;
}
}
if (!need_check) continue;
auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue;
if (CheckVarHasNanOrInfRet(op.Type(), exec_scope, vname, place)) {
return true;
}
}
}
return false;
}

} // namespace details
} // namespace framework
} // namespace paddle
77 changes: 77 additions & 0 deletions paddle/fluid/framework/details/nan_inf_utils_detail.cu
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,83 @@ void tensor_check<platform::CUDADeviceContext>(const std::string& op_type,
VisitDataType(tensor.type(), vistor);
}
template <typename T>
__global__ void CountNanInfNumKernel(const size_t len, const T* val,
unsigned int* nan_num,
unsigned int* inf_num) {
/* Per block accumulator */
__shared__ unsigned int block_nan, block_inf;
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.x == 0) {
block_nan = 0;
block_inf = 0;
}
__syncthreads();
if (i < len) {
unsigned int count = 0;
if (isnan(val[i])) {
count = atomicAdd(&block_nan, 1);
} else if (isinf(val[i])) {
count = atomicAdd(&block_inf, 1);
}
// for cuda, print in every block
if (count > 0) {
printf("numel:%lu idx:%lu value:%f\n", static_cast<uint64_t>(len),
static_cast<uint64_t>(i), static_cast<float>(val[i]));
}
}
__syncthreads();
if (threadIdx.x == 0) {
atomicAdd(nan_num, block_nan);
atomicAdd(inf_num, block_inf);
}
}
bool CudaTensorCheckNanInf(const std::string& op_type,
const std::string& var_name,
const framework::Tensor& tensor) {
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, tensor.place()).device;
auto stream = dev_ctx->stream();
auto gpu_tensor = paddle::memory::Alloc(*dev_ctx, sizeof(unsigned int) * 2);
unsigned int* num_ptr = reinterpret_cast<unsigned int*>(gpu_tensor->ptr());
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemsetAsync(num_ptr, 0, sizeof(unsigned int) * 2, stream));
size_t len = static_cast<size_t>(tensor.numel());
const size_t threads = 1024;
size_t blocks = std::min(static_cast<size_t>(128),
static_cast<size_t>((len + threads - 1) / threads));
if (tensor.type() == proto::VarType::FP32) {
CountNanInfNumKernel<<<blocks, threads, 0, dev_ctx->stream()>>>(
len, tensor.data<float>(), &num_ptr[0], &num_ptr[1]);
} else if (tensor.type() == proto::VarType::INT64) {
CountNanInfNumKernel<<<blocks, threads, 0, dev_ctx->stream()>>>(
len, tensor.data<int64_t>(), &num_ptr[0], &num_ptr[1]);
} else if (tensor.type() == proto::VarType::FP64) {
CountNanInfNumKernel<<<blocks, threads, 0, dev_ctx->stream()>>>(
len, tensor.data<double>(), &num_ptr[0], &num_ptr[1]);
} else {
return false;
}
unsigned int nan_inf_num[2] = {0};
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(&nan_inf_num[0], num_ptr,
sizeof(unsigned int) * 2,
cudaMemcpyDeviceToHost, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
if (nan_inf_num[0] > 0 || nan_inf_num[1] > 0) {
printf("device [%d], op %s, name: %s, there has %u,%u,%u nan,inf,num\n",
dev_id, op_type.c_str(), var_name.c_str(), nan_inf_num[0],
nan_inf_num[1], len);
return true;
}
return false;
}
} // namespace details
} // namespace framework
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/nan_inf_utils_detail.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ template <typename DeviceContext>
void tensor_check(const std::string& op_type, const std::string& var_name,
const framework::Tensor& tensor,
const platform::Place& place);

bool CudaTensorCheckNanInf(const std::string& op_type,
const std::string& var_name,
const framework::Tensor& tensor);
} // namespace details
} // namespace framework
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/framework/fleet/box_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place,

CheckEmbedSizeIsValid(hidden_size - cvm_offset_, expand_embed_dim);
switch (embedx_dim_) {
EMBEDX_CASE(0, PULLSPARSE_CASE(0););
EMBEDX_CASE(8, PULLSPARSE_CASE(0); PULLSPARSE_CASE(1); PULLSPARSE_CASE(2);
PULLSPARSE_CASE(3); PULLSPARSE_CASE(4); PULLSPARSE_CASE(5);
PULLSPARSE_CASE(6); PULLSPARSE_CASE(7); PULLSPARSE_CASE(8);
Expand Down Expand Up @@ -500,6 +501,7 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,

CheckEmbedSizeIsValid(hidden_size - cvm_offset_, expand_embed_dim);
switch (embedx_dim_) {
EMBEDX_CASE(0, PUSHSPARSE_CASE(0););
EMBEDX_CASE(8, PUSHSPARSE_CASE(0); PUSHSPARSE_CASE(1); PUSHSPARSE_CASE(2);
PUSHSPARSE_CASE(3); PUSHSPARSE_CASE(4); PUSHSPARSE_CASE(5);
PUSHSPARSE_CASE(6); PUSHSPARSE_CASE(7); PUSHSPARSE_CASE(8);
Expand Down
Loading

0 comments on commit edf364c

Please sign in to comment.