From 3a5693e0a8689cb9490e4310c0c46e28dcad0514 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 27 Sep 2017 16:16:00 -0700 Subject: [PATCH 1/8] Add Skeleton of Double support --- paddle/framework/data_type.h | 36 +++++++++ paddle/framework/op_registry.h | 5 +- paddle/framework/operator.h | 77 +++++++++++++++---- paddle/framework/tensor.h | 12 +-- paddle/operators/accuracy_op.cu | 2 +- paddle/operators/accuracy_op.h | 2 +- paddle/operators/activation_op.h | 20 ++--- paddle/operators/add_op.h | 2 +- paddle/operators/clip_op.h | 4 +- paddle/operators/concat_op.h | 2 +- paddle/operators/cos_sim_op.h | 4 +- paddle/operators/crop_op.h | 4 +- paddle/operators/cross_entropy_op.cc | 12 +++ paddle/operators/cross_entropy_op.cu | 4 +- paddle/operators/cross_entropy_op.h | 4 +- paddle/operators/dropout_op.cu | 2 +- paddle/operators/dropout_op.h | 4 +- paddle/operators/elementwise_add_op.h | 4 +- paddle/operators/elementwise_div_op.h | 4 +- paddle/operators/elementwise_mul_op.h | 4 +- paddle/operators/elementwise_sub_op.h | 4 +- paddle/operators/fill_zeros_like_op.h | 2 +- paddle/operators/gather_op.h | 4 +- paddle/operators/gaussian_random_op.cc | 2 +- paddle/operators/gaussian_random_op.cu | 2 +- paddle/operators/gemm_conv2d_op.h | 4 +- paddle/operators/lookup_table_op.cu | 4 +- paddle/operators/lookup_table_op.h | 4 +- paddle/operators/lstm_unit_op.cu | 4 +- paddle/operators/lstm_unit_op.h | 4 +- paddle/operators/mean_op.h | 4 +- paddle/operators/minus_op.h | 2 +- paddle/operators/modified_huber_loss_op.cu | 2 +- paddle/operators/modified_huber_loss_op.h | 4 +- paddle/operators/mul_op.h | 4 +- paddle/operators/multiplex_op.cu | 4 +- paddle/operators/multiplex_op.h | 4 +- paddle/operators/pad_op.h | 4 +- paddle/operators/prelu_op.h | 4 +- paddle/operators/rank_loss_op.h | 4 +- paddle/operators/reshape_op.h | 4 +- paddle/operators/rowwise_add_op.h | 4 +- paddle/operators/scale_op.h | 2 +- paddle/operators/scatter_op.h | 4 +- paddle/operators/sequence_pool_op.h | 4 +- paddle/operators/sgd_op.h | 2 +- paddle/operators/smooth_l1_loss_op.h | 4 +- paddle/operators/softmax_op.h | 4 +- .../softmax_with_cross_entropy_op.cu | 4 +- .../operators/softmax_with_cross_entropy_op.h | 4 +- paddle/operators/split_op.h | 2 +- paddle/operators/squared_l2_distance_op.h | 4 +- paddle/operators/sum_op.h | 4 +- paddle/operators/top_k_op.cu | 2 +- paddle/operators/top_k_op.h | 2 +- paddle/operators/transpose_op.h | 4 +- paddle/operators/uniform_random_op.cc | 2 +- paddle/operators/uniform_random_op.cu | 2 +- paddle/platform/place.cc | 2 +- paddle/pybind/tensor_py.h | 8 +- 60 files changed, 217 insertions(+), 129 deletions(-) create mode 100644 paddle/framework/data_type.h diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h new file mode 100644 index 0000000000000..55e3931f870d6 --- /dev/null +++ b/paddle/framework/data_type.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/framework.pb.h" + +namespace paddle { +namespace framework { + +inline DataType ToDataType(std::type_index type) { + if (typeid(float).hash_code() == type.hash_code()) { + return DataType::FP32; + } else if (typeid(double).hash_code() == type.hash_code()) { + return DataType::FP64; + } else if (typeid(int).hash_code() == type.hash_code()) { + return DataType::INT32; + } else { + PADDLE_THROW("Not supported"); + return static_cast(-1); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 90077d0192421..0db67e4c67852 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -104,8 +104,9 @@ template class OpKernelRegistrar : public Registrar { public: explicit OpKernelRegistrar(const char* op_type) { - OperatorWithKernel::OpKernelKey key; - key.place_ = PlaceType(); + using T = typename KernelType::ELEMENT_TYPE; + OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))), + PlaceType()); OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType); } }; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 77c7c855c0ffe..4e81d1eaa9dfc 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "op_info.h" #include "paddle/framework/attribute.h" +#include "paddle/framework/data_type.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/scope.h" @@ -407,7 +408,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase { const Scope& scope_; }; -class OpKernel { +class OpKernelBase { public: /** * ExecutionContext is the only parameter of Kernel Run function. @@ -418,33 +419,47 @@ class OpKernel { virtual void Compute(const ExecutionContext& context) const = 0; - virtual ~OpKernel() {} + virtual ~OpKernelBase() = default; +}; + +template +class OpKernel : public OpKernelBase { + public: + using ELEMENT_TYPE = T; }; class OperatorWithKernel : public OperatorBase { public: struct OpKernelKey { platform::Place place_; + DataType data_type_; - OpKernelKey() = default; - explicit OpKernelKey(const platform::DeviceContext& dev_ctx) { - place_ = dev_ctx.GetPlace(); - } + OpKernelKey(DataType data_type, platform::Place place) + : place_(place), data_type_(data_type) {} + + OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx) + : place_(dev_ctx.GetPlace()), data_type_(data_type) {} bool operator==(const OpKernelKey& o) const { - return platform::places_are_same_class(place_, o.place_); + return platform::places_are_same_class(place_, o.place_) && + data_type_ == o.data_type_; } }; struct OpKernelHash { - std::hash hash_; + std::hash hash_; size_t operator()(const OpKernelKey& key) const { - return hash_(platform::is_gpu_place(key.place_)); + int place = key.place_.which(); + int data_type = static_cast(key.data_type_); + // NOTE: Number of places limit to 16. + int pre_hash = data_type << 4 | (place & 0x0F); + return hash_(pre_hash); } }; using OpKernelMap = - std::unordered_map, OpKernelHash>; + std::unordered_map, + OpKernelHash>; OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) @@ -458,8 +473,10 @@ class OperatorWithKernel : public OperatorBase { void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const final { - auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(ExecutionContext(*this, scope, dev_ctx)); + ExecutionContext ctx(*this, scope, dev_ctx); + auto& opKernel = AllOpKernels().at(type_).at( + OpKernelKey(IndicateDataType(ctx), dev_ctx)); + opKernel->Compute(ctx); } static std::unordered_map& @@ -469,13 +486,43 @@ class OperatorWithKernel : public OperatorBase { } bool SupportGPU() const override { - OperatorWithKernel::OpKernelKey key; - key.place_ = platform::GPUPlace(); - return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0; + auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); + return std::any_of(op_kernels.begin(), op_kernels.end(), + [](OpKernelMap::const_reference kern_pair) { + return platform::is_gpu_place(kern_pair.first.place_); + }); } protected: virtual void InferShape(InferShapeContextBase* ctx) const = 0; + + // indicate kernel DataType by input data. Defaultly all input data must be + // same. + virtual DataType IndicateDataType(const ExecutionContext& ctx) const { + auto& scope = ctx.scope(); + int data_type = -1; + for (auto& input : this->inputs_) { + for (auto& ipt_name : input.second) { + auto* var = scope.FindVar(ipt_name); + if (var != nullptr) { + const Tensor* t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t != nullptr) { + int tmp = static_cast(ToDataType(t->type())); + PADDLE_ENFORCE(tmp == data_type || data_type == -1, + "DataType of Paddle Op must be same."); + data_type = tmp; + } + } + } + } + PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input"); + return static_cast(data_type); + } }; } // namespace framework diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index f040c09c089ec..80a3f0a3935ef 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -29,20 +29,10 @@ limitations under the License. */ namespace paddle { -namespace pybind { -namespace details { -template -struct CastToPyBufferImpl; -} -} // namespace pybind - namespace framework { class Tensor { public: - template - friend struct pybind::details::CastToPyBufferImpl; - template friend struct EigenTensor; @@ -119,6 +109,8 @@ class Tensor { return holder_->place(); } + std::type_index type() const { return holder_->type(); } + private: template inline void check_memory_size() const; diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 75e8a989036f0..0ca9ef941d4cb 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata, } template -class AccuracyOpCUDAKernel : public framework::OpKernel { +class AccuracyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h index fe704efe1c979..12c6b9aac8819 100644 --- a/paddle/operators/accuracy_op.h +++ b/paddle/operators/accuracy_op.h @@ -35,7 +35,7 @@ template ; template -class AccuracyKernel : public framework::OpKernel { +class AccuracyKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Inference"); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 15f8afb4ba45c..e400992ae2968 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -20,7 +20,7 @@ namespace paddle { namespace operators { template -class ActivationKernel : public framework::OpKernel { +class ActivationKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -36,7 +36,7 @@ class ActivationKernel : public framework::OpKernel { }; template -class ActivationGradKernel : public framework::OpKernel { +class ActivationGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -202,7 +202,7 @@ struct SquareGradFunctor { }; template -class BReluKernel : public framework::OpKernel { +class BReluKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -219,7 +219,7 @@ class BReluKernel : public framework::OpKernel { }; template -class BReluGradKernel : public framework::OpKernel { +class BReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -239,7 +239,7 @@ class BReluGradKernel : public framework::OpKernel { }; template -class SoftReluKernel : public framework::OpKernel { +class SoftReluKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -256,7 +256,7 @@ class SoftReluKernel : public framework::OpKernel { }; template -class SoftReluGradKernel : public framework::OpKernel { +class SoftReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -277,7 +277,7 @@ class SoftReluGradKernel : public framework::OpKernel { }; template -class PowKernel : public framework::OpKernel { +class PowKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -293,7 +293,7 @@ class PowKernel : public framework::OpKernel { }; template -class PowGradKernel : public framework::OpKernel { +class PowGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -312,7 +312,7 @@ class PowGradKernel : public framework::OpKernel { }; template -class STanhKernel : public framework::OpKernel { +class STanhKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); @@ -329,7 +329,7 @@ class STanhKernel : public framework::OpKernel { }; template -class STanhGradKernel : public framework::OpKernel { +class STanhGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index a7307b6818aa3..75163032a1ff1 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -25,7 +25,7 @@ template ; template -class AddKernel : public framework::OpKernel { +class AddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* input0 = context.Input("X"); diff --git a/paddle/operators/clip_op.h b/paddle/operators/clip_op.h index ce1d4e1f46041..ac702e9935201 100644 --- a/paddle/operators/clip_op.h +++ b/paddle/operators/clip_op.h @@ -56,7 +56,7 @@ class ClipGradFunctor { }; template -class ClipKernel : public framework::OpKernel { +class ClipKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto max = context.Attr("max"); @@ -73,7 +73,7 @@ class ClipKernel : public framework::OpKernel { }; template -class ClipGradKernel : public framework::OpKernel { +class ClipGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto max = context.Attr("max"); diff --git a/paddle/operators/concat_op.h b/paddle/operators/concat_op.h index f977054fdf8aa..b0801ab062dc8 100644 --- a/paddle/operators/concat_op.h +++ b/paddle/operators/concat_op.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { template -class ConcatKernel : public framework::OpKernel { +class ConcatKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index bcf6f758cae56..68c56f531f941 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -28,7 +28,7 @@ template ; template -class CosSimKernel : public framework::OpKernel { +class CosSimKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { // get Tensor @@ -67,7 +67,7 @@ class CosSimKernel : public framework::OpKernel { }; template -class CosSimGradKernel : public framework::OpKernel { +class CosSimGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { // get Tensor diff --git a/paddle/operators/crop_op.h b/paddle/operators/crop_op.h index ac3aeaf41e206..2e72583d68d0a 100644 --- a/paddle/operators/crop_op.h +++ b/paddle/operators/crop_op.h @@ -27,7 +27,7 @@ using EigenTensor = framework::EigenTensor; using framework::Tensor; template -class CropKernel : public framework::OpKernel { +class CropKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); @@ -69,7 +69,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { } template -class CropGradKernel : public framework::OpKernel { +class CropGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { size_t rank = diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 26fc9b51c44d2..4b67887f3638f 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -47,6 +47,12 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Y", {x_dims[0], 1}); ctx->ShareLoD("X", /*->*/ "Y"); } + + // CrossEntropy's data type just determined by "X" + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class CrossEntropyGradientOp : public framework::OperatorWithKernel { @@ -87,6 +93,12 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } + + // CrossEntropy's data type just determined by "X" + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 1cfeb7a53b047..76d63f77adccb 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -53,7 +53,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, } // namespace template -class CrossEntropyOpCUDAKernel : public framework::OpKernel { +class CrossEntropyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -69,7 +69,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { }; template -class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { +class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 1f67461d3fadb..fa81d3b4310a8 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -26,7 +26,7 @@ template ; template -class CrossEntropyOpKernel : public framework::OpKernel { +class CrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), @@ -42,7 +42,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { }; template -class CrossEntropyGradientOpKernel : public framework::OpKernel { +class CrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index a04e4a22cc09d..30c769000f2b9 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -47,7 +47,7 @@ struct MaskGenerator { // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. template -class GPUDropoutKernel : public framework::OpKernel { +class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h index d57f64afcb355..745525fe81dad 100644 --- a/paddle/operators/dropout_op.h +++ b/paddle/operators/dropout_op.h @@ -26,7 +26,7 @@ template ; template -class CPUDropoutKernel : public framework::OpKernel { +class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); @@ -62,7 +62,7 @@ class CPUDropoutKernel : public framework::OpKernel { }; template -class DropoutGradKernel : public framework::OpKernel { +class DropoutGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(context.Attr("is_training"), diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index e9f78ef26e058..f04fe3ec6069a 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -20,7 +20,7 @@ namespace paddle { namespace operators { template -class ElementwiseAddKernel : public framework::OpKernel { +class ElementwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { ElementwiseCompute(ctx); @@ -101,7 +101,7 @@ struct ElementwiseAddBroadCast2GradFunctor { }; template -class ElementwiseAddGradKernel : public framework::OpKernel { +class ElementwiseAddGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { ElementwiseGradCompute, diff --git a/paddle/operators/elementwise_div_op.h b/paddle/operators/elementwise_div_op.h index 99b6d9c1991ed..8946ff3d25c2a 100644 --- a/paddle/operators/elementwise_div_op.h +++ b/paddle/operators/elementwise_div_op.h @@ -20,7 +20,7 @@ namespace paddle { namespace operators { template -class ElementwiseDivKernel : public framework::OpKernel { +class ElementwiseDivKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { ElementwiseCompute(ctx); @@ -103,7 +103,7 @@ struct ElementwiseDivBroadCast2GradFunctor { }; template -class ElementwiseDivGradKernel : public framework::OpKernel { +class ElementwiseDivGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { ElementwiseGradCompute, diff --git a/paddle/operators/elementwise_mul_op.h b/paddle/operators/elementwise_mul_op.h index 6ab642378bb0a..4469b07eaa08a 100644 --- a/paddle/operators/elementwise_mul_op.h +++ b/paddle/operators/elementwise_mul_op.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { template -class ElementwiseMulKernel : public framework::OpKernel { +class ElementwiseMulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { ElementwiseCompute(ctx); @@ -102,7 +102,7 @@ struct ElementwiseMulBroadCast2GradFunctor { }; template -class ElementwiseMulGradKernel : public framework::OpKernel { +class ElementwiseMulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { ElementwiseGradCompute, diff --git a/paddle/operators/elementwise_sub_op.h b/paddle/operators/elementwise_sub_op.h index 3ca1376c73b33..3f40c1c5bcea5 100644 --- a/paddle/operators/elementwise_sub_op.h +++ b/paddle/operators/elementwise_sub_op.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { template -class ElementwiseSubKernel : public framework::OpKernel { +class ElementwiseSubKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { ElementwiseCompute(ctx); @@ -102,7 +102,7 @@ struct ElementwiseSubBroadCast2GradFunctor { }; template -class ElementwiseSubGradKernel : public framework::OpKernel { +class ElementwiseSubGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { ElementwiseGradCompute, diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h index 4474581784531..cdf56a723b117 100644 --- a/paddle/operators/fill_zeros_like_op.h +++ b/paddle/operators/fill_zeros_like_op.h @@ -20,7 +20,7 @@ namespace paddle { namespace operators { template -class FillZerosLikeKernel : public framework::OpKernel { +class FillZerosLikeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* output = context.Output("Y"); diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index 381854f301870..073e566e8f696 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -24,7 +24,7 @@ namespace operators { using Tensor = framework::Tensor; template -class GatherOpKernel : public framework::OpKernel { +class GatherOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *X = ctx.Input("X"); @@ -37,7 +37,7 @@ class GatherOpKernel : public framework::OpKernel { }; template -class GatherGradientOpKernel : public framework::OpKernel { +class GatherGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *Index = ctx.Input("Index"); diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 05120a6e7bcfd..fc340c181cc6c 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -16,7 +16,7 @@ namespace paddle { namespace operators { template -class CPUGaussianRandomKernel : public framework::OpKernel { +class CPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { float mean = context.Attr("mean"); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 2d63b3049988c..315560bf1ba8a 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -37,7 +37,7 @@ struct GaussianGenerator { }; template -class GPUGaussianRandomKernel : public framework::OpKernel { +class GPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index 5c9e81732aa72..323e3f7c3bd50 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -25,7 +25,7 @@ namespace operators { using Tensor = framework::Tensor; template -class GemmConv2DKernel : public framework::OpKernel { +class GemmConv2DKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -98,7 +98,7 @@ class GemmConv2DKernel : public framework::OpKernel { }; template -class GemmConvGrad2DKernel : public framework::OpKernel { +class GemmConvGrad2DKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 62f63b4f3c876..c3808fa9a8de0 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -61,7 +61,7 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, } template -class LookupTableCUDAKernel : public framework::OpKernel { +class LookupTableCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto table_t = context.Input("W"); @@ -85,7 +85,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { }; template -class LookupTableGradCUDAKernel : public framework::OpKernel { +class LookupTableGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto ids_t = context.Input("Ids"); diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index a1298906dd4b4..dfead2fc5b25b 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -23,7 +23,7 @@ namespace operators { using Tensor = framework::Tensor; template -class LookupTableKernel : public framework::OpKernel { +class LookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto table_t = context.Input("W"); // float tensor @@ -44,7 +44,7 @@ class LookupTableKernel : public framework::OpKernel { }; template -class LookupTableGradKernel : public framework::OpKernel { +class LookupTableGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto ids_t = context.Input("Ids"); diff --git a/paddle/operators/lstm_unit_op.cu b/paddle/operators/lstm_unit_op.cu index 6e5e4978994c2..b1db0d5322714 100644 --- a/paddle/operators/lstm_unit_op.cu +++ b/paddle/operators/lstm_unit_op.cu @@ -90,7 +90,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim, } template -class LstmUnitOpCUDAKernel : public framework::OpKernel { +class LstmUnitOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -121,7 +121,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel { }; template -class LstmUnitGradOpCUDAKernel : public framework::OpKernel { +class LstmUnitGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h index 683034fe15df8..0dc9a7d9a7aae 100644 --- a/paddle/operators/lstm_unit_op.h +++ b/paddle/operators/lstm_unit_op.h @@ -33,7 +33,7 @@ inline T tanh(T x) { } template -class LstmUnitKernel : public framework::OpKernel { +class LstmUnitKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), @@ -76,7 +76,7 @@ class LstmUnitKernel : public framework::OpKernel { }; template -class LstmUnitGradKernel : public framework::OpKernel { +class LstmUnitGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index ce31e178d8e37..c99286a5b928f 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -28,7 +28,7 @@ template ; template -class MeanKernel : public framework::OpKernel { +class MeanKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* input = context.Input("X"); @@ -45,7 +45,7 @@ class MeanKernel : public framework::OpKernel { }; template -class MeanGradKernel : public framework::OpKernel { +class MeanGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto OG = context.Input(framework::GradVarName("Out")); diff --git a/paddle/operators/minus_op.h b/paddle/operators/minus_op.h index 6310a4fd51415..bd9a2790aa2b2 100644 --- a/paddle/operators/minus_op.h +++ b/paddle/operators/minus_op.h @@ -20,7 +20,7 @@ namespace paddle { namespace operators { template -class MinusKernel : public framework::OpKernel { +class MinusKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* left_tensor = context.Input("X"); diff --git a/paddle/operators/modified_huber_loss_op.cu b/paddle/operators/modified_huber_loss_op.cu index bce760f95e72c..8854e166cd99c 100644 --- a/paddle/operators/modified_huber_loss_op.cu +++ b/paddle/operators/modified_huber_loss_op.cu @@ -39,7 +39,7 @@ struct ModifiedHuberLossBackward { }; template -class ModifiedHuberLossGradGPUKernel : public framework::OpKernel { +class ModifiedHuberLossGradGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("Y"); diff --git a/paddle/operators/modified_huber_loss_op.h b/paddle/operators/modified_huber_loss_op.h index cb51007749e3c..aba75efad9c19 100644 --- a/paddle/operators/modified_huber_loss_op.h +++ b/paddle/operators/modified_huber_loss_op.h @@ -47,7 +47,7 @@ struct ModifiedHuberLossForward { }; template -class ModifiedHuberLossKernel : public framework::OpKernel { +class ModifiedHuberLossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("X"); @@ -73,7 +73,7 @@ class ModifiedHuberLossKernel : public framework::OpKernel { // CPU backward kernel template -class ModifiedHuberLossGradCPUKernel : public framework::OpKernel { +class ModifiedHuberLossGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("Y"); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index ac7136a76933d..684b1ea0c0c8d 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -28,7 +28,7 @@ template ; template -class MulKernel : public framework::OpKernel { +class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* x = context.Input("X"); @@ -52,7 +52,7 @@ class MulKernel : public framework::OpKernel { }; template -class MulGradKernel : public framework::OpKernel { +class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { int x_num_col_dims = ctx.template Attr("x_num_col_dims"); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 505776612e711..72b1f96eafde3 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -21,7 +21,7 @@ namespace operators { using Tensor = framework::Tensor; template -class MultiplexGPUKernel : public framework::OpKernel { +class MultiplexGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); @@ -51,7 +51,7 @@ class MultiplexGPUKernel : public framework::OpKernel { }; template -class MultiplexGradGPUKernel : public framework::OpKernel { +class MultiplexGradGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 637c63a34af39..ab3cafaa324a2 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -23,7 +23,7 @@ namespace paddle { namespace operators { template -class MultiplexCPUKernel : public framework::OpKernel { +class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); @@ -48,7 +48,7 @@ class MultiplexCPUKernel : public framework::OpKernel { }; template -class MultiplexGradCPUKernel : public framework::OpKernel { +class MultiplexGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h index 2cc3b945ae5b2..9534dbf54529e 100644 --- a/paddle/operators/pad_op.h +++ b/paddle/operators/pad_op.h @@ -47,7 +47,7 @@ void PadFunction(const framework::ExecutionContext& context) { } template -class PadKernel : public framework::OpKernel { +class PadKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { int rank = context.Input("X")->dims().size(); @@ -97,7 +97,7 @@ void PadGradFunction(const framework::ExecutionContext& context) { } template -class PadGradKernel : public framework::OpKernel { +class PadGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { size_t rank = diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index 6b78ed295cbac..5ad31c2203ae6 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -40,7 +40,7 @@ class PReluFunctor { }; template -class PReluKernel : public framework::OpKernel { +class PReluKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); @@ -77,7 +77,7 @@ class PReluGradFunctor { }; template -class PReluGradKernel : public framework::OpKernel { +class PReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* dx = context.Output(framework::GradVarName("X")); diff --git a/paddle/operators/rank_loss_op.h b/paddle/operators/rank_loss_op.h index 7df195ff47ecf..f184d6efcb496 100644 --- a/paddle/operators/rank_loss_op.h +++ b/paddle/operators/rank_loss_op.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { template -class RankLossKernel : public framework::OpKernel { +class RankLossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); @@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel { }; template -class RankLossGradKernel : public framework::OpKernel { +class RankLossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_left_t = diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index 873acf30782d3..628dfe4c0fadc 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { template -class ReshapeKernel : public framework::OpKernel { +class ReshapeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output("Out"); @@ -39,7 +39,7 @@ class ReshapeKernel : public framework::OpKernel { }; template -class ReshapeGradKernel : public framework::OpKernel { +class ReshapeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 35774b940926f..b43e5d868b383 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -28,7 +28,7 @@ template ; template -class RowwiseAddKernel : public framework::OpKernel { +class RowwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); @@ -50,7 +50,7 @@ class RowwiseAddKernel : public framework::OpKernel { }; template -class RowwiseAddGradKernel : public framework::OpKernel { +class RowwiseAddGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* dout = context.Input(framework::GradVarName("Out")); diff --git a/paddle/operators/scale_op.h b/paddle/operators/scale_op.h index 02fbdc52bbf89..dc6bc768997f4 100644 --- a/paddle/operators/scale_op.h +++ b/paddle/operators/scale_op.h @@ -20,7 +20,7 @@ namespace paddle { namespace operators { template -class ScaleKernel : public framework::OpKernel { +class ScaleKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& context) const { auto* tensor = context.Output("Out"); diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index e9595638a86a4..a8eb54399a932 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -24,7 +24,7 @@ namespace operators { using Tensor = framework::Tensor; template -class ScatterOpKernel : public framework::OpKernel { +class ScatterOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *Ref = ctx.Input("Ref"); @@ -40,7 +40,7 @@ class ScatterOpKernel : public framework::OpKernel { }; template -class ScatterGradientOpKernel : public framework::OpKernel { +class ScatterGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *dRef = ctx.Output(framework::GradVarName("Ref")); diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index cb80586e88f8d..752d714125578 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -38,7 +38,7 @@ enum SeqPoolType { }; template -class SequencePoolKernel : public framework::OpKernel { +class SequencePoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); @@ -85,7 +85,7 @@ class SequencePoolKernel : public framework::OpKernel { }; template -class SequencePoolGradKernel : public framework::OpKernel { +class SequencePoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index f8888f9c362e1..a3fe3308942f9 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -25,7 +25,7 @@ template ; template -class SGDOpKernel : public framework::OpKernel { +class SGDOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto param = ctx.Input("param"); diff --git a/paddle/operators/smooth_l1_loss_op.h b/paddle/operators/smooth_l1_loss_op.h index 0604fb5e1c2f1..39d0070b6c890 100644 --- a/paddle/operators/smooth_l1_loss_op.h +++ b/paddle/operators/smooth_l1_loss_op.h @@ -45,7 +45,7 @@ struct SmoothL1LossForward { }; template -class SmoothL1LossKernel : public framework::OpKernel { +class SmoothL1LossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("X"); @@ -115,7 +115,7 @@ struct SmoothL1LossBackward { }; template -class SmoothL1LossGradKernel : public framework::OpKernel { +class SmoothL1LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("InsideWeight"); diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 7220f486be055..9996536454b1b 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -26,7 +26,7 @@ template ; template -class SoftmaxKernel : public framework::OpKernel { +class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto X = context.Input("X"); @@ -40,7 +40,7 @@ class SoftmaxKernel : public framework::OpKernel { }; template -class SoftmaxGradKernel : public framework::OpKernel { +class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto Y = context.Input("Y"); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index 1cf4296dccf68..c3086e729e493 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -53,7 +53,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, } // namespace template -class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { +class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), @@ -73,7 +73,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { }; template -class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { +class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index bf792c1f59e2e..a8b18504e1c3a 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -27,7 +27,7 @@ template ; template -class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { +class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), @@ -47,7 +47,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { }; template -class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { +class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* out_grad = diff --git a/paddle/operators/split_op.h b/paddle/operators/split_op.h index 860690ee89507..bc1b12279e350 100644 --- a/paddle/operators/split_op.h +++ b/paddle/operators/split_op.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { template -class SplitKernel : public framework::OpKernel { +class SplitKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); diff --git a/paddle/operators/squared_l2_distance_op.h b/paddle/operators/squared_l2_distance_op.h index 097ac04fc09a1..259ef40296469 100644 --- a/paddle/operators/squared_l2_distance_op.h +++ b/paddle/operators/squared_l2_distance_op.h @@ -28,7 +28,7 @@ template ; template -class SquaredL2DistanceKernel : public framework::OpKernel { +class SquaredL2DistanceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("X"); @@ -68,7 +68,7 @@ class SquaredL2DistanceKernel : public framework::OpKernel { }; template -class SquaredL2DistanceGradKernel : public framework::OpKernel { +class SquaredL2DistanceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("sub_result"); diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h index 0b1e9ebaa38d4..7e8fbb9e41c69 100644 --- a/paddle/operators/sum_op.h +++ b/paddle/operators/sum_op.h @@ -22,7 +22,7 @@ template ; template -class SumKernel : public framework::OpKernel { +class SumKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto ins = context.MultiInput("X"); @@ -43,7 +43,7 @@ class SumKernel : public framework::OpKernel { }; template -class SumGradKernel : public framework::OpKernel { +class SumGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* input = context.Input(framework::GradVarName("Out")); diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu index 53fe505b77bfa..7be6932f1e301 100644 --- a/paddle/operators/top_k_op.cu +++ b/paddle/operators/top_k_op.cu @@ -279,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int* indices, } template -class TopkOpCUDAKernel : public framework::OpKernel { +class TopkOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), diff --git a/paddle/operators/top_k_op.h b/paddle/operators/top_k_op.h index ef66acc1d5692..4b248faa120bc 100644 --- a/paddle/operators/top_k_op.h +++ b/paddle/operators/top_k_op.h @@ -28,7 +28,7 @@ template ; template -class TopkKernel : public framework::OpKernel { +class TopkKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { // Get the top k elements of each row of input tensor diff --git a/paddle/operators/transpose_op.h b/paddle/operators/transpose_op.h index ea299dce72ad3..aaa3f47ab5545 100644 --- a/paddle/operators/transpose_op.h +++ b/paddle/operators/transpose_op.h @@ -38,7 +38,7 @@ void EigenTranspose(const framework::ExecutionContext& context, } template -class TransposeKernel : public framework::OpKernel { +class TransposeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); @@ -73,7 +73,7 @@ class TransposeKernel : public framework::OpKernel { }; template -class TransposeGradKernel : public framework::OpKernel { +class TransposeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* out_grad = diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 2771df56086ff..878d71802abb3 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -21,7 +21,7 @@ namespace operators { // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. template -class CPUUniformRandomKernel : public framework::OpKernel { +class CPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* tensor = ctx.Output("Out"); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 6614b53b3f990..5612ce9eb1c64 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -40,7 +40,7 @@ struct UniformGenerator { // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. template -class GPUUniformRandomKernel : public framework::OpKernel { +class GPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); diff --git a/paddle/platform/place.cc b/paddle/platform/place.cc index b31515e1f028a..856e54df89c1c 100644 --- a/paddle/platform/place.cc +++ b/paddle/platform/place.cc @@ -47,7 +47,7 @@ bool is_cpu_place(const Place &p) { } bool places_are_same_class(const Place &p1, const Place &p2) { - return is_gpu_place(p1) == is_gpu_place(p2); + return p1.which() == p2.which(); } std::ostream &operator<<(std::ostream &os, const Place &p) { diff --git a/paddle/pybind/tensor_py.h b/paddle/pybind/tensor_py.h index f0d5a6f9ff963..10621e90eebf5 100644 --- a/paddle/pybind/tensor_py.h +++ b/paddle/pybind/tensor_py.h @@ -42,7 +42,7 @@ template struct CastToPyBufferImpl { using CUR_TYPE = typename std::tuple_element>::type; py::buffer_info operator()(framework::Tensor &tensor) { - if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { + if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) { auto dim_vec = framework::vectorize(tensor.dims()); std::vector dims_outside; std::vector strides; @@ -56,13 +56,13 @@ struct CastToPyBufferImpl { prod *= dims_outside[i - 1]; } framework::Tensor dst_tensor; - if (paddle::platform::is_gpu_place(tensor.holder_->place())) { + if (paddle::platform::is_gpu_place(tensor.place())) { dst_tensor.CopyFrom(tensor, platform::CPUPlace()); - } else if (paddle::platform::is_cpu_place(tensor.holder_->place())) { + } else if (paddle::platform::is_cpu_place(tensor.place())) { dst_tensor = tensor; } return py::buffer_info( - dst_tensor.mutable_data(dst_tensor.holder_->place()), + dst_tensor.mutable_data(dst_tensor.place()), sizeof(CUR_TYPE), py::format_descriptor::format(), (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); } else { From 2c05465d2f0f1134610364508fa73281fd44f1ad Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 27 Sep 2017 17:10:07 -0700 Subject: [PATCH 2/8] Fix unit-tests --- paddle/framework/operator_test.cc | 7 +++++-- paddle/operators/gather_op.cc | 10 ++++++++++ paddle/operators/gaussian_random_op.cc | 7 +++++++ paddle/operators/lookup_table_op.cc | 10 ++++++++++ paddle/operators/multiplex_op.cc | 10 ++++++++++ paddle/operators/scatter_op.cc | 10 ++++++++++ paddle/operators/softmax_with_cross_entropy_op.cc | 11 +++++++++++ paddle/operators/uniform_random_op.cc | 7 +++++++ 8 files changed, 70 insertions(+), 2 deletions(-) diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 8b4bb01a7bb80..7f0ec90adef78 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -116,10 +116,13 @@ class OpWithKernelTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase* ctx) const override {} + DataType IndicateDataType(const ExecutionContext& ctx) const override { + return DataType::FP32; + } }; template -class CPUKernelTest : public OpKernel { +class CPUKernelTest : public OpKernel { public: void Compute(const ExecutionContext& ctx) const { std::cout << "this is cpu kernel" << std::endl; @@ -146,7 +149,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker } }; -class CPUKernalMultiInputsTest : public OpKernel { +class CPUKernalMultiInputsTest : public OpKernel { public: void Compute(const ExecutionContext& ctx) const { auto xs = ctx.op().Inputs("xs"); diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 0e3cd174adee1..da22bd0c52c27 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -37,6 +37,11 @@ class GatherOp : public framework::OperatorWithKernel { output_dims[0] = batch_size; ctx->SetOutputDim("Out", output_dims); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class GatherGradOp : public framework::OperatorWithKernel { @@ -47,6 +52,11 @@ class GatherGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContextBase* ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class GatherOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index fc340c181cc6c..5cd2c7d2c066c 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -56,6 +56,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel { "dims can be one int or array. dims must be set."); ctx->SetOutputDim("Out", framework::make_ddim(temp)); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("data_type")); + } }; class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { @@ -76,6 +81,8 @@ Use to initialize tensor with gaussian random generator. "Random seed of generator." "0 means use system wide seed") .SetDefault(0); + AddAttr("data_type", "output data type") + .SetDefault(framework::DataType::FP32); } }; diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 9b1314bfbade8..929008fbcbe03 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -36,6 +36,11 @@ class LookupTableOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->ShareLoD("Ids", /*->*/ "Out"); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("W")->type()); + } }; class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { @@ -69,6 +74,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); ctx->SetOutputDim(framework::GradVarName("W"), table_dims); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("W")->type()); + } }; } // namespace operators diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 9896d269ccc86..a069127a19a1d 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -50,6 +50,11 @@ class MultiplexOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", in_dim); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.MultiInput("X")[0]->type()); + } }; class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { @@ -99,6 +104,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel { } ctx->SetOutputsDim(framework::GradVarName("X"), d_ins); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.MultiInput("X")[0]->type()); + } }; } // namespace operators diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 3fc4a39ebc552..619acfc8b62ed 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -48,6 +48,11 @@ class ScatterOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", ref_dims); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class ScatterGradOp : public framework::OperatorWithKernel { @@ -60,6 +65,11 @@ class ScatterGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("Updates")); ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index e2299b254458c..de7c532421c48 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/softmax_with_cross_entropy_op.h" +#include namespace paddle { namespace operators { @@ -115,6 +116,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ctx->ShareLoD("Logits", /*->*/ "Softmax"); ctx->ShareLoD("Logits", /*->*/ "Loss"); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("Logits")->type()); + } }; class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { @@ -149,6 +155,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("Logits"), ctx->GetInputDim("Softmax")); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("Logits")->type()); + } }; } // namespace operators diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 878d71802abb3..97b1d0bed4595 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -62,6 +62,11 @@ class UniformRandomOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", framework::make_ddim(temp)); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("data_type")); + } }; class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { @@ -80,6 +85,8 @@ Used to initialize tensor with uniform random generator. "Random seed of uniform random. " "0 means generate a seed by system") .SetDefault(0); + AddAttr("data_type", "output tensor data type") + .SetDefault(framework::DataType::FP32); } }; } // namespace operators From f1913d46972b11d852f42072eedd5485c721d2c5 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 27 Sep 2017 17:28:12 -0700 Subject: [PATCH 3/8] Change registry, test register double kernel --- paddle/framework/op_registry.h | 34 ++++++++++++++++++++++---- paddle/operators/elementwise_mul_op.cc | 6 +++-- paddle/operators/elementwise_mul_op.cu | 6 +++-- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 0db67e4c67852..804f901dfa278 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -100,14 +100,38 @@ class OpRegistrar : public Registrar { } }; -template +template +struct OpKernelRegistrarFunctor; + +template +struct OpKernelRegistrarFunctor { + using KT = typename std::tuple_element>::type; + + void operator()(const char* op_type) const { + using T = typename KT::ELEMENT_TYPE; + OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))), + PlaceType()); + OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KT); + + constexpr auto size = std::tuple_size>::value; + OpKernelRegistrarFunctor + func; + func(op_type); + } +}; + +template +struct OpKernelRegistrarFunctor { + void operator()(const char* op_type) const {} +}; + +// User can register many kernel in one place. The data type could be different. +template class OpKernelRegistrar : public Registrar { public: explicit OpKernelRegistrar(const char* op_type) { - using T = typename KernelType::ELEMENT_TYPE; - OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))), - PlaceType()); - OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType); + OpKernelRegistrarFunctor func; + func(op_type); } }; diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc index bda5dfe03e974..da7765aa6a7a8 100644 --- a/paddle/operators/elementwise_mul_op.cc +++ b/paddle/operators/elementwise_mul_op.cc @@ -36,7 +36,9 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker, elementwise_mul_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_mul, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CPU_KERNEL( elementwise_mul_grad, - ops::ElementwiseMulGradKernel); + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel); diff --git a/paddle/operators/elementwise_mul_op.cu b/paddle/operators/elementwise_mul_op.cu index da08a75596c4d..056f081d3e6ac 100644 --- a/paddle/operators/elementwise_mul_op.cu +++ b/paddle/operators/elementwise_mul_op.cu @@ -19,7 +19,9 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( elementwise_mul, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_GPU_KERNEL( elementwise_mul_grad, - ops::ElementwiseMulGradKernel); + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel); From ae3dca770c34d14ec81587eef9798c6a7070ec13 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 27 Sep 2017 19:15:51 -0700 Subject: [PATCH 4/8] Fix CI --- paddle/operators/scatter_op.cc | 4 ++-- paddle/operators/softmax_with_cross_entropy_op.cc | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 619acfc8b62ed..cadd8841b6ab3 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -51,7 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel { framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return framework::ToDataType(ctx.Input("X")->type()); + return framework::ToDataType(ctx.Input("Ref")->type()); } }; @@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel { framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return framework::ToDataType(ctx.Input("X")->type()); + return framework::ToDataType(ctx.Input("Ref")->type()); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index de7c532421c48..a76489871f30d 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -158,7 +158,8 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return framework::ToDataType(ctx.Input("Logits")->type()); + return framework::ToDataType( + ctx.Input(framework::GradVarName("Loss"))->type()); } }; From f2feb333843d83b74a9d29f0c73ffe8409795df0 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 28 Sep 2017 10:50:14 -0700 Subject: [PATCH 5/8] Follow comments --- paddle/framework/operator.h | 4 ++-- paddle/platform/place.h | 11 +++++++++++ paddle/platform/variant.h | 2 ++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 4e81d1eaa9dfc..7d563a3c05987 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -451,8 +451,8 @@ class OperatorWithKernel : public OperatorBase { size_t operator()(const OpKernelKey& key) const { int place = key.place_.which(); int data_type = static_cast(key.data_type_); - // NOTE: Number of places limit to 16. - int pre_hash = data_type << 4 | (place & 0x0F); + int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT | + (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1)); return hash_(pre_hash); } }; diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 1117476bb37f1..0efc6932349a5 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include + #include "paddle/platform/variant.h" namespace paddle { @@ -46,8 +47,18 @@ struct IsGPUPlace : public boost::static_visitor { bool operator()(const GPUPlace &gpu) const { return true; } }; +// Define the max number of Place in bit length. i.e., the max number of places +// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) +#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 + typedef boost::variant Place; +// static check number of place types is less equal than +// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) +BOOST_MPL_ASSERT((boost::mpl::less_equal< + Place::types::size, + boost::mpl::long_<1 << NUM_PLACE_TYPE_LIMIT_IN_BIT>>)); + void set_place(const Place &); const Place &get_place(); diff --git a/paddle/platform/variant.h b/paddle/platform/variant.h index c2257af1b5dd1..16ee00efe7a9b 100644 --- a/paddle/platform/variant.h +++ b/paddle/platform/variant.h @@ -29,4 +29,6 @@ #endif #endif +#include +#include #include From b9c863723870ecec231ca0ff0b84e7beab8fe5ce Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 28 Sep 2017 11:09:42 -0700 Subject: [PATCH 6/8] Fix compile --- paddle/operators/concat_op.h | 2 +- paddle/operators/reduce_op.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/operators/concat_op.h b/paddle/operators/concat_op.h index bff453971a00a..c113f19fb5cf8 100644 --- a/paddle/operators/concat_op.h +++ b/paddle/operators/concat_op.h @@ -44,7 +44,7 @@ class ConcatKernel : public framework::OpKernel { }; template -class ConcatGradKernel : public framework::OpKernel { +class ConcatGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* in = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index 2fbf94e34f396..ba3f3db81dc62 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -87,7 +87,7 @@ struct MaxOrMinGradFunctor { }; template -class ReduceKernel : public framework::OpKernel { +class ReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { int rank = context.Input("X")->dims().size(); @@ -141,7 +141,7 @@ class ReduceKernel : public framework::OpKernel { }; template -class ReduceGradKernel : public framework::OpKernel { +class ReduceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { int rank = context.Input("X")->dims().size(); From 87da1542260bfe6d8002c8da05008d5dde426b7c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 28 Sep 2017 12:04:09 -0700 Subject: [PATCH 7/8] FIx sigmoid_xe_with_logits_op compile --- paddle/operators/sigmoid_cross_entropy_with_logits_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/sigmoid_cross_entropy_with_logits_op.h b/paddle/operators/sigmoid_cross_entropy_with_logits_op.h index a6de9043fdbcd..41c619f181c87 100644 --- a/paddle/operators/sigmoid_cross_entropy_with_logits_op.h +++ b/paddle/operators/sigmoid_cross_entropy_with_logits_op.h @@ -21,7 +21,7 @@ namespace operators { // Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) template -class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { +class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { const framework::Tensor *X = context.Input("X"); @@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { // dX = sigmoid(X) - labels template -class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { +class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { const framework::Tensor *X = context.Input("X"); From d53b38e340b5f56f9547b53449fe6cdceefd3b97 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 28 Sep 2017 14:32:25 -0700 Subject: [PATCH 8/8] Follow comments, change KT to KERNEL_TYPE --- paddle/framework/op_registry.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 804f901dfa278..4db38badaea8a 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -103,18 +103,19 @@ class OpRegistrar : public Registrar { template struct OpKernelRegistrarFunctor; -template -struct OpKernelRegistrarFunctor { - using KT = typename std::tuple_element>::type; +template +struct OpKernelRegistrarFunctor { + using KERNEL_TYPE = + typename std::tuple_element>::type; void operator()(const char* op_type) const { - using T = typename KT::ELEMENT_TYPE; + using T = typename KERNEL_TYPE::ELEMENT_TYPE; OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))), PlaceType()); - OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KT); + OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE); - constexpr auto size = std::tuple_size>::value; - OpKernelRegistrarFunctor + constexpr auto size = std::tuple_size>::value; + OpKernelRegistrarFunctor func; func(op_type); }