Skip to content

Commit

Permalink
[Refactor] Use wrapped create_call (#3421)
Browse files Browse the repository at this point in the history
* [Refactor] Use wrapped create_call

* Auto Format

* fix

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
sjwsl and taichi-gardener authored Nov 8, 2021
1 parent 0ea71b1 commit bb68caf
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 67 deletions.
56 changes: 22 additions & 34 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,47 +201,39 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {

auto op = stmt->op_type;

#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function("__nv_" #x "f"), input); \
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function("__nv_" #x), input); \
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \
llvm_val[stmt] = builder->CreateCall(get_runtime_function(#x), input); \
} else { \
TI_NOT_IMPLEMENTED \
} \
#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \
llvm_val[stmt] = create_call("__nv_" #x "f", input); \
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \
llvm_val[stmt] = create_call("__nv_" #x, input); \
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \
llvm_val[stmt] = create_call(#x, input); \
} else { \
TI_NOT_IMPLEMENTED \
} \
}
if (op == UnaryOpType::abs) {
if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_fabsf"), input);
llvm_val[stmt] = create_call("__nv_fabsf", input);
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_fabs"), input);
llvm_val[stmt] = create_call("__nv_fabs", input);
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_abs"), input);
llvm_val[stmt] = create_call("__nv_abs", input);
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::sqrt) {
if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_sqrtf"), input);
llvm_val[stmt] = create_call("__nv_sqrtf", input);
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_sqrt"), input);
llvm_val[stmt] = create_call("__nv_sqrt", input);
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::logic_not) {
if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("logic_not_i32"), input);
llvm_val[stmt] = create_call("logic_not_i32", input);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -366,11 +358,8 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
}
TI_ASSERT(atomics.at(prim_type).find(op) != atomics.at(prim_type).end());

return builder->CreateCall(
get_runtime_function(atomics.at(prim_type).at(op)),
{llvm_val[stmt->dest], llvm_val[stmt->val]});

return nullptr;
return create_call(atomics.at(prim_type).at(op),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
}

void visit(AtomicOpStmt *stmt) override {
Expand Down Expand Up @@ -590,10 +579,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
void visit(ExternalTensorShapeAlongAxisStmt *stmt) override {
const auto arg_id = stmt->arg_id;
const auto axis = stmt->axis;
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("Context_get_extra_args"),
{get_context(), tlctx->get_constant(arg_id),
tlctx->get_constant(axis)});
llvm_val[stmt] = create_call("Context_get_extra_args",
{get_context(), tlctx->get_constant(arg_id),
tlctx->get_constant(axis)});
}

void visit(BinaryOpStmt *stmt) override {
Expand Down
44 changes: 18 additions & 26 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,11 @@ void CodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) {
#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function(#x "_f32"), input); \
llvm_val[stmt] = create_call(#x "_f32", input); \
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function(#x "_f64"), input); \
llvm_val[stmt] = create_call(#x "_f64", input); \
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function(#x "_i32"), input); \
llvm_val[stmt] = create_call(#x "_i32", input); \
} else { \
TI_NOT_IMPLEMENTED \
} \
Expand Down Expand Up @@ -762,7 +759,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag,
value =
builder->CreateFPExt(value, tlctx->get_data_type(PrimitiveType::f64));
args.push_back(value);
return builder->CreateCall(runtime_printf, args);
return create_call(runtime_printf, args);
}

llvm::Value *CodeGenLLVM::create_print(std::string tag, llvm::Value *value) {
Expand Down Expand Up @@ -822,7 +819,7 @@ void CodeGenLLVM::visit(PrintStmt *stmt) {
args.insert(args.begin(),
builder->CreateGlobalStringPtr(formats.c_str(), "format_string"));

llvm_val[stmt] = builder->CreateCall(runtime_printf, args);
llvm_val[stmt] = create_call(runtime_printf, args);
}

void CodeGenLLVM::visit(ConstStmt *stmt) {
Expand Down Expand Up @@ -944,12 +941,12 @@ void CodeGenLLVM::emit_gc(OffloadedStmt *stmt) {
}

llvm::Value *CodeGenLLVM::create_call(llvm::Value *func,
std::vector<llvm::Value *> args) {
llvm::ArrayRef<llvm::Value *> args) {
check_func_call_signature(func, args);
return builder->CreateCall(func, args);
}
llvm::Value *CodeGenLLVM::create_call(std::string func_name,
std::vector<llvm::Value *> args) {
llvm::ArrayRef<llvm::Value *> args) {
auto func = get_runtime_function(func_name);
return create_call(func, args);
}
Expand Down Expand Up @@ -1090,8 +1087,7 @@ void CodeGenLLVM::visit(ReturnStmt *stmt) {
auto extended = builder->CreateZExt(
builder->CreateBitCast(llvm_val[stmt->value], intermediate_type),
dest_ty);
builder->CreateCall(get_runtime_function("LLVMRuntime_store_result"),
{get_runtime(), extended});
create_call("LLVMRuntime_store_result", {get_runtime(), extended});
}
}

Expand Down Expand Up @@ -1210,12 +1206,10 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) {
llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f32)) {
old_value =
builder->CreateCall(get_runtime_function("atomic_min_f32"),
old_value = create_call("atomic_min_f32",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f64)) {
old_value =
builder->CreateCall(get_runtime_function("atomic_min_f64"),
old_value = create_call("atomic_min_f64",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else {
TI_NOT_IMPLEMENTED
Expand All @@ -1226,12 +1220,10 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) {
llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f32)) {
old_value =
builder->CreateCall(get_runtime_function("atomic_max_f32"),
old_value = create_call("atomic_max_f32",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f64)) {
old_value =
builder->CreateCall(get_runtime_function("atomic_max_f64"),
old_value = create_call("atomic_max_f64",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else {
TI_NOT_IMPLEMENTED
Expand Down Expand Up @@ -1527,8 +1519,8 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
std::vector<llvm::Value *> sizes(num_indices);

for (int i = 0; i < num_indices; i++) {
auto raw_arg = builder->CreateCall(
get_runtime_function("Context_get_extra_args"),
auto raw_arg = create_call(
"Context_get_extra_args",
{get_context(), tlctx->get_constant(arg_id), tlctx->get_constant(i)});
sizes[i] = raw_arg;
}
Expand All @@ -1550,8 +1542,8 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
void CodeGenLLVM::visit(ExternalTensorShapeAlongAxisStmt *stmt) {
const auto arg_id = stmt->arg_id;
const auto axis = stmt->axis;
llvm_val[stmt] = builder->CreateCall(
get_runtime_function("Context_get_extra_args"),
llvm_val[stmt] = create_call(
"Context_get_extra_args",
{get_context(), tlctx->get_constant(arg_id), tlctx->get_constant(axis)});
}

Expand Down Expand Up @@ -2131,7 +2123,7 @@ void CodeGenLLVM::visit_call_bitcode(ExternalFuncCallStmt *stmt) {
arg_values[i] =
builder->CreatePointerCast(tmp_value, func_ptr->getArg(i)->getType());
}
builder->CreateCall(func_ptr, arg_values);
create_call(func_ptr, arg_values);
}

void CodeGenLLVM::visit_call_shared_object(ExternalFuncCallStmt *stmt) {
Expand Down Expand Up @@ -2159,7 +2151,7 @@ void CodeGenLLVM::visit_call_shared_object(ExternalFuncCallStmt *stmt) {

auto addr = tlctx->get_constant((std::size_t)stmt->so_func);
auto func = builder->CreateIntToPtr(addr, func_ptr_type);
builder->CreateCall(func, arg_values);
create_call(func, arg_values);
}

void CodeGenLLVM::visit(ExternalFuncCallStmt *stmt) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
void emit_gc(OffloadedStmt *stmt);

llvm::Value *create_call(llvm::Value *func,
std::vector<llvm::Value *> args = {});
llvm::ArrayRef<llvm::Value *> args = {});

llvm::Value *create_call(std::string func_name,
std::vector<llvm::Value *> args = {});
llvm::ArrayRef<llvm::Value *> args = {});
llvm::Value *call(SNode *snode,
llvm::Value *node_ptr,
const std::string &method,
Expand Down
5 changes: 0 additions & 5 deletions taichi/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@ std::string type_name(llvm::Type *type);
void check_func_call_signature(llvm::Value *func,
std::vector<llvm::Value *> arglist);

template <typename... Args>
inline bool check_func_call_signature(llvm::Value *func, Args &&... args) {
return check_func_call_signature(func, {args...});
}

class LLVMModuleBuilder {
public:
std::unique_ptr<llvm::Module> module{nullptr};
Expand Down

0 comments on commit bb68caf

Please sign in to comment.