Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Replace FuncCallExpression with FrontendFuncCallStmt #7027

Merged
merged 13 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,12 @@ def func_call_rvalue(self, key, args):
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args,
real_func_arg=True)
func_call = Expr(
_ti_core.make_func_call_expr(
self.taichi_functions[key.instance_id], non_template_args))
impl.get_runtime().prog.current_ast_builder().insert_expr_stmt(
func_call.ptr)
func_call = impl.get_runtime().prog.current_ast_builder(
).insert_func_call(self.taichi_functions[key.instance_id],
non_template_args)
if self.return_type is None:
return None
func_call = Expr(func_call)
if id(self.return_type) in primitive_types.type_ids:
return Expr(_ti_core.make_get_element_expr(func_call.ptr, (0, )))
if isinstance(self.return_type, StructType):
Expand Down
4 changes: 2 additions & 2 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->axis);
}

void visit(FuncCallExpression *expr) override {
emit(ExprOpCode::FuncCallExpression);
void visit(FrontendFuncCallStmt *expr) override {
emit(StmtOpCode::FrontendFuncCallStmt);
emit(expr->func);
emit(expr->args.exprs);
}
Expand Down
19 changes: 10 additions & 9 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2723,27 +2723,28 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) {
llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val);
}
llvm::Value *result_buffer = nullptr;
auto *ret_type = get_real_func_ret_type(stmt->func);
result_buffer = builder->CreateAlloca(ret_type);
auto *result_buffer_u64 = builder->CreatePointerCast(
result_buffer, llvm::PointerType::get(tlctx->get_data_type<uint64>(), 0));
call("RuntimeContext_set_result_buffer", new_ctx, result_buffer_u64);
if (stmt->ret_type) {
auto *ret_type = tlctx->get_data_type(stmt->ret_type);
result_buffer = builder->CreateAlloca(ret_type);
auto *result_buffer_u64 = builder->CreatePointerCast(
result_buffer,
llvm::PointerType::get(tlctx->get_data_type<uint64>(), 0));
call("RuntimeContext_set_result_buffer", new_ctx, result_buffer_u64);
}
call(llvm_func, new_ctx);
llvm_val[stmt] = result_buffer;
call("recycle_runtime_context", get_runtime(), new_ctx);
}

void TaskCodeGenLLVM::visit(GetElementStmt *stmt) {
auto *real_func = stmt->src->as<FuncCallStmt>()->func;
auto *real_func_ret_type = tlctx->get_data_type(real_func->ret_type);
auto *struct_type = tlctx->get_data_type(stmt->src->ret_type);
std::vector<llvm::Value *> index;
index.reserve(stmt->index.size() + 1);
index.push_back(tlctx->get_constant(0));
for (auto &i : stmt->index) {
index.push_back(tlctx->get_constant(i));
}
auto *gep =
builder->CreateGEP(real_func_ret_type, llvm_val[stmt->src], index);
auto *gep = builder->CreateGEP(struct_type, llvm_val[stmt->src], index);
auto *val = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), gep);
llvm_val[stmt] = val;
}
Expand Down
1 change: 0 additions & 1 deletion taichi/inc/expressions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ PER_EXPRESSION(AtomicOpExpression)
PER_EXPRESSION(SNodeOpExpression)
PER_EXPRESSION(ConstExpression)
PER_EXPRESSION(ExternalTensorShapeAlongAxisExpression)
PER_EXPRESSION(FuncCallExpression)
PER_EXPRESSION(MeshPatchIndexExpression)
PER_EXPRESSION(MeshRelationAccessExpression)
PER_EXPRESSION(MeshIndexConversionExpression)
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/frontend_statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
PER_STATEMENT(FrontendFuncDefStmt)
PER_STATEMENT(FrontendReturnStmt)
PER_STATEMENT(FrontendFuncCallStmt)
6 changes: 0 additions & 6 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,6 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
emit(", ", expr->axis, ')');
}

void visit(FuncCallExpression *expr) override {
emit("func_call(\"", expr->func->func_key.get_full_name(), "\", ");
emit_vector(expr->args.exprs);
emit(')');
}

void visit(MeshPatchIndexExpression *expr) override {
emit("mesh_patch_idx()");
}
Expand Down
43 changes: 17 additions & 26 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,37 +1150,14 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void FuncCallExpression::type_check(CompileConfig *) {
for (auto &arg : args.exprs) {
TI_ASSERT_TYPE_CHECKED(arg);
// no arg type compatibility check for now due to lack of specification
}
ret_type = PrimitiveType::u64;
ret_type.set_is_pointer(true);
}

void FuncCallExpression::flatten(FlattenContext *ctx) {
std::vector<Stmt *> stmt_args;
for (auto &arg : args.exprs) {
stmt_args.push_back(flatten_rvalue(arg, ctx));
}
ctx->push_back<FuncCallStmt>(func, stmt_args);
stmt = ctx->back_stmt();
}

void GetElementExpression::type_check(CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(src);
auto func_call = src.cast<FuncCallExpression>();
TI_ASSERT(func_call);
// The return values are flattened now,
// so the length of stmt->index is 1.
// Will be refactored soon.
TI_ASSERT(index[0] < func_call->func->rets.size());
ret_type = func_call->func->rets[index[0]].dt;

ret_type = src->ret_type->as<StructType>()->get_element_type(index);
}

void GetElementExpression::flatten(FlattenContext *ctx) {
ctx->push_back<GetElementStmt>(src->get_flattened_stmt(), index);
ctx->push_back<GetElementStmt>(flatten_rvalue(src, ctx), index);
stmt = ctx->back_stmt();
}
// Mesh related.
Expand Down Expand Up @@ -1391,6 +1368,20 @@ Expr ASTBuilder::expr_alloca() {
return var;
}

std::optional<Expr> ASTBuilder::insert_func_call(Function *func,
const ExprGroup &args) {
if (func->ret_type) {
auto var = Expr(std::make_shared<IdExpression>(get_next_id()));
this->insert(std::make_unique<FrontendFuncCallStmt>(
func, args, std::static_pointer_cast<IdExpression>(var.expr)->id));
var.expr->ret_type = func->ret_type;
return var;
} else {
this->insert(std::make_unique<FrontendFuncCallStmt>(func, args));
return std::nullopt;
}
}

Expr ASTBuilder::make_matrix_expr(const std::vector<int> &shape,
const DataType &dt,
const std::vector<Expr> &elements) {
Expand Down
20 changes: 13 additions & 7 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,20 +774,25 @@ class ExternalTensorShapeAlongAxisExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class FuncCallExpression : public Expression {
class FrontendFuncCallStmt : public Stmt {
public:
std::optional<Identifier> ident;
Function *func;
ExprGroup args;

void type_check(CompileConfig *config) override;

FuncCallExpression(Function *func, const ExprGroup &args)
: func(func), args(args) {
explicit FrontendFuncCallStmt(
Function *func,
const ExprGroup &args,
const std::optional<Identifier> &id = std::nullopt)
: ident(id), func(func), args(args) {
TI_ASSERT(id.has_value() == !func->rets.empty());
}

void flatten(FlattenContext *ctx) override;
bool is_container_statement() const override {
return false;
}

TI_DEFINE_ACCEPT_FOR_EXPRESSION
TI_DEFINE_ACCEPT
};

class GetElementExpression : public Expression {
Expand Down Expand Up @@ -962,6 +967,7 @@ class ASTBuilder {
mesh::ConvType &conv_type);

void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb);
std::optional<Expr> insert_func_call(Function *func, const ExprGroup &args);
void create_assert_stmt(const Expr &cond,
const std::string &msg,
const std::vector<Expr> &args);
Expand Down
3 changes: 3 additions & 0 deletions taichi/program/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ int Callable::insert_texture_arg(const DataType &dt) {
}

void Callable::finalize_rets() {
if (rets.empty()) {
return;
}
std::vector<const Type *> types;
types.reserve(rets.size());
for (const auto &ret : rets) {
Expand Down
4 changes: 1 addition & 3 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ void export_lang(py::module &m) {
.def("expand_exprs", &ASTBuilder::expand_exprs)
.def("mesh_index_conversion", &ASTBuilder::mesh_index_conversion)
.def("expr_subscript", &ASTBuilder::expr_subscript)
.def("insert_func_call", &ASTBuilder::insert_func_call)
.def("sifakis_svd_f32", sifakis_svd_export<float32, int32>)
.def("sifakis_svd_f64", sifakis_svd_export<float64, int64>)
.def("expr_var", &ASTBuilder::make_var)
Expand Down Expand Up @@ -823,9 +824,6 @@ void export_lang(py::module &m) {
with_runtime_context);
});

m.def("make_func_call_expr",
Expr::make<FuncCallExpression, Function *, const ExprGroup &>);

m.def("make_get_element_expr",
Expr::make<GetElementExpression, const Expr &, std::vector<int>>);

Expand Down
12 changes: 12 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,18 @@ class IRPrinter : public IRVisitor {
}
}

void visit(FrontendFuncCallStmt *stmt) override {
std::string args;
for (int i = 0; i < stmt->args.exprs.size(); i++) {
if (i) {
args += ", ";
}
args += expr_to_string(stmt->args.exprs[i]);
}
print("{}${} = call \"{}\", args = ({}), ret = {}", stmt->type_hint(),
stmt->id, stmt->func->get_name(), args, stmt->ident->name());
}

void visit(FuncCallStmt *stmt) override {
std::vector<std::string> args;
for (const auto &arg : stmt->args) {
Expand Down
17 changes: 17 additions & 0 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ class LowerAST : public IRVisitor {
}
}

void visit(FrontendFuncCallStmt *stmt) override {
Block *block = stmt->parent;
std::vector<Stmt *> args;
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
args.reserve(stmt->args.exprs.size());
auto fctx = make_flatten_ctx();
for (const auto &arg : stmt->args.exprs) {
args.push_back(flatten_rvalue(arg, &fctx));
}
auto lowered = fctx.push_back<FuncCallStmt>(stmt->func, args);
stmt->parent->replace_with(stmt, std::move(fctx.stmts));
if (const auto &ident = stmt->ident) {
TI_ASSERT(block->local_var_to_stmt.find(ident.value()) ==
block->local_var_to_stmt.end());
block->local_var_to_stmt.insert(std::make_pair(ident.value(), lowered));
}
}

void visit(FrontendIfStmt *stmt) override {
auto fctx = make_flatten_ctx();
auto condition_stmt = flatten_rvalue(stmt->condition, &fctx);
Expand Down
18 changes: 9 additions & 9 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,18 +407,18 @@ class TypeCheck : public IRVisitor {
void visit(FuncCallStmt *stmt) override {
auto *func = stmt->func;
TI_ASSERT(func);
stmt->ret_type = PrimitiveType::u64;
stmt->ret_type.set_is_pointer(true);
stmt->ret_type = func->ret_type;
}

void visit(FrontendFuncCallStmt *stmt) override {
auto *func = stmt->func;
TI_ASSERT(func);
stmt->ret_type = func->ret_type;
}

void visit(GetElementStmt *stmt) override {
TI_ASSERT(stmt->src->is<FuncCallStmt>());
auto *func = stmt->src->as<FuncCallStmt>()->func;
// The return values are flattened now,
// so the length of stmt->index is 1.
// Will be refactored soon.
TI_ASSERT(stmt->index[0] < func->rets.size());
stmt->ret_type = func->rets[stmt->index[0]].dt;
stmt->ret_type =
stmt->src->ret_type->as<StructType>()->get_element_type(stmt->index);
}

void visit(ArgLoadStmt *stmt) override {
Expand Down