Skip to content

Commit

Permalink
[llvm] Let real function support returning struct (taichi-dev#6614)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#602 taichi-dev#6590

### Brief Summary
Only supports scalar struct (every element in the struct is a scalar)
for now.

This PR does the following things:
1. Let `FuncCallStmt` return the `real_func_ret_struct *` result buffer
instead of returning the return value directly.
2. Add `GetElementStmt` and `GetElementExpression` to get the i-th
return value in a result buffer
3. Add `StructType.from_real_func_ret` to construct the returned struct
to the `StructType` in Python

Will add support for nested struct and matrix in struct in the following
PRs.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent ad7e166 commit 9572f9a
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 31 deletions.
19 changes: 11 additions & 8 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from taichi.lang.impl import current_cfg
from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector
from taichi.lang.snode import append, deactivate
from taichi.lang.struct import Struct, StructType
from taichi.lang.util import is_taichi_class, to_taichi_type
from taichi.types import (annotations, ndarray_type, primitive_types,
texture_type)
Expand Down Expand Up @@ -562,7 +563,11 @@ def build_FunctionDef(ctx, node):
def transform_as_kernel():
# Treat return type
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type)
if isinstance(ctx.func.return_type, StructType):
for tp in ctx.func.return_type.members.values():
kernel_arguments.decl_ret(tp)
else:
kernel_arguments.decl_ret(ctx.func.return_type)

for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation,
Expand Down Expand Up @@ -747,6 +752,11 @@ def build_Return(ctx, node):
ti_ops.cast(exp, ctx.func.return_type.dtype)
for exp in values
]))
elif isinstance(ctx.func.return_type, StructType):
values = node.value.ptr
assert isinstance(values, Struct)
ctx.ast_builder.create_kernel_exprgroup_return(
expr.make_expr_group(values._members))
else:
raise TaichiSyntaxError(
"The return type is not supported now!")
Expand Down Expand Up @@ -1401,13 +1411,6 @@ def build_If(ctx, node):
@staticmethod
def build_Expr(ctx, node):
build_stmt(ctx, node.value)
if not isinstance(node.value, ast.Call):
return None
is_taichi_function = getattr(node.value.func.ptr,
'_is_taichi_function', False)
if is_taichi_function and node.value.func.ptr._is_real_function:
func_call_result = node.value.ptr
ctx.ast_builder.insert_expr_stmt(func_call_result.ptr)
return None

@staticmethod
Expand Down
14 changes: 12 additions & 2 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from taichi.lang.enums import AutodiffMode, Layout
from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError,
TaichiRuntimeTypeError, TaichiSyntaxError,
handle_exception_from_cpp)
TaichiTypeError, handle_exception_from_cpp)
from taichi.lang.expr import Expr
from taichi.lang.kernel_arguments import KernelArgument
from taichi.lang.matrix import Matrix, MatrixType
from taichi.lang.shell import _shell_pop_print, oinspect
from taichi.lang.struct import StructType
from taichi.lang.util import (cook_dtype, has_paddle, has_pytorch,
to_taichi_type)
from taichi.types import (ndarray_type, primitive_types, sparse_matrix_builder,
Expand Down Expand Up @@ -264,9 +265,18 @@ def func_call_rvalue(self, key, args):
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args,
real_func_arg=True)
return Expr(
func_call = Expr(
_ti_core.make_func_call_expr(
self.taichi_functions[key.instance_id], non_template_args))
impl.get_runtime().prog.current_ast_builder().insert_expr_stmt(
func_call.ptr)
if self.return_type is None:
return None
if id(self.return_type) in primitive_types.type_ids:
return Expr(_ti_core.make_get_element_expr(func_call.ptr, 0))
if isinstance(self.return_type, StructType):
return self.return_type.from_real_func_ret(func_call)[0]
raise TaichiTypeError(f"Unsupported return type: {self.return_type}")

def do_compile(self, key, args):
tree, ctx = _get_tree_and_ctx(self,
Expand Down
18 changes: 17 additions & 1 deletion python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numbers
from types import MethodType

from taichi.lang import impl, ops
from taichi._lib import core as _ti_core
from taichi.lang import expr, impl, ops
from taichi.lang.common_ops import TaichiOperations
from taichi.lang.enums import Layout
from taichi.lang.exception import TaichiSyntaxError
Expand Down Expand Up @@ -682,6 +683,21 @@ def __call__(self, *args, **kwargs):
struct = self.cast(entries)
return struct

def from_real_func_ret(self, func_ret, ret_index=0):
d = {}
items = self.members.items()
for index, pair in enumerate(items):
name, dtype = pair
if isinstance(dtype, CompoundType):
d[name], ret_index = dtype.from_real_func_ret(
func_ret, ret_index)
else:
d[name] = expr.Expr(
_ti_core.make_get_element_expr(func_ret.ptr, ret_index))
ret_index += 1

return Struct(d), ret_index

def cast(self, struct):
# sanity check members
if self.members.keys() != struct.entries.keys():
Expand Down
53 changes: 42 additions & 11 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,18 @@ void TaskCodeGenLLVM::visit(ReturnStmt *stmt) {
if (std::any_of(types.begin(), types.end(),
[](const DataType &t) { return t.is_pointer(); })) {
TI_NOT_IMPLEMENTED
} else if (now_real_func) {
TI_ASSERT(stmt->values.size() == now_real_func->rets.size());
auto *result_buf = call("RuntimeContext_get_result_buffer", get_context());
auto *ret_type = get_real_func_ret_type(now_real_func);
result_buf = builder->CreatePointerCast(
result_buf, llvm::PointerType::get(ret_type, 0));
for (int i = 0; i < stmt->values.size(); i++) {
auto *gep =
builder->CreateGEP(ret_type, result_buf,
{tlctx->get_constant(0), tlctx->get_constant(i)});
builder->CreateStore(llvm_val[stmt->values[i]], gep);
}
} else {
TI_ASSERT(stmt->values.size() <= taichi_max_num_ret_value);
int idx{0};
Expand Down Expand Up @@ -2695,8 +2707,11 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) {
auto guard = get_function_creation_guard(
{llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0)},
stmt->func->get_name());
Function *old_real_func = now_real_func;
now_real_func = stmt->func;
func_map.insert({stmt->func, guard.body});
stmt->func->ir->accept(this);
now_real_func = old_real_func;
}
llvm::Function *llvm_func = func_map[stmt->func];
auto *new_ctx = call("allocate_runtime_context", get_runtime());
Expand All @@ -2708,20 +2723,36 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) {
llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val);
}
llvm::Value *result_buffer = nullptr;
if (stmt->ret_type->is<PrimitiveType>() &&
!stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) {
result_buffer = builder->CreateAlloca(tlctx->get_data_type<uint64>());
call("RuntimeContext_set_result_buffer", new_ctx, result_buffer);
call(llvm_func, new_ctx);
auto *ret_val_u64 =
builder->CreateLoad(builder->getInt64Ty(), result_buffer);
llvm_val[stmt] = bitcast_from_u64(ret_val_u64, stmt->ret_type);
} else {
call(llvm_func, new_ctx);
}
auto *ret_type = get_real_func_ret_type(stmt->func);
result_buffer = builder->CreateAlloca(ret_type);
auto *result_buffer_u64 = builder->CreatePointerCast(
result_buffer, llvm::PointerType::get(tlctx->get_data_type<uint64>(), 0));
call("RuntimeContext_set_result_buffer", new_ctx, result_buffer_u64);
call(llvm_func, new_ctx);
llvm_val[stmt] = result_buffer;
call("recycle_runtime_context", get_runtime(), new_ctx);
}

void TaskCodeGenLLVM::visit(GetElementStmt *stmt) {
auto *real_func = stmt->src->as<FuncCallStmt>()->func;
auto &rets = real_func->rets;
auto *ret_type = get_real_func_ret_type(real_func);
auto *gep = builder->CreateGEP(
ret_type, llvm_val[stmt->src],
{tlctx->get_constant(0), tlctx->get_constant(stmt->index)});
auto *val =
builder->CreateLoad(tlctx->get_data_type(rets[stmt->index].dt), gep);
llvm_val[stmt] = val;
}

llvm::Type *TaskCodeGenLLVM::get_real_func_ret_type(Function *real_func) {
std::vector<llvm::Type *> tps;
for (auto &ret : real_func->rets) {
tps.push_back(tlctx->get_data_type(ret.dt));
}
return llvm::StructType::get(*llvm_context, tps);
}

LLVMCompiledTask LLVMCompiledTask::clone() const {
return {tasks, llvm::CloneModule(*module), used_tree_ids,
struct_for_tls_sizes};
Expand Down
5 changes: 5 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
bool returned{false};
std::unordered_set<int> used_tree_ids;
std::unordered_set<int> struct_for_tls_sizes;
Function *now_real_func{nullptr};

std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

Expand Down Expand Up @@ -94,6 +95,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Type *get_mesh_xlogue_function_type();

llvm::Type *get_real_func_ret_type(Function *real_func);

llvm::Value *get_root(int snode_tree_id);

llvm::Value *get_runtime();
Expand Down Expand Up @@ -393,6 +396,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(FuncCallStmt *stmt) override;

void visit(GetElementStmt *stmt) override;

llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type);
llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);

Expand Down
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ PER_STATEMENT(WhileStmt)
PER_STATEMENT(WhileControlStmt)
PER_STATEMENT(ContinueStmt)
PER_STATEMENT(FuncCallStmt)
PER_STATEMENT(GetElementStmt)
PER_STATEMENT(ReturnStmt)

PER_STATEMENT(ArgLoadStmt)
Expand Down
19 changes: 14 additions & 5 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,11 +1114,8 @@ void FuncCallExpression::type_check(CompileConfig *) {
TI_ASSERT_TYPE_CHECKED(arg);
// no arg type compatibility check for now due to lack of specification
}
TI_ASSERT_INFO(func->rets.size() <= 1,
"Too many (> 1) return values for FuncCallExpression");
if (func->rets.size() == 1) {
ret_type = func->rets[0].dt;
}
ret_type = PrimitiveType::u64;
ret_type.set_is_pointer(true);
}

void FuncCallExpression::flatten(FlattenContext *ctx) {
Expand All @@ -1131,6 +1128,18 @@ void FuncCallExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void GetElementExpression::type_check(CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(src);
auto func_call = src.cast<FuncCallExpression>();
TI_ASSERT(func_call);
TI_ASSERT(index < func_call->func->rets.size());
ret_type = func_call->func->rets[index].dt;
}

void GetElementExpression::flatten(FlattenContext *ctx) {
ctx->push_back<GetElementStmt>(src->stmt, index);
stmt = ctx->back_stmt();
}
// Mesh related.

void MeshPatchIndexExpression::flatten(FlattenContext *ctx) {
Expand Down
15 changes: 15 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,21 @@ class FuncCallExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class GetElementExpression : public Expression {
public:
Expr src;
int index;

void type_check(CompileConfig *config) override;

GetElementExpression(const Expr &src, int index) : src(src), index(index) {
}

void flatten(FlattenContext *ctx) override;

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

// Mesh related.

class MeshPatchIndexExpression : public Expression {
Expand Down
15 changes: 15 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,21 @@ class ReferenceStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* Gets an element from a struct
*/
class GetElementStmt : public Stmt {
public:
Stmt *src;
int index;
GetElementStmt(Stmt *src, int index) : src(src), index(index) {
TI_STMT_REG_FIELDS;
}

TI_STMT_DEF_FIELDS(ret_type, src, index);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* Exit the kernel or function with a return value.
*/
Expand Down
3 changes: 3 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,9 @@ void export_lang(py::module &m) {
m.def("make_func_call_expr",
Expr::make<FuncCallExpression, Function *, const ExprGroup &>);

m.def("make_get_element_expr",
Expr::make<GetElementExpression, const Expr &, int>);

m.def("value_cast", static_cast<Expr (*)(const Expr &expr, DataType)>(cast));
m.def("bits_cast",
static_cast<Expr (*)(const Expr &expr, DataType)>(bit_cast));
Expand Down
13 changes: 9 additions & 4 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,15 @@ class TypeCheck : public IRVisitor {
void visit(FuncCallStmt *stmt) override {
auto *func = stmt->func;
TI_ASSERT(func);
TI_ASSERT(func->rets.size() <= 1);
if (func->rets.size() == 1) {
stmt->ret_type = func->rets[0].dt;
}
stmt->ret_type = PrimitiveType::u64;
stmt->ret_type.set_is_pointer(true);
}

void visit(GetElementStmt *stmt) override {
TI_ASSERT(stmt->src->is<FuncCallStmt>());
auto *func = stmt->src->as<FuncCallStmt>()->func;
TI_ASSERT(stmt->index < func->rets.size());
stmt->ret_type = func->rets[stmt->index].dt;
}

void visit(ArgLoadStmt *stmt) override {
Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,19 @@ def test_real_func_matrix_arg():
real_matrix_scalarize=True)
def test_real_func_matrix_arg_real_matrix():
_test_real_func_matrix_arg()


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_struct_ret():
s = ti.types.struct(a=ti.i16, b=ti.f64)

@ti.experimental.real_func
def bar() -> s:
return s(a=123, b=ti.f64(1.2345e300))

@ti.kernel
def foo() -> ti.f64:
a = bar()
return a.a * a.b

assert foo() == pytest.approx(123 * 1.2345e300)

0 comments on commit 9572f9a

Please sign in to comment.