Skip to content

Commit

Permalink
support multiple template parameter in KernelType for REGISTER_OP_XPU…
Browse files Browse the repository at this point in the history
…_KERNEL (#2932)
  • Loading branch information
jacquesqiao authored Jul 18, 2017
1 parent 861b66d commit 051676a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 8 additions & 6 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class OpRegisterHelper {
/**
* Macro to Register OperatorKernel.
*/
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
Expand All @@ -320,17 +320,19 @@ class OpRegisterHelper {
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
.reset(new __VA_ARGS__()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }

#define REGISTER_OP_GPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)
// (type, KernelType)
#define REGISTER_OP_GPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)

#define REGISTER_OP_CPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)
// (type, KernelType)
#define REGISTER_OP_CPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)

/**
* Macro to mark what Operator and Kernel we will use and tell the compiler to
Expand Down
4 changes: 3 additions & 1 deletion paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class OpWithKernelTest : public OperatorWithKernel {
const std::vector<Tensor*>& outputs) const override {}
};

template <typename T1, typename T2>
class CPUKernelTest : public OpKernel {
public:
void Compute(const KernelContext& ctx) const {
Expand Down Expand Up @@ -171,7 +172,8 @@ class CPUKernalMultiInputsTest : public OpKernel {

REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);
REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>);

// test with single input
TEST(OpKernel, all) {
Expand Down

0 comments on commit 051676a

Please sign in to comment.