Skip to content

Commit

Permalink
add SyncBatchNorm,test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 committed Aug 7, 2020
2 parents 93f3510 + 3eee046 commit d09a0d6
Show file tree
Hide file tree
Showing 72 changed files with 5,319 additions and 450 deletions.
2 changes: 1 addition & 1 deletion cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op")
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
Expand Down
13 changes: 6 additions & 7 deletions paddle/fluid/framework/fleet/gloo_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
paddle::framework::fs_remove(tmp);
if (i == retry_times_) {
VLOG(0) << "fs_open_write failed, retry times reaches limit";
// PADDLE_THROW(platform::errors::PreconditionNotMet(
// "fs_open_write failed, retry times reaches"
// " limit ",
// retry_times_));
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"fs_open_write failed, retry times reaches %d limit.",
retry_times_));
}
} else {
break;
Expand Down Expand Up @@ -143,9 +142,9 @@ void HdfsStore::wait(const std::vector<std::string>& keys,
break;
}
}
// PADDLE_THROW(platform::errors::ExecutionTimeout(
VLOG(0) << "TIMEOUT self_rank = " << self_rank_
<< " pair_rank = " << last_check_rank;
PADDLE_THROW(paddle::platform::errors::ExecutionTimeout(
"TIMEOUT self_rank = %d pair_rank = %d", self_rank_,
last_check_rank));
}
std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_));
}
Expand Down
69 changes: 36 additions & 33 deletions paddle/fluid/framework/unused_var_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,6 @@ DEFINE_bool(enable_unused_var_check, false,
"Checking whether operator contains unused inputs, "
"especially for grad operator. It should be in unittest.");

// NOTE(zhiqiu): Currently, there are some operators which involves unused
// inputs and cannot be removed from the allow_list below.
// They can be mainly divided into four categories:
// 0: the inputs of which are only used in if branch, or used in cuda kernel but
// not in cpu kernel;
// 1: the inputs of which are used to indicate dtype of outputs;
// 2: the inputs of which are used in fused operators.
// The category number is presented in the comments after each operator.

const std::unordered_set<std::string> op_with_unsed_vars_allow_list = {
"batch_norm", // 0
"batch_norm_grad", // 0
"sync_batch_norm", // 0
"sync_batch_norm_grad", // 0
"inplace_abn", // 0
"inplace_abn_grad", // 0
"dgc_momentum", // 0
"fake_quantize_range_abs_max", // 0
"rmsprop", // 0
"sequence_conv_grad", // 0
"roi_perspective_transform_grad", // 0
"fill_zeros_like", // 1
"fill_any_like", // 1
"nce_grad", // 1
"precision_recall", // 1
"fusion_seqpool_cvm_concat", // 2
"fused_batch_norm_act", // 2
"fused_batch_norm_act_grad", // 2
"data_norm", // 0
"data_norm_grad", // 0
};

namespace paddle {
namespace framework {

Expand All @@ -75,9 +43,44 @@ void LogVarUsageIfUnusedVarCheckEnabled(const std::string &name) {
}
}

static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() {
// NOTE(zhiqiu): Currently, there are some operators which involves unused
// inputs and cannot be removed from the allow_list below.
// They can be mainly divided into four categories:
// 0: the inputs of which are only used in if branch, or used in cuda kernel
// but not in cpu kernel; 1: the inputs of which are used to indicate dtype of
// outputs; 2: the inputs of which are used in fused operators. The category
// number is presented in the comments after each operator.
// Use pointer here for safe static deinitialization
static auto *allow_set = new std::unordered_set<std::string>({
// called once
"batch_norm", // 0
"batch_norm_grad", // 0
"sync_batch_norm", // 0
"sync_batch_norm_grad", // 0
"inplace_abn", // 0
"inplace_abn_grad", // 0
"dgc_momentum", // 0
"fake_quantize_range_abs_max", // 0
"rmsprop", // 0
"sequence_conv_grad", // 0
"roi_perspective_transform_grad", // 0
"fill_zeros_like", // 1
"fill_any_like", // 1
"nce_grad", // 1
"precision_recall", // 1
"fusion_seqpool_cvm_concat", // 2
"fused_batch_norm_act", // 2
"fused_batch_norm_act_grad", // 2
"data_norm", // 0
"data_norm_grad", // 0);
});
return *allow_set;
}

void CheckUnusedVar(const OperatorBase &op, const Scope &scope) {
// skip op in allow list.
if (op_with_unsed_vars_allow_list.count(op.Type()) != 0) {
if (GetOpWithUnusedVarAllowSet().count(op.Type()) != 0) {
return;
}
auto *used_set = GetThreadLocalUsedVarNameSet();
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/inference/tensorrt/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
}
static nvinfer1::IPluginRegistry* getPluginRegistry() {
#if IS_TRT_VERSION_GE(6000)
static nvinfer1::IPluginRegistry* GetPluginRegistry() {
return static_cast<nvinfer1::IPluginRegistry*>(dy::getPluginRegistry());
}
#endif

// A logger for create TensorRT infer builder.
class NaiveLogger : public nvinfer1::ILogger {
Expand Down
10 changes: 8 additions & 2 deletions paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,16 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
std::string name_space_;
std::string plugin_base_;
};
#endif

template <typename T>
class TrtPluginRegistrarV2 {
public:
TrtPluginRegistrarV2() { getPluginRegistry()->registerCreator(creator, ""); }
TrtPluginRegistrarV2() {
static auto func_ptr = GetPluginRegistry();
if (func_ptr != nullptr) {
func_ptr->registerCreator(creator, "");
}
}

private:
T creator;
Expand All @@ -193,6 +197,8 @@ class TrtPluginRegistrarV2 {
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {}

#endif

} // namespace plugin
} // namespace tensorrt
} // namespace inference
Expand Down
183 changes: 183 additions & 0 deletions paddle/fluid/operators/flatten_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,156 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
}
};

class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FlattenContiguousRange");
const auto &start_axis = ctx->Attrs().Get<int>("start_axis");
const auto &stop_axis = ctx->Attrs().Get<int>("stop_axis");
const auto &in_dims = ctx->GetInputDim("X");
int in_dims_size = in_dims.size();
int real_start_axis = start_axis, real_stop_axis = stop_axis;
if (start_axis < 0) {
real_start_axis = start_axis + in_dims_size;
}
if (stop_axis < 0) {
real_stop_axis = stop_axis + in_dims_size;
}
PADDLE_ENFORCE_GE(
real_stop_axis, real_start_axis,
platform::errors::InvalidArgument("The stop_axis should be greater"
"than or equal to start_axis."));

const auto &out_dims =
GetOutputShape(real_start_axis, real_stop_axis, in_dims);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
if (in_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}

OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
xshape_dims[i + 1] = in_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", "XShape");
}

static std::vector<int32_t> GetOutputShape(const int start_axis,
const int stop_axis,
const framework::DDim &in_dims) {
int64_t outer = 1;
std::vector<int32_t> out_shape;
int in_dims_size = in_dims.size();
out_shape.reserve(in_dims_size - stop_axis + start_axis);

for (int i = 0; i < start_axis; ++i) {
out_shape.push_back(in_dims[i]);
}
for (int i = start_axis; i <= stop_axis; i++) {
outer *= in_dims[i];
}
out_shape.push_back(outer);
for (int i = stop_axis + 1; i < in_dims_size; i++) {
out_shape.push_back(in_dims[i]);
}

return out_shape;
}
};

class FlattenContiguousRangeOpMaker : public FlattenOpMaker {
public:
void Make() override {
AddInput("X", "(Tensor) A tensor of rank >= axis.");
AddOutput("Out",
"A 2D tensor is reshaped input tensor. The input dimensions"
"up to axis are flattened to the outer dimension of the output"
"and the remaining input dimensions are flattened into the inner"
"dimension of the output.");
AddAttr<int>("start_axis",
"(int)"
"Indicate the input start dimension (exclusive) to flatten")
.SetDefault(1);
AddAttr<int>("stop_axis",
"(int)"
"Indicate the input stop dimension (exclusive) to flatten")
.SetDefault(1);
AddComment(R"DOC(
Flatten Operator
Flattens the input tensor into a new matrix according to start_axis and stop_axis.
Examples:
Case 1:
Given
X.shape = (3, 100, 100, 4)
and
start_axis = 2, stop_axis = -1
We get:
Out.shape = (3, 100, 400)
Case 2:
Given
X.shape = (3, 100, 100, 4)
and
start_axis = 0, stop_axis = -1
We get:
Out.shape = (3 * 100 * 100 * 4)
)DOC");
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp.")
.AsIntermediate();
}
};

template <typename T>
class FlattenContiguousRangeGradOpMaker
: public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("flatten_contiguous_range_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};

class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("XShape"), "Input", "XShape",
"FlattenContiguousRangeGrad");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "FlattenContiguousRangeGrad");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer,
{framework::GradVarName("Out"),
Expand All @@ -266,6 +416,16 @@ REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
ops::FlattenGradInplaceInferer);

REGISTER_OPERATOR(
flatten_contiguous_range, ops::FlattenContiguousRangeOp,
ops::FlattenContiguousRangeOpMaker,
ops::FlattenContiguousRangeGradOpMaker<paddle::framework::OpDesc>,
ops::FlattenContiguousRangeGradOpMaker<paddle::imperative::OpBase>,
ops::FlattenOpInplaceInferer);
REGISTER_OPERATOR(flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradOp,
ops::FlattenGradInplaceInferer);

REGISTER_OP_CPU_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CPUDeviceContext, float>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, double>,
Expand All @@ -292,3 +452,26 @@ REGISTER_OP_CPU_KERNEL(
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
23 changes: 23 additions & 0 deletions paddle/fluid/operators/flatten_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,26 @@ REGISTER_OP_CUDA_KERNEL(
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
Loading

0 comments on commit d09a0d6

Please sign in to comment.