diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 5b03cbf8c7f844..ecf2dbc81762a5 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -118,7 +118,7 @@ function(op_library TARGET) "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" -"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op") +"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/framework/fleet/gloo_wrapper.cc b/paddle/fluid/framework/fleet/gloo_wrapper.cc index 49181cd05f3fac..bb958f1ac015bf 100644 --- a/paddle/fluid/framework/fleet/gloo_wrapper.cc +++ b/paddle/fluid/framework/fleet/gloo_wrapper.cc @@ -54,10 +54,9 @@ void HdfsStore::set(const std::string& key, const std::vector& data) { paddle::framework::fs_remove(tmp); if (i == retry_times_) { VLOG(0) << "fs_open_write failed, retry times reaches limit"; - // PADDLE_THROW(platform::errors::PreconditionNotMet( - // "fs_open_write failed, retry times reaches" - // " limit ", - // retry_times_)); + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "fs_open_write failed, retry times reaches %d limit.", + retry_times_)); } } else { break; @@ -143,9 +142,9 @@ void HdfsStore::wait(const std::vector& keys, break; } } - // PADDLE_THROW(platform::errors::ExecutionTimeout( - VLOG(0) << "TIMEOUT self_rank = " << self_rank_ - << " pair_rank = " << last_check_rank; + PADDLE_THROW(paddle::platform::errors::ExecutionTimeout( + "TIMEOUT self_rank = %d pair_rank = %d", self_rank_, + last_check_rank)); } std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_)); } diff --git a/paddle/fluid/framework/unused_var_check.cc b/paddle/fluid/framework/unused_var_check.cc index eee100bc81c337..e7e964b4181859 100644 --- a/paddle/fluid/framework/unused_var_check.cc +++ b/paddle/fluid/framework/unused_var_check.cc @@ -28,38 +28,6 @@ DEFINE_bool(enable_unused_var_check, false, "Checking whether operator contains unused inputs, " "especially for grad operator. It should be in unittest."); -// NOTE(zhiqiu): Currently, there are some operators which involves unused -// inputs and cannot be removed from the allow_list below. -// They can be mainly divided into four categories: -// 0: the inputs of which are only used in if branch, or used in cuda kernel but -// not in cpu kernel; -// 1: the inputs of which are used to indicate dtype of outputs; -// 2: the inputs of which are used in fused operators. -// The category number is presented in the comments after each operator. - -const std::unordered_set op_with_unsed_vars_allow_list = { - "batch_norm", // 0 - "batch_norm_grad", // 0 - "sync_batch_norm", // 0 - "sync_batch_norm_grad", // 0 - "inplace_abn", // 0 - "inplace_abn_grad", // 0 - "dgc_momentum", // 0 - "fake_quantize_range_abs_max", // 0 - "rmsprop", // 0 - "sequence_conv_grad", // 0 - "roi_perspective_transform_grad", // 0 - "fill_zeros_like", // 1 - "fill_any_like", // 1 - "nce_grad", // 1 - "precision_recall", // 1 - "fusion_seqpool_cvm_concat", // 2 - "fused_batch_norm_act", // 2 - "fused_batch_norm_act_grad", // 2 - "data_norm", // 0 - "data_norm_grad", // 0 -}; - namespace paddle { namespace framework { @@ -75,9 +43,44 @@ void LogVarUsageIfUnusedVarCheckEnabled(const std::string &name) { } } +static const std::unordered_set &GetOpWithUnusedVarAllowSet() { + // NOTE(zhiqiu): Currently, there are some operators which involves unused + // inputs and cannot be removed from the allow_list below. + // They can be mainly divided into four categories: + // 0: the inputs of which are only used in if branch, or used in cuda kernel + // but not in cpu kernel; 1: the inputs of which are used to indicate dtype of + // outputs; 2: the inputs of which are used in fused operators. The category + // number is presented in the comments after each operator. + // Use pointer here for safe static deinitialization + static auto *allow_set = new std::unordered_set({ + // called once + "batch_norm", // 0 + "batch_norm_grad", // 0 + "sync_batch_norm", // 0 + "sync_batch_norm_grad", // 0 + "inplace_abn", // 0 + "inplace_abn_grad", // 0 + "dgc_momentum", // 0 + "fake_quantize_range_abs_max", // 0 + "rmsprop", // 0 + "sequence_conv_grad", // 0 + "roi_perspective_transform_grad", // 0 + "fill_zeros_like", // 1 + "fill_any_like", // 1 + "nce_grad", // 1 + "precision_recall", // 1 + "fusion_seqpool_cvm_concat", // 2 + "fused_batch_norm_act", // 2 + "fused_batch_norm_act_grad", // 2 + "data_norm", // 0 + "data_norm_grad", // 0); + }); + return *allow_set; +} + void CheckUnusedVar(const OperatorBase &op, const Scope &scope) { // skip op in allow list. - if (op_with_unsed_vars_allow_list.count(op.Type()) != 0) { + if (GetOpWithUnusedVarAllowSet().count(op.Type()) != 0) { return; } auto *used_set = GetThreadLocalUsedVarNameSet(); diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index 55a57caf9a0d6e..971f99e6919722 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) { return static_cast( dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION)); } -static nvinfer1::IPluginRegistry* getPluginRegistry() { +#if IS_TRT_VERSION_GE(6000) +static nvinfer1::IPluginRegistry* GetPluginRegistry() { return static_cast(dy::getPluginRegistry()); } +#endif // A logger for create TensorRT infer builder. class NaiveLogger : public nvinfer1::ILogger { diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index f4424b8b7851fb..528adacb27c989 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -178,12 +178,16 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt { std::string name_space_; std::string plugin_base_; }; -#endif template class TrtPluginRegistrarV2 { public: - TrtPluginRegistrarV2() { getPluginRegistry()->registerCreator(creator, ""); } + TrtPluginRegistrarV2() { + static auto func_ptr = GetPluginRegistry(); + if (func_ptr != nullptr) { + func_ptr->registerCreator(creator, ""); + } + } private: T creator; @@ -193,6 +197,8 @@ class TrtPluginRegistrarV2 { static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2 \ plugin_registrar_##name {} +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index aff5f49c73bdce..d23beea7e4e62e 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -241,6 +241,156 @@ class Flatten2GradOp : public framework::OperatorWithKernel { } }; +class FlattenContiguousRangeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "FlattenContiguousRange"); + const auto &start_axis = ctx->Attrs().Get("start_axis"); + const auto &stop_axis = ctx->Attrs().Get("stop_axis"); + const auto &in_dims = ctx->GetInputDim("X"); + int in_dims_size = in_dims.size(); + int real_start_axis = start_axis, real_stop_axis = stop_axis; + if (start_axis < 0) { + real_start_axis = start_axis + in_dims_size; + } + if (stop_axis < 0) { + real_stop_axis = stop_axis + in_dims_size; + } + PADDLE_ENFORCE_GE( + real_stop_axis, real_start_axis, + platform::errors::InvalidArgument("The stop_axis should be greater" + "than or equal to start_axis.")); + + const auto &out_dims = + GetOutputShape(real_start_axis, real_stop_axis, in_dims); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + if (in_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + + OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2"); + std::vector xshape_dims(in_dims.size() + 1); + xshape_dims[0] = 0; + for (int i = 0; i < in_dims.size(); ++i) { + xshape_dims[i + 1] = in_dims[i]; + } + ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); + ctx->ShareLoD("X", "XShape"); + } + + static std::vector GetOutputShape(const int start_axis, + const int stop_axis, + const framework::DDim &in_dims) { + int64_t outer = 1; + std::vector out_shape; + int in_dims_size = in_dims.size(); + out_shape.reserve(in_dims_size - stop_axis + start_axis); + + for (int i = 0; i < start_axis; ++i) { + out_shape.push_back(in_dims[i]); + } + for (int i = start_axis; i <= stop_axis; i++) { + outer *= in_dims[i]; + } + out_shape.push_back(outer); + for (int i = stop_axis + 1; i < in_dims_size; i++) { + out_shape.push_back(in_dims[i]); + } + + return out_shape; + } +}; + +class FlattenContiguousRangeOpMaker : public FlattenOpMaker { + public: + void Make() override { + AddInput("X", "(Tensor) A tensor of rank >= axis."); + AddOutput("Out", + "A 2D tensor is reshaped input tensor. The input dimensions" + "up to axis are flattened to the outer dimension of the output" + "and the remaining input dimensions are flattened into the inner" + "dimension of the output."); + AddAttr("start_axis", + "(int)" + "Indicate the input start dimension (exclusive) to flatten") + .SetDefault(1); + AddAttr("stop_axis", + "(int)" + "Indicate the input stop dimension (exclusive) to flatten") + .SetDefault(1); + AddComment(R"DOC( +Flatten Operator + +Flattens the input tensor into a new matrix according to start_axis and stop_axis. + +Examples: +Case 1: + Given + X.shape = (3, 100, 100, 4) + and + start_axis = 2, stop_axis = -1 + We get: + Out.shape = (3, 100, 400) + +Case 2: + Given + X.shape = (3, 100, 100, 4) + and + start_axis = 0, stop_axis = -1 + We get: + Out.shape = (3 * 100 * 100 * 4) +)DOC"); + AddOutput("XShape", + "XShape is just used to store the shape and lod of X, which will " + "be used in FlattenGradOp.") + .AsIntermediate(); + } +}; + +template +class FlattenContiguousRangeGradOpMaker + : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("flatten_contiguous_range_grad"); + grad_op->SetInput("XShape", this->Output("XShape")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + OP_INOUT_CHECK(context->HasInput("XShape"), "Input", "XShape", + "FlattenContiguousRangeGrad"); + OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "FlattenContiguousRangeGrad"); + auto xshape_dims = context->GetInputDim("XShape"); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + context->SetOutputDim(framework::GradVarName("X"), x_dims); + context->ShareLoD("XShape", framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer, {framework::GradVarName("Out"), @@ -266,6 +416,16 @@ REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker, REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp, ops::FlattenGradInplaceInferer); +REGISTER_OPERATOR( + flatten_contiguous_range, ops::FlattenContiguousRangeOp, + ops::FlattenContiguousRangeOpMaker, + ops::FlattenContiguousRangeGradOpMaker, + ops::FlattenContiguousRangeGradOpMaker, + ops::FlattenOpInplaceInferer); +REGISTER_OPERATOR(flatten_contiguous_range_grad, + ops::FlattenContiguousRangeGradOp, + ops::FlattenGradInplaceInferer); + REGISTER_OP_CPU_KERNEL( flatten, ops::FlattenKernel, ops::FlattenKernel, @@ -292,3 +452,26 @@ REGISTER_OP_CPU_KERNEL( ops::Flatten2GradKernel, ops::Flatten2GradKernel, ops::Flatten2GradKernel); +REGISTER_OP_CPU_KERNEL( + flatten_contiguous_range, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel); +REGISTER_OP_CPU_KERNEL( + flatten_contiguous_range_grad, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel); diff --git a/paddle/fluid/operators/flatten_op.cu.cc b/paddle/fluid/operators/flatten_op.cu.cc index ac4ad8e2dc1c09..40fda804eaab9d 100644 --- a/paddle/fluid/operators/flatten_op.cu.cc +++ b/paddle/fluid/operators/flatten_op.cu.cc @@ -42,3 +42,26 @@ REGISTER_OP_CUDA_KERNEL( ops::Flatten2GradKernel, ops::Flatten2GradKernel, ops::Flatten2GradKernel); +REGISTER_OP_CUDA_KERNEL( + flatten_contiguous_range, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel); +REGISTER_OP_CUDA_KERNEL( + flatten_contiguous_range_grad, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel); diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index 165832c0e68bde..08efaedccd4f40 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -112,5 +112,73 @@ class Flatten2GradKernel : public framework::OpKernel { } }; +template +class FlattenContiguousRangeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto &start_axis = context.Attr("start_axis"); + auto &stop_axis = context.Attr("stop_axis"); + + auto *in = context.Input("X"); + auto x_dims = in->dims(); + int in_dims_size = x_dims.size(); + int real_start_axis = start_axis, real_stop_axis = stop_axis; + if (start_axis < 0) { + real_start_axis = start_axis + in_dims_size; + } + if (stop_axis < 0) { + real_stop_axis = stop_axis + in_dims_size; + } + auto *out = context.Output("Out"); + + auto out_dims = framework::make_ddim( + GetOutputShape(real_start_axis, real_stop_axis, x_dims)); + + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } + static std::vector GetOutputShape(const int start_axis, + const int stop_axis, + const framework::DDim &in_dims) { + int64_t outer = 1; + std::vector out_shape; + int in_dims_size = in_dims.size(); + out_shape.reserve(in_dims_size - stop_axis + start_axis); + + for (int i = 0; i < start_axis; ++i) { + out_shape.push_back(in_dims[i]); + } + for (int i = start_axis; i <= stop_axis; i++) { + outer *= in_dims[i]; + } + out_shape.push_back(outer); + for (int i = stop_axis + 1; i < in_dims_size; i++) { + out_shape.push_back(in_dims[i]); + } + + return out_shape; + } +}; + +template +class FlattenContiguousRangeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + d_x->Resize(x_dims); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 24f656140f42df..3fc5f3bfc6b163 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -7,7 +7,12 @@ register_operators(EXCLUDES fused_fc_elementwise_layernorm_op multihead_matmul_op fused_embedding_eltwise_layernorm_op - fusion_group_op) + fusion_group_op + fusion_gru_op) + +# fusion_gru_op does not have CUDA kernel +op_library(fusion_gru_op) +file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n") if (WITH_GPU) # fused_bn_activation_op needs cudnn 7.4.1 above diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index f6c8316e2e9fa0..d0920098f606e4 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -19,6 +19,9 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -122,8 +125,17 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, + library); } void FusionGRUOpMaker::Make() { @@ -187,6 +199,9 @@ void FusionGRUOpMaker::Make() { "bool" "use origin mode in article https://arxiv.org/abs/1412.3555") .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( The Fusion complete GRU Operator. This operator fuse the fully-connected operator into GRU, diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc new file mode 100644 index 00000000000000..3940aae53b8ef7 --- /dev/null +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -0,0 +1,439 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_gru_op.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + +template +class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { + public: + GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, + const platform::MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine mkldnn_engine, + platform::Place cpu_place, const LoDTensor* input, + const Tensor* weight_h, const Tensor* h0, + const bool is_reverse, const int64_t N, const int64_t Ti, + const int64_t IC, const int64_t OC, + const std::string& unique_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(unique_name, Ti)), + N(N), + Ti(Ti), + IC(IC), + OC(OC) { + // Create memory key without Ti because weights, bias and h0 memories + // do not depend on Ti size but primitive and input/output memory do + if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { + memory_key_ = unique_name; + } else { + memory_key_ = unique_name + "-t:" + platform::ThreadIDasStr(); + } + + if (!this->isCached()) { + // oneDNN kernel has hardcoded activation functions + PADDLE_ENFORCE_EQ( + ctx.Attr("gate_activation"), "sigmoid", + platform::errors::Unimplemented( + "oneDNN fusion_gru supports only sigmoid as a gate activation.")); + PADDLE_ENFORCE_EQ( + ctx.Attr("activation"), "tanh", + platform::errors::Unimplemented( + "oneDNN fusion_gru supports only tanh as an activation.")); + + // oneDNN RNN dimensions + const int64_t D = 1; // Directions + const int64_t L = 1; // Layers (PP supports only 1 stacked layer) + const int64_t G = 3; // Number of Gates, 3 for GRU + + // Create memory descriptors + auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::any); + auto weight_x_md = MKLDNNMemDesc( + {L, D, IC, G, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); + auto weight_h_md = MKLDNNMemDesc( + {L, D, OC, G, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); + auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldgo); + auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::any); + auto h0_md = dnnl::memory::desc(); + if (h0) { + h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc); + } + + // Create GRU oneDNN primitive + const auto direction = + is_reverse ? dnnl::rnn_direction::unidirectional_right2left + : dnnl::rnn_direction::unidirectional_left2right; + + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_inference, direction, input_md, h0_md, + weight_x_md, weight_h_md, bias_md, hidden_md, dnnl::memory::desc()); + } + } + + bool is_NTC() { + return (platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc()) == + dnnl::memory::format_tag::ntc); + } + + void reorderRNNdata(const T* input_data, T* output_data, + std::vector lod, const bool is_reverse, + platform::RNNReorderType reorder_type) { + switch (reorder_type) { + // Reorder input memory [WORDS, C] + LoD -> [N, T, C] + case platform::RNNReorderType::PP_NTC: { + auto* input_data_iter = input_data; + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * IC; + const auto offset = is_reverse ? (Ti * IC - num_elements) : 0; + memcpy(output_data + n * Ti * IC + offset, input_data_iter, + sizeof(T) * num_elements); + input_data_iter += num_elements; + } + } break; + // Reorder input memory [WORDS, C] + LoD -> [T, N, C] + case platform::RNNReorderType::PP_TNC: { + auto* input_data_iter = input_data; + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]); + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data + (t + offset) * N * IC + n * IC, + input_data_iter, sizeof(T) * IC); + input_data_iter += IC; + } + } + } break; + // Reorder output values to PP format [N, T, C] -> [WORDS, C] + case platform::RNNReorderType::NTC_PP: { + auto* output_data_iter = output_data; + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * OC; + const auto offset = is_reverse ? (Ti * OC - num_elements) : 0; + memcpy(output_data_iter, input_data + n * Ti * OC + offset, + sizeof(T) * num_elements); + output_data_iter += num_elements; + } + } break; + // Reorder output values to PP format [T, N, C] -> [WORDS, C] + case platform::RNNReorderType::TNC_PP: { + auto* output_data_iter = output_data; + for (int n = 0; n < N; ++n) { + const auto num_elements = lod[n + 1] - lod[n]; + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data_iter, + input_data + (t + offset) * N * OC + n * OC, sizeof(T) * OC); + output_data_iter += OC; + } + } + } break; + } + } + + std::shared_ptr AcquireInputMemoryWithReorder( + const LoDTensor* input, const bool is_reverse) { + const auto name = this->key_ + "@input_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->src_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + + const auto& input_lod = input->lod()[0]; + auto* x_data = input->data(); + + auto* x_onednn_data = reinterpret_cast(memory_p->get_data_handle()); + memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); + + if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) == + dnnl::memory::format_tag::ntc) { + reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse, + platform::RNNReorderType::PP_NTC); + } else { + reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse, + platform::RNNReorderType::PP_TNC); + } + return memory_p; + } + + std::shared_ptr AcquireOutputMemory() { + const auto name = this->key_ + "@output_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->dst_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireH0Memory(const Tensor* h0) { + const std::string h0_key = memory_key_ + "@h0"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(h0_key)); + + auto* h0_data = to_void_cast(h0->data()); + + if (!memory_p) { + memory_p = std::make_shared( + this->fwd_pd_->weights_layer_desc(), this->engine_, h0_data); + this->dev_ctx_.SetBlob(h0_key, memory_p); + } else { + memory_p->set_data_handle(h0_data); + } + return memory_p; + } + + std::shared_ptr AcquireWeightXMemory(const Tensor* weight_x, + const bool origin_mode) { + const std::string wx_key = memory_key_ + "@weight_x"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); + + if (!memory_p) { + auto user_md = + MKLDNNMemDesc({1, 1, IC, 3, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + auto* weight_x_data = + reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_x_data, weight_x->data(), + sizeof(float) * IC * 3 * OC); + + if (origin_mode == false) { + for (int64_t i = 0; i < IC; ++i) { + for (int64_t j = 0; j < OC; ++j) { + weight_x_data[j] *= -1; + } + weight_x_data += 3 * OC; + } + } + + memory_p = std::make_shared( + this->fwd_pd_->weights_layer_desc(), this->engine_); + + dnnl::stream astream(this->engine_); + dnnl::reorder(user_memory, *memory_p) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wx_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireWeightHMemory(const Tensor* weight_h, + const bool origin_mode) { + const std::string wh_key = memory_key_ + "@weight_h"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); + + if (!memory_p) { + auto user_md = + MKLDNNMemDesc({1, 1, OC, 3, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + // Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to + // oneDNN format [OC, 3OC] + auto* weight_h_data = + reinterpret_cast(user_memory.get_data_handle()); + auto* user_weight_h_data = weight_h->data(); + + auto src1_iter = user_weight_h_data; + auto src2_iter = user_weight_h_data + 2 * OC * OC; + + for (int64_t c = 0; c < OC; ++c) { + memcpy(weight_h_data, src1_iter, 2 * OC * sizeof(float)); + memcpy(weight_h_data + 2 * OC, src2_iter, OC * sizeof(float)); + + src1_iter += 2 * OC; + src2_iter += OC; + weight_h_data += 3 * OC; + } + + weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + + if (origin_mode == false) { + for (int64_t i = 0; i < OC; ++i) { + for (int64_t j = 0; j < OC; ++j) { + weight_h_data[j] *= -1; + } + weight_h_data += 3 * OC; + } + } + + memory_p = std::make_shared( + this->fwd_pd_->weights_iter_desc(), this->engine_); + + dnnl::stream astream(this->engine_); + dnnl::reorder(user_memory, *memory_p) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wh_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireBiasMemory(const Tensor* bias, + const bool origin_mode) { + const std::string bias_key = memory_key_ + "@bias"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(bias_key)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->bias_desc(), + this->engine_); + auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); + if (bias) { + const float* user_bias_data = + bias->data(); // Bias in oneDNN is always float + memcpy(bias_data, user_bias_data, sizeof(float) * 3 * OC); + } else { + // oneDNN always need bias memory, if it's not provided in PP, let + // oneDNN allocate memory and set it to 0 + memset(bias_data, 0, sizeof(float) * 3 * OC); + } + + if (origin_mode == false && bias) { + for (int64_t i = 0; i < OC; ++i) { + bias_data[i] *= -1; + } + } + this->dev_ctx_.SetBlob(bias_key, memory_p); + } + return memory_p; + } + + private: + // RNN dimensions + // N - Batch Size + // Ti - Max sentence length + // IC - Input Channels + // OC - Output Channels + const int64_t N, Ti, IC, OC; + + // Memory size of weights, bias and h0 does not depend + // on Ti size, thus we need another key to cache them + std::string memory_key_; +}; + +template +class FusionGRUMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + // Get Tensors + const auto* input = ctx.Input("X"); + const auto* h0 = ctx.Input("H0"); + const auto* weight_x = ctx.Input("WeightX"); + const auto* weight_h = ctx.Input("WeightH"); + const auto* bias = ctx.Input("Bias"); + auto* hidden = ctx.Output("Hidden"); + + // Get attributes + const bool is_reverse = ctx.Attr("is_reverse"); + const bool origin_mode = ctx.Attr("origin_mode"); + + // Get tensor dimensions + const auto x_dims = framework::vectorize(input->dims()); + const auto weight_h_dims = framework::vectorize(weight_h->dims()); + const auto& input_lod = input->lod()[0]; + + // Calculate RNN dimensions + const int64_t N = input_lod.size() - 1; // Number of sentences (batches) + const int64_t Ti = // Max length of the sentence in a batch + [&input_lod]() { + size_t res = 0; + for (size_t i = 0; i < (input_lod.size() - 1); ++i) { + res = std::max(res, input_lod[i + 1] - input_lod[i]); + } + return res; + }(); + const int64_t IC = x_dims[1]; // Input channels + const int64_t OC = weight_h_dims[0]; // Output channels + + GRUMKLDNNHandler handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), + input, weight_h, h0, is_reverse, N, Ti, IC, OC, + ctx.InputName("X") + ctx.InputName("WeightH")); + + auto input_memory_p = + handler.AcquireInputMemoryWithReorder(input, is_reverse); + auto weight_x_memory_p = + handler.AcquireWeightXMemory(weight_x, origin_mode); + auto weight_h_memory_p = + handler.AcquireWeightHMemory(weight_h, origin_mode); + auto bias_memory_p = handler.AcquireBiasMemory(bias, origin_mode); + auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); + + std::unordered_map gru_args = { + {DNNL_ARG_SRC_LAYER, *input_memory_p}, + {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, + {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, + {DNNL_ARG_BIAS, *bias_memory_p}, + {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; + + if (h0) { + auto h0_memory_p = handler.AcquireH0Memory(h0); + gru_args.insert({DNNL_ARG_SRC_ITER, *h0_memory_p}); + } + + auto gru_forward_p = handler.AcquireForwardPrimitive(); + + dnnl::stream astream(mkldnn_engine); + gru_forward_p->execute(astream, gru_args); + astream.wait(); + + auto* hidden_onednn_data = + reinterpret_cast(hidden_onednn_memory_p->get_data_handle()); + auto* hidden_data = hidden->mutable_data(ctx.GetPlace()); + if (handler.is_NTC()) { + handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod, + is_reverse, platform::RNNReorderType::NTC_PP); + } else { + handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod, + is_reverse, platform::RNNReorderType::TNC_PP); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace, + ops::FusionGRUMKLDNNKernel); diff --git a/paddle/fluid/platform/dynload/tensorrt.h b/paddle/fluid/platform/dynload/tensorrt.h index 60e299385d6a64..67a79ce4bb1594 100644 --- a/paddle/fluid/platform/dynload/tensorrt.h +++ b/paddle/fluid/platform/dynload/tensorrt.h @@ -36,26 +36,29 @@ extern void* tensorrt_dso_handle; struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using tensorrt_func = decltype(&::__name); \ std::call_once(tensorrt_dso_flag, []() { \ tensorrt_dso_handle = paddle::platform::dynload::GetTensorRtHandle(); \ - PADDLE_ENFORCE_NOT_NULL(tensorrt_dso_handle, \ - platform::errors::Unavailable( \ - "Load tensorrt %s failed", #__name)); \ }); \ static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \ - PADDLE_ENFORCE_NOT_NULL( \ - p_##__name, \ - platform::errors::Unavailable("Load tensorrt %s failed", #__name)); \ + if (p_##__name == nullptr) { \ + return nullptr; \ + } \ + using tensorrt_func = decltype(&::__name); \ return reinterpret_cast(p_##__name)(args...); \ } \ }; \ extern DynLoad__##__name __name +#if (NV_TENSORRT_MAJOR >= 6) #define TENSORRT_RAND_ROUTINE_EACH(__macro) \ __macro(createInferBuilder_INTERNAL); \ __macro(createInferRuntime_INTERNAL); \ __macro(getPluginRegistry); +#else +#define TENSORRT_RAND_ROUTINE_EACH(__macro) \ + __macro(createInferBuilder_INTERNAL); \ + __macro(createInferRuntime_INTERNAL); +#endif TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index c147bdccbe99e5..60588d89db803f 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -181,6 +181,8 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( if (inner_nblks == 0) { if (strides[0] >= strides[1] && strides[1] >= strides[2]) { return mkldnn::memory::format_tag::ncw; + } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) { + return mkldnn::memory::format_tag::ntc; } else { return mkldnn::memory::format_tag::nwc; } @@ -420,5 +422,7 @@ inline std::vector> ToMkldnnPadding( } } +enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP }; + } // namespace platform } // namespace paddle diff --git a/python/paddle/fleet/__init__.py b/python/paddle/fleet/__init__.py index b25c362ce9301c..cc5ce0f2b74b61 100644 --- a/python/paddle/fleet/__init__.py +++ b/python/paddle/fleet/__init__.py @@ -16,10 +16,13 @@ from .base.distributed_strategy import DistributedStrategy from .base.fleet_base import Fleet from .base.util_factory import UtilBase - +from .dataset import * #from .base.role_maker import PaddleCloudRoleMaker -__all__ = ["DistributedStrategy", "UtilBase"] +__all__ = [ + "DistributedStrategy", "UtilBase", "DatasetFactory", "DatasetBase", + "InMemoryDataset", "QueueDataset" +] fleet = Fleet() init = fleet.init diff --git a/python/paddle/fleet/base/role_maker.py b/python/paddle/fleet/base/role_maker.py index f6b5c8ac12e92d..b3e8120af6f855 100644 --- a/python/paddle/fleet/base/role_maker.py +++ b/python/paddle/fleet/base/role_maker.py @@ -12,5 +12,523 @@ # See the License for the specific language governing permissions and # limitations under the License. """Defination of Role Makers.""" +import os +import numpy as np +from multiprocessing import Process, Manager +import paddle.fluid as fluid -# __all__ = ['RoleMakerBase', 'UserDefinedRoleMaker', 'PaddleCloudRoleMaker'] +__all__ = ['RoleMakerBase', 'UserDefinedRoleMaker', 'PaddleCloudRoleMaker'] + + +class Role: + WORKER = 1 + SERVER = 2 + + +class RoleMakerBase(object): + """ + RoleMakerBase is a base class for assigning a role to current process + in distributed training. + A paddle developer can implement RoleMakerBase to design a role maker + for worker or pserver assignment. + """ + + def __init__(self): + self._worker_endpoints = [] + self._server_endpoints = [] + self._role_is_generated = False + self._role = None + self._current_id = -1 + + self._node_type = None + self._node_type_comm = None + self._all_comm = None + + def is_worker(self): + """ + return is_worker() of current process + """ + raise NotImplementedError("Please implement this method in child class") + + def is_server(self): + """ + return is_server() of current process + """ + raise NotImplementedError("Please implement this method in child class") + + def is_first_worker(self): + """ + Check whether the node is the first instance of worker. + Returns: + bool: True if this is the first node of worker, + False if not. + """ + raise NotImplementedError("Please implement this method in child class") + + def worker_num(self): + """ + Get current total worker number. + + Returns: + int: worker number + """ + raise NotImplementedError("Please implement this method in child class") + + def server_num(self): + """ + Get current total server number. + + Returns: + int: server number + """ + raise NotImplementedError("Please implement this method in child class") + + def worker_index(self): + """ + Get current worker id. + + Returns: + int: node id + """ + raise NotImplementedError("Please implement this method in child class") + + def server_index(self): + """ + Get current server id. + + Returns: + int: node id + """ + raise NotImplementedError("Please implement this method in child class") + + def role_id(self): + """ + Get current id. + + Returns: + int: node id + """ + raise NotImplementedError("Please implement this method in child class") + + def get_trainer_endpoints(self): + """ + return trainer endpoints + """ + return self._worker_endpoints + + def get_pserver_endpoints(self): + """ + return pserver endpoints + """ + return self._server_endpoints + + def to_string(self): + return "role: {}, current_id: {}, worker_endpoints: {}, server_endpoints: {}".format( + self._role, self._current_id, self._worker_endpoints, + self._server_endpoints) + + def _all_gather(self, comm_world, input): + """ + + Args: + input(int|float): input value + + Returns: + return a list of values + """ + print("warning: RoleMakerBase does not have all gather.") + return None + + def _all_reduce(self, comm_world, input, mode="sum"): + """ + Args: + input(list/numpy.array): array of one dim + output(list/numpy.array): array of one dim + mode(str): "sum" or "min" or "max" + """ + print("warning: RoleMakerBase does not have all reduce worker.") + return None + + def _barrier(self, comm_world): + """ + barrier between trainers if current role is TRAINER + """ + print("warning: RoleMakerBase does not have barrier worker.") + + +class PaddleCloudRoleMaker(RoleMakerBase): + def __init__(self, is_collective=False, init_gloo=True, **kwargs): + super(PaddleCloudRoleMaker, self).__init__() + self._is_collective = is_collective + self._init_gloo = init_gloo + self._kwargs = kwargs + + self._role_is_generated = False + + self._server_endpoints = None + self._worker_endpoints = None + + self._node_type_comm = None + self._all_comm = None + + if not self._is_collective: + self._hdfs_name = kwargs.get("hdfs_name", "") + self._hdfs_ugi = kwargs.get("hdfs_ugi", "") + self._hdfs_path = kwargs.get("path", "").rstrip("/") + self._init_timeout_seconds = kwargs.get("init_timeout_seconds", + 3600) + self._run_timeout_seconds = kwargs.get("run_timeout_seconds", + 9999999) + ip_port = kwargs.get("http_ip_port", "") + self._http_ip_port = [] + self._http_server = None + # if ip_port is not empty, it will use http instead of hdfs + if ip_port != "": + self._http_ip_port = ip_port.split(":") + # it's for communication between processes + self._manager = Manager() + # global dict to store status + self._http_server_d = self._manager.dict() + # set running status of http server + self._http_server_d["running"] = False + self._iface = self.__get_default_iface() + # this environment variable can be empty + self._prefix = os.getenv("SYS_JOB_ID", "") + + def _barrier(self, comm_world): + if comm_world: + comm_world.barrier() + + def _all_gather(self, comm_world, input): + if comm_world: + self._barrier(comm_world) + output = comm_world.all_gather(input) + return output + else: + return None + + def _all_reduce(self, comm_world, input, mode="sum"): + if not comm_world: + return None + + input = np.array(input) + + input_shape = input.shape + input_list = input.reshape(-1).tolist() + + self._barrier(comm_world) + ans = comm_world.all_reduce(input_list, mode) + output = np.array(ans).reshape(input_shape) + return output + + def is_worker(self): + """ + whether current process is worker + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.WORKER + + def is_server(self): + """ + whether current process is server + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.SERVER + + def is_first_worker(self): + """ + whether current process is worker of rank 0 + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.WORKER and self._current_id == 0 + + def worker_index(self): + """ + get index of current worker + """ + if not self._role_is_generated: + self.generate_role() + return self._current_id + + def server_index(self): + """ + get index of current server + """ + if not self._role_is_generated: + self.generate_role() + return self._current_id + + def role_id(self): + """ + get index of current node + """ + if self.is_server(): + return self.server_index() + elif self.is_worker(): + return self.worker_index() + + def worker_num(self): + """ + retrun the current number of worker + """ + if not self._role_is_generated: + self.generate_role() + return self._trainers_num + + def server_num(self): + """ + return the current number of server + """ + if not self._role_is_generated: + self.generate_role() + return self._trainers_num + + def get_trainer_endpoints(self): + """ + get endpoint of all trainers + """ + if not self._role_is_generated: + self.generate_role() + return self._worker_endpoints + + def get_pserver_endpoints(self): + """ + get endpoint of all pservers + """ + if not self._role_is_generated: + self.generate_role() + return self._server_endpoints + + def _get_rank(self): + """ + get current rank in all workers and pservers + """ + if not self._role_is_generated: + self.generate_role() + return self._rank + + def _get_size(self): + """ + get total num of all workers and pservers + """ + if not self._role_is_generated: + self.generate_role() + return self._size + + def _ps_env(self): + try: + # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set + # format: string(ip:port), eg. 127.0.0.1:6001 + self._server_endpoints = os.environ[ + "PADDLE_PSERVERS_IP_PORT_LIST"].split(",") + self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", + "").split(",") + + trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"]) + training_role = os.environ["TRAINING_ROLE"] + + if training_role not in ["TRAINER", "PSERVER"]: + raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER") + + if training_role == "TRAINER": + role = Role.WORKER + current_id = int(os.environ["PADDLE_TRAINER_ID"]) + if len(self._worker_endpoints) > 0: + self._cur_endpoint = self._worker_endpoints[current_id] + elif training_role == "PSERVER": + role = Role.SERVER + port = os.environ["PADDLE_PORT"] + ip = os.environ["POD_IP"] + self._cur_endpoint = ip + ":" + port + current_id = self._server_endpoints.index(self._cur_endpoint) + else: + raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER") + except ValueError as ve: + raise ValueError( + "something wrong with PaddleCloud, please check environment") + + self._trainers_num = trainers_num + self._role = role + self._current_id = current_id + + def _collective_env(self): + self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER") + assert (self._training_role == "TRAINER") + self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS") + self._cur_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT") + assert self._worker_endpoints is not None, "can't find PADDLE_TRAINER_ENDPOINTS" + self._worker_endpoints = self._worker_endpoints.split(",") + self._trainers_num = len(self._worker_endpoints) + + def _init_gloo_env(self): + def init_gloo_instance(role="trainer"): + role = role.lower() + assert role in ["trainer", "pserver", "all"] + if role == "trainer": + all_list = self._worker_endpoints + rank = self._current_id + elif role == "pserver": + all_list = self._server_endpoints + rank = self._current_id + else: + all_list = self._worker_endpoints + self._server_endpoints + rank = all_list.index(self._cur_endpoint) + gloo = fluid.core.Gloo() + gloo.set_rank(rank) + gloo.set_size(len(all_list)) + gloo.set_prefix(self._prefix) + gloo.set_iface(self._iface) + gloo.set_timeout_seconds(self._init_timeout_seconds, + self._run_timeout_seconds) + if len(self._http_ip_port) != 0: + gloo.set_http_store(self._http_ip_port[0], + int(self._http_ip_port[1]), role) + else: + gloo.set_hdfs_store(self._hdfs_path + "/" + role, + self._hdfs_name, self._hdfs_ugi) + gloo.init() + return gloo + + # paddlecloud support gloo + if self._role == Role.WORKER: + if self._current_id == 0 and len(self._http_ip_port) != 0: + size_d = { + "trainer": len(self._worker_endpoints), + "pserver": len(self._server_endpoints), + "all": + len(self._worker_endpoints) + len(self._server_endpoints) + } + # child process for http server + self._http_server = Process( + target=self.__start_kv_server, + args=(self._http_server_d, size_d)) + self._http_server.daemon = True + # set running status to True + self._http_server_d["running"] = True + # start child process + self._http_server.start() + self._node_type = 1 + gloo = init_gloo_instance("trainer") + self._node_type_comm = gloo + else: + assert self._role == Role.SERVER + self._node_type = 0 + gloo = init_gloo_instance("pserver") + self._node_type_comm = gloo + + all_list = self._worker_endpoints + self._server_endpoints + self._rank = all_list.index(self._cur_endpoint) + self._size = len(all_list) + + gloo = init_gloo_instance("all") + self._all_comm = gloo + + if self._http_server is not None: + # set running status to False + self._http_server_d["running"] = False + # wait until child process exits + self._http_server.join() + + def generate_role(self): + """ + generate role for role maker + """ + if not self._role_is_generated: + if not self._is_collective: + self._ps_env() + if self._init_gloo: + self._init_gloo_env() + else: + self._collective_env() + self._role_is_generated = True + + def __get_default_iface(self): + """ + get default physical interface + """ + default1 = self.__get_default_iface_from_gateway() + default2 = self.__get_default_iface_from_interfaces() + return default2 if default1 == "lo" else default1 + + def __get_default_iface_from_gateway(self): + """ + get default physical interface + """ + import netifaces + gateways = netifaces.gateways() + if gateways.get(netifaces.AF_INET) != None: + gateway = gateways[netifaces.AF_INET] + if len(gateway) > 0 and len(gateway[0]) > 1: + return gateway[0][1] + return "lo" + + def __get_default_iface_from_interfaces(self): + """ + get default physical interface + """ + import netifaces + for intf_name in netifaces.interfaces(): + addresses = netifaces.ifaddresses(intf_name) + if netifaces.AF_INET in addresses: + ipv4_addresses = addresses[netifaces.AF_INET] + for ipv4_address in ipv4_addresses: + if 'broadcast' in ipv4_address: + return intf_name + return "lo" + + def __start_kv_server(self, http_server_d, size_d): + from paddle.fleet.utils import KVServer + http_server = KVServer(int(self._http_ip_port[1]), size_d) + http_server.start() + wait_seconds = 5 + while http_server_d.get("running", + False) and not http_server.shoud_stop(): + time.sleep(wait_seconds) + http_server.stop() + + +class UserDefinedRoleMaker(PaddleCloudRoleMaker): + def __init__(self, is_collective=False, init_gloo=False, **kwargs): + super(UserDefinedRoleMaker, self).__init__( + is_collective=is_collective, init_gloo=init_gloo, **kwargs) + + def _user_defined_ps_env(self): + self._server_endpoints = self._kwargs.get("server_endpoints") + self._worker_endpoints = self._kwargs.get("worker_endpoints", []) + self._trainers_num = self._kwargs.get("worker_num", 0) + + if self._trainers_num == 0: + assert (len(self._worker_endpoints) > 0) + self._trainers_num = len(self._worker_endpoints) + + self._role = self._kwargs.get("role") + self._current_id = self._kwargs.get("current_id") + + if self._role == Role.WORKER and len( + self._worker_endpoints) > self._current_id: + self._cur_endpoint = self._worker_endpoints[self._current_id] + elif self._role == Role.SERVER: + self._cur_endpoint = self._server_endpoints[self._current_id] + + def _user_defined_collective_env(self): + self._worker_endpoints = self._kwargs.get("worker_endpoints") + self._current_id = self._kwargs.get("current_id") + self._trainers_num = len(self._worker_endpoints) + self._training_role = Role.Worker + + def generate_role(self): + """ + generate role for role maker + """ + if not self._role_is_generated: + if not self._is_collective: + self._user_defined_ps_env() + if self._init_gloo: + self._init_gloo_env() + else: + self._user_defined_collective_env() + self._role_is_generated = True diff --git a/python/paddle/fleet/base/util_factory.py b/python/paddle/fleet/base/util_factory.py index 385500de8c0188..ed2a8db586aa9c 100644 --- a/python/paddle/fleet/base/util_factory.py +++ b/python/paddle/fleet/base/util_factory.py @@ -18,12 +18,27 @@ __all__ = ['UtilBase'] +import numpy as np +import os + +import subprocess +from paddle.fluid import core +from collections import OrderedDict +import paddle.fluid as fluid +from google.protobuf import text_format +from paddle.fluid import debugger +from paddle.fluid.framework import Program +from paddle.fluid.proto import framework_pb2 +from ..utils.fs import FS, LocalFS, HDFSClient + class UtilFactory(object): - def _create_util(self, context): + def _create_util(self, context=None): util = UtilBase() - util._set_strategy(context["valid_strategy"]) - util._set_role_maker(context["role_maker"]) + if context is not None and "valid_strategy" in context: + util._set_strategy(context["valid_strategy"]) + if context is not None and "role_maker" in context: + util._set_role_maker(context["role_maker"]) return util @@ -38,43 +53,390 @@ def _set_strategy(self, dist_strategy): def _set_role_maker(self, role_maker): self.role_maker = role_maker - ''' def set_file_system(self, fs_client): + assert isinstance( + fs_client, + FS), "fs_client must be the instance of paddle.fleet.utils.FS" self.fs_client = fs_client - def broadcast(self): - pass + def __check_comm_world(self, comm_world="worker"): + if not self.role_maker._role_is_generated: + self.role_maker.generate_role() - def all_gather(self): - pass + _comm_world = None + comm_world_upper = comm_world.upper() + if comm_world_upper == "WORKER": + if not self.role_maker.is_worker(): + print( + "warning: current role is not worker in collective_func(comm_world=\"worker\")" + ) + _comm_world = self.role_maker._node_type_comm + elif comm_world_upper == "SERVER": + if not self.role_maker.is_server(): + print( + "warning: current role is not server in collective_func(comm_world=\"server\")" + ) + _comm_world = self.role_maker._node_type_comm + elif comm_world_upper == "ALL": + _comm_world = self.role_maker._all_comm + else: + raise ValueError( + "not support comm_world, please choose one from [worker, server, all]" + ) - def all_reduce(self): - pass + return _comm_world - def reduce_scatter(self): + def all_reduce(self, input, mode, comm_world="worker"): + _comm_world = self.__check_comm_world(comm_world) + return self.role_maker._all_reduce(_comm_world, input, mode) + + def barrier(self, comm_world="worker"): + _comm_world = self.__check_comm_world(comm_world) + self.role_maker._barrier(_comm_world) + + def all_gather(self, input, comm_world="worker"): + _comm_world = self.__check_comm_world(comm_world) + return self.role_maker._all_gather(_comm_world, input) + + def broadcast(self): pass - def reduce(self): + def scatter(self): pass def get_file_shard(self, files): - pass + """ + split files before distributed training, + example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer + 0 gets [a, b, c] and trainer 1 gets [d, e]. + example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets + [a], trainer 1 gets [b], trainer 2 gets [] - def feed_gen(self, batch_size, feed_vars_dims, feeded_vars_filelist): - pass + Args: + files(list): file list need to be read. - def save_program(program, output_dir): - pass + Returns: + list: files belongs to this worker. + """ + if not isinstance(files, list): + raise TypeError("files should be a list of file need to be read.") - def load_program(input_dir): - pass + trainer_id = self.role_maker.worker_index() + trainers = self.role_maker.worker_num() - def load_var(): - pass + remainder = len(files) % trainers + blocksize = int(len(files) / trainers) - def save_var(): - pass + blocks = [blocksize] * trainers + for i in range(remainder): + blocks[i] += 1 - def print_on_rank(self): - pass - ''' + trainer_files = [[]] * trainers + begin = 0 + for i in range(trainers): + trainer_files[i] = files[begin:begin + blocks[i]] + begin += blocks[i] + + return trainer_files[trainer_id] + + def print_on_rank(self, message, rank_id): + if self.role_maker.worker_index() != rank_id: + return + print(message) + + def _save_program(self, program, model_filename='__model__', is_text=False): + if is_text: + with open(model_filename, "w") as f: + f.write(str(program)) + else: + with open(model_filename, "wb") as f: + f.write(program.desc.serialize_to_string()) + + def _load_program(self, path, is_text): + def load_program_binary(path): + """load program from binary string file""" + with open(path, "rb") as f: + program_desc_str = f.read() + return Program.parse_from_string(program_desc_str) + + def load_program_text(path): + """load program from human-readable text file""" + with open(path, "r") as f: + program_desc_text = f.read() + + prog_desc = framework_pb2.ProgramDesc() + text_format.Merge(program_desc_text, prog_desc) + return Program.parse_from_string(prog_desc.SerializeToString()) + + if is_text: + return load_program_text(path) + else: + return load_program_binary(path) + + def _program_type_trans(self, prog_dir, prog_fn, is_text): + prog = self._load_program(os.path.join(prog_dir, prog_fn), is_text) + prog_out_fn = prog_fn + ".bin" if is_text else prog_fn + ".pbtxt" + self._save_program(prog, + os.path.join(prog_dir, prog_out_fn), 1 - is_text) + return prog_out_fn + + def _visualize_graphviz(self, program, output_dir, output_filename): + block = program.global_block() + dot_path = os.path.join(output_dir, output_filename + '.dot') + pdf_path = os.path.join(output_dir, output_filename + '.pdf') + debugger.draw_block_graphviz(block, path=dot_path) + cmd = ["dot", "-Tpdf", dot_path, "-o", pdf_path] + p = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + p.wait() + + def _proto_check(self, config): + train_prog = self._load_program(config.train_prog_path, + config.is_text_train_program) + pruned_prog = self._load_program(config.pruned_prog_path, + config.is_text_pruned_program) + + is_match = True + + pruned_vars = [(v.name, v) for v in pruned_prog.list_vars() + if fluid.io.is_persistable(v)] + pruned_vars = OrderedDict(pruned_vars) + pruned_vars_name = [name for name in pruned_vars] + print("persistable vars in pruned program: {}".format(pruned_vars_name)) + + # feed and fetch op is added in pruned program when pruning, not need to be found in train program + feed_fetch_type_list = [ + core.VarDesc.VarType.FEED_MINIBATCH, core.VarDesc.VarType.FETCH_LIST + ] + + for var_name in pruned_vars: + var = pruned_vars[var_name] + # feed and fetch op is added in pruned program when pruning, not need to be found in train program + if var.type in feed_fetch_type_list: + break + try: + train_prog_var = train_prog.global_block().var(var_name) + except ValueError as e: + print( + "Not find variable '%s' in train program. please check pruning." + % var_name) + is_match = False + continue + if var.shape != train_prog_var.shape or var.dtype != train_prog_var.dtype: + print( + "variable: {} not match. in pruned program shape: {} dtype:{}, in train program shape: {} dtype: {}". + format(var_name, var.shape, var.dtype, train_prog_var.shape, + train_prog_var.dtype)) + is_match = False + return is_match + + def _params_check(self, config): + def feed_gen(batch_size, feeded_vars_dims, feeded_vars_filelist): + def reader(batch_size, fn, dim): + data = [] + if isinstance(dim, list) or isinstance(dim, tuple): + shape = list(dim) + _temp = 1 + for x in dim: + _temp = _temp * x + dim = _temp + else: + shape = [dim] + + shape = [batch_size] + shape + dim = dim * batch_size + + for line in open(fn, 'r'): + fields = line.strip().split(' ') + fields = [float(d) for d in fields] + while len(fields) >= dim: + tmp = fields[:dim] + fields = fields[dim:] + data.append(np.array(tmp).reshape(shape)) + return data + + batch_feed = [] + for i, fn in enumerate(feeded_vars_filelist): + batch_feed.append(reader(batch_size, fn, feeded_vars_dims[i])) + return batch_feed + + prog = self._load_program( + os.path.join(config.dump_model_dir, config.dump_program_filename), + config.is_text_dump_program) + if config.is_text_dump_program: + model_filename = self._program_type_trans( + config.dump_model_dir, config.dump_program_filename, + config.is_text_dump_program) + + saved_params = [ + v for v in prog.list_vars() if fluid.io.is_persistable(v) + ] + print("persistable vars in dump program: {}".format( + [v.name for v in saved_params])) + + def check_not_expected_ops(prog, not_expected_op_types): + op_types_set = set() + for op in prog.global_block().ops: + if op.type in not_expected_op_types and op.type not in op_types_set: + op_types_set.add(op.type) + return op_types_set + + not_expected_op_types = check_not_expected_ops(prog, ["lookup_table"]) + if len(not_expected_op_types) > 0: + print( + "find op type '{}' in program, please check if your program is pruned correctly !". + format(list(not_expected_op_types))) + return False + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + inference_program, feed_target_names, fetch_targets = \ + fluid.io.load_inference_model(config.dump_model_dir, exe, model_filename=model_filename, + params_filename=config.save_params_filename) + + # check program vars and saved vars shape + orig_para_shape = { + each_var.name: tuple(each_var.desc.shape()) + for each_var in saved_params + } + for each_var in saved_params: + var_temp = fluid.global_scope().find_var(each_var.name) + assert var_temp != None, "can't not find var: " + each_var.name + new_shape = (np.array(var_temp.get_tensor())).shape + assert each_var.name in orig_para_shape, each_var.name + "MUST in var list" + orig_shape = orig_para_shape.get(each_var.name) + if new_shape != orig_shape: + raise RuntimeError( + "Shape not matching: the Program requires a parameter with a shape of ({}), " + "while the loaded parameter (namely [ {} ]) has a shape of ({}).". + format(orig_shape, each_var.name, new_shape)) + + # check feed/fetch vars in program and config + feed_config = config.feed_config + fetch_config = config.fetch_config + fetch_targets_names = [v.name for v in fetch_targets] + if not feed_target_names: + print("warning! no feed targets in program.") + if not fetch_targets_names: + print("warning! no fetch targets in program.") + fetch_list = fetch_targets + feed_name_list = feed_target_names + if feed_config.feeded_vars_names is not None and feed_target_names != feed_config.feeded_vars_names: + print( + "warning! feed vars in program and config are diff: feed in program: {}. feed in config {}.". + format(feed_target_names, feed_config.feeded_vars_names)) + feed_name_list = feed_config.feeded_vars_names + # remove feed op in inference_program. new feed op will be added in exe.run + global_block = inference_program.global_block() + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed": # only remove feed op here + need_to_remove_op_index.append(i) + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + if fetch_config.fetch_vars_names is not None and fetch_targets_names != fetch_config.fetch_vars_names: + print( + "warning! fetch vars in program and config are diff: fetch in program: {}. fetch in config {}.". + format(fetch_targets_names, fetch_config.fetch_vars_names)) + fetch_list = [ + inference_program.global_block().var(i) + for i in fetch_config.fetch_vars_names + ] + # remove fetch op in inference_program. new fetch op will be added in exe.run + global_block = inference_program.global_block() + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "fetch": # only remove fetch op here + need_to_remove_op_index.append(i) + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + + # if fetch_list have lod tensor + return_numpy = all([v.lod_level == 0 for v in fetch_list]) + + # try dump fetch_targets + feed_tensors = [] + assert len(feed_config.feeded_vars_names) == len( + feed_config.feeded_vars_dims) == len( + feed_config.feeded_vars_types) + # check program vars and feed tensor shape in config + for i in range(len(feed_config.feeded_vars_names)): + var = inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + if not isinstance(feed_config.feeded_vars_dims[i], + (list, tuple)): + tensor_shape = (feed_config.feeded_vars_dims[i], ) + else: + tensor_shape = tuple(feed_config.feeded_vars_dims[i]) + feed_config.feeded_vars_dims[i] = tensor_shape + var_shape = var.shape[1:] + if tensor_shape != var_shape: + raise RuntimeError( + "feed variable '{}' shape not match. infer program shape: {}. feed tensor shape: {}". + format(feed_config.feeded_vars_names[i], var_shape, + tensor_shape)) + + if not feed_config.feeded_vars_filelist: + print("generate random feed vars.") + for i in range(len(feed_config.feeded_vars_names)): + var = inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + # create fake feed tensor. if lod_level > 1, should create_lod_tensor() + if var.lod_level == 0: + feed_tensors.append( + np.array( + np.random.random( + tuple([config.batch_size] + list( + feed_config.feeded_vars_dims[i]))), + dtype=feed_config.feeded_vars_types[i])) + elif var.lod_level == 1: + t = np.array( + np.random.random( + tuple([config.batch_size] + list( + feed_config.feeded_vars_dims[i]))), + dtype=feed_config.feeded_vars_types[i]) + feed_tensors.append( + fluid.create_lod_tensor(t, [[1] * config.batch_size + ], place)) + else: + raise RuntimeError( + "vars with lod_level >= 2 is not supported now in this infer program check tool." + ) + results = exe.run(inference_program, + feed={ + name: feed_tensors[i] + for i, name in enumerate(feed_name_list) + }, + fetch_list=fetch_list, + return_numpy=return_numpy) + else: + print("load feed vars from files: {}.".format( + feed_config.feeded_vars_filelist)) + feed_vars = [ + inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + for i in range(len(feed_config.feeded_vars_names)) + ] + feeder = fluid.DataFeeder(feed_list=feed_vars, place=place) + batch_feed = feed_gen(config.batch_size, + feed_config.feeded_vars_dims, + feed_config.feeded_vars_filelist) + slots = [batch_feed] + results = exe.run(inference_program, + feed=feeder.feed(slots), + fetch_list=fetch_list, + return_numpy=return_numpy) + for i, v in enumerate(fetch_list): + print("fetch_targets name: %s" % v.name) + print("fetch_targets: {}".format(results[i])) + return results + + +fleet_util = UtilFactory()._create_util(None) diff --git a/python/paddle/fleet/dataset/__init__.py b/python/paddle/fleet/dataset/__init__.py index 8647330f3290f3..af33c4eafb3968 100644 --- a/python/paddle/fleet/dataset/__init__.py +++ b/python/paddle/fleet/dataset/__init__.py @@ -10,3 +10,5 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and + +from .dataset import * diff --git a/python/paddle/fleet/dataset/dataset.py b/python/paddle/fleet/dataset/dataset.py new file mode 100644 index 00000000000000..f6504cacd96808 --- /dev/null +++ b/python/paddle/fleet/dataset/dataset.py @@ -0,0 +1,1103 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This is definition of dataset class, which is high performance IO.""" + +import paddle +import paddle.fluid as fluid +from paddle.fluid.proto import data_feed_pb2 +from google.protobuf import text_format +import paddle.fluid.core as core + + +class DatasetFactory(object): + """ + DatasetFactory is a factory which create dataset by its name, + you can create "QueueDataset" or "InMemoryDataset", or "FileInstantDataset", + the default is "QueueDataset". + + Example: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + + """ + + def __init__(self): + """ Init. """ + pass + + def create_dataset(self, datafeed_class="QueueDataset"): + """ + Create "QueueDataset" or "InMemoryDataset", or "FileInstantDataset", + the default is "QueueDataset". + + Args: + datafeed_class(str): datafeed class name, QueueDataset or InMemoryDataset. + Default is QueueDataset. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + + """ + try: + dataset = globals()[datafeed_class]() + return dataset + except: + raise ValueError("datafeed class %s does not exist" % + datafeed_class) + + +class DatasetBase(object): + """ Base dataset class. """ + + def __init__(self): + """ Init. """ + # define class name here + # to decide whether we need create in memory instance + self.proto_desc = data_feed_pb2.DataFeedDesc() + self.proto_desc.pipe_command = "cat" + self.dataset = core.Dataset("MultiSlotDataset") + self.thread_num = 1 + self.filelist = [] + + def set_pipe_command(self, pipe_command): + """ + Set pipe command of current dataset + A pipe command is a UNIX pipeline command that can be used only + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_pipe_command("python my_script.py") + + Args: + pipe_command(str): pipe command + + """ + self.proto_desc.pipe_command = pipe_command + + def set_rank_offset(self, rank_offset): + """ + Set rank_offset for merge_pv. It set the message of Pv. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_rank_offset("rank_offset") + + Args: + rank_offset(str): rank_offset's name + + """ + self.proto_desc.rank_offset = rank_offset + + def set_fea_eval(self, record_candidate_size, fea_eval=True): + """ + set fea eval mode for slots shuffle to debug the importance level of + slots(features), fea_eval need to be set True for slots shuffle. + + Args: + record_candidate_size(int): size of instances candidate to shuffle + one slot + fea_eval(bool): whether enable fea eval mode to enable slots shuffle. + default is True. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_fea_eval(1000000, True) + + """ + if fea_eval: + self.dataset.set_fea_eval(fea_eval, record_candidate_size) + self.fea_eval = fea_eval + + def slots_shuffle(self, slots): + """ + Slots Shuffle + Slots Shuffle is a shuffle method in slots level, which is usually used + in sparse feature with large scale of instances. To compare the metric, i.e. + auc while doing slots shuffle on one or several slots with baseline to + evaluate the importance level of slots(features). + + Args: + slots(list[string]): the set of slots(string) to do slots shuffle. + + Examples: + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_merge_by_lineid() + #suppose there is a slot 0 + dataset.slots_shuffle(['0']) + """ + if self.fea_eval: + slots_set = set(slots) + self.dataset.slots_shuffle(slots_set) + + def set_batch_size(self, batch_size): + """ + Set batch size. Will be effective during training + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_batch_size(128) + + Args: + batch_size(int): batch size + + """ + self.proto_desc.batch_size = batch_size + + def set_pv_batch_size(self, pv_batch_size): + """ + Set pv batch size. It will be effective during enable_pv_merge + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_pv_batch(128) + Args: + pv_batch_size(int): pv batch size + + """ + self.proto_desc.pv_batch_size = pv_batch_size + + def set_thread(self, thread_num): + """ + Set thread num, it is the num of readers. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_thread(12) + + Args: + thread_num(int): thread num + """ + self.dataset.set_thread_num(thread_num) + self.thread_num = thread_num + + def set_filelist(self, filelist): + """ + Set file list in current worker. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_filelist(['a.txt', 'b.txt']) + + Args: + filelist(list): file list + """ + self.dataset.set_filelist(filelist) + self.filelist = filelist + + def set_input_type(self, input_type): + self.proto_desc.input_type = input_type + + def set_use_var(self, var_list): + """ + Set Variables which you will use. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_use_var([data, label]) + + Args: + var_list(list): variable list + """ + multi_slot = self.proto_desc.multi_slot_desc + for var in var_list: + slot_var = multi_slot.slots.add() + slot_var.is_used = True + slot_var.name = var.name + if var.lod_level == 0: + slot_var.is_dense = True + slot_var.shape.extend(var.shape) + if var.dtype == core.VarDesc.VarType.FP32: + slot_var.type = "float" + elif var.dtype == core.VarDesc.VarType.INT64: + slot_var.type = "uint64" + else: + raise ValueError( + "Currently, fluid.dataset only supports dtype=float32 and dtype=int64" + ) + + def set_hdfs_config(self, fs_name, fs_ugi): + """ + Set hdfs config: fs name ad ugi + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_hdfs_config("my_fs_name", "my_fs_ugi") + + Args: + fs_name(str): fs name + fs_ugi(str): fs ugi + """ + self.dataset.set_hdfs_config(fs_name, fs_ugi) + + def set_download_cmd(self, download_cmd): + """ + Set customized download cmd: download_cmd + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_download_cmd("./read_from_afs") + + Args: + download_cmd(str): customized download command + """ + self.dataset.set_download_cmd(download_cmd) + + def _prepare_to_run(self): + """ + Set data_feed_desc before load or shuffle, + user no need to call this function. + """ + if self.thread_num > len(self.filelist): + self.thread_num = len(self.filelist) + self.dataset.set_thread_num(self.thread_num) + self.dataset.set_data_feed_desc(self.desc()) + self.dataset.create_readers() + + def _finish_to_run(self): + self.dataset.destroy_readers() + + def desc(self): + """ + Returns a protobuf message for this DataFeedDesc + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + print(dataset.desc()) + + Returns: + A string message + """ + return text_format.MessageToString(self.proto_desc) + + def _dynamic_adjust_before_train(self, thread_num): + pass + + def _dynamic_adjust_after_train(self): + pass + + +class InMemoryDataset(DatasetBase): + """ + InMemoryDataset, it will load data into memory + and shuffle data before training. + This class should be created by DatasetFactory + + Example: + dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset") + """ + + def __init__(self): + """ Init. """ + super(InMemoryDataset, self).__init__() + self.proto_desc.name = "MultiSlotInMemoryDataFeed" + self.fleet_send_batch_size = None + self.is_user_set_queue_num = False + self.queue_num = None + self.parse_ins_id = False + self.parse_content = False + self.parse_logkey = False + self.merge_by_sid = True + self.enable_pv_merge = False + self.merge_by_lineid = False + self.fleet_send_sleep_seconds = None + + def set_feed_type(self, data_feed_type): + """ + Set data_feed_desc + """ + self.proto_desc.name = data_feed_type + + def _prepare_to_run(self): + """ + Set data_feed_desc before load or shuffle, + user no need to call this function. + """ + if self.thread_num <= 0: + self.thread_num = 1 + self.dataset.set_thread_num(self.thread_num) + if self.queue_num is None: + self.queue_num = self.thread_num + self.dataset.set_queue_num(self.queue_num) + self.dataset.set_parse_ins_id(self.parse_ins_id) + self.dataset.set_parse_content(self.parse_content) + self.dataset.set_parse_logkey(self.parse_logkey) + self.dataset.set_merge_by_sid(self.merge_by_sid) + self.dataset.set_enable_pv_merge(self.enable_pv_merge) + self.dataset.set_data_feed_desc(self.desc()) + self.dataset.create_channel() + self.dataset.create_readers() + + def _dynamic_adjust_before_train(self, thread_num): + if not self.is_user_set_queue_num: + self.dataset.dynamic_adjust_channel_num(thread_num, False) + self.dataset.dynamic_adjust_readers_num(thread_num) + + def _dynamic_adjust_after_train(self): + if not self.is_user_set_queue_num: + self.dataset.dynamic_adjust_channel_num(self.thread_num, False) + self.dataset.dynamic_adjust_readers_num(self.thread_num) + + def set_queue_num(self, queue_num): + """ + Set Dataset output queue num, training threads get data from queues + + Args: + queue_num(int): dataset output queue num + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_queue_num(12) + + """ + self.is_user_set_queue_num = True + self.queue_num = queue_num + + def set_parse_ins_id(self, parse_ins_id): + """ + Set id Dataset need to parse insid + + Args: + parse_ins_id(bool): if parse ins_id or not + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_parse_ins_id(True) + + """ + self.parse_ins_id = parse_ins_id + + def set_parse_content(self, parse_content): + """ + Set if Dataset need to parse content + + Args: + parse_content(bool): if parse content or not + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_parse_content(True) + + """ + self.parse_content = parse_content + + def set_parse_logkey(self, parse_logkey): + """ + Set if Dataset need to parse logkey + + Args: + parse_content(bool): if parse logkey or not + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_parse_logkey(True) + + """ + self.parse_logkey = parse_logkey + + def set_merge_by_sid(self, merge_by_sid): + """ + Set if Dataset need to merge sid. If not, one ins means one Pv. + + Args: + merge_by_sid(bool): if merge sid or not + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_merge_by_sid(True) + + """ + self.merge_by_sid = merge_by_sid + + def set_enable_pv_merge(self, enable_pv_merge): + """ + Set if Dataset need to merge pv. + + Args: + enable_pv_merge(bool): if enable_pv_merge or not + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_enable_pv_merge(True) + + """ + self.enable_pv_merge = enable_pv_merge + + def preprocess_instance(self): + """ + Merge pv instance and convey it from input_channel to input_pv_channel. + It will be effective when enable_pv_merge_ is True. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.preprocess_instance() + + """ + self.dataset.preprocess_instance() + + def set_current_phase(self, current_phase): + """ + Set current phase in train. It is useful for untest. + current_phase : 1 for join, 0 for update. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.set_current_phase(1) + + """ + self.dataset.set_current_phase(current_phase) + + def postprocess_instance(self): + """ + Divide pv instance and convey it to input_channel. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.preprocess_instance() + exe.train_from_dataset(dataset) + dataset.postprocess_instance() + + """ + self.dataset.postprocess_instance() + + def set_fleet_send_batch_size(self, fleet_send_batch_size=1024): + """ + Set fleet send batch size, default is 1024 + + Args: + fleet_send_batch_size(int): fleet send batch size + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_fleet_send_batch_size(800) + + """ + self.fleet_send_batch_size = fleet_send_batch_size + + def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0): + """ + Set fleet send sleep time, default is 0 + + Args: + fleet_send_sleep_seconds(int): fleet send sleep time + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_fleet_send_sleep_seconds(2) + + """ + self.fleet_send_sleep_seconds = fleet_send_sleep_seconds + + def set_merge_by_lineid(self, merge_size=2): + """ + Set merge by line id, instances of same line id will be merged after + shuffle, you should parse line id in data generator. + + Args: + merge_size(int): ins size to merge. default is 2. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_merge_by_lineid() + + """ + self.dataset.set_merge_by_lineid(merge_size) + self.merge_by_lineid = True + self.parse_ins_id = True + + def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num): + self.dataset.set_generate_unique_feasigns(generate_uni_feasigns) + self.gen_uni_feasigns = generate_uni_feasigns + self.local_shard_num = shard_num + + def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num, + consume_thread_num, shard_num): + self.dataset.generate_local_tables_unlock( + table_id, fea_dim, read_thread_num, consume_thread_num, shard_num) + + def load_into_memory(self): + """ + Load data into memory + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + """ + self._prepare_to_run() + self.dataset.load_into_memory() + + def preload_into_memory(self, thread_num=None): + """ + Load data into memory in async mode + + Args: + thread_num(int): preload thread num + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.preload_into_memory() + dataset.wait_preload_done() + """ + self._prepare_to_run() + if thread_num is None: + thread_num = self.thread_num + self.dataset.set_preload_thread_num(thread_num) + self.dataset.create_preload_readers() + self.dataset.preload_into_memory() + + def wait_preload_done(self): + """ + Wait preload_into_memory done + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.preload_into_memory() + dataset.wait_preload_done() + """ + self.dataset.wait_preload_done() + self.dataset.destroy_preload_readers() + + def local_shuffle(self): + """ + Local shuffle + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.local_shuffle() + """ + self.dataset.local_shuffle() + + def global_shuffle(self, fleet=None, thread_num=12): + """ + Global shuffle. + Global shuffle can be used only in distributed mode. i.e. multiple + processes on single machine or multiple machines training together. + If you run in distributed mode, you should pass fleet instead of None. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.global_shuffle(fleet) + + Args: + fleet(Fleet): fleet singleton. Default None. + thread_num(int): shuffle thread num. Default is 12. + + """ + trainer_num = 1 + if fleet is not None: + fleet._role_maker.barrier_worker() + trainer_num = fleet.worker_num() + if self.fleet_send_batch_size is None: + self.fleet_send_batch_size = 1024 + if self.fleet_send_sleep_seconds is None: + self.fleet_send_sleep_seconds = 0 + self.dataset.register_client2client_msg_handler() + self.dataset.set_trainer_num(trainer_num) + self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) + self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds) + if fleet is not None: + fleet._role_maker.barrier_worker() + self.dataset.global_shuffle(thread_num) + if fleet is not None: + fleet._role_maker.barrier_worker() + if self.merge_by_lineid: + self.dataset.merge_by_lineid() + if fleet is not None: + fleet._role_maker.barrier_worker() + + def release_memory(self): + """ + :api_attr: Static Graph + + Release InMemoryDataset memory data, when data will not be used again. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.global_shuffle(fleet) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + exe.train_from_dataset(fluid.default_main_program(), dataset) + dataset.release_memory() + + """ + self.dataset.release_memory() + + def get_pv_data_size(self): + """ + Get memory data size of Pv, user can call this function to know the pv num + of ins in all workers after load into memory. + + Note: + This function may cause bad performance, because it has barrier + + Returns: + The size of memory pv data. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + print dataset.get_pv_data_size() + + """ + return self.dataset.get_pv_data_size() + + def get_memory_data_size(self, fleet=None): + """ + Get memory data size, user can call this function to know the num + of ins in all workers after load into memory. + + Note: + This function may cause bad performance, because it has barrier + + Args: + fleet(Fleet): Fleet Object. + + Returns: + The size of memory data. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + print dataset.get_memory_data_size(fleet) + + """ + import numpy as np + local_data_size = self.dataset.get_memory_data_size() + local_data_size = np.array([local_data_size]) + if fleet is not None: + global_data_size = local_data_size * 0 + fleet._role_maker.all_reduce_worker(local_data_size, + global_data_size) + return global_data_size[0] + return local_data_size[0] + + def get_shuffle_data_size(self, fleet=None): + """ + Get shuffle data size, user can call this function to know the num + of ins in all workers after local/global shuffle. + + Note: + This function may cause bad performance to local shuffle, + because it has barrier. It does not affect global shuffle. + + Args: + fleet(Fleet): Fleet Object. + + Returns: + The size of shuffle data. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.global_shuffle(fleet) + print dataset.get_shuffle_data_size(fleet) + + """ + import numpy as np + local_data_size = self.dataset.get_shuffle_data_size() + local_data_size = np.array([local_data_size]) + if fleet is not None: + global_data_size = local_data_size * 0 + fleet._role_maker.all_reduce_worker(local_data_size, + global_data_size) + return global_data_size[0] + return local_data_size[0] + + +class QueueDataset(DatasetBase): + """ + QueueDataset, it will process data streamly. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("QueueDataset") + + """ + + def __init__(self): + """ + Initialize QueueDataset + This class should be created by DatasetFactory + """ + super(QueueDataset, self).__init__() + self.proto_desc.name = "MultiSlotDataFeed" + + def _prepare_to_run(self): + """ + Set data_feed_desc/thread num/filelist before run, + user no need to call this function. + """ + if self.thread_num > len(self.filelist): + self.thread_num = len(self.filelist) + if self.thread_num == 0: + self.thread_num = 1 + self.dataset.set_thread_num(self.thread_num) + self.dataset.set_filelist(self.filelist) + self.dataset.set_data_feed_desc(self.desc()) + self.dataset.create_readers() + + def local_shuffle(self): + """ + Local shuffle data. + + Local shuffle is not supported in QueueDataset + NotImplementedError will be raised + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset.local_shuffle() + + Raises: + NotImplementedError: QueueDataset does not support local shuffle + + """ + raise NotImplementedError( + "QueueDataset does not support local shuffle, " + "please use InMemoryDataset for local_shuffle") + + def global_shuffle(self, fleet=None): + """ + Global shuffle data. + + Global shuffle is not supported in QueueDataset + NotImplementedError will be raised + + Args: + fleet(Fleet): fleet singleton. Default None. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet + dataset = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset.global_shuffle(fleet) + + Raises: + NotImplementedError: QueueDataset does not support global shuffle + + """ + raise NotImplementedError( + "QueueDataset does not support global shuffle, " + "please use InMemoryDataset for global_shuffle") + + +class FileInstantDataset(DatasetBase): + """ + FileInstantDataset, it will process data streamly. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory.create_dataset("FileInstantDataset") + """ + + def __init__(self): + """ + Initialize FileInstantDataset + This class should be created by DatasetFactory + """ + super(FileInstantDataset, self).__init__() + self.proto_desc.name = "MultiSlotFileInstantDataFeed" + + def local_shuffle(self): + """ + Local shuffle + FileInstantDataset does not support local shuffle + """ + raise NotImplementedError( + "FileInstantDataset does not support local shuffle, " + "please use InMemoryDataset for local_shuffle") + + def global_shuffle(self, fleet=None): + """ + Global shuffle + FileInstantDataset does not support global shuffle + """ + raise NotImplementedError( + "FileInstantDataset does not support global shuffle, " + "please use InMemoryDataset for global_shuffle") + + +class BoxPSDataset(InMemoryDataset): + """ + BoxPSDataset: derived from InMemoryDataset. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + """ + + def __init__(self): + """ + Initialize BoxPSDataset + This class should be created by DatasetFactory + """ + super(BoxPSDataset, self).__init__() + self.boxps = core.BoxPS(self.dataset) + self.proto_desc.name = "PaddleBoxDataFeed" + + def set_date(self, date): + """ + Workaround for date + """ + year = int(date[:4]) + month = int(date[4:6]) + day = int(date[6:]) + self.boxps.set_date(year, month, day) + + def begin_pass(self): + """ + Begin Pass + Notify BoxPS to load sparse parameters of next pass to GPU Memory + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + dataset.begin_pass() + """ + self.boxps.begin_pass() + + def end_pass(self, need_save_delta): + """ + End Pass + Notify BoxPS that current pass ended + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + dataset.end_pass(True) + """ + self.boxps.end_pass(need_save_delta) + + def wait_preload_done(self): + """ + Wait async preload done + Wait Until Feed Pass Done + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.preload_into_memory() + dataset.wait_preload_done() + """ + self.boxps.wait_feed_pass_done() + + def load_into_memory(self): + """ + Load next pass into memory and notify boxps to fetch its emb from SSD + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + """ + self._prepare_to_run() + self.boxps.load_into_memory() + + def preload_into_memory(self): + """ + Begin async preload next pass while current pass may be training + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.preload_into_memory() + """ + self._prepare_to_run() + self.boxps.preload_into_memory() + + def _dynamic_adjust_before_train(self, thread_num): + if not self.is_user_set_queue_num: + self.dataset.dynamic_adjust_channel_num(thread_num, True) + self.dataset.dynamic_adjust_readers_num(thread_num) + + def _dynamic_adjust_after_train(self): + pass + + def slots_shuffle(self, slots): + """ + Slots Shuffle + Slots Shuffle is a shuffle method in slots level, which is usually used + in sparse feature with large scale of instances. To compare the metric, i.e. + auc while doing slots shuffle on one or several slots with baseline to + evaluate the importance level of slots(features). + + Args: + slots(list[string]): the set of slots(string) to do slots shuffle. + + Examples: + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_merge_by_lineid() + #suppose there is a slot 0 + dataset.slots_shuffle(['0']) + """ + slots_set = set(slots) + self.boxps.slots_shuffle(slots_set) diff --git a/python/paddle/fleet/utils/__init__.py b/python/paddle/fleet/utils/__init__.py new file mode 100644 index 00000000000000..212308159aabb1 --- /dev/null +++ b/python/paddle/fleet/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fs import * +from .http_server import KVHandler, KVHTTPServer, KVServer + +__all__ = ['KVHandler', 'KVHTTPServer', 'KVServer'] + fs.__all__ diff --git a/python/paddle/fleet/utils/fs.py b/python/paddle/fleet/utils/fs.py new file mode 100644 index 00000000000000..3fec773f273180 --- /dev/null +++ b/python/paddle/fleet/utils/fs.py @@ -0,0 +1,382 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess +import multiprocessing +from datetime import datetime + +import re +import copy +import errno +import time +import logging +import six +import abc +import paddle.fluid as fluid +import functools + +from pathlib import PurePosixPath, Path +import shutil + +__all__ = [ + 'FS', 'LocalFS', 'HDFSClient', 'ExecuteError', 'FSTimeOut', + 'FSFileExistsError', 'FSFileNotExistsError' +] + + +class ExecuteError(Exception): + pass + + +class FSFileExistsError(Exception): + pass + + +class FSFileNotExistsError(Exception): + pass + + +class FSTimeOut(Exception): + pass + + +class FS(object): + @abc.abstractmethod + def ls_dir(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def is_file(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def is_dir(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def is_exist(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def upload(self, local_path, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def download(self, fs_path, local_path): + raise NotImplementedError + + @abc.abstractmethod + def mkdirs(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def delete(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def need_upload_download(self): + raise NotImplementedError + + @abc.abstractmethod + def rename(self, fs_src_path, fs_dst_path): + raise NotImplementedError + + @abc.abstractmethod + def mv(self, fs_src_path, fs_dst_path): + raise NotImplementedError + + @abc.abstractmethod + def upload_dir(self, local_dir, dest_dir): + raise NotImplementedError + + @abc.abstractmethod + def glob(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def stat(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def walk(self, fs_path): + raise NotImplementedError + + +class LocalFS(FS): + def ls_dir(self, fs_path): + if not self.is_exist(fs_path): + return [], [] + + dirs = [] + files = [] + for f in os.listdir(fs_path): + if os.path.isdir(fs_path + "/" + f): + dirs.append(f) + else: + files.append(f) + + return dirs, files + + def mkdirs(self, fs_path): + assert not os.path.isfile(fs_path), "{} is already a file".format( + fs_path) + os.system("mkdir -p {}".format(fs_path)) + + def is_file(self, fs_path): + return os.path.isfile(fs_path) + + def is_dir(self, fs_path): + return os.path.isdir(fs_path) + + def is_exist(self, fs_path): + return os.path.exists(fs_path) + + def _rmr(self, fs_path): + shutil.rmtree(fs_path) + + def _rm(self, fs_path): + os.remove(fs_path) + + def delete(self, fs_path): + if not self.is_exist(fs_path): + return + + if os.path.isfile(fs_path): + return self._rm(fs_path) + + return self._rmr(fs_path) + + def rename(self, fs_src_path, fs_dst_path): + os.rename(fs_src_path, fs_dst_path) + + def need_upload_download(self): + return False + + def touch(self, fs_path): + return Path(fs_path).touch() + + def mv(self, src_path, dst_path): + if not self.is_exist(src_path): + raise FSFileNotExistsError + + if self.is_exist(dst_path): + raise FSFileExistsError + + return self.rename(src_path, dst_path) + + +"""HDFS Utils.""" + + +def _handle_errors(f): + def handler(*args, **kwargs): + start = time.time() + while True: + try: + return f(*args, **kwargs) + except ExecuteError as e: + o = args[0] + time_out = float(o._time_out) / 1000.0 + inter = float(o._sleep_inter) / 1000.0 + if time.time() - start >= time_out: + raise FSTimeOut + time.sleep(inter) + + return functools.wraps(f)(handler) + + +class HDFSClient(FS): + def __init__( + self, + hadoop_home, + configs, + time_out=5 * 60 * 1000, #ms + sleep_inter=1000): #ms + # Raise exception if JAVA_HOME not exists. + java_home = os.environ["JAVA_HOME"] + + self.pre_commands = [] + hadoop_bin = '%s/bin/hadoop' % hadoop_home + self.pre_commands.append(hadoop_bin) + dfs = 'fs' + self.pre_commands.append(dfs) + + if configs: + for k, v in six.iteritems(configs): + self.pre_commands.append('-D%s=%s' % (k, v)) + + self._time_out = time_out + self._sleep_inter = sleep_inter + self._base_cmd = " ".join(self.pre_commands) + self._bd_err_re = re.compile( + r'\s?responseErrorMsg\s?\:.*, errorCode\:\s?[0-9]+, path\:') + + def _run_cmd(self, cmd, redirect_stderr=False): + ret, output = fluid.core.shell_execute_cmd(cmd, 0, 0, redirect_stderr) + return int(ret), output.splitlines() + + @_handle_errors + def ls_dir(self, fs_path): + """ + list directory under fs_path, and only give the pure name, not include the fs_path + """ + if not self.is_exist(fs_path): + return [], [] + + cmd = "{} -ls {}".format(self._base_cmd, fs_path) + ret, lines = self._run_cmd(cmd) + + if ret != 0: + raise ExecuteError + + dirs = [] + files = [] + for line in lines: + arr = line.split() + if len(arr) != 8: + continue + + if fs_path not in arr[7]: + continue + + p = PurePosixPath(arr[7]) + if arr[0][0] == 'd': + dirs.append(p.name) + else: + files.append(p.name) + + return dirs, files + + def _test_match(self, lines): + for l in lines: + m = self._bd_err_re.match(l) + if m != None: + return m + + return None + + @_handle_errors + def is_dir(self, fs_path): + if not self.is_exist(fs_path): + return False + + cmd = "{} -test -d {}".format( + self._base_cmd, fs_path, redirect_stderr=True) + ret, lines = self._run_cmd(cmd) + if ret: + # other error + if self._test_match(lines) != None: + raise ExecuteError + + return False + + return True + + def is_file(self, fs_path): + if not self.is_exist(fs_path): + return False + + return not self.is_dir(fs_path) + + @_handle_errors + def is_exist(self, fs_path): + cmd = "{} -ls {} ".format(self._base_cmd, fs_path) + ret, out = self._run_cmd(cmd, redirect_stderr=True) + if ret != 0: + for l in out: + if "No such file or directory" in l: + return False + raise ExecuteError + + return True + + @_handle_errors + def upload(self, local_path, fs_path): + if self.is_exist(fs_path): + raise FSFileExistsError + + local = LocalFS() + if not local.is_exist(local_path): + raise FSFileNotExistsError + + cmd = "{} -put {} {}".format(self._base_cmd, local_path, fs_path) + ret, lines = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def download(self, fs_path, local_path): + if self.is_exist(local_path): + raise FSFileExistsError + + if not self.is_exist(fs_path): + raise FSFileNotExistsError + + cmd = "{} -get {} {}".format(self._base_cmd, fs_path, local_path) + ret, lines = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def mkdirs(self, fs_path): + if self.is_exist(fs_path): + return + + cmd = "{} -mkdir {}".format(self._base_cmd, fs_path) + ret, lines = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def mv(self, fs_src_path, fs_dst_path, test_exists=True): + if test_exists: + if not self.is_exist(fs_src_path): + raise FSFileNotExistsError + + if self.is_exist(fs_dst_path): + raise FSFileExistsError + + cmd = "{} -mv {} {}".format(self._base_cmd, fs_src_path, fs_dst_path) + ret, _ = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def _rmr(self, fs_path): + cmd = "{} -rmr {}".format(self._base_cmd, fs_path) + ret, _ = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def _rm(self, fs_path): + cmd = "{} -rm {}".format(self._base_cmd, fs_path) + ret, _ = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + def delete(self, fs_path): + if not self.is_exist(fs_path): + return + + is_dir = self.is_dir(fs_path) + if is_dir: + return self._rmr(fs_path) + + return self._rm(fs_path) + + def need_upload_download(self): + return True diff --git a/python/paddle/fleet/utils/http_server.py b/python/paddle/fleet/utils/http_server.py new file mode 100644 index 00000000000000..78e310b0a5a516 --- /dev/null +++ b/python/paddle/fleet/utils/http_server.py @@ -0,0 +1,195 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Http Server.""" + +import logging + +import six +# NOTE: HTTPServer has a different name in python2 and python3 +if six.PY2: + from BaseHTTPServer import HTTPServer + import SimpleHTTPServer +else: + from http.server import HTTPServer + import http.server as SimpleHTTPServer + +import time +import threading +import socket + + +def get_logger(name, level, fmt): + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.FileHandler('http.log', mode='w') + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + +_http_server_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): + """ + kv handler class for kv http server, + it defines the way to get/set kv in server. + """ + + def do_GET(self): + """ + get method for kv handler, get value according to key. + """ + log_str = "GET " + self.address_string() + self.path + paths = self.path.split('/') + if len(paths) < 3: + print('len of request path must be 3: ' + self.path) + self.send_status_code(400) + return + _, scope, key = paths + with self.server.kv_lock: + value = self.server.kv.get(scope, {}).get(key) + if value is None: + log_str += ' , key not found: ' + key + self.send_status_code(404) + else: + log_str += ' , key found: ' + key + self.send_response(200) + self.send_header("Content-Length", str(len(value))) + self.end_headers() + self.wfile.write(value) + _http_server_logger.info(log_str) + + def do_PUT(self): + """ + put method for kv handler, set value according to key. + """ + log_str = "PUT " + self.address_string() + self.path + paths = self.path.split('/') + if len(paths) < 3: + print('len of request path must be 3: ' + self.path) + self.send_status_code(400) + return + _, scope, key = paths + content_length = int(self.headers['Content-Length']) + try: + value = self.rfile.read(content_length) + except: + print("receive error invalid request") + self.send_status_code(404) + return + with self.server.kv_lock: + if self.server.kv.get(scope) is None: + self.server.kv[scope] = {} + self.server.kv[scope][key] = value + self.send_status_code(200) + _http_server_logger.info(log_str) + + def do_DELETE(self): + """ + delete method for kv handler, set value according to key. + """ + log_str = "DELETE " + self.address_string() + self.path + paths = self.path.split('/') + if len(paths) < 3: + print('len of request path must be 3: ' + self.path) + self.send_status_code(400) + return + _, scope, key = paths + with self.server.delete_kv_lock: + if self.server.delete_kv.get(scope) is None: + self.server.delete_kv[scope] = [] + self.server.delete_kv[scope].append(key) + self.send_status_code(200) + _http_server_logger.info(log_str) + + def log_message(self, format, *args): + """ + ignore all logging messages in kv handler. + """ + pass + + def send_status_code(self, code): + """ + send status code back to client. + """ + self.send_response(code) + self.send_header("Content-Length", 0) + self.end_headers() + + +class KVHTTPServer(HTTPServer, object): + """ + it is a http server storing kv pairs. + """ + + def __init__(self, port, handler): + """Init.""" + super(KVHTTPServer, self).__init__(('', port), handler) + self.delete_kv_lock = threading.Lock() + self.delete_kv = {} + self.kv_lock = threading.Lock() + self.kv = {} + + def get_deleted_size(self, key): + """ + get deleted size in key. + """ + ret = 0 + with self.delete_kv_lock: + ret = self.delete_kv.get(key, 0) + return ret + + +class KVServer: + """ + it is a server storing kv pairs, has a http server inside. + """ + + def __init__(self, port, size={}): + """Init.""" + self.http_server = KVHTTPServer(port, KVHandler) + self.listen_thread = None + self.size = {} + + def start(self): + """ + start server until user calls stop to let it quit. + """ + self.listen_thread = threading.Thread( + target=lambda: self.http_server.serve_forever()) + self.listen_thread.start() + + def stop(self): + """ + stop server and clear its resources. + """ + self.http_server.shutdown() + self.listen_thread.join() + self.http_server.server_close() + + def shoud_stop(self): + """ + return whether the server should stop. + + Returns: + ret(bool): whether the server should stop + """ + for key in self.size: + s = self.http_server.get_deleted_size(key) + if s != self.size.get(key, 0): + return False + return True diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index de4330cf51669e..82018132cc8b86 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -16,12 +16,13 @@ import os import collections -from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase +from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer import pickle import six from . import learning_rate_scheduler import warnings from .. import core +from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME, _load_persistable_vars __all__ = [ 'save_dygraph', @@ -140,22 +141,83 @@ def load_dygraph(model_path, keep_name_table=False): elif model_prefix.endswith(".pdopt"): model_prefix = model_prefix[:-6] - params_file_path = model_prefix + ".pdparams" - if not os.path.exists(params_file_path): - raise RuntimeError("Parameter file [ {} ] not exists".format( - params_file_path)) - - with open(params_file_path, 'rb') as f: - para_dict = pickle.load(f) if six.PY2 else pickle.load( - f, encoding='latin1') - - if not keep_name_table and "StructuredToParameterName@@" in para_dict: - del para_dict["StructuredToParameterName@@"] + para_dict = None opti_dict = None + params_file_path = model_prefix + ".pdparams" opti_file_path = model_prefix + ".pdopt" - if os.path.exists(opti_file_path): - with open(opti_file_path, 'rb') as f: - opti_dict = pickle.load(f) if six.PY2 else pickle.load( - f, encoding='latin1') + if not os.path.exists(params_file_path) and not os.path.exists( + opti_file_path): + # Load state dict by `jit.save` save format + # TODO(chenweihang): [Why not support `io.save_infernece_model` save format here] + # The model saved by `save_inference_model` does not completely correspond to + # the information required by the `state_dict` under the dygraph. + # Although we reluctantly restore the `state_dict` in some scenarios, + # this may not be complete and there are some limitations, so this function + # will be considered later. The limitations include: + # 1. `save_inference_model` not save structured name, we need to remind + # the user to configure the `use_structured_name` argument when `set_dict`, + # but this argument is currently not public + # 2. if `save_inference_model` save all persistable variables in a single file, + # user need to give the variable name list to load `state_dict` + + # 1. check model path + if not os.path.isdir(model_prefix): + raise ValueError("Model saved directory '%s' is not exists." % + model_prefix) + # 2. load `__variables.info__` + var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME) + if not os.path.exists(var_info_path): + raise RuntimeError( + "No target can be loaded. Now only supports loading `state_dict` from " + "the result saved by `imperative.save` and `imperative.jit.save`." + ) + with open(var_info_path, 'rb') as f: + extra_var_info = pickle.load(f) + # 3. load `__variables__` + # TODO(chenweihang): now only supports loading from default save format: + # - all persistable vars saved in one file named `__variables__` + # for other case, we may need to modify the arguments of this API + var_file_path = os.path.join(model_prefix, VARIABLE_FILENAME) + if not os.path.exists(var_file_path): + raise RuntimeError( + "The parameter file to be loaded was not found. " + "Now only supports loading from the default save format, " + "and does not support custom params_filename and " + "save parameters separately.") + # 4. load all persistable vars + load_var_list = [] + for name in sorted(extra_var_info): + var = _varbase_creator(name=name, persistable=True) + load_var_list.append(var) + _dygraph_tracer().trace_op( + type='load_combine', + inputs={}, + outputs={'Out': load_var_list}, + attrs={'file_path': var_file_path}) + # 5. construct state_dict + para_dict = dict() + for var in load_var_list: + structured_name = extra_var_info[var.name].get('structured_name', + None) + if structured_name is None: + raise RuntimeError( + "Cannot find saved variable (%s)'s structured name in saved model.", + var.name) + para_dict[structured_name] = var.numpy() + # NOTE: `jit.save` doesn't save optimizer state + else: + # Load state dict by `save_dygraph` save format + if os.path.exists(params_file_path): + with open(params_file_path, 'rb') as f: + para_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + + if not keep_name_table and "StructuredToParameterName@@" in para_dict: + del para_dict["StructuredToParameterName@@"] + + if os.path.exists(opti_file_path): + with open(opti_file_path, 'rb') as f: + opti_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') return para_dict, opti_dict diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py new file mode 100644 index 00000000000000..c99e23703d9241 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -0,0 +1,104 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import traceback + +from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map + +ERROR_DATA = "Error data about original source code information and traceback." + + +def attach_error_data(error): + """ + Attachs error data about original source code information and traceback to an error. + + Args: + error(Exception): An native error. + + Returns: + An error attached data about original source code information and traceback. + """ + e_type, e_value, e_traceback = sys.exc_info() + tb = traceback.extract_tb(e_traceback)[1:] + + error_data = ErrorData(e_type, e_value, tb, global_origin_info_map) + setattr(error, ERROR_DATA, error_data) + + return error + + +class TraceBackFrame(OriginInfo): + """ + Traceback frame information. + """ + + def __init__(self, location, function_name, source_code): + self.location = location + self.function_name = function_name + self.source_code = source_code + + +class ErrorData(object): + """ + Error data attached to an exception which is raised in un-transformed code. + + TODO(liym27): Consider the case that op_callstack when error raised from c++ code + """ + + def __init__(self, error_type, error_value, origin_traceback, + origin_info_map): + self.error_type = error_type + self.error_value = error_value + self.origin_traceback = origin_traceback + self.origin_info_map = origin_info_map + + def create_exception(self): + message = self.create_message() + new_exception = self.error_type(message) + setattr(new_exception, ERROR_DATA, self) + return new_exception + + def create_message(self): + """ + Creates a custom error message which includes trace stack with source code information of dygraph from user. + """ + message_lines = [] + + # Step1: Adds header message to prompt users that the following is the original information. + header_message = "In user code:" + message_lines.append(header_message) + message_lines.append("") + + # Step2: Optimizes stack information with source code information of dygraph from user. + for filepath, lineno, funcname, code in self.origin_traceback: + loc = Location(filepath, lineno) + + dygraph_func_info = self.origin_info_map.get(loc.line_location, + None) + if dygraph_func_info: + # TODO(liym27): more information to prompt users that this is the original information. + # Replaces trace stack information about transformed static code with original dygraph code. + traceback_frame = self.origin_info_map[loc.line_location] + else: + traceback_frame = TraceBackFrame(loc, funcname, code) + + message_lines.append(traceback_frame.formated_message()) + + # Step3: Adds error message like "TypeError: dtype must be int32, but received float32". + error_message = " " * 4 + traceback.format_exception_only( + self.error_type, self.error_value)[0].strip("\n") + message_lines.append(error_message) + + return '\n'.join(message_lines) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 6b9ee9cbbe21b4..c66778992c25c6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -39,32 +39,21 @@ def create_while_node(condition_name, body_name, loop_var_names): - while_args = [] - while_args.append( - gast.Name( - id=condition_name, - ctx=gast.Param(), - annotation=None, - type_comment=None)) - while_args.append( - gast.Name( - id=body_name, ctx=gast.Param(), annotation=None, type_comment=None)) - assign_targets = [ - gast.Name( - id=var_name, ctx=gast.Param(), annotation=None, type_comment=None) - for var_name in loop_var_names - ] - while_args.append(gast.List(elts=assign_targets, ctx=gast.Param())) - - while_func_id = gast.parse( - 'fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop' - ).body[0].value - while_node = gast.Call(func=while_func_id, args=while_args, keywords=[]) - assign_node = gast.Assign( - targets=[gast.Tuple( - elts=assign_targets, ctx=gast.Store())], - value=while_node) - return assign_node + # NOTE(liym27): + # It's better to parse the source code into an AST node than to customize an AST node + # including child nodes, because it is easy to mistake the ast node type when customizing the node. + # + # For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name, + # but the type of `foo.x` gast.Attribute. + + while_func_name = "fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop" + while_node_str = "[{}] = {}({}, {}, [{}])".format( + ",".join(loop_var_names), while_func_name, condition_name, body_name, + ",".join(loop_var_names)) + + while_node = gast.parse(while_node_str).body[0] + + return while_node class NameVisitor(gast.NodeVisitor): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py index 429fa27f618765..d37201bfc55c47 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py @@ -21,6 +21,7 @@ # NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. ORIGI_INFO = "Original information of source code for ast node." +ORIGI_INFO_MAP = "Original information map of source code." class Location(object): @@ -64,6 +65,11 @@ def __str__(self): return "{} \nsource_code: {} in function {}\n ".format( self.location, self.source_code, self.function_name) + def formated_message(self): + return ' File "{}", line {}, in {}\n\t{}'.format( + self.location.filepath, self.location.lineno, self.function_name, + self.source_code.lstrip()) + class OriginInfoAttacher(gast.NodeTransformer): """ @@ -119,7 +125,12 @@ def _abs_col_offset(self, node): return self.col_offset + node.col_offset -def create_origin_info_map(transformed_node, static_func): +global_origin_info_map = {} + + +def create_and_update_origin_info_map(transformed_node, + static_func, + is_global=True): """ Creates a original information map between transformed static function and original dygraph function. @@ -156,6 +167,10 @@ def create_origin_info_map(transformed_node, static_func): origin_info_map[static_loc] = dygraph_info + global_origin_info_map.update(origin_info_map) + if is_global: + return global_origin_info_map + return origin_info_map diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py index 1b6b64ae1fdee8..d555c8ed28f358 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py @@ -47,8 +47,7 @@ def transform(self): # NOTE: deal with print in PY3 def visit_Call(self, node): if isinstance(node.func, gast.Name) and node.func.id == 'print': - convert_print_node = self._create_print_node(node.args) - return gast.Expr(value=convert_print_node) + node = self._create_print_node(node.args) return node # NOTE: deal with print in PY2 diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 79e812ff6192bc..e68719595d86b8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -36,6 +36,8 @@ from paddle.fluid.dygraph.base import param_guard from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from +from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map +from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA __all__ = ['ProgramTranslator', 'convert_to_static'] @@ -88,15 +90,23 @@ def foo(x, y): # with decorator directly and function.__wrapped__ holds the actual function. func = getattr(func, '__wrapped__', func) source_code = func_to_source_code(func) + + # TODO(liym27): + # Consider this case: source_code in self._code_to_ast_caches, + # but actually they are methods in different classes. + # Maybe use (__class__, source_code) as key if source_code in self._code_to_ast_caches: root_wrapper = self._code_to_ast_caches[source_code] else: root = gast.parse(source_code) + root = attach_origin_info(root, func) root_wrapper = self._dygraph_to_static.get_static_ast(root) self._code_to_ast_caches[source_code] = root_wrapper # Get static function from AST static_func, file_name = ast_to_func(root_wrapper.node, func) + + create_and_update_origin_info_map(root_wrapper.node, static_func) return static_func def exist(self, func): @@ -125,6 +135,7 @@ def __init__(self, func, args, kwargs): self._args = args self._kwargs = kwargs + # TODO(liym27): func has multi layer decorator dyfunc = getattr(func, '__wrapped__', func) self._dyfunc_code = inspect.getsource(dyfunc) @@ -282,7 +293,13 @@ def from_func_spec(func_spec): # 3. Builds program only once and returns the output Variables. with param_guard(func_spec.parameters(False)), param_guard( func_spec.buffers(False)): - outputs = static_func(*inputs) + try: + outputs = static_func(*inputs) + except BaseException as e: + # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. + attach_error_data(e) + raise + if not isinstance(outputs, (tuple, list)) and outputs is not None: outputs = [outputs] @@ -483,14 +500,24 @@ def func(x): return dygraph_func(*args, **kwargs) function_spec = FunctionSpec(dygraph_func, args, kwargs) - _, partial_program_layer = self._program_cache[function_spec] + concrete_program, partial_program_layer = self._program_cache[ + function_spec] if args and isinstance(args[0], layers.Layer): # Synchronize self.training attribute. partial_program_layer.training = args[0].training args = args[1:] - - return partial_program_layer(args) + try: + return partial_program_layer(args) + + except BaseException as e: + # NOTE: + # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before; + # 2. If e raised in runtime, e should be attached to ERROR_DATA here. + if not hasattr(e, ERROR_DATA): + # runtime error + attach_error_data(e) + raise def get_func(self, dygraph_func): """ diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 38e4e517836ed8..7396289392affa 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -425,8 +425,7 @@ def _load_persistable_vars(model_path, params_filename=None): # 1. load extra var info with open(var_info_path, 'rb') as f: - extra_var_info = pickle.load(f) if six.PY2 else pickle.load( - f, encoding='latin1') + extra_var_info = pickle.load(f) # 2. construct var dict load_var_dict = dict() diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 64faae247fbf80..128c4964c45018 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -15,20 +15,23 @@ from __future__ import print_function import os -import six import pickle - import warnings + +import six from paddle.fluid import core from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph -from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, FunctionSpec +from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA +from paddle.fluid.dygraph.dygraph_to_static.program_translator import FunctionSpec, ProgramTranslator +from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer from paddle.fluid.dygraph.layers import Layer from paddle.fluid.executor import Executor, scope_guard -from paddle.fluid.framework import Program, Block, Variable, ParamBase, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode +from paddle.fluid.framework import Block, ParamBase, Program, Variable +from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer +from paddle.fluid.framework import dygraph_only, in_dygraph_mode from paddle.fluid.wrapped_decorator import wrap_decorator -from paddle.fluid.dygraph.io import TranslatedLayer, VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME __all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func'] @@ -167,7 +170,15 @@ def __impl__(*args, **kwargs): "The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. " "We will just return dygraph output.") return dygraph_func(*args, **kwargs) - return program_translator.get_output(dygraph_func, *args, **kwargs) + try: + return program_translator.get_output(dygraph_func, *args, **kwargs) + except Exception as e: + error_data = getattr(e, ERROR_DATA, None) + if error_data: + new_exception = error_data.create_exception() + raise new_exception + else: + raise return __impl__ diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 3205c7921d246a..86651eec939d16 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -35,7 +35,7 @@ 'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding', 'GRUUnit', 'InstanceNorm', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', - 'SpectralNorm', 'TreeConv', 'SyncBatchNorm' + 'SpectralNorm', 'TreeConv', 'Flatten', 'SyncBatchNorm' ] @@ -3227,10 +3227,10 @@ class SyncBatchNorm(layers.Layer): .. math:: \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ - \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + \\sigma_{\\beta}^{2} + \\eps}} \\qquad &//\ normalize \\\\ y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift - - :math:`\\epsilon` : add a smaller value to the variance to prevent division by zero + - :math:`\\eps` : add a smaller value to the variance to prevent division by zero - :math:`\\gamma` : trainable proportional parameter - :math:`\\beta` : trainable deviation parameter @@ -3259,8 +3259,9 @@ class SyncBatchNorm(layers.Layer): Examples: .. code-block:: python - from paddle.fluid.dygraph import to_variable import paddle.nn as nn + import paddle.fluid as fluid + from paddle.fluid.dygraph import to_variable import numpy as np x = np.random.random(size=(3, 10, 3, 7)).astype('float32') @@ -3325,7 +3326,7 @@ def __init__(self, self._data_layout = 'NCHW' self._momentum = momentum - self._epsilon = epsilon + self._eps = eps self._track_running_stats = track_running_stats def forward(self, input): @@ -3344,7 +3345,7 @@ def forward(self, input): trainable_statistics = False if in_dygraph_mode(): - attrs = ("momentum", self._momentum, "epsilon", self._epsilon, + attrs = ("momentum", self._momentum, "epsilon", self._eps, "is_test", not self.training, "data_layout", self._data_layout, "use_mkldnn", False, "fuse_with_relu", False, "use_global_stats", use_global_stats, @@ -3360,7 +3361,7 @@ def forward(self, input): attrs = { "momentum": self._momentum, - "epsilon": self._epsilon, + "epsilon": self._eps, "is_test": not self.training, "data_layout": self._data_layout, "use_mkldnn": False, @@ -3395,3 +3396,58 @@ def forward(self, input): self._helper.append_op( type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) return sync_batch_norm_out + + +class Flatten(layers.Layer): + """ + :alias_main: paddle.nn.Flatten + :alias: paddle.nn.Flatten,paddle.nn.layer.Flatten,paddle.nn.layer.common.Flatten + This interface is used to construct a callable object of the ``FLatten`` class. + For more details, refer to code examples. + It implements flatten a contiguous range of dims into a tensor. + + Equation: + + Parameters: + start_axis(int): first dim to flatten (default = 1) + stop_axis(int): last dim to flatten (default = -1). + + + .. code-block:: python + + import paddle + from paddle.imperative import to_variable + import numpy as np + + inp_np = np.ones([5, 2, 3, 4]).astype('float32') + + paddle.enable_imperative() + + inp_np = to_variable(inp_np) + flatten = paddle.nn.Flatten(start_axis=1, stop_axis=2) + flatten_res = flatten(inp_np) + + """ + + def __init__(self, start_axis=1, stop_axis=-1): + super(Flatten, self).__init__() + self.start_axis = start_axis + self.stop_axis = stop_axis + + def forward(self, input): + out = self._helper.create_variable_for_type_inference(input.dtype) + x_shape = self._helper.create_variable_for_type_inference(input.dtype) + + if in_dygraph_mode(): + dy_out, _ = core.ops.flatten_contiguous_range( + input, 'start_axis', self.start_axis, 'stop_axis', + self.stop_axis) + return dy_out + self._helper.append_op( + type="flatten_contiguous_range", + inputs={"X": input}, + outputs={"Out": out, + "XShape": x_shape}, + attrs={"start_axis": self.start_axis, + "stop_axis": self.stop_axis}) + return out diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index 9be1fe92d1d0c7..f236a3e98c61ba 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -21,7 +21,7 @@ from paddle.fluid.optimizer import SGD from paddle.fluid.incubate.fleet.base.mode import Mode -from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase +from paddle.fleet.base.role_maker import RoleMakerBase from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecision from . import mode @@ -209,7 +209,10 @@ def init(self, role_maker=None): self._executor = Executor(fluid.CPUPlace()) if role_maker and not isinstance(role_maker, RoleMakerBase): - raise TypeError("role_maker must be an instance of RoleMakerBase") + from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase as RoleMakerBaseIncubate + if role_maker and not isinstance(role_maker, RoleMakerBaseIncubate): + raise TypeError( + "role_maker must be an instance of RoleMakerBase") self._role_maker = role_maker self._role_maker.generate_role() diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 7906f563c0009a..ea6abe2d335e66 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -471,9 +471,9 @@ def rpn_target_assign(bbox_pred, def sigmoid_focal_loss(x, label, fg_num, gamma=2.0, alpha=0.25): """ - :alias_main: paddle.nn.functional.sigmoid_focal_loss - :alias: paddle.nn.functional.sigmoid_focal_loss,paddle.nn.functional.loss.sigmoid_focal_loss - :old_api: paddle.fluid.layers.sigmoid_focal_loss + :alias_main: paddle.nn.functional.sigmoid_focal_loss + :alias: paddle.nn.functional.sigmoid_focal_loss,paddle.nn.functional.loss.sigmoid_focal_loss + :old_api: paddle.fluid.layers.sigmoid_focal_loss **Sigmoid Focal Loss Operator.** @@ -628,9 +628,9 @@ def detection_output(loc, nms_eta=1.0, return_index=False): """ - :alias_main: paddle.nn.functional.detection_output - :alias: paddle.nn.functional.detection_output,paddle.nn.functional.vision.detection_output - :old_api: paddle.fluid.layers.detection_output + :alias_main: paddle.nn.functional.detection_output + :alias: paddle.nn.functional.detection_output,paddle.nn.functional.vision.detection_output + :old_api: paddle.fluid.layers.detection_output Given the regression locations, classification confidences and prior boxes, calculate the detection outputs by performing following steps: @@ -761,9 +761,9 @@ class number, M is number of bounding boxes. @templatedoc() def iou_similarity(x, y, box_normalized=True, name=None): """ - :alias_main: paddle.nn.functional.iou_similarity - :alias: paddle.nn.functional.iou_similarity,paddle.nn.functional.loss.iou_similarity - :old_api: paddle.fluid.layers.iou_similarity + :alias_main: paddle.nn.functional.iou_similarity + :alias: paddle.nn.functional.iou_similarity,paddle.nn.functional.loss.iou_similarity + :old_api: paddle.fluid.layers.iou_similarity ${comment} @@ -821,9 +821,9 @@ def box_coder(prior_box, name=None, axis=0): """ - :alias_main: paddle.nn.functional.box_coder - :alias: paddle.nn.functional.box_coder,paddle.nn.functional.vision.box_coder - :old_api: paddle.fluid.layers.box_coder + :alias_main: paddle.nn.functional.box_coder + :alias: paddle.nn.functional.box_coder,paddle.nn.functional.vision.box_coder + :old_api: paddle.fluid.layers.box_coder **Box Coder Layer** @@ -1012,9 +1012,9 @@ def yolov3_loss(x, name=None, scale_x_y=1.): """ - :alias_main: paddle.nn.functional.yolov3_loss - :alias: paddle.nn.functional.yolov3_loss,paddle.nn.functional.vision.yolov3_loss - :old_api: paddle.fluid.layers.yolov3_loss + :alias_main: paddle.nn.functional.yolov3_loss + :alias: paddle.nn.functional.yolov3_loss,paddle.nn.functional.vision.yolov3_loss + :old_api: paddle.fluid.layers.yolov3_loss ${comment} @@ -1139,9 +1139,9 @@ def yolo_box(x, name=None, scale_x_y=1.): """ - :alias_main: paddle.nn.functional.yolo_box - :alias: paddle.nn.functional.yolo_box,paddle.nn.functional.vision.yolo_box - :old_api: paddle.fluid.layers.yolo_box + :alias_main: paddle.nn.functional.yolo_box + :alias: paddle.nn.functional.yolo_box,paddle.nn.functional.vision.yolo_box + :old_api: paddle.fluid.layers.yolo_box ${comment} @@ -1318,9 +1318,9 @@ def bipartite_match(dist_matrix, dist_threshold=None, name=None): """ - :alias_main: paddle.nn.functional.bipartite_match - :alias: paddle.nn.functional.bipartite_match,paddle.nn.functional.vision.bipartite_match - :old_api: paddle.fluid.layers.bipartite_match + :alias_main: paddle.nn.functional.bipartite_match + :alias: paddle.nn.functional.bipartite_match,paddle.nn.functional.vision.bipartite_match + :old_api: paddle.fluid.layers.bipartite_match This operator implements a greedy bipartite matching algorithm, which is used to obtain the matching with the maximum distance based on the input @@ -1412,9 +1412,9 @@ def target_assign(input, mismatch_value=None, name=None): """ - :alias_main: paddle.nn.functional.target_assign - :alias: paddle.nn.functional.target_assign,paddle.nn.functional.extension.target_assign - :old_api: paddle.fluid.layers.target_assign + :alias_main: paddle.nn.functional.target_assign + :alias: paddle.nn.functional.target_assign,paddle.nn.functional.extension.target_assign + :old_api: paddle.fluid.layers.target_assign This operator can be, for given the target bounding boxes or labels, to assign classification and regression targets to each prediction as well as @@ -1530,9 +1530,9 @@ def ssd_loss(location, normalize=True, sample_size=None): """ - :alias_main: paddle.nn.functional.ssd_loss - :alias: paddle.nn.functional.ssd_loss,paddle.nn.functional.loss.ssd_loss - :old_api: paddle.fluid.layers.ssd_loss + :alias_main: paddle.nn.functional.ssd_loss + :alias: paddle.nn.functional.ssd_loss,paddle.nn.functional.loss.ssd_loss + :old_api: paddle.fluid.layers.ssd_loss **Multi-box loss layer for object detection algorithm of SSD** @@ -1777,9 +1777,9 @@ def prior_box(input, name=None, min_max_aspect_ratios_order=False): """ - :alias_main: paddle.nn.functional.prior_box - :alias: paddle.nn.functional.prior_box,paddle.nn.functional.vision.prior_box - :old_api: paddle.fluid.layers.prior_box + :alias_main: paddle.nn.functional.prior_box + :alias: paddle.nn.functional.prior_box,paddle.nn.functional.vision.prior_box + :old_api: paddle.fluid.layers.prior_box This op generates prior boxes for SSD(Single Shot MultiBox Detector) algorithm. Each position of the input produce N prior boxes, N is determined by @@ -1938,9 +1938,9 @@ def density_prior_box(input, flatten_to_2d=False, name=None): """ - :alias_main: paddle.nn.functional.density_prior_box - :alias: paddle.nn.functional.density_prior_box,paddle.nn.functional.vision.density_prior_box - :old_api: paddle.fluid.layers.density_prior_box + :alias_main: paddle.nn.functional.density_prior_box + :alias: paddle.nn.functional.density_prior_box,paddle.nn.functional.vision.density_prior_box + :old_api: paddle.fluid.layers.density_prior_box This op generates density prior boxes for SSD(Single Shot MultiBox Detector) @@ -2130,7 +2130,7 @@ def multi_box_head(inputs, name=None, min_max_aspect_ratios_order=False): """ - :api_attr: Static Graph + :api_attr: Static Graph Base on SSD ((Single Shot MultiBox Detector) algorithm, generate prior boxes, regression location and classification confidence on multiple input feature @@ -2407,9 +2407,9 @@ def anchor_generator(input, offset=0.5, name=None): """ - :alias_main: paddle.nn.functional.anchor_generator - :alias: paddle.nn.functional.anchor_generator,paddle.nn.functional.vision.anchor_generator - :old_api: paddle.fluid.layers.anchor_generator + :alias_main: paddle.nn.functional.anchor_generator + :alias: paddle.nn.functional.anchor_generator,paddle.nn.functional.vision.anchor_generator + :old_api: paddle.fluid.layers.anchor_generator **Anchor generator operator** @@ -2612,9 +2612,9 @@ def generate_proposal_labels(rpn_rois, is_cls_agnostic=False, is_cascade_rcnn=False): """ - :alias_main: paddle.nn.functional.generate_proposal_labels - :alias: paddle.nn.functional.generate_proposal_labels,paddle.nn.functional.vision.generate_proposal_labels - :old_api: paddle.fluid.layers.generate_proposal_labels + :alias_main: paddle.nn.functional.generate_proposal_labels + :alias: paddle.nn.functional.generate_proposal_labels,paddle.nn.functional.vision.generate_proposal_labels + :old_api: paddle.fluid.layers.generate_proposal_labels **Generate Proposal Labels of Faster-RCNN** @@ -2737,9 +2737,9 @@ def generate_proposal_labels(rpn_rois, def generate_mask_labels(im_info, gt_classes, is_crowd, gt_segms, rois, labels_int32, num_classes, resolution): """ - :alias_main: paddle.nn.functional.generate_mask_labels - :alias: paddle.nn.functional.generate_mask_labels,paddle.nn.functional.vision.generate_mask_labels - :old_api: paddle.fluid.layers.generate_mask_labels + :alias_main: paddle.nn.functional.generate_mask_labels + :alias: paddle.nn.functional.generate_mask_labels,paddle.nn.functional.vision.generate_mask_labels + :old_api: paddle.fluid.layers.generate_mask_labels **Generate Mask Labels for Mask-RCNN** @@ -2896,9 +2896,9 @@ def generate_proposals(scores, name=None, return_rois_num=False): """ - :alias_main: paddle.nn.functional.generate_proposals - :alias: paddle.nn.functional.generate_proposals,paddle.nn.functional.vision.generate_proposals - :old_api: paddle.fluid.layers.generate_proposals + :alias_main: paddle.nn.functional.generate_proposals + :alias: paddle.nn.functional.generate_proposals,paddle.nn.functional.vision.generate_proposals + :old_api: paddle.fluid.layers.generate_proposals **Generate proposal Faster-RCNN** diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a46391452bb243..1b8df4a098ff1c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9911,7 +9911,7 @@ def flatten(x, axis=1, name=None): return out -def stack(x, axis=0): +def stack(x, axis=0, name=None): """ This OP stacks all the inputs :code:`x` along axis. @@ -9991,15 +9991,16 @@ def stack(x, axis=0): data = layers.stack(x1) # stack according to axis 0, data.shape=[1, None, 1, 2] """ - - helper = LayerHelper('stack', **locals()) axis = 0 if axis is None else axis - if not isinstance(x, list) and not isinstance(x, tuple): x = [x] + + if in_dygraph_mode(): + return core.ops.stack(x, 'axis', axis) + + helper = LayerHelper('stack', **locals()) out = helper.create_variable_for_type_inference(x[0].dtype) - if not in_dygraph_mode() and \ - x[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY: + if x[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY: assert len(x) == 1, "If the elements of 'x' in stack are Variable(LoDTensorArray), " \ "number of the elements must be 1, but received %s." % len(x) out_index = helper.create_variable_for_type_inference(dtype="int32") diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 0289ecea34acf6..1f96bbc4ceeac1 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -28,7 +28,6 @@ from .unique_name import UniqueNameGenerator import logging import warnings -from .dataset import DatasetBase, InMemoryDataset ### Dygraph DataLoader configs ### import os @@ -1670,7 +1669,7 @@ def generator(): class DatasetLoader(DataLoaderBase): def __init__(self, dataset, places, drop_last): - assert isinstance(dataset, + assert isinstance(dataset, paddle.fleet.dataset. DatasetBase), "dataset must be type of DatasetBase" assert not in_dygraph_mode( ), "DatasetLoader is not supported in dygraph mode yet" @@ -1686,7 +1685,7 @@ def __init__(self, dataset, places, drop_last): dataset.set_thread(thread_num) - if isinstance(dataset, + if isinstance(dataset, paddle.fleet.dataset. InMemoryDataset) and dataset.queue_num > thread_num: logging.warn("queue_num {} which is set in Dataset is ignored". format(dataset.queue_num)) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index d73b9511b76ed6..686844fea76c01 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -345,7 +345,6 @@ if(WITH_DISTRIBUTE) # FIXME(typhoonzero): add these tests back list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transpiler") - list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_ctr") #not need list(REMOVE_ITEM DIST_TEST_OPS "test_dist_base") diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index 6bf95b9d6715bf..033bc385005219 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -28,6 +28,7 @@ import ctr_dataset_reader from test_dist_fleet_base import runtime_main, FleetDistRunnerBase +from paddle.fleet.base.util_factory import fleet_util # Fix seed for test fluid.default_startup_program().random_seed = 1 @@ -181,8 +182,14 @@ def do_pyreader_training(self, fleet): loss_val = exe.run(program=compiled_prog, fetch_list=[self.avg_cost.name]) loss_val = np.mean(loss_val) - print("TRAIN ---> pass: {} loss: {}\n".format(epoch_id, - loss_val)) + reduce_output = fleet_util.all_reduce( + np.array(loss_val), mode="sum") + loss_all_trainer = fleet_util.all_gather(float(loss_val)) + loss_val = float(reduce_output) / len(loss_all_trainer) + message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id, + loss_val) + fleet_util.print_on_rank(message, 0) + pass_time = time.time() - pass_start except fluid.core.EOFException: self.reader.reset() @@ -210,7 +217,7 @@ def do_dataset_training(self, fleet): filelist.append(train_file_path) # config dataset - dataset = fluid.DatasetFactory().create_dataset() + dataset = paddle.fleet.DatasetFactory().create_dataset() dataset.set_batch_size(batch_size) dataset.set_use_var(self.feeds) pipe_command = 'python ctr_dataset_reader.py' diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py new file mode 100644 index 00000000000000..54c82cb895c5ea --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -0,0 +1,136 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import inspect +import unittest + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.core import EnforceNotMet +from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA, ErrorData +from paddle.fluid.dygraph.dygraph_to_static.origin_info import unwrap +from paddle.fluid.dygraph.jit import declarative + + +def inner_func(): + fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int") + return + + +@declarative +def func_error_in_compile_time(x): + x = fluid.dygraph.to_variable(x) + inner_func() + if fluid.layers.mean(x) < 0: + x_v = x - 1 + else: + x_v = x + 1 + return x_v + + +@declarative +def func_error_in_compile_time_2(x): + x = fluid.dygraph.to_variable(x) + x = fluid.layers.reshape(x, shape=[1, 2]) + return x + + +@declarative +def func_error_in_runtime(x, iter_num=3): + x = fluid.dygraph.to_variable(x) + a = [] + iter_num = fluid.layers.fill_constant( + shape=[1], value=iter_num, dtype="int32") + for i in range(iter_num): + a.append(b) + a = fluid.layers.concat(a, axis=0) + return a + + +class TestErrorInCompileTime(unittest.TestCase): + def setUp(self): + self.set_func() + self.set_input() + self.set_exception_type() + + def set_func(self): + self.func = func_error_in_compile_time + + def set_exception_type(self): + self.exception_type = TypeError + + def set_input(self): + self.input = np.ones([3, 2]) + + def set_message(self): + self.expected_message = \ + ['File "{}", line 36, in func_error_in_compile_time'.format(self.filepath), + 'inner_func()', + 'File "{}", line 29, in inner_func'.format(self.filepath), + 'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + ] + + def _test_create_message(self, error_data): + self.filepath = inspect.getfile(unwrap(self.func)) + self.set_message() + error_message = error_data.create_message() + + self.assertIn('In user code:', error_message) + for m in self.expected_message: + self.assertIn(m, error_message) + + def test(self): + with fluid.dygraph.guard(): + with self.assertRaises(self.exception_type) as cm: + self.func(self.input) + exception = cm.exception + error_data = getattr(exception, ERROR_DATA) + self.assertIsInstance(error_data, ErrorData) + self._test_create_message(error_data) + + +class TestErrorInCompileTime2(TestErrorInCompileTime): + def set_func(self): + self.func = func_error_in_compile_time_2 + + def set_exception_type(self): + self.exception_type = EnforceNotMet + + def set_message(self): + + self.expected_message = \ + [ + 'File "{}", line 47, in func_error_in_compile_time_2'.format(self.filepath), + 'x = fluid.layers.reshape(x, shape=[1, 2])' + ] + + +# TODO(liym27): Consider the case that op_callstack when error raised from c++ code +class TestErrorInRuntime(TestErrorInCompileTime): + def set_func(self): + self.func = func_error_in_runtime + + def set_exception_type(self): + self.exception_type = EnforceNotMet + + def test(self): + with fluid.dygraph.guard(): + with self.assertRaises(self.exception_type) as cm: + self.func(self.input) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py index 631655ec744283..b03777b6ebc7f3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py @@ -90,7 +90,8 @@ def _get_OriginInfo_map(self): # step3 self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func) - info_map = create_origin_info_map(dygraph_ast, self.static_func) + info_map = create_and_update_origin_info_map(dygraph_ast, + self.static_func) return info_map diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py new file mode 100644 index 00000000000000..cfbbf7de22087d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.test_fusion_gru_op import TestFusionGRUOp + + +class TestFusionGRUMKLDNNOp(TestFusionGRUOp): + def set_confs(self): + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpNoInitial(TestFusionGRUOp): + def set_confs(self): + self.with_h0 = False + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpNoBias(TestFusionGRUOp): + def set_confs(self): + self.with_bias = False + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpReverse(TestFusionGRUOp): + def set_confs(self): + self.is_reverse = True + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpOriginMode(TestFusionGRUOp): + def set_confs(self): + self.origin_mode = True + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpMD1(TestFusionGRUOp): + def set_confs(self): + self.M = 36 + self.D = 8 + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpMD2(TestFusionGRUOp): + def set_confs(self): + self.M = 8 + self.D = 8 + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpMD3(TestFusionGRUOp): + def set_confs(self): + self.M = 17 + self.D = 15 + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpBS1(TestFusionGRUOp): + def set_confs(self): + self.lod = [[3]] + self.D = 16 + self.use_mkldnn = True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/multi_process.py b/python/paddle/fluid/tests/unittests/multi_process.py index a67634adfcc0c2..f999ce803a5124 100644 --- a/python/paddle/fluid/tests/unittests/multi_process.py +++ b/python/paddle/fluid/tests/unittests/multi_process.py @@ -17,7 +17,7 @@ import time -def train(): +def train(prefix): selected_gpus = os.getenv("FLAGS_selected_gpus") trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS") @@ -29,11 +29,12 @@ def train(): .format(selected_gpus, worker_endpoints, trainers_num, current_endpoint,trainer_id) print(name) - with open("multi_process.check_{}.log".format(trainer_id), "w") as f: + with open("multi_process_{}.check_{}.log".format(prefix, trainer_id), + "w") as f: f.write(name) -def train_abort(): +def train_abort(prefix): selected_gpus = os.getenv("FLAGS_selected_gpus") trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS") @@ -49,8 +50,9 @@ def train_abort(): name = "abort>>> selected_gpus:{} worker_endpoints:{} trainers_num:{} current_endpoint:{} trainer_id:{}"\ .format(selected_gpus, worker_endpoints, trainers_num, current_endpoint,trainer_id) print(name) - with open("multi_process.check_{}.log".format(trainer_id), - "w") as f: + with open( + "multi_process_{}.check_{}.log".format(prefix, trainer_id), + "w") as f: f.write(name) raise else: @@ -60,12 +62,15 @@ def train_abort(): .format(selected_gpus, worker_endpoints, trainers_num, current_endpoint,trainer_id) print(name) - with open("multi_process.check_{}.log".format(trainer_id), "w") as f: + with open("multi_process_{}.check_{}.log".format(prefix, trainer_id), + "w") as f: f.write(name) if __name__ == '__main__': - if len(sys.argv) == 2 and sys.argv[1] == "abort": - train_abort() + if len(sys.argv) == 3 and sys.argv[2] == "abort": + prefix = sys.argv[1] + train_abort(prefix) else: - train() + prefix = sys.argv[1] + train(prefix) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py index 96be49e277deb2..1cd02dadcc2876 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py @@ -26,8 +26,7 @@ import paddle.fluid.dygraph as dygraph from paddle.fluid import core from paddle.fluid.optimizer import SGDOptimizer -#from paddle.nn import Conv2D, Pool2D, Linear, SyncBatchNorm -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, SyncBatchNorm +from paddle.nn import Conv2D, Pool2D, Linear, SyncBatchNorm from paddle.fluid.dygraph.base import to_variable from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase diff --git a/python/paddle/fluid/tests/unittests/test_addmm_op.py b/python/paddle/fluid/tests/unittests/test_addmm_op.py index 8c0b599a37936f..0bcdc45a2ccd0f 100644 --- a/python/paddle/fluid/tests/unittests/test_addmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_addmm_op.py @@ -63,18 +63,104 @@ class TestAddMMOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): # The input type of addmm_op must be Variable. + input = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace()) + np.array([[-1, -1], [-1, -1]]), [[2]], fluid.CPUPlace()) x1 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace()) + np.array([[-1, -1], [-1, -1]]), [[2]], fluid.CPUPlace()) x2 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace()) + np.array([[-1, -1], [-1, -1]]), [[2]], fluid.CPUPlace()) self.assertRaises(TypeError, paddle.addmm, input, x1, x2) + # The input dtype of mul_op must be float32 or float64. - input = fluid.layers.data(name='input', shape=[4], dtype="int32") - x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32") - x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32") + input = fluid.layers.data( + name='input', + shape=[4, 4], + dtype="int32", + append_batch_size=False) + x3 = fluid.layers.data( + name='x3', shape=[4, 4], dtype="int32", append_batch_size=False) + x4 = fluid.layers.data( + name='x4', shape=[4, 4], dtype="int32", append_batch_size=False) self.assertRaises(TypeError, paddle.addmm, input, x3, x4) + # x and y dimension mismatch + x5 = fluid.layers.data( + name='x5', + shape=[4, 5], + dtype="float32", + append_batch_size=False) + x6 = fluid.layers.data( + name='x6', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + self.assertRaises(ValueError, paddle.addmm, input, x5, x6) + # input and x are not broadcastable + x7 = fluid.layers.data( + name='x7', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + x8 = fluid.layers.data( + name='x8', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + input1 = fluid.layers.data( + name='input1', + shape=[2, 4], + dtype="float32", + append_batch_size=False) + self.assertRaises(ValueError, paddle.addmm, input1, x7, x8) + # input and x are not broadcastable + x9 = fluid.layers.data( + name='x9', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + x10 = fluid.layers.data( + name='x10', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + input2 = fluid.layers.data( + name='input2', + shape=[1, 2], + dtype="float32", + append_batch_size=False) + self.assertRaises(ValueError, paddle.addmm, input2, x9, x10) + x11 = fluid.layers.data( + name='x11', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + x12 = fluid.layers.data( + name='x12', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + input3 = fluid.layers.data( + name='input3', + shape=[4, 2], + dtype="float32", + append_batch_size=False) + self.assertRaises(ValueError, paddle.addmm, input3, x11, x12) + x13 = fluid.layers.data( + name='x13', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + x14 = fluid.layers.data( + name='x14', + shape=[4, 4], + dtype="float32", + append_batch_size=False) + input4 = fluid.layers.data( + name='input4', + shape=[3, 1], + dtype="float32", + append_batch_size=False) + self.assertRaises(ValueError, paddle.addmm, input4, x13, x14) class TestAddMMOp2(TestAddMMOp): @@ -147,5 +233,23 @@ def test_api_with_dygraph(self): assert np.allclose(np_input + np.dot(np_x, np_y), out.numpy()) +''' +class TestAddMMAPI(unittest.TestCase): + def test_api_error(self): + data_x = np.ones((2, 2)).astype(np.float32) + data_y = np.ones((2, 2)).astype(np.float32) + data_input = np.ones((2, 2)).astype(np.float32) + + paddle.enable_imperative() + + def test_error1(): + data_x_wrong = np.ones((2, 3)).astype(np.float32) + x = paddle.imperative.to_variable(data_x_wrong) + y = paddle.imperative.to_variable(data_y) + input = paddle.imperative.to_variable(data_input) + out = paddle.tensor.addmm( input=input, x=x, y=y, beta=0.5, alpha=5.0 ) + self.assertRaises(ValueError, test_error1) +''' + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index 993ac25d8d4b63..cb1b3ded53472c 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -73,5 +73,15 @@ def test_out(self): self.assertTrue(np.allclose(expected_result, out_np)) +class TestBmmAPIError(unittest.TestCase): + def test_api_error(self): + x_data = np.arange(24, dtype='float32').reshape((2, 3, 4)) + y_data = np.arange(16, dtype='float32').reshape((2, 4, 2)) + y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4)) + y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2)) + self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1) + self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index cc2cee602918d5..90d5f58539500b 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -17,6 +17,7 @@ """ from __future__ import print_function +import paddle import paddle.fluid as fluid import paddle.compat as cpt import paddle.fluid.core as core @@ -37,23 +38,26 @@ def setUp(self): def test_dataset_create(self): """ Testcase for dataset create. """ try: - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") except: self.assertTrue(False) try: - dataset = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "QueueDataset") except: self.assertTrue(False) try: - dataset = fluid.DatasetFactory().create_dataset( + dataset = paddle.fleet.DatasetFactory().create_dataset( "FileInstantDataset") except: self.assertTrue(False) try: - dataset = fluid.DatasetFactory().create_dataset("MyOwnDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "MyOwnDataset") self.assertTrue(False) except: self.assertTrue(True) @@ -91,7 +95,8 @@ def test_run_with_dump(self): name=slot, shape=[1], dtype="int64", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist( @@ -125,7 +130,7 @@ def test_dataset_config(self): dataset.set_trainer_num(4) dataset.set_hdfs_config("my_fs_name", "my_fs_ugi") dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi") - dataset.enable_pv_merge() + dataset.set_enable_pv_merge(False) thread_num = dataset.get_thread_num() self.assertEqual(thread_num, 12) @@ -171,7 +176,8 @@ def test_set_download_cmd(self): name=slot, shape=[1], dtype="int64", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist([filename1, filename2]) @@ -222,7 +228,8 @@ def test_in_memory_dataset_run(self): name=slot, shape=[1], dtype="int64", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist([ @@ -293,7 +300,8 @@ def test_in_memory_dataset_masterpatch(self): name=slot, shape=[1], dtype="float32", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(1) dataset.set_parse_ins_id(True) @@ -359,7 +367,8 @@ def test_in_memory_dataset_masterpatch1(self): name="slot4", shape=[1], dtype="float32", lod_level=0) slots_vars = [var1, var2, var3, var4] - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(1) dataset.set_parse_ins_id(True) @@ -414,7 +423,8 @@ def test_in_memory_dataset_run_2(self): name=slot, shape=[1], dtype="float32", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist([ @@ -507,7 +517,7 @@ def test_queue_dataset_run(self): name=slot, shape=[1], dtype="int64", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset("QueueDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist( @@ -532,7 +542,7 @@ def test_queue_dataset_run(self): except Exception as e: self.assertTrue(False) - dataset2 = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset2 = paddle.fleet.DatasetFactory().create_dataset("QueueDataset") dataset2.set_use_var(slots_vars) dataset2.set_batch_size(32) dataset2.set_thread(3) @@ -573,7 +583,7 @@ def test_queue_dataset_run_2(self): name=slot, shape=[1], dtype="float32", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset("QueueDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist( @@ -628,7 +638,8 @@ def test_queue_dataset_run_3(self): name=slot, shape=[None, 1], dtype="int64", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_input_type(1) dataset.set_batch_size(1) dataset.set_thread(2) @@ -707,7 +718,7 @@ def get_dataset(self, inputs, files): inputs(list): inputs of get_dataset files(list): files of get_dataset """ - dataset = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset("QueueDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist(files) @@ -864,7 +875,8 @@ def test_dataset_fleet(self): except ImportError as e: print("warning: no mpi4py") exe.run(startup_program) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist([ @@ -884,9 +896,6 @@ def test_dataset_fleet2(self): """ Testcase for InMemoryDataset from create to run. """ - - self.skipTest("parameter server will add pslib UT later") - with open("test_in_memory_dataset2_run2_a.txt", "w") as f: data = "1 1 2 3 3 4 5 5 5 5 1 1\n" data += "1 2 2 3 4 4 6 6 6 6 1 2\n" @@ -902,7 +911,7 @@ def test_dataset_fleet2(self): train_program = fluid.Program() startup_program = fluid.Program() scope = fluid.Scope() - from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet + from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet with fluid.program_guard(train_program, startup_program): slots = ["slot1_ff", "slot2_ff", "slot3_ff", "slot4_ff"] slots_vars = [] @@ -936,7 +945,8 @@ def test_dataset_fleet2(self): except ImportError as e: print("warning: no mpi4py") exe.run(startup_program) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist([ @@ -952,6 +962,63 @@ def test_dataset_fleet2(self): print("warning: catch expected error") fleet._opt_info = None fleet._fleet_ptr = None + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") + dataset.set_rank_offset("") + dataset.set_pv_batch_size(1) + dataset.set_hdfs_config("", "") + d = paddle.fleet.DatasetBase() + try: + dataset.set_feed_type("MultiSlotInMemoryDataFeed") + except: + print("warning: catch expected error") + dataset.thread_num = 0 + try: + dataset._prepare_to_run() + except: + print("warning: catch expected error") + dataset.set_parse_logkey(True) + dataset.set_merge_by_sid(True) + dataset.set_enable_pv_merge(True) + try: + dataset.preprocess_instance() + except: + print("warning: catch expected error") + try: + dataset.set_current_phase(1) + except: + print("warning: catch expected error") + try: + dataset.postprocess_instance() + except: + print("warning: catch expected error") + dataset.set_fleet_send_batch_size(1024) + try: + dataset.global_shuffle() + except: + print("warning: catch expected error") + dataset.get_pv_data_size() + dataset.get_memory_data_size() + dataset.get_shuffle_data_size() + dataset = paddle.fleet.DatasetFactory().create_dataset( + "QueueDataset") + try: + dataset.local_shuffle() + except: + print("warning: catch expected error") + try: + dataset.global_shuffle() + except: + print("warning: catch expected error") + dataset = paddle.fleet.FileInstantDataset() + try: + dataset.local_shuffle() + except: + print("warning: catch expected error") + try: + dataset.global_shuffle() + except: + print("warning: catch expected error") os.remove("./test_in_memory_dataset2_run2_a.txt") os.remove("./test_in_memory_dataset2_run2_b.txt") diff --git a/python/paddle/fluid/tests/unittests/test_dataset_dataloader.py b/python/paddle/fluid/tests/unittests/test_dataset_dataloader.py index 10aefbb222bb02..22d59e78fff867 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset_dataloader.py +++ b/python/paddle/fluid/tests/unittests/test_dataset_dataloader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle import paddle.fluid as fluid import numpy as np import six @@ -96,7 +97,8 @@ def build_network(self): def check_batch_number(self, place, randomize_batch_num=False): main_prog, startup_prog, feeds = self.build_network() - dataset = fluid.DatasetFactory().create_dataset(self.dataset_name) + dataset = paddle.fleet.DatasetFactory().create_dataset( + self.dataset_name) dataset.set_batch_size(BATCH_SIZE) if isinstance(place, fluid.CPUPlace): diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py index 16f0fc0a35e614..8b2f7118ea766a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py @@ -21,6 +21,9 @@ import sys import subprocess +import six +import shutil +import numpy as np import argparse from contextlib import closing import socket @@ -29,7 +32,8 @@ import unittest import paddle.fluid as fluid -import paddle.fluid.incubate.fleet.base.role_maker as role_maker +import paddle.fleet.base.role_maker as role_maker +from paddle.fleet.base.util_factory import fleet_util from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory @@ -48,18 +52,26 @@ class FleetDistRunnerBase(object): """ def build_role(self, args): + if args.role.upper() == "PSERVER": role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=True, + path=args.gloo_path, current_id=args.current_id, role=role_maker.Role.SERVER, - worker_num=args.trainers, + worker_endpoints=args.trainer_endpoints.split(","), server_endpoints=args.endpoints.split(",")) else: role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=True, + path=args.gloo_path, current_id=args.current_id, role=role_maker.Role.WORKER, - worker_num=args.trainers, + worker_endpoints=args.trainer_endpoints.split(","), server_endpoints=args.endpoints.split(",")) + self.role = role return role def build_strategy(self, args): @@ -114,26 +126,13 @@ def build_optimizer(self, avg_cost, strategy): optimizer.minimize(avg_cost) def run_pserver(self, args): - fleet.init(self.build_role(args)) - strategy = self.build_strategy(args) - avg_cost = self.net(args) - self.build_optimizer(avg_cost, strategy) - fleet.init_server() fleet.run_server() def run_dataset_trainer(self, args): - fleet.init(self.build_role(args)) - strategy = self.build_strategy(args) - avg_cost = self.net(args) - self.build_optimizer(avg_cost, strategy) out = self.do_dataset_training(fleet) def run_pyreader_trainer(self, args): - fleet.init(self.build_role(args)) - strategy = self.build_strategy(args) - avg_cost = self.net(args) - self.build_optimizer(avg_cost, strategy) out = self.do_pyreader_training(fleet) def net(self, args, batch_size=4, lr=0.01): @@ -173,10 +172,14 @@ def setUp(self): print("set begin_port:", DIST_UT_PORT) self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( DIST_UT_PORT, DIST_UT_PORT + 1) - DIST_UT_PORT += 2 + self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + DIST_UT_PORT + 2, DIST_UT_PORT + 3) + DIST_UT_PORT += 4 else: self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._find_free_port(), self._find_free_port()) + self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) self._python_interp = sys.executable self._geo_sgd_need_push_nums = 5 @@ -236,18 +239,22 @@ def _start_trainer(self, cmd, required_envs): def _run_cluster(self, model, envs): env = {'GRAD_CLIP': str(self._grad_clip_mode)} python_path = self._python_interp + gloo_path = tempfile.mkdtemp() + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') python_path += " -m coverage run --branch -p" env.update(envs) - tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format( - python_path, model, self._ps_endpoints, self._trainers, self._mode, - self._geo_sgd_need_push_nums, self._reader) + tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8}".format( + python_path, model, self._ps_endpoints, self._tr_endpoints, + self._trainers, self._mode, self._geo_sgd_need_push_nums, + self._reader, gloo_path) - ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format( - python_path, model, self._ps_endpoints, self._trainers, self._mode, - self._geo_sgd_need_push_nums, self._reader) + ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8}".format( + python_path, model, self._ps_endpoints, self._tr_endpoints, + self._trainers, self._mode, self._geo_sgd_need_push_nums, + self._reader, gloo_path) # Run dist train to compare with local results ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env) @@ -284,6 +291,7 @@ def _run_cluster(self, model, envs): ps0.terminate() ps1.terminate() + shutil.rmtree(gloo_path) return 0, 0 def check_with_place(self, @@ -313,6 +321,9 @@ def runtime_main(test_class): parser.add_argument( '--role', type=str, required=True, choices=['pserver', 'trainer']) parser.add_argument('--endpoints', type=str, required=False, default="") + parser.add_argument( + '--trainer_endpoints', type=str, required=False, default="") + parser.add_argument('--gloo_path', type=str, required=False, default="") parser.add_argument('--current_id', type=int, required=False, default=0) parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--mode', type=str, required=False, default='geo') @@ -322,6 +333,13 @@ def runtime_main(test_class): args = parser.parse_args() model = test_class() + role = model.build_role(args) + fleet.init(role) + strategy = model.build_strategy(args) + avg_cost = model.net(args) + model.build_optimizer(avg_cost, strategy) + fleet_util._set_strategy(strategy) + fleet_util._set_role_maker(role) if args.role == "pserver": model.run_pserver(args) else: diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py index 5fc37335b21536..18629c4f996a6d 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py @@ -22,7 +22,7 @@ class TestDistMnistSync2x2(TestFleetBase): def _setup_config(self): - self._mode = "sync" + self._mode = "async" self._reader = "pyreader" def check_with_place(self, diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py new file mode 100644 index 00000000000000..6d67afe6cbfbb0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -0,0 +1,204 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle +from op_test import OpTest + + +class TestFlattenOp(OpTest): + def setUp(self): + self.op_type = "flatten_contiguous_range" + self.start_axis = 0 + self.stop_axis = -1 + self.init_test_case() + self.inputs = {"X": np.random.random(self.in_shape).astype("float64")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.in_shape).astype("float32") + } + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = -1 + self.new_shape = (120) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_1(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 1 + self.stop_axis = 2 + self.new_shape = (3, 10, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_2(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_3(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 2 + self.new_shape = (30, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_4(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = -2 + self.stop_axis = -1 + self.new_shape = (3, 2, 20) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_5(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 2 + self.stop_axis = 2 + self.new_shape = (3, 2, 5, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOpSixDims(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 3, 2, 4, 4) + self.start_axis = 3 + self.stop_axis = 5 + self.new_shape = (3, 2, 3, 32) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlatten2OpError(unittest.TestCase): + def test_errors(self): + image_shape = (2, 3, 4, 4) + x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x = x.astype('float32') + + def test_ValueError1(): + x_var = paddle.nn.data(name="x", shape=image_shape, dtype='float32') + out = paddle.flatten(x_var, start_axis=2, stop_axis=1) + + self.assertRaises(ValueError, test_ValueError1) + + def test_ValueError2(): + x_var = paddle.nn.data(name="x", shape=image_shape, dtype='float32') + paddle.flatten(x_var, start_axis=10, stop_axis=1) + + self.assertRaises(ValueError, test_ValueError2) + + def test_ValueError3(): + x_var = paddle.nn.data(name="x", shape=image_shape, dtype='float32') + paddle.flatten(x_var, start_axis=2, stop_axis=10) + + self.assertRaises(ValueError, test_ValueError3) + + def test_type(): + # dtype must be float32, float64, int8, int32, int64. + x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x2 = x2.astype('float16') + x2_var = paddle.data(name='x2', shape=[3, 2, 4, 5], dtype='float16') + paddle.flatten(x2_var) + + self.assertRaises(TypeError, test_type) + + def test_InputError(): + out = paddle.flatten(x) + + self.assertRaises(ValueError, test_InputError) + + +class TestFlattenPython(unittest.TestCase): + def test_python_api(self): + image_shape = (2, 3, 4, 4) + x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x = x.astype('float32') + + def test_InputError(): + out = paddle.flatten(x) + + self.assertRaises(ValueError, test_InputError) + + def test_Negative(): + paddle.enable_imperative() + img = paddle.imperative.to_variable(x) + out = paddle.flatten(img, start_axis=-2, stop_axis=-1) + return out.numpy().shape + + res_shape = test_Negative() + self.assertTrue((2, 3, 16) == res_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_launch.sh b/python/paddle/fluid/tests/unittests/test_fleet_launch.sh index 577f9f6504fd83..5e5c4e17f5b97b 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_launch.sh +++ b/python/paddle/fluid/tests/unittests/test_fleet_launch.sh @@ -4,7 +4,6 @@ set -e function test_launch_ps(){ fleetrun --server_num=2 --worker_num=2 fleet_ps_training.py 2> ut.elog - if grep -q "server are killed" ut.elog; then echo "test pserver launch succeed" else @@ -20,7 +19,7 @@ fi test_launch_ps # use default values -fleetrun multi_process.py +fleetrun multi_process.py fleetrun # use paddlecloud echo "begin test use paddlecloud" @@ -30,16 +29,16 @@ export POD_IP=127.0.0.1 export PADDLE_TRAINERS=127.0.0.1,127.0.0.2 export PADDLE_TRAINER_ID=0 -export PADDLE_PORT=35019 +export PADDLE_PORT=35789 export TRAINER_PORTS_NUM=2 distributed_args="--ips=${cluster_node_ips} --gpus=0,1 --log_dir=testlog" -CUDA_VISIBLE_DEVICES=0,1 fleetrun ${distributed_args} multi_process.py +CUDA_VISIBLE_DEVICES=0,1 fleetrun ${distributed_args} multi_process.py fleetrun -str1="selected_gpus:0 worker_endpoints:127.0.0.1:35019,127.0.0.1:35020,127.0.0.2:35019,127.0.0.2:35020 trainers_num:4 current_endpoint:127.0.0.1:35019 trainer_id:0" -str2="selected_gpus:1 worker_endpoints:127.0.0.1:35019,127.0.0.1:35020,127.0.0.2:35019,127.0.0.2:35020 trainers_num:4 current_endpoint:127.0.0.1:35020 trainer_id:1" -file_0="multi_process.check_0.log" -file_1="multi_process.check_1.log" +str1="selected_gpus:0 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790,127.0.0.2:35789,127.0.0.2:35790 trainers_num:4 current_endpoint:127.0.0.1:35789 trainer_id:0" +str2="selected_gpus:1 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790,127.0.0.2:35789,127.0.0.2:35790 trainers_num:4 current_endpoint:127.0.0.1:35790 trainer_id:1" +file_0="multi_process_fleetrun.check_0.log" +file_1="multi_process_fleetrun.check_1.log" echo "paddlecloud params test" if grep -q "$str1" "$file_0"; then @@ -70,7 +69,7 @@ unset TRAINER_PORTS_NUM echo "" echo "paddle.distributed.launch async poll process test" -if ! CUDA_VISIBLE_DEVICES=0,1 fleetrun ${distributed_args} multi_process.py abort; then +if ! CUDA_VISIBLE_DEVICES=0,1 fleetrun ${distributed_args} multi_process.py fleetrun abort; then echo "train abort as planned" fi diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py index 88a9d235855ce8..351dc0a5d0f66d 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py @@ -14,6 +14,7 @@ """Test cases for role makers.""" from __future__ import print_function +import paddle import os import unittest @@ -162,7 +163,8 @@ def test_pslib_2(self): data = "1 1 1 1\n" f.write(data) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_filelist(["test_fleet_gloo_role_maker_1.txt"]) dataset.set_use_var([show, label]) dataset.load_into_memory() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_4.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_4.py index dd5cd715ecd1ed..a91f6cbd69e18e 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_4.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_4.py @@ -40,10 +40,9 @@ def test_pslib_1(self): from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib from paddle.fluid.incubate.fleet.base.role_maker import \ GeneralRoleMaker - from paddle.fluid.incubate.fleet.utils.http_server import KVHandler - from paddle.fluid.incubate.fleet.utils.http_server import KVServer - from paddle.fluid.incubate.fleet.utils.http_server import \ - KVHTTPServer + from paddle.fleet.utils import KVHandler + from paddle.fleet.utils import KVServer + from paddle.fleet.utils import KVHTTPServer except: print("warning: no fleet, skip test_pslib_4") return diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py new file mode 100644 index 00000000000000..659cc34b549589 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py @@ -0,0 +1,171 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cloud role maker.""" + +from __future__ import print_function +import os +import unittest +import paddle.fleet.base.role_maker as role_maker + + +class TestRoleMakerBase(unittest.TestCase): + """ + Test cases for RoleMakerBase + """ + + def test_rolemaker_base(self): + role = role_maker.RoleMakerBase() + self.assertRaises(Exception, role.is_worker) + self.assertRaises(Exception, role.is_server) + self.assertRaises(Exception, role.is_first_worker) + self.assertRaises(Exception, role.worker_num) + self.assertRaises(Exception, role.server_num) + self.assertRaises(Exception, role.worker_index) + self.assertRaises(Exception, role.server_index) + self.assertRaises(Exception, role.role_id) + + trainer_endpoints = role.get_trainer_endpoints() + self.assertTrue(len(trainer_endpoints) == 0) + pserver_endpoints = role.get_pserver_endpoints() + self.assertTrue(len(pserver_endpoints) == 0) + + print(role.to_string()) + self.assertTrue(role._all_gather(role._node_type_comm, 1) is None) + self.assertTrue(role._all_reduce(role._node_type_comm, 1) is None) + role._barrier(role._node_type_comm) + + +class TestCloudRoleMaker(unittest.TestCase): + """ + Test cases for PaddleCloudRoleMaker. + """ + + def setUp(self): + """Set up, set envs.""" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ[ + "PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.2:36001" + os.environ["POD_IP"] = "127.0.0.1" + + def test_tr_rolemaker(self): + """Test tr rolenamer.""" + os.environ["TRAINING_ROLE"] = "TRAINER" + os.environ["PADDLE_TRAINER_ID"] = "0" + + try: + import netifaces + except: + print("warning: no netifaces, skip test_tr_rolemaker") + return + + ro = role_maker.PaddleCloudRoleMaker( + is_collective=False, init_gloo=False) + self.assertTrue(ro.is_worker()) + self.assertFalse(ro.is_server()) + self.assertEqual(ro.worker_num(), 2) + self.assertTrue(ro.is_first_worker()) + worker_endpoints = ro.get_trainer_endpoints() + self.assertEqual(worker_endpoints[0], '127.0.0.1:36001') + self.assertEqual(ro.role_id(), 0) + + def test_tr_rolemaker_collective(self): + ro = role_maker.PaddleCloudRoleMaker(is_collective=True) + self.assertEqual(ro.worker_num(), 2) + + def test_ps_rolemaker(self): + """Test ps rolemaker.""" + os.environ["TRAINING_ROLE"] = "PSERVER" + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_PORT"] = "36001" + + try: + import netifaces + except: + print("warning: no netifaces, skip test_ps_rolemaker") + return + + ro = role_maker.PaddleCloudRoleMaker( + is_collective=False, init_gloo=False) + self.assertEqual(ro.server_index(), 0) + self.assertFalse(ro.is_worker()) + self.assertTrue(ro.is_server()) + self.assertEqual(ro.server_num(), 2) + pserver_endpoints = ro.get_pserver_endpoints() + self.assertEqual(pserver_endpoints[0], '127.0.0.1:36001') + self.assertTrue(ro._all_gather(ro._all_comm, 1) is None) + self.assertTrue(ro._all_reduce(ro._all_comm, 1) is None) + + def test_traing_role(self): + """Test training role.""" + os.environ["TRAINING_ROLE"] = "TEST" + try: + import netifaces + except: + print("warning: no netifaces, skip test_training_role") + return + + ro = role_maker.PaddleCloudRoleMaker(is_collective=False) + self.assertRaises(ValueError, ro.generate_role) + + +class TestUserDefinedRoleMaker(unittest.TestCase): + """ + Test cases for UserDefinedRoleMaker. + """ + + def setUp(self): + pass + + def test_ps_rolemaker(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_ps_rolemaker") + return + + ro = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + server_endpoints="127.0.0.1:36001,127.0.0.1:36001", + role=role_maker.Role.SERVER, + current_id=0, + worker_num=2) + self.assertEqual(ro.server_num(), 2) + ro.generate_role() + self.assertTrue(ro.is_server()) + self.assertEqual(ro.role_id(), 0) + + def test_tr_rolemaker(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_tr_rolemaker") + return + + ro = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + server_endpoints="127.0.0.1:36001,127.0.0.1:36001", + role=role_maker.Role.WORKER, + current_id=0, + worker_num=2) + self.assertIn("127.0.0.1:36001", ro.get_pserver_endpoints()) + self.assertTrue(ro.is_worker()) + self.assertEqual(ro.role_id(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_util.py b/python/paddle/fluid/tests/unittests/test_fleet_util.py index 427e077416e979..e52cb5f920c2eb 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_util.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_util.py @@ -12,12 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from __future__ import print_function import paddle +import paddle.fluid as fluid +import unittest +import numpy as np +import tarfile +import tempfile import os +import sys +from paddle.dataset.common import download, DATA_HOME +from paddle.fleet.base.util_factory import fleet_util +import paddle.fleet.base.role_maker as role_maker class TestFleetUtil(unittest.TestCase): + proto_data_url = "https://fleet.bj.bcebos.com/fleet_util_data.tgz" + proto_data_md5 = "59b7f12fd9dc24b64ae8e4629523a92a" + module_name = "fleet_util_data" + pruned_dir = os.path.join("fleet_util_data", "pruned_model") + train_dir = os.path.join("fleet_util_data", "train_program") + def test_util_base(self): import paddle.fleet as fleet util = fleet.UtilBase() @@ -65,6 +80,262 @@ def get_user_id(self): user_id = fleet.util.get_user_id() self.assertEqual(user_id, 10) + def test_fs(self): + from paddle.fleet.utils import LocalFS + fs = LocalFS() + dirs, files = fs.ls_dir("test_tmp") + dirs, files = fs.ls_dir("./") + self.assertFalse(fs.need_upload_download()) + fleet_util.set_file_system(fs) + + def test_barrier(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_barrier") + return + + gloo = fluid.core.Gloo() + gloo.set_rank(0) + gloo.set_size(1) + gloo.set_prefix("123") + gloo.set_iface("lo") + gloo.set_hdfs_store("./tmp_test_fleet_barrier", "", "") + gloo.init() + + role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + current_id=0, + role=role_maker.Role.SERVER, + worker_endpoints=["127.0.0.1:6003"], + server_endpoints=["127.0.0.1:6001"]) + role._node_type_comm = gloo + role._role_is_generated = True + fleet_util._set_role_maker(role) + + fleet_util.barrier("worker") + + def test_all_reduce(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_all_reduce") + return + + gloo = fluid.core.Gloo() + gloo.set_rank(0) + gloo.set_size(1) + gloo.set_prefix("123") + gloo.set_iface("lo") + gloo.set_hdfs_store("./tmp_test_fleet_reduce", "", "") + gloo.init() + + role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + current_id=0, + role=role_maker.Role.WORKER, + worker_endpoints=["127.0.0.1:6003"], + server_endpoints=["127.0.0.1:6001"]) + role._node_type_comm = gloo + role._role_is_generated = True + fleet_util._set_role_maker(role) + + output = fleet_util.all_reduce(1, "sum", comm_world="server") + print(output) + + # self.assertEqual(output, 1) + + def test_all_gather(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_all_gather") + return + + gloo = fluid.core.Gloo() + gloo.set_rank(0) + gloo.set_size(1) + gloo.set_prefix("123") + gloo.set_iface("lo") + gloo.set_hdfs_store("./tmp_test_fleet_reduce", "", "") + gloo.init() + + role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + current_id=0, + role=role_maker.Role.SERVER, + worker_endpoints=["127.0.0.1:6003"], + server_endpoints=["127.0.0.1:6001"]) + role._node_type_comm = gloo + role._all_comm = gloo + role._role_is_generated = True + fleet_util._set_role_maker(role) + + output = fleet_util.all_gather(1, comm_world="all") + print(output) + # self.assertTrue(len(output) == 1 and output[0] == 1) + self.assertRaises(Exception, fleet_util.all_gather, 1, "test") + + def download_files(self): + path = download(self.proto_data_url, self.module_name, + self.proto_data_md5) + print('data is downloaded at ' + path) + tar = tarfile.open(path) + unzip_folder = tempfile.mkdtemp() + tar.extractall(unzip_folder) + return unzip_folder + + def test_get_file_shard(self): + self.assertRaises(Exception, fleet_util.get_file_shard, "files") + try: + import netifaces + except: + print("warning: no netifaces, skip test_get_file_shard") + return + + role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + current_id=0, + role=role_maker.Role.WORKER, + worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"], + server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"]) + fleet_util._set_role_maker(role) + files = fleet_util.get_file_shard(["1", "2", "3"]) + self.assertTrue(len(files) == 2 and "1" in files and "2" in files) + + def test_program_type_trans(self): + data_dir = self.download_files() + program_dir = os.path.join(data_dir, self.pruned_dir) + text_program = "pruned_main_program.pbtxt" + binary_program = "pruned_main_program.bin" + text_to_binary = fleet_util._program_type_trans(program_dir, + text_program, True) + binary_to_text = fleet_util._program_type_trans(program_dir, + binary_program, False) + self.assertTrue( + os.path.exists(os.path.join(program_dir, text_to_binary))) + self.assertTrue( + os.path.exists(os.path.join(program_dir, binary_to_text))) + + def test_prams_check(self): + data_dir = self.download_files() + + class config: + pass + + feed_config = config() + feed_config.feeded_vars_names = ['concat_1.tmp_0', 'concat_2.tmp_0'] + feed_config.feeded_vars_dims = [682, 1199] + feed_config.feeded_vars_types = [np.float32, np.float32] + feed_config.feeded_vars_filelist = [ + os.path.join(data_dir, os.path.join(self.pruned_dir, "concat_1")), + os.path.join(data_dir, os.path.join(self.pruned_dir, "concat_2")) + ] + + fetch_config = config() + fetch_config.fetch_vars_names = ['similarity_norm.tmp_0'] + + conf = config() + conf.batch_size = 1 + conf.feed_config = feed_config + conf.fetch_config = fetch_config + conf.dump_model_dir = os.path.join(data_dir, self.pruned_dir) + conf.dump_program_filename = "pruned_main_program.pbtxt" + conf.is_text_dump_program = True + conf.save_params_filename = None + + # test saved var's shape + conf.dump_program_filename = "pruned_main_program.save_var_shape_not_match" + + self.assertRaises(Exception, fleet_util._params_check) + + # test program.proto without feed_op and fetch_op + conf.dump_program_filename = "pruned_main_program.no_feed_fetch" + results = fleet_util._params_check(conf) + self.assertTrue(len(results) == 1) + np.testing.assert_array_almost_equal( + results[0], np.array( + [[3.0590223e-07]], dtype=np.float32)) + + # test feed_var's shape + conf.dump_program_filename = "pruned_main_program.feed_var_shape_not_match" + self.assertRaises(Exception, fleet_util._params_check) + + # test correct case with feed_vars_filelist + conf.dump_program_filename = "pruned_main_program.pbtxt" + results = fleet_util._params_check(conf) + self.assertTrue(len(results) == 1) + np.testing.assert_array_almost_equal( + results[0], np.array( + [[3.0590223e-07]], dtype=np.float32)) + + # test correct case without feed_vars_filelist + conf.feed_config.feeded_vars_filelist = None + # test feed var with lod_level >= 2 + conf.dump_program_filename = "pruned_main_program.feed_lod2" + self.assertRaises(Exception, fleet_util._params_check) + + conf.dump_program_filename = "pruned_main_program.pbtxt" + results = fleet_util._params_check(conf) + self.assertTrue(len(results) == 1) + + def test_proto_check(self): + data_dir = self.download_files() + + class config: + pass + + conf = config() + conf.train_prog_path = os.path.join( + data_dir, os.path.join(self.train_dir, "join_main_program.pbtxt")) + conf.is_text_train_program = True + + # test not match + conf.pruned_prog_path = os.path.join( + data_dir, + os.path.join(self.pruned_dir, + "pruned_main_program.save_var_shape_not_match")) + conf.is_text_pruned_program = True + conf.draw = False + res = fleet_util._proto_check(conf) + self.assertFalse(res) + + # test match + conf.pruned_prog_path = os.path.join( + data_dir, + os.path.join(self.pruned_dir, "pruned_main_program.pbtxt")) + if sys.platform == 'win32' or sys.platform == 'sys.platform': + conf.draw = False + else: + conf.draw = True + conf.draw_out_name = "pruned_check" + res = fleet_util._proto_check(conf) + self.assertTrue(res) + + def test_visualize(self): + if sys.platform == 'win32' or sys.platform == 'sys.platform': + pass + else: + data_dir = self.download_files() + program_path = os.path.join( + data_dir, + os.path.join(self.train_dir, "join_main_program.pbtxt")) + is_text = True + program = fleet_util._load_program(program_path, is_text) + output_dir = os.path.join(data_dir, self.train_dir) + output_filename = "draw_prog" + fleet_util._visualize_graphviz(program, output_dir, output_filename) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, output_filename + ".dot"))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, output_filename + ".pdf"))) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fs_interface.py b/python/paddle/fluid/tests/unittests/test_fs_interface.py index 0d87b94538f05d..7f780bd44f7e2d 100644 --- a/python/paddle/fluid/tests/unittests/test_fs_interface.py +++ b/python/paddle/fluid/tests/unittests/test_fs_interface.py @@ -20,9 +20,7 @@ import sys import inspect -from paddle.fluid.incubate.fleet.utils.fs import LocalFS, FS -from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient -from paddle.fluid.incubate.fleet.utils.hdfs import FSTimeOut, FSFileExistsError, FSFileNotExistsError +from paddle.fleet.utils import LocalFS, FS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError class FSTest(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py index fb745454258732..d8a5816a42a2fd 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py @@ -30,6 +30,7 @@ def fusion_gru( wh, # D x 3D bias, # 1 x 3D is_reverse, + origin_mode, act_state, act_gate): return gru(fc(x, wx, bias), @@ -40,7 +41,8 @@ def fusion_gru( (1, wh.shape[1]), dtype='float32'), is_reverse, act_state, - act_gate) + act_gate, + origin_mode=origin_mode) class TestFusionGRUOp(OpTest): @@ -57,6 +59,8 @@ def setUp(self): self.with_bias = True self.act_state = 'tanh' self.act_gate = 'sigmoid' + self.origin_mode = False + self.use_mkldnn = False self.set_confs() T = sum(self.lod[0]) @@ -73,7 +77,7 @@ def setUp(self): (N, self.D), dtype='float32') _, _, _, hidden = fusion_gru( - x, self.lod, h0, wx, wh, bias, self.is_reverse, + x, self.lod, h0, wx, wh, bias, self.is_reverse, self.origin_mode, ACTIVATION[self.act_state], ACTIVATION[self.act_gate]) self.inputs = {'X': (x, self.lod), 'WeightX': wx, 'WeightH': wh} @@ -89,7 +93,9 @@ def setUp(self): self.attrs = { 'activation': self.act_state, 'gate_activation': self.act_gate, - 'is_reverse': self.is_reverse + 'is_reverse': self.is_reverse, + 'origin_mode': self.origin_mode, + 'use_mkldnn': self.use_mkldnn } def test_check_output(self): diff --git a/python/paddle/fluid/tests/unittests/test_hdfs.py b/python/paddle/fluid/tests/unittests/test_hdfs.py index 9826542cee3732..80c7fd4ad57d15 100644 --- a/python/paddle/fluid/tests/unittests/test_hdfs.py +++ b/python/paddle/fluid/tests/unittests/test_hdfs.py @@ -19,9 +19,7 @@ import os import sys -from paddle.fluid.incubate.fleet.utils.fs import LocalFS -from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient -from paddle.fluid.incubate.fleet.utils.hdfs import FSTimeOut, FSFileExistsError, FSFileNotExistsError +from paddle.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError java_home = os.environ["JAVA_HOME"] diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index abc46034957cf7..a61d31e88253d7 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -14,13 +14,15 @@ from __future__ import print_function +import os import unittest import numpy as np import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import Linear -from paddle.fluid.dygraph import declarative +from paddle.fluid.dygraph import declarative, ProgramTranslator +from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME BATCH_SIZE = 32 BATCH_NUM = 20 @@ -77,8 +79,8 @@ def forward(self, x): def train(layer): # create optimizer - adam = fluid.optimizer.AdamOptimizer( - learning_rate=0.1, parameter_list=layer.parameters()) + adam = fluid.optimizer.SGDOptimizer( + learning_rate=0.01, parameter_list=layer.parameters()) # create data loader train_loader = fluid.io.DataLoader.from_generator(capacity=5) train_loader.set_batch_generator(random_batch_reader()) @@ -111,37 +113,43 @@ def setUp(self): # config seed fluid.default_main_program().random_seed = SEED - def train_and_save_model(self): + def train_and_save_model(self, model_path=None, configs=None): layer = LinearNet(784, 1) example_inputs, layer, _ = train(layer) + final_model_path = model_path if model_path else self.model_path orig_input_types = [type(x) for x in example_inputs] fluid.dygraph.jit.save( - layer=layer, model_path=self.model_path, input_spec=example_inputs) + layer=layer, + model_path=final_model_path, + input_spec=example_inputs, + configs=configs) new_input_types = [type(x) for x in example_inputs] self.assertEqual(orig_input_types, new_input_types) return layer - def test_save(self): - # train and save model - self.train_and_save_model() - - def test_load_infernece(self): + def test_save_load(self): # train and save model train_layer = self.train_and_save_model() # load model - infer_layer = fluid.dygraph.jit.load(self.model_path) + program_translator = ProgramTranslator() + program_translator.enable(False) + loaded_layer = fluid.dygraph.jit.load(self.model_path) + self.load_and_inference(train_layer, loaded_layer) + self.load_dygraph_state_dict(train_layer) + self.load_and_finetune(train_layer, loaded_layer) + program_translator.enable(True) + + def load_and_inference(self, train_layer, infer_layer): train_layer.eval() + infer_layer.eval() # inference & compare x = fluid.dygraph.to_variable( np.random.random((1, 784)).astype('float32')) self.assertTrue( np.array_equal(train_layer(x).numpy(), infer_layer(x).numpy())) - def test_load_finetune(self): - # train and save model - train_layer = self.train_and_save_model() - # load model - load_train_layer = fluid.dygraph.jit.load(self.model_path) + def load_and_finetune(self, train_layer, load_train_layer): + train_layer.train() load_train_layer.train() # train & compare _, _, train_loss = train(train_layer) @@ -149,6 +157,19 @@ def test_load_finetune(self): self.assertTrue( np.array_equal(train_loss.numpy(), load_train_loss.numpy())) + def load_dygraph_state_dict(self, train_layer): + train_layer.eval() + # contruct new model + new_layer = LinearNet(784, 1) + model_dict, _ = fluid.dygraph.load_dygraph(self.model_path) + new_layer.set_dict(model_dict) + new_layer.eval() + # inference & compare + x = fluid.dygraph.to_variable( + np.random.random((1, 784)).astype('float32')) + self.assertTrue( + np.array_equal(train_layer(x).numpy(), new_layer(x).numpy())) + def test_save_get_program_failed(self): layer = LinearNetNotDeclarative(784, 1) example_inputs, layer, _ = train(layer) @@ -158,6 +179,31 @@ def test_save_get_program_failed(self): model_path=self.model_path, input_spec=example_inputs) + def test_load_dygraoh_no_path(self): + model_path = "model.test_jit_save_load.no_path" + new_layer = LinearNet(784, 1) + with self.assertRaises(ValueError): + model_dict, _ = fluid.dygraph.load_dygraph(model_path) + + def test_load_dygraph_no_var_info(self): + model_path = "model.test_jit_save_load.no_var_info" + self.train_and_save_model(model_path=model_path) + # remove `__variables.info__` + var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME) + os.remove(var_info_path) + new_layer = LinearNet(784, 1) + with self.assertRaises(RuntimeError): + model_dict, _ = fluid.dygraph.load_dygraph(model_path) + + def test_load_dygraph_not_var_file(self): + model_path = "model.test_jit_save_load.no_var_file" + configs = fluid.dygraph.jit.SaveLoadConfig() + configs.params_filename = "__params__" + self.train_and_save_model(model_path=model_path, configs=configs) + new_layer = LinearNet(784, 1) + with self.assertRaises(RuntimeError): + model_dict, _ = fluid.dygraph.load_dygraph(model_path) + class TestJitSaveLoadConfig(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_launch.sh b/python/paddle/fluid/tests/unittests/test_launch.sh index f1bf6395f15ce0..98c907a5519653 100644 --- a/python/paddle/fluid/tests/unittests/test_launch.sh +++ b/python/paddle/fluid/tests/unittests/test_launch.sh @@ -3,7 +3,7 @@ set -e # use default values # FIXME: random fails on Unknown command lines -c (or -m). launch_py=${PADDLE_BINARY_DIR}/python/paddle/distributed/launch.py -python ${launch_py} multi_process.py +python ${launch_py} multi_process.py launch # use paddlecloud echo "begin test use paddlecloud" @@ -18,12 +18,12 @@ export PADDLE_PORT=35019 export TRAINER_PORTS_NUM=2 distributed_args="--use_paddlecloud --cluster_node_ips=${cluster_node_ips} --node_ip=${node_ip} --selected_gpus=0,1 --log_dir=testlog" -CUDA_VISIBLE_DEVICES=0,1 python ${launch_py} ${distributed_args} multi_process.py +CUDA_VISIBLE_DEVICES=0,1 python ${launch_py} ${distributed_args} multi_process.py launch str1="selected_gpus:0 worker_endpoints:127.0.0.1:35019,127.0.0.1:35020,127.0.0.2:35019,127.0.0.2:35020 trainers_num:4 current_endpoint:127.0.0.1:35019 trainer_id:0" str2="selected_gpus:1 worker_endpoints:127.0.0.1:35019,127.0.0.1:35020,127.0.0.2:35019,127.0.0.2:35020 trainers_num:4 current_endpoint:127.0.0.1:35020 trainer_id:1" -file_0="multi_process.check_0.log" -file_1="multi_process.check_1.log" +file_0="multi_process_launch.check_0.log" +file_1="multi_process_launch.check_1.log" echo "paddlecloud params test" if grep -q "$str1" "$file_0"; then @@ -54,7 +54,7 @@ unset TRAINER_PORTS_NUM echo "" echo "paddle.distributed.launch async poll process test" -if ! CUDA_VISIBLE_DEVICES=0,1 python ${launch_py} ${distributed_args} multi_process.py abort; then +if ! CUDA_VISIBLE_DEVICES=0,1 python ${launch_py} ${distributed_args} multi_process.py launch abort; then echo "train abort as planned" fi diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index a1ead2aef63f7b..9da70e85f01c0a 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -180,6 +180,51 @@ def test_type(): self.assertRaises(TypeError, test_type) + def test_Flatten(self): + inp = np.ones([3, 4, 4, 5], dtype='float32') + with self.static_graph(): + t = layers.data( + name='data', + shape=[3, 4, 4, 5], + dtype='float32', + append_batch_size=False) + flatten = nn.Flatten() + ret = flatten(t) + static_ret = self.get_static_graph_result( + feed={'data': inp}, fetch_list=[ret])[0] + with self.dynamic_graph(): + t = base.to_variable(inp) + flatten = nn.Flatten() + dy_ret = flatten(t) + dy_ret_value = dy_ret.numpy() + + self.assertTrue(np.array_equal(static_ret, dy_ret_value)) + + with self.static_graph(): + + # the input of Linear must be Variable. + def test_Variable(): + inp = np.ones([3, 32, 32], dtype='float32') + linear = nn.Linear( + 32, + 4, + bias_attr=fluid.initializer.ConstantInitializer(value=1)) + linear_ret1 = linear(inp) + + self.assertRaises(TypeError, test_Variable) + + # the input dtype of Linear must be float16 or float32 or float64 + # float16 only can be set on GPU place + def test_type(): + inp = np.ones([3, 32, 32], dtype='int32') + linear = nn.Linear( + 32, + 4, + bias_attr=fluid.initializer.ConstantInitializer(value=1)) + linear_ret2 = linear(inp) + + self.assertRaises(TypeError, test_type) + def test_layer_norm(self): inp = np.ones([3, 32, 32], dtype='float32') with self.static_graph(): diff --git a/python/paddle/fluid/tests/unittests/test_monitor.py b/python/paddle/fluid/tests/unittests/test_monitor.py index 39601eb0e12ff8..2d4c8f61c0406d 100644 --- a/python/paddle/fluid/tests/unittests/test_monitor.py +++ b/python/paddle/fluid/tests/unittests/test_monitor.py @@ -16,6 +16,7 @@ """ from __future__ import print_function +import paddle import paddle.fluid as fluid import paddle.fluid.core as core import numpy as np @@ -51,7 +52,8 @@ def test_dataset_run_with_stat(self): name=slot, shape=[1], dtype="int64", lod_level=1) slots_vars.append(var) - dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset = paddle.fleet.DatasetFactory().create_dataset( + "InMemoryDataset") dataset.set_batch_size(32) dataset.set_thread(3) dataset.set_filelist([ diff --git a/python/paddle/fluid/tests/unittests/test_pipeline.py b/python/paddle/fluid/tests/unittests/test_pipeline.py index fe31add697c656..dd1cf29eff9b75 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import print_function +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py index 0f14c9d1c3ba99..aed265b21b5781 100644 --- a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py +++ b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py @@ -63,7 +63,7 @@ def case_generator(op_type, Xshape, diagonal, expected): "diagonal: TypeError": "diagonal in {} must be a python Int".format(op_type), "input: ValueError": - "input shape in {} must be at least 2-D".format(op_type), + "x shape in {} must be at least 2-D".format(op_type), } class FailureCase(unittest.TestCase): @@ -71,7 +71,7 @@ def test_failure(self): data = fluid.data(shape=Xshape, dtype='float64', name=cls_name) with self.assertRaisesRegexp( eval(expected.split(':')[-1]), errmsg[expected]): - getattr(tensor, op_type)(input=data, diagonal=diagonal) + getattr(tensor, op_type)(x=data, diagonal=diagonal) class SuccessCase(TrilTriuOpDefaultTest): def initTestCase(self): diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 19af0c92154e0f..b8258f3153a801 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -17,6 +17,7 @@ 'fake_quantize_range_abs_max', 'coalesce_tensor', 'flatten2', + 'flatten_contiguous_range', 'lrn', 'squeeze2', 'reshape2', diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index c974fc14145124..f2c1dd8c67bade 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -63,6 +63,7 @@ from .layer.common import Pad2D #DEFINE_ALIAS from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS +from .layer.common import Flatten #DEFINE_ALIAS from .layer.common import UpSample #DEFINE_ALIAS from .layer.conv import Conv2D #DEFINE_ALIAS from .layer.conv import Conv2DTranspose #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 674782d4748179..e598c3c957bd6c 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -39,6 +39,7 @@ from .common import Pad2D #DEFINE_ALIAS from .common import Embedding #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS +from .common import Flatten #DEFINE_ALIAS from .common import UpSample #DEFINE_ALIAS from .conv import Conv2D #DEFINE_ALIAS from .conv import Conv2DTranspose #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 8125e528b195b2..45259bea49d42e 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -17,6 +17,7 @@ from ...fluid.dygraph import Pool2D #DEFINE_ALIAS from ...fluid.dygraph import Embedding #DEFINE_ALIAS from ...fluid.dygraph import Linear #DEFINE_ALIAS +from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import layers from .. import functional as F diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 10f93f90fbb875..8b07bddfc15fc2 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -490,14 +490,13 @@ def _tril_triu_op(helper): """Base op of tril_op and triu_op """ op_type = helper.layer_type - x = helper.kwargs.get('input', None) + x = helper.kwargs.get('x', None) assert x is not None, 'x cannot be None in {}'.format(op_type) check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], op_type) if len(x.shape) < 2: - raise ValueError("input shape in {} must be at least 2-D".format( - op_type)) + raise ValueError("x shape in {} must be at least 2-D".format(op_type)) diagonal = helper.kwargs.get('diagonal', 0) if not isinstance(diagonal, (int, )): raise TypeError("diagonal in {} must be a python Int".format(op_type)) @@ -521,18 +520,18 @@ def _tril_triu_op(helper): return out -def tril(input, diagonal=0, name=None): +def tril(x, diagonal=0, name=None): """ :alias_main: paddle.tril :alias: paddle.tril,paddle.tensor.tril,paddle.tensor.creation.tril This op returns the lower triangular part of a matrix (2-D tensor) or batch - of matrices :attr:`input`, the other elements of the result tensor are set + of matrices :attr:`x`, the other elements of the result tensor are set to 0. The lower triangular part of the matrix is defined as the elements on and below the diagonal. Args: - input (Variable): The input variable which is a Tensor. + x (Variable): The input variable x which is a Tensor. Support data types: ``float64``, ``float32``, ``int32``, ``int64``. diagonal (int, optional): The diagonal to consider, default value is 0. If :attr:`diagonal` = 0, all elements on and below the main diagonal are @@ -545,47 +544,41 @@ def tril(input, diagonal=0, name=None): user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: Tensor, results of lower triangular operation by the specified diagonal of input tensor, - it's data type is the same as input's Tensor. + Variable: Tensor, results of lower triangular operation by the specified diagonal of input tensor x, + it's data type is the same as x's Tensor. Raises: TypeError: diagonal is not a int type. - ValueError: dimension of :attr:`input` is less than 2. + ValueError: dimension of :attr:`x` is less than 2. Examples: .. code-block:: python import numpy as np - import paddle.tensor as tensor - import paddle.fluid as fluid + import paddle data = np.arange(1, 13, dtype="int64").reshape(3,-1) # array([[ 1, 2, 3, 4], # [ 5, 6, 7, 8], # [ 9, 10, 11, 12]]) - x = fluid.data(shape=(-1, 4), dtype='int64', name='x') - exe = fluid.Executor(fluid.CPUPlace()) - # example 1, default diagonal - tril = tensor.tril(x) - tril_out, = exe.run(fluid.default_main_program(), feed={"x": data}, - fetch_list=[tril], return_numpy=True) + paddle.enable_imperative() + + x = paddle.imperative.to_variable(data) + + tril1 = paddle.tensor.tril(x) # array([[ 1, 0, 0, 0], # [ 5, 6, 0, 0], # [ 9, 10, 11, 0]]) # example 2, positive diagonal value - tril = tensor.tril(x, diagonal=2) - tril_out, = exe.run(fluid.default_main_program(), feed={"x": data}, - fetch_list=[tril], return_numpy=True) + tril2 = paddle.tensor.tril(x, diagonal=2) # array([[ 1, 2, 3, 0], # [ 5, 6, 7, 8], # [ 9, 10, 11, 12]]) # example 3, negative diagonal value - tril = tensor.tril(x, diagonal=-1) - tril_out, = exe.run(fluid.default_main_program(), feed={"x": data}, - fetch_list=[tril], return_numpy=True) + tril3 = paddle.tensor.tril(x, diagonal=-1) # array([[ 0, 0, 0, 0], # [ 5, 0, 0, 0], # [ 9, 10, 0, 0]]) @@ -593,23 +586,23 @@ def tril(input, diagonal=0, name=None): """ if in_dygraph_mode(): op = getattr(core.ops, 'tril_triu') - return op(input, 'diagonal', diagonal, "lower", True) + return op(x, 'diagonal', diagonal, "lower", True) return _tril_triu_op(LayerHelper('tril', **locals())) -def triu(input, diagonal=0, name=None): +def triu(x, diagonal=0, name=None): """ :alias_main: paddle.triu :alias: paddle.triu,paddle.tensor.triu,paddle.tensor.creation.triu This op returns the upper triangular part of a matrix (2-D tensor) or batch of matrices - :attr:`input`, the other elements of the result tensor are set to 0. + :attr:`x`, the other elements of the result tensor are set to 0. The upper triangular part of the matrix is defined as the elements on and above the diagonal. Args: - input (Variable): The input variable which is a Tensor. + x (Variable): The input variable x which is a Tensor. Support data types: ``float64``, ``float32``, ``int32``, ``int64``. diagonal (int, optional): The diagonal to consider, default value is 0. If :attr:`diagonal` = 0, all elements on and above the main diagonal are @@ -622,47 +615,41 @@ def triu(input, diagonal=0, name=None): user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: Tensor, results of upper triangular operation by the specified diagonal of input tensor, - it's data type is the same as input's Tensor. + Variable: Tensor, results of upper triangular operation by the specified diagonal of input tensor x, + it's data type is the same as x's Tensor. Raises: TypeError: diagonal is not a int type. - ValueError: dimension of :attr:`input` is less than 2. + ValueError: dimension of :attr:`x` is less than 2. Examples: .. code-block:: python import numpy as np - import paddle.fluid as fluid - import paddle.tensor as tensor + import paddle data = np.arange(1, 13, dtype="int64").reshape(3,-1) # array([[ 1, 2, 3, 4], # [ 5, 6, 7, 8], # [ 9, 10, 11, 12]]) - x = fluid.data(shape=(-1, 4), dtype='int64', name='x') - exe = fluid.Executor(fluid.CPUPlace()) + + paddle.enable_imperative() # example 1, default diagonal - triu = tensor.triu(x) - triu_out, = exe.run(fluid.default_main_program(), feed={"x": data}, - fetch_list=[triu], return_numpy=True) + x = paddle.imperative.to_variable(data) + triu1 = paddle.tensor.triu(x) # array([[ 1, 2, 3, 4], # [ 0, 6, 7, 8], # [ 0, 0, 11, 12]]) # example 2, positive diagonal value - triu = tensor.triu(x, diagonal=2) - triu_out, = exe.run(fluid.default_main_program(), feed={"x": data}, - fetch_list=[triu], return_numpy=True) + triu2 = paddle.tensor.triu(x, diagonal=2) # array([[0, 0, 3, 4], # [0, 0, 0, 8], # [0, 0, 0, 0]]) # example 3, negative diagonal value - triu = tensor.triu(x, diagonal=-1) - triu_out, = exe.run(fluid.default_main_program(), feed={"x": data}, - fetch_list=[triu], return_numpy=True) + triu3 = paddle.tensor.triu(x, diagonal=-1) # array([[ 1, 2, 3, 4], # [ 5, 6, 7, 8], # [ 0, 10, 11, 12]]) @@ -670,7 +657,7 @@ def triu(input, diagonal=0, name=None): """ if in_dygraph_mode(): op = getattr(core.ops, 'tril_triu') - return op(input, 'diagonal', diagonal, "lower", False) + return op(x, 'diagonal', diagonal, "lower", False) return _tril_triu_op(LayerHelper('triu', **locals())) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 6b67394b6bd250..fcff5585bc12a7 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -729,26 +729,32 @@ def bmm(x, y, name=None): Examples: import paddle - import paddle.fluid as fluid - x = fluid.layers.data(name='x', shape=[10, 3, 4], dtype='float32') - y = fluid.layers.data(name='y', shape=[10, 4, 5], dtype='float32') - out = paddle.bmm(x, y) - - # In dygraph mode: + + # In imperative mode: # size input1: (2, 2, 3) and input2: (2, 3, 2) input1 = np.array([[[1.0, 1.0, 1.0],[2.0, 2.0, 2.0]],[[3.0, 3.0, 3.0],[4.0, 4.0, 4.0]]]) input2 = np.array([[[1.0, 1.0],[2.0, 2.0],[3.0, 3.0]],[[4.0, 4.0],[5.0, 5.0],[6.0, 6.0]]]) - with fluid.dygraph.guard(): - x = fluid.dygraph.to_variable(input1) - y = fluid.dygraph.to_variable(input2) - out = paddle.bmm(x, y) - #output size: (2, 2, 2) - #output value: - #[[[6.0, 6.0],[12.0, 12.0]],[[45.0, 45.0],[60.0, 60.0]]] - out_np = out.numpy() + paddle.enable_imperative() + + x = paddle.imperative.to_variable(input1) + y = paddle.imperative.to_variable(input2) + out = paddle.bmm(x, y) + #output size: (2, 2, 2) + #output value: + #[[[6.0, 6.0],[12.0, 12.0]],[[45.0, 45.0],[60.0, 60.0]]] + out_np = out.numpy() """ - + x_shape = x.shape + y_shape = y.shape + if not len(x_shape) == len(y_shape) == 3: + raise ValueError( + "x and y should be 3-dimensional. But received x's dimention: {}, y's dimention: {}". + format(x_shape, y_shape)) + if x_shape[2] != y_shape[1]: + raise ValueError( + "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}". + format(x_shape, y_shape)) helper = LayerHelper('bmm', **locals()) if in_dygraph_mode(): return core.ops.bmm(x, y) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index c2f67b4e13855b..07d327a21ede6c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -25,7 +25,6 @@ from ..fluid.layers import cast #DEFINE_ALIAS from ..fluid.layers import expand #DEFINE_ALIAS from ..fluid.layers import expand_as #DEFINE_ALIAS -from ..fluid.layers import flatten #DEFINE_ALIAS from ..fluid.layers import reshape #DEFINE_ALIAS from ..fluid.layers import scatter #DEFINE_ALIAS from ..fluid.layers import slice #DEFINE_ALIAS @@ -169,6 +168,114 @@ def flip(x, axis, name=None): reverse = flip #DEFINE_ALIAS +def flatten(x, start_axis=0, stop_axis=-1, name=None): + """ + **Flatten op** + + Flattens a contiguous range of axes in a tensor according to start_axis and stop_axis. + + For Example: + + .. code-block:: text + + Case 1: + + Given + X.shape = (3, 100, 100, 4) + + and + start_axis = 1 + end_axis = 2 + + We get: + Out.shape = (3, 1000 * 100, 2) + + Case 2: + + Given + X.shape = (3, 100, 100, 4) + + and + start_axis = 0 + stop_axis = -1 + + We get: + Out.shape = (3 * 100 * 100 * 4) + + Args: + x (Variable): A tensor of number of dimentions >= axis. A tensor with data type float32, + float64, int8, int32, int64. + start_axis (int): the start axis to flatten + stop_axis (int): the stop axis to flatten + name(str, Optional): For details, please refer to :ref:`api_guide_Name`. + Generally, no setting is required. Default: None. + + Returns: + Variable: A tensor with the contents of the input tensor, with input \ + axes flattened by indicated start axis and end axis. \ + A Tensor with data type same as input x. + + Raises: + ValueError: If x is not a Variable. + ValueError: If start_axis or stop_axis is illegal. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.enable_imperative() + + image_shape=(2, 3, 4, 4) + x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * image_shape[3]).reshape(image_shape) / 100. + x = x.astype('float32') + + img = paddle.imperative.to_variable(x) + out = paddle.flatten(img, start_axis=1, stop_axis=2) + # out shape is [2, 12, 4] + """ + if not (isinstance(x, Variable)): + raise ValueError("The input x should be a Variable") + + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'int8', 'int32', 'int64'], 'flatten') + helper = LayerHelper('flatten', **locals()) + + x_dim = len(x.shape) + if not (isinstance(start_axis, int)) or ( + start_axis > x_dim - 1) or start_axis < -x_dim: + raise ValueError( + "The start_axis should be a int, and in range [-rank(x), rank(x))") + if not (isinstance(stop_axis, int)) or ( + stop_axis > x_dim - 1) or stop_axis < -x_dim: + raise ValueError( + "The stop_axis should be a int, and in range [-rank(x), rank(x))") + if start_axis < 0: + start_axis = start_axis + x_dim + if stop_axis < 0: + stop_axis = stop_axis + x_dim + if start_axis > stop_axis: + raise ValueError("The stop_axis should be larger than stat_axis") + + if in_dygraph_mode(): + dy_out, _ = core.ops.flatten_contiguous_range( + x, 'start_axis', start_axis, 'stop_axis', stop_axis) + return dy_out + + out = helper.create_variable_for_type_inference(x.dtype) + x_shape = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='flatten_contiguous_range', + inputs={"X": x}, + outputs={'Out': out, + 'XShape': x_shape}, + attrs={"start_axis": start_axis, + "stop_axis": stop_axis}) + return out + + def roll(x, shifts, axis=None, name=None): """ :alias_main: paddle.roll @@ -252,13 +359,18 @@ def roll(x, shifts, axis=None, name=None): return out -def stack(x, axis=0, out=None, name=None): +def stack(x, axis=0, name=None): """ :alias_main: paddle.stack - :alias: paddle.stack,paddle.tensor.stack,paddle.tensor.manipulation.stack + :alias: paddle.stack, paddle.tensor.stack, paddle.tensor.manipulation.stack - - This OP stacks all the inputs :code:`x` along axis. + This OP stacks all the input tensors ``x`` along ``axis`` dimemsion. + All tensors must be of the same shape and same dtype. + + For example, given N tensors of shape [A, B], if ``axis == 0``, the shape of stacked + tensor is [N, A, B]; if ``axis == 1``, the shape of stacked + tensor is [A, N, B], etc. + .. code-block:: text @@ -284,7 +396,6 @@ def stack(x, axis=0, out=None, name=None): Case 2: - Input: x[0].shape = [1, 2] x[0].data = [ [1.0 , 2.0 ] ] @@ -295,7 +406,7 @@ def stack(x, axis=0, out=None, name=None): Attrs: - axis = 1 or axis = -2 + axis = 1 or axis = -2 # If axis = -2, axis = axis+ndim(x[0])+1 = -2+2+1 = 1. Output: Out.shape = [1, 3, 2] @@ -304,65 +415,40 @@ def stack(x, axis=0, out=None, name=None): [5.0, 6.0] ] ] Args: - x (Variable|list(Variable)): Input :code:`x` can be a single Tensor, a :code:`list` of Tensors. - If :code:`x` is a :code:`list`, the shapes of all these Tensors - must be the same. Supposing input is N dims - Tensors :math:`[d_0, d_1, ..., d_{n-1}]`, the output is N+1 dims - Tensor :math:`[d_0, d_1, d_{axis-1}, len(x), d_{axis}, ..., d_{n-1}]`. - Support data types: float32, float64, int32, int64. - axis (int, optional): The axis along which all inputs are stacked. ``axis`` range is :math:`[-(R+1), R+1)`. - R is the first tensor of inputs. If ``axis`` < 0, :math:`axis=axis+rank(x[0])+1`. - The default value of axis is 0. - + x (Tensor|list[Tensor]): Input ``x`` can be a single tensor, or a ``list`` of tensors. + If ``x`` is a ``list``, the Tensors in ``x`` + must be of the same shape and dtype. Support data types: float32, float64, int32, int64. + axis (int, optional): The axis along which all inputs are stacked. ``axis`` range is ``[-(R+1), R+1)``, + where ``R`` is the number of dimensions of the first input tensor ``x[0]``. + If ``axis < 0``, ``axis = axis+R+1``. The default value of axis is 0. + name (str, optional): Please refer to :ref:`api_guide_Name`, Default None. + Returns: - Variable: The stacked Tensor, has same data type with input Tensors. Output dim is :math:`rank(x[0])+1`. + Tensor: The stacked tensor with same data type as input. Example: .. code-block:: python - import numpy as np + import paddle - import paddle.fluid as fluid + import numpy as np data1 = np.array([[1.0, 2.0]]) data2 = np.array([[3.0, 4.0]]) data3 = np.array([[5.0, 6.0]]) - with fluid.dygraph.guard(): - x1 = fluid.dygraph.to_variable(data1) - x2 = fluid.dygraph.to_variable(data2) - x3 = fluid.dygraph.to_variable(data3) - result = paddle.stack([x1, x2, x3], axis=0) - # result shape: [3, 1, 2] - # result value: [[[1.0, 2.0]], - # [[3.0, 4.0]], - # [[5.0, 6.0]]] - """ - - helper = LayerHelper('stack', **locals()) - axis = 0 if axis is None else axis - - if not isinstance(x, list) and not isinstance(x, tuple): - x = [x] - out = helper.create_variable_for_type_inference(x[0].dtype) - if not in_dygraph_mode() and \ - x[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY: - assert len(x) == 1, "If the elements of 'x' in stack are Variable(LoDTensorArray), " \ - "number of the elements must be 1, but received %s." % len(x) - out_index = helper.create_variable_for_type_inference(dtype="int32") - helper.append_op( - type='tensor_array_to_tensor', - inputs={'X': x[0]}, - outputs={'Out': [out], - 'OutIndex': [out_index]}, - attrs={'axis': axis, - 'use_stack': True}) - else: - helper.append_op( - type='stack', - inputs={'X': x}, - outputs={'Y': out}, - attrs={'axis': axis}) - return out + paddle.enable_imperative() + x1 = paddle.imperative.to_variable(data1) + x2 = paddle.imperative.to_variable(data2) + x3 = paddle.imperative.to_variable(data3) + + out = paddle.stack([x1, x2, x3], axis=0) + print(out.shape) # [3, 1, 2] + print(out.numpy()) + # [[[1., 2.]], + # [[3., 4.]], + # [[5., 6.]]] + """ + return layers.stack(x, axis, name) def split(x, num_or_sections, axis=0, name=None): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 878fdbfc1f5761..9b1d7ec3a542c4 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -915,7 +915,7 @@ def __check_input(x, y): return out -def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): +def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): """ :alias_main: paddle.addmm :alias: paddle.addmm,paddle.tensor.addmm,paddle.tensor.math.addmm @@ -935,8 +935,8 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): input (Variable): The input Tensor/LoDTensor to be added to the final result. x (Variable): The first input Tensor/LoDTensor for matrix multiplication. y (Variable): The second input Tensor/LoDTensor for matrix multiplication. - alpha (float): Coefficient of $x*y$. beta (float): Coefficient of $input$. + alpha (float): Coefficient of $x*y$. name (str, optional): Name of the output. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default is None. Returns: @@ -947,25 +947,43 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): import numpy as np import paddle - import paddle.fluid as fluid - - input = fluid.data(name='input', shape=[2, 2], dtype='float32') - x = fluid.data(name='x', shape=[2, 2], dtype='float32') - y = fluid.data(name='y', shape=[2, 2], dtype='float32') - out = paddle.addmm( input=input, x=x, y=y, alpha=5.0, beta=0.5 ) data_x = np.ones((2, 2)).astype(np.float32) data_y = np.ones((2, 2)).astype(np.float32) data_input = np.ones((2, 2)).astype(np.float32) - place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda() else fluid.CPUPlace() - exe = fluid.Executor(place) - results = exe.run(fluid.default_main_program(), - fetch_list=[out], feed={"input": data_input, 'x': data_x, "y": data_y}) - print( np.array(results[0]) ) + paddle.enable_imperative() + + x = paddle.imperative.to_variable(data_x) + y = paddle.imperative.to_variable(data_y) + input = paddle.imperative.to_variable(data_input) + + out = paddle.tensor.addmm( input=input, x=x, y=y, beta=0.5, alpha=5.0 ) + + print( out.numpy() ) # [[10.5 10.5] # [10.5 10.5]] """ + input_shape = input.shape + x_shape = x.shape + y_shape = y.shape + if not len(input_shape) == len(x_shape) == len(y_shape) == 2: + raise ValueError("The dimention of input, x, y should be 2 but receive input's shape: {}, x's shape: {}, y's shape: {}".format(input_shape, x_shape, y_shape)) + if input_shape[0] != x_shape[0]: + if input_shape[0] != 1: + raise ValueError( "When x's dimension[0] is not equal with input's dimension[0], input's dimension[0] must be 1 but got {}".format(input_shape[0])) + if input_shape[1] != y_shape[1] and input_shape[1] != 1: + raise ValueError( "When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format(input_shape[1])) + if input_shape[1] != y_shape[1]: + if input_shape[1] != 1: + raise ValueError( "When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format(input_shape[1])) + if input_shape[0] != x_shape[0] and input_shape[0] != 1: + raise ValueError( "When x's dimension[0] is not equal with input's dimension[0], input's dimension[0] must be 1 but got {}".format(input_shape[0])) + if x_shape[1] != y_shape[0]: + raise ValueError("The input Variable x's width must be equal with Variable y' height. But received x's shape = {}, y's shape = {}.".format(x_shape, y_shape)) + + + if in_dygraph_mode(): out = core.ops.addmm(input, x, y, "Alpha", alpha, "Beta", beta) return out @@ -974,7 +992,7 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): attrs = {'Alpha': alpha, 'Beta': beta} helper = LayerHelper("addmm", **locals()) - check_variable_and_dtype(x, 'Input', ['float32', 'float64'], 'addmm') + check_variable_and_dtype(input, 'Input', ['float32', 'float64'], 'addmm') check_variable_and_dtype(x, 'X', ['float32', 'float64'], 'addmm') check_variable_and_dtype(y, 'Y', ['float32', 'float64'], 'addmm') out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/requirements.txt b/python/requirements.txt index 5e081f5e85b6e0..13a1c9a9d638da 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -21,3 +21,4 @@ prettytable objgraph astor pathlib +netifaces diff --git a/python/setup.py.in b/python/setup.py.in index df200da2cfc5b9..72819a7b9eed35 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -152,6 +152,7 @@ packages=['paddle', 'paddle.fleet.dataset', 'paddle.fleet.metrics', 'paddle.fleet.proto', + 'paddle.fleet.utils', 'paddle.framework', 'paddle.fluid', 'paddle.fluid.dygraph',