Skip to content

Commit

Permalink
Refine scalar fn registry (PaddlePaddle#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Aug 13, 2020
1 parent 83d1754 commit 27e3c03
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 186 deletions.
8 changes: 4 additions & 4 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "cinn/lang/module.h"
#include "cinn/lang/placeholder.h"
#include "cinn/optim/ir_simplify.h"
#include "cinn/runtime/cpu/use_extern_funcs.h"

namespace cinn {
namespace backends {
Expand Down Expand Up @@ -396,9 +397,8 @@ TEST(CodeGenC, matmul_packed) {
C->Bind(C_buf);

{
poly::Iterator i_outer, i_inner, j_outer, j_inner, k_outer, k_inner;
std::tie(i_outer, i_inner, j_outer, j_inner) = C->stage()->Tile(0, 1, bn.as_int32(), bn.as_int32());
std::tie(k_outer, k_inner) = C->stage()->Split(poly::Iterator("k0"), 4);
auto [i_outer, i_inner, j_outer, j_inner] = C->stage()->Tile(0, 1, bn.as_int32(), bn.as_int32());
auto [k_outer, k_inner] = C->stage()->Split(poly::Iterator("k0"), 4);
C->stage()->Reorder({i_outer, j_outer, i_inner, j_inner, k_outer, k_inner});
}

Expand Down Expand Up @@ -469,7 +469,7 @@ TEST(CodeGenC, call_extern) {
Placeholder<float> x("x", {M});

ir::Tensor y = Compute(
{M}, [=](Var i) -> Expr { return lang::CallExtern("tanh", {x(i)}); }, "y");
{M}, [=](Var i) -> Expr { return lang::CallExtern("cinn_cpu_tanh_fp32", {x(i)}); }, "y");
y->WithBuffer();

auto yexpr = Lower("yy", {y});
Expand Down
2 changes: 1 addition & 1 deletion cinn/backends/codegen_cuda_dev_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,7 @@ TEST(Cuda, external_function) {

auto fn = Lower("fn5", {A, B, C});

Module::Builder builder("module", common::DefaultHostTarget());
Module::Builder builder("module", common::DefaultNVGPUTarget());
builder.AddFunction(fn);

auto source_code = codegen.Compile(builder.Build());
Expand Down
13 changes: 1 addition & 12 deletions cinn/backends/extern_func_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,7 @@ std::ostream& operator<<(std::ostream& os, const ExternFuncID& x) {
return os;
}

ExternFunctionEmitterRegistry::ExternFunctionEmitterRegistry() {
// Register the runtime functions.
RuntimeSymbolRegistry::Global().RegisterFn(extern_tanh_host_repr, reinterpret_cast<void*>(__cinn_host_tanh_fp32));
RuntimeSymbolRegistry::Global().RegisterFn(extern_tanh_v_host_repr, reinterpret_cast<void*>(__cinn_host_tanh_v));

// tanh
Register(ExternFuncID(backend_C, extern_func__tanh), new ExternFuncEmitter_C_tanh);
Register(ExternFuncID(backend_llvm_host, extern_func__tanh), new ExternFuncEmitter_LLVM_tanh);
// tanh_v
Register(ExternFuncID(backend_C, extern_func__tanh_v), new ExternFuncEmitter_C_tanh_v);
Register(ExternFuncID(backend_llvm_host, extern_func__tanh_v), new ExternFuncEmitter_LLVM_tanh_v);
}
ExternFunctionEmitterRegistry::ExternFunctionEmitterRegistry() {}

const FunctionProto& ExternFunctionEmitter::func_proto() const {
auto* proto = ExternFunctionProtoRegistry::Global().Lookup(func_name());
Expand Down
93 changes: 0 additions & 93 deletions cinn/backends/extern_func_emitter_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,99 +8,6 @@
namespace cinn {
namespace backends {

// C tanh --
// @{
const char* ExternFuncEmitter_C_tanh::func_name() const { return extern_func__tanh; }
void ExternFuncEmitter_C_tanh::EmitImpl(const ir::Call* op) {
CHECK(codegen_) << "codegen_ should be bind first";

CHECK_EQ(op->read_args.size(), 1UL);
CHECK(op->write_args.empty());

codegen_->os() << extern_tanh_host_repr << "(";
codegen_->Print(op->read_args[0]);
codegen_->os() << ")";
}
bool ExternFuncEmitter_C_tanh::RetValuePacked() const { return false; }
const char* ExternFuncEmitter_C_tanh::backend_kind() const { return backend_C; }
void ExternFuncEmitter_C_tanh::BindCodeGen(void* codegen) {
CHECK(codegen);
codegen_ = reinterpret_cast<CodeGenC*>(codegen);
}
// @}

// LLVM tanh --
// @{
void ExternFuncEmitter_LLVM_tanh::BindCodeGen(void* codegen) { codegen_ = reinterpret_cast<CodeGenLLVM*>(codegen); }
const char* ExternFuncEmitter_LLVM_tanh::func_name() const { return extern_func__tanh; }
void ExternFuncEmitter_LLVM_tanh::EmitImpl(const ir::Call* op) {
CHECK(codegen_);
CodeGenLLVMforEmitter codegen_for_emitter(codegen_);

// function type.
llvm::Type* f32 = codegen_for_emitter.b()->getFloatTy();
llvm::FunctionType* fn_type = llvm::FunctionType::get(f32, {f32}, false);

llvm::Function* custom_function = llvm::dyn_cast<llvm::Function>(
codegen_for_emitter.m()->getOrInsertFunction(extern_tanh_host_repr, fn_type).getCallee());
custom_function->setCallingConv(llvm::CallingConv::C);

auto* arg = codegen_->Visit(&op->read_args[0]);

auto* ret = codegen_for_emitter.b()->CreateCall(custom_function, {arg});

codegen_->extern_func_emit_res_ = ret;
}

bool ExternFuncEmitter_LLVM_tanh::RetValuePacked() const { return false; }
const char* ExternFuncEmitter_LLVM_tanh::backend_kind() const { return backend_llvm_host; }
// @}

// @{
void ExternFuncEmitter_C_tanh_v::BindCodeGen(void* codegen) { codegen_ = reinterpret_cast<CodeGenC*>(codegen); }
const char* ExternFuncEmitter_C_tanh_v::func_name() const { return extern_func__tanh_v; }
void ExternFuncEmitter_C_tanh_v::EmitImpl(const ir::Call* op) {
auto& os = codegen_->os();
os << extern_tanh_v_host_repr;
os << "(";
codegen_->Print(op->read_args[0]);
os << ", ";
codegen_->Print(op->write_args[0]);
os << ")";
}
bool ExternFuncEmitter_C_tanh_v::RetValuePacked() const { return true; }
const char* ExternFuncEmitter_C_tanh_v::backend_kind() const { return backend_C; }
// @}

// @{
void ExternFuncEmitter_LLVM_tanh_v::BindCodeGen(void* codegen) { codegen_ = reinterpret_cast<CodeGenLLVM*>(codegen); }
const char* ExternFuncEmitter_LLVM_tanh_v::func_name() const { return extern_func__tanh_v; }
void ExternFuncEmitter_LLVM_tanh_v::EmitImpl(const ir::Call* op) {
CHECK(codegen_);
CodeGenLLVMforEmitter codegen_for_emitter(codegen_);

// function type.
llvm::Type* buffer_p = backends::llvm_type_of<cinn_buffer_t*>(codegen_->m());
llvm::Type* void_ty = codegen_->b()->getVoidTy();

llvm::FunctionType* fn_type = llvm::FunctionType::get(void_ty, {buffer_p, buffer_p}, false);

llvm::Function* custom_function = llvm::dyn_cast<llvm::Function>(
codegen_for_emitter.m()->getOrInsertFunction(extern_tanh_v_host_repr, fn_type).getCallee());
CHECK(custom_function) << "No function called " << extern_tanh_v_host_repr;
custom_function->setCallingConv(llvm::CallingConv::C);

auto* arg = codegen_for_emitter.GetVar(op->read_args[0].as_tensor()->buffer->name, false /*lazy*/);
auto* arg1 = codegen_for_emitter.GetVar(op->write_args[0].as_tensor()->buffer->name, false /*lazy*/);

auto* ret = codegen_for_emitter.b()->CreateCall(custom_function, {arg, arg1});

codegen_->extern_func_emit_res_ = ret;
}
bool ExternFuncEmitter_LLVM_tanh_v::RetValuePacked() const { return true; }
const char* ExternFuncEmitter_LLVM_tanh_v::backend_kind() const { return backend_llvm_host; }
// @}

void ExternFunctionLLVMEmitter::BindCodeGen(void* codegen) { codegen_ = reinterpret_cast<CodeGenLLVM*>(codegen); }

const char* ExternFunctionLLVMEmitter::func_name() const { return fn_name_.c_str(); }
Expand Down
51 changes: 0 additions & 51 deletions cinn/backends/extern_func_emitter_builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,56 +43,5 @@ class ExternFunctionLLVMEmitter : public ExternFunctionEmitter {
std::string fn_name_;
};

/**
* Emitter for tanh in CodeGenC.
*/
class ExternFuncEmitter_C_tanh : public ExternFunctionEmitter {
public:
void BindCodeGen(void* codegen) override;
const char* func_name() const override;
void EmitImpl(const ir::Call* op) override;
bool RetValuePacked() const override;
const char* backend_kind() const override;

private:
CodeGenC* codegen_{};
};

class ExternFuncEmitter_LLVM_tanh : public ExternFunctionEmitter {
public:
void BindCodeGen(void* codegen) override;
const char* func_name() const override;
void EmitImpl(const ir::Call* op) override;
bool RetValuePacked() const override;
const char* backend_kind() const override;

private:
CodeGenLLVM* codegen_{};
};

class ExternFuncEmitter_C_tanh_v : public ExternFunctionEmitter {
public:
void BindCodeGen(void* codegen) override;
const char* func_name() const override;
void EmitImpl(const ir::Call* op) override;
bool RetValuePacked() const override;
const char* backend_kind() const override;

private:
CodeGenC* codegen_{};
};

class ExternFuncEmitter_LLVM_tanh_v : public ExternFunctionEmitter {
public:
void BindCodeGen(void* codegen) override;
const char* func_name() const override;
void EmitImpl(const ir::Call* op) override;
bool RetValuePacked() const override;
const char* backend_kind() const override;

private:
CodeGenLLVM* codegen_{};
};

} // namespace backends
} // namespace cinn
3 changes: 2 additions & 1 deletion cinn/backends/llvm/execution_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "cinn/lang/placeholder.h"
#include "cinn/optim/optimize.h"
#include "cinn/runtime/cpu/host_intrinsics.h"
#include "cinn/runtime/cpu/use_extern_funcs.h"

namespace cinn {
namespace backends {
Expand Down Expand Up @@ -288,7 +289,7 @@ TEST(ExecutionEngine, call_extern) {
add_out->stage()->ComputeInline();

ir::Tensor res = Compute(
{M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern("tanh", {add_out(i, j)}); }, "res");
{M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern("cinn_cpu_tanh_fp32", {add_out(i, j)}); }, "res");

auto func = Lower("comp", {x, y, res});

Expand Down
2 changes: 1 addition & 1 deletion cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Expr Optimize(Expr e, bool runtime_debug_info) {

RemoveNestedBlock(&copied);

ActivateToExternCall(&copied);
// ActivateToExternCall(&copied);
ExternCallMultiOutputShallowStore(&copied);

Simplify(&copied);
Expand Down
2 changes: 1 addition & 1 deletion cinn/pybind/bind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct Visitor : Ts... {
};

template <class... Ts>
Visitor(Ts...) -> Visitor<Ts...>;
Visitor(Ts...)->Visitor<Ts...>;

using ExprOp = std::variant<ir::IntImm,
ir::UIntImm,
Expand Down
1 change: 1 addition & 0 deletions cinn/runtime/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(srcs
)

cc_test(test_mkl_math SRCS mkl_math_test.cc mkl_math.cc DEPS core)
cc_test(test_host_intrinsics SRCS host_intrinsics_test.cc DEPS core)

foreach(cpp ${srcs})
set(core_src
Expand Down
29 changes: 7 additions & 22 deletions cinn/runtime/cpu/host_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ using namespace std;
float cinn_cpu_##name__##_fp32(float a) { return name__(a); }

#define CINN_IMP_CPU_FUNC_INT_BINARY(name__, rule__) \
int cinn_cpu_##name__##_int(int a, int b) { return a rule__ b; }
int cinn_cpu_##name__##_int32(int a, int b) { return a rule__ b; }

#define CINN_IMP_CPU_FUNC_INT_UNARY(name__, rule__) \
int cinn_cpu_##name__##_int(int a) { return rule__(a); }
int cinn_cpu_##name__##_int32(int a) { return rule__(a); }

CINN_IMP_CPU_FUNC_FP32(exp);
CINN_IMP_CPU_FUNC_FP32(erf);
Expand Down Expand Up @@ -67,24 +67,18 @@ void __cinn_host_tanh_v(const cinn_buffer_t* x, cinn_buffer_t* out) {
float __cinn_host_ceil_fp32(float x) { return std::ceil(x); }
}

namespace cinn {
namespace runtime {
namespace cpu {
using backends::FunctionProto;

namespace {

bool RegisterRuntimeSymbols() {
auto host_target = common::DefaultHostTarget();
REGISTER_EXTERN_FUNC(host_intrinsics) {
auto host_target = cinn::common::DefaultHostTarget();
using cinn::backends::FunctionProto;

#define REGISTER_EXTERN_FUNC_ONE_IN_ONE_OUT_FLOAT(func__) \
REGISTER_EXTERN_FUNC_ONE_IN_ONE_OUT(cinn_cpu_##func__##_fp32, host_target, float, float);

#define REGISTER_EXTERN_FUNC_ONE_IN_ONE_OUT_INT(func__) \
REGISTER_EXTERN_FUNC_ONE_IN_ONE_OUT(cinn_cpu_##func__##_int, host_target, int, int);
REGISTER_EXTERN_FUNC_ONE_IN_ONE_OUT(cinn_cpu_##func__##_int32, host_target, int, int);

#define REGISTER_EXTERN_FUNC_TWO_IN_ONE_OUT_INT(func__) \
REGISTER_EXTERN_FUNC_TWO_IN_ONE_OUT(cinn_cpu_##func__##_int, host_target, int, int, int);
REGISTER_EXTERN_FUNC_TWO_IN_ONE_OUT(cinn_cpu_##func__##_int32, host_target, int, int, int);

REGISTER_EXTERN_FUNC_ONE_IN_ONE_OUT_FLOAT(exp);
REGISTER_EXTERN_FUNC_ONE_IN_ONE_OUT_FLOAT(erf);
Expand Down Expand Up @@ -131,13 +125,4 @@ bool RegisterRuntimeSymbols() {
.AddInputType<float>()
.SetShapeInference(FunctionProto::ShapeFollowNthArgument(0))
.End();

return true;
}

[[maybe_unused]] bool x = RegisterRuntimeSymbols();

} // namespace
} // namespace cpu
} // namespace runtime
} // namespace cinn
7 changes: 7 additions & 0 deletions cinn/runtime/cpu/host_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,11 @@ float __cinn_host_tanh_fp32(float x);
float __cinn_host_ceil_fp32(float x);
void __cinn_host_tanh_v(const cinn_buffer_t* x, cinn_buffer_t* out);
//@}

//! math map like functions
//@{
float cinn_cpu_tanh_v_fp32(float x);
float cinn_cpu_cos_v_fp32(float x);
float cinn_cpu_sin_v_fp32(float x);
//@}
}
54 changes: 54 additions & 0 deletions cinn/runtime/cpu/host_intrinsics_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "cinn/runtime/cpu/host_intrinsics.h"

#include <gtest/gtest.h>

#include "cinn/backends/compiler.h"
#include "cinn/backends/llvm/execution_engine.h"
#include "cinn/backends/llvm/simple_jit.h"
#include "cinn/cinn.h"
#include "cinn/common/ir_util.h"
#include "cinn/common/target.h"
#include "cinn/common/test_helper.h"
#include "cinn/runtime/cpu/use_extern_funcs.h"

namespace cinn {
namespace runtime {
namespace cpu {

TEST(tanh, basic) {
Expr M(10), N(20);
Placeholder<float> x("x", {M, N});
auto y = Compute({M, N}, [&](Expr i, Expr j) { return CallExtern("cinn_cpu_tanh_fp32", {x(i, j)}); });

auto jit = backends::SimpleJIT::Create();

lang::Module::Builder builder("module1", common::DefaultHostTarget());

auto fn = Lower("fn", {x, y});
LOG(INFO) << "fn:\n" << fn;

builder.AddFunction(fn);

jit->Link(builder.Build());

auto fn_ptr = jit->Lookup("fn");
auto fnp = reinterpret_cast<lower_func_ptr_t>(fn_ptr);
ASSERT_TRUE(fnp);

auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
auto* out_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build();
auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build();
fnp(args.data(), args.size());

auto* x_buf_data = reinterpret_cast<float*>(x_buf->memory);
auto* out_buf_data = reinterpret_cast<float*>(out_buf->memory);

for (int i = 0; i < x_buf->num_elements(); i++) {
LOG_FIRST_N(INFO, 3) << out_buf_data[i];
ASSERT_NEAR(out_buf_data[i], std::tanh(x_buf_data[i]), 1e-5);
}
}

} // namespace cpu
} // namespace runtime
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/runtime/cpu/mkl_math_test.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include <gtest/gtest.h>

#include "cinn/backends/compiler.h"
#include "cinn/backends/llvm/execution_engine.h"
#include "cinn/backends/llvm/simple_jit.h"
#include "cinn/cinn.h"
#include "cinn/common/ir_util.h"
#include "cinn/common/target.h"
#include "cinn/common/test_helper.h"
#include "cinn/runtime/cpu/host_intrinsics.h"
#include "cinn/runtime/cpu/use_extern_funcs.h"
Expand Down
2 changes: 2 additions & 0 deletions cinn/runtime/cpu/use_extern_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
#include "cinn/backends/extern_func_jit_register.h"

USE_EXTERN_FUNC(cinn_cpu_mkl_gemm_fp32);

USE_EXTERN_FUNC(host_intrinsics)

0 comments on commit 27e3c03

Please sign in to comment.