Skip to content

Commit

Permalink
Merge pull request #4455 from reyoung/feature/make_paddle_support_double
Browse files Browse the repository at this point in the history
Support double precision
  • Loading branch information
reyoung authored Sep 28, 2017
2 parents 5d6d2bc + d53b38e commit 184768e
Show file tree
Hide file tree
Showing 72 changed files with 341 additions and 142 deletions.
36 changes: 36 additions & 0 deletions paddle/framework/data_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <typeindex>
#include "paddle/framework/framework.pb.h"

namespace paddle {
namespace framework {

inline DataType ToDataType(std::type_index type) {
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
} else {
PADDLE_THROW("Not supported");
return static_cast<DataType>(-1);
}
}

} // namespace framework
} // namespace paddle
34 changes: 30 additions & 4 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,39 @@ class OpRegistrar : public Registrar {
}
};

template <typename PlaceType, typename KernelType>
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;

template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;

void operator()(const char* op_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);

constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func;
func(op_type);
}
};

template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type) const {}
};

// User can register many kernel in one place. The data type could be different.
template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar {
public:
explicit OpKernelRegistrar(const char* op_type) {
OperatorWithKernel::OpKernelKey key;
key.place_ = PlaceType();
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
func(op_type);
}
};

Expand Down
77 changes: 62 additions & 15 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */

#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
Expand Down Expand Up @@ -403,7 +404,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
const Scope& scope_;
};

class OpKernel {
class OpKernelBase {
public:
/**
* ExecutionContext is the only parameter of Kernel Run function.
Expand All @@ -414,33 +415,47 @@ class OpKernel {

virtual void Compute(const ExecutionContext& context) const = 0;

virtual ~OpKernel() {}
virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
public:
using ELEMENT_TYPE = T;
};

class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
platform::Place place_;
DataType data_type_;

OpKernelKey() = default;
explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}
OpKernelKey(DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}

OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}

bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_);
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};

struct OpKernelHash {
std::hash<bool> hash_;
std::hash<int> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};

using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
OpKernelHash>;

OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
Expand All @@ -451,8 +466,10 @@ class OperatorWithKernel : public OperatorBase {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);

auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(*this, scope, dev_ctx));
ExecutionContext ctx(*this, scope, dev_ctx);
auto& opKernel = AllOpKernels().at(type_).at(
OpKernelKey(IndicateDataType(ctx), dev_ctx));
opKernel->Compute(ctx);
}

static std::unordered_map<std::string /* op_type */, OpKernelMap>&
Expand All @@ -462,13 +479,43 @@ class OperatorWithKernel : public OperatorBase {
}

bool SupportGPU() const override {
OperatorWithKernel::OpKernelKey key;
key.place_ = platform::GPUPlace();
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_);
});
}

protected:
virtual void InferShape(InferShapeContextBase* ctx) const = 0;

// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op must be same.");
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
};

} // namespace framework
Expand Down
7 changes: 5 additions & 2 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,13 @@ class OpWithKernelTest : public OperatorWithKernel {

protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
}
};

template <typename T1, typename T2>
class CPUKernelTest : public OpKernel {
class CPUKernelTest : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
Expand All @@ -144,7 +147,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
}
};

class CPUKernalMultiInputsTest : public OpKernel {
class CPUKernalMultiInputsTest : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op().Inputs("xs");
Expand Down
12 changes: 2 additions & 10 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,10 @@ limitations under the License. */

namespace paddle {

namespace pybind {
namespace details {
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
}
} // namespace pybind

namespace framework {

class Tensor {
public:
template <bool less, size_t i, typename... args>
friend struct pybind::details::CastToPyBufferImpl;

template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;

Expand Down Expand Up @@ -119,6 +109,8 @@ class Tensor {
return holder_->place();
}

std::type_index type() const { return holder_->type(); }

private:
template <typename T>
inline void check_memory_size() const;
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
}

template <typename T>
class AccuracyOpCUDAKernel : public framework::OpKernel {
class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/accuracy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;

template <typename Place, typename T>
class AccuracyKernel : public framework::OpKernel {
class AccuracyKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference");
Expand Down
20 changes: 10 additions & 10 deletions paddle/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace paddle {
namespace operators {

template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel {
class ActivationKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -36,7 +36,7 @@ class ActivationKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel {
class ActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand Down Expand Up @@ -202,7 +202,7 @@ struct SquareGradFunctor {
};

template <typename Place, typename T, typename AttrType = T>
class BReluKernel : public framework::OpKernel {
class BReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -219,7 +219,7 @@ class BReluKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class BReluGradKernel : public framework::OpKernel {
class BReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -239,7 +239,7 @@ class BReluGradKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class SoftReluKernel : public framework::OpKernel {
class SoftReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -256,7 +256,7 @@ class SoftReluKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class SoftReluGradKernel : public framework::OpKernel {
class SoftReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -277,7 +277,7 @@ class SoftReluGradKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class PowKernel : public framework::OpKernel {
class PowKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -293,7 +293,7 @@ class PowKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class PowGradKernel : public framework::OpKernel {
class PowGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -312,7 +312,7 @@ class PowGradKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class STanhKernel : public framework::OpKernel {
class STanhKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -329,7 +329,7 @@ class STanhKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class STanhGradKernel : public framework::OpKernel {
class STanhGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
class AddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X");
Expand Down
Loading

0 comments on commit 184768e

Please sign in to comment.