Skip to content

Commit

Permalink
[ir] Let real function return nested StructType
Browse files Browse the repository at this point in the history
ghstack-source-id: c7345fbae605793d13fd4c358e65ab20815c9155
Pull Request resolved: taichi-dev#7059
  • Loading branch information
lin-hitonami authored and quadpixels committed May 13, 2023
1 parent 8ef64dc commit 6d30f12
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 38 deletions.
4 changes: 1 addition & 3 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def decl_rw_texture_arg(num_dimensions, num_channels, channel_format, lod):

def decl_ret(dtype, real_func=False):
if isinstance(dtype, StructType):
for member in dtype.members.values():
decl_ret(member, real_func)
return
dtype = dtype.dtype
if isinstance(dtype, MatrixType):
if real_func:
for i in range(dtype.n * dtype.m):
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def func_call_rvalue(self, key, args):
if id(self.return_type) in primitive_types.type_ids:
return Expr(_ti_core.make_get_element_expr(func_call.ptr, (0, )))
if isinstance(self.return_type, StructType):
return self.return_type.from_real_func_ret(func_call)[0]
return self.return_type.from_real_func_ret(func_call, (0, ))
raise TaichiTypeError(f"Unsupported return type: {self.return_type}")

def do_compile(self, key, args):
Expand Down
11 changes: 6 additions & 5 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,12 +1509,13 @@ def __call__(self, *args):
# type cast
return self.cast(Matrix(entries, dt=self.dtype, ndim=self.ndim))

def from_real_func_ret(self, func_ret, ret_index=0):
def from_real_func_ret(self, func_ret, ret_index=()):
return self([
expr.Expr(ti_python_core.make_get_element_expr(
func_ret.ptr, (i, )))
for i in range(ret_index, ret_index + self.m * self.n)
]), ret_index + self.m * self.n
expr.Expr(
ti_python_core.make_get_element_expr(func_ret.ptr,
ret_index + (i, )))
for i in range(self.m * self.n)
])

def cast(self, mat):
if in_python_scope():
Expand Down
11 changes: 5 additions & 6 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,21 +704,20 @@ def __call__(self, *args, **kwargs):
struct = self.cast(entries)
return struct

def from_real_func_ret(self, func_ret, ret_index=0):
def from_real_func_ret(self, func_ret, ret_index=()):
d = {}
items = self.members.items()
for index, pair in enumerate(items):
name, dtype = pair
if isinstance(dtype, CompoundType):
d[name], ret_index = dtype.from_real_func_ret(
func_ret, ret_index)
d[name] = dtype.from_real_func_ret(func_ret,
ret_index + (index, ))
else:
d[name] = expr.Expr(
_ti_core.make_get_element_expr(func_ret.ptr,
(ret_index, )))
ret_index += 1
ret_index + (index, )))

return Struct(d), ret_index
return Struct(d)

def cast(self, struct):
# sanity check members
Expand Down
68 changes: 48 additions & 20 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,18 +1271,10 @@ void TaskCodeGenLLVM::visit(ReturnStmt *stmt) {
if (std::any_of(types.begin(), types.end(),
[](const DataType &t) { return t.is_pointer(); })) {
TI_NOT_IMPLEMENTED
} else if (now_real_func) {
TI_ASSERT(stmt->values.size() == now_real_func->rets.size());
auto *result_buf = call("RuntimeContext_get_result_buffer", get_context());
auto *ret_type = get_real_func_ret_type(now_real_func);
result_buf = builder->CreatePointerCast(
result_buf, llvm::PointerType::get(ret_type, 0));
for (int i = 0; i < stmt->values.size(); i++) {
auto *gep =
builder->CreateGEP(ret_type, result_buf,
{tlctx->get_constant(0), tlctx->get_constant(i)});
builder->CreateStore(llvm_val[stmt->values[i]], gep);
}
} else if (current_real_func) {
TI_ASSERT(stmt->values.size() ==
current_real_func->ret_type->get_num_elements());
create_return(stmt->values);
} else {
TI_ASSERT(stmt->values.size() <= taichi_max_num_ret_value);
int idx{0};
Expand Down Expand Up @@ -2707,11 +2699,11 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) {
auto guard = get_function_creation_guard(
{llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0)},
stmt->func->get_name());
Function *old_real_func = now_real_func;
now_real_func = stmt->func;
Function *old_real_func = current_real_func;
current_real_func = stmt->func;
func_map.insert({stmt->func, guard.body});
stmt->func->ir->accept(this);
now_real_func = old_real_func;
current_real_func = old_real_func;
}
llvm::Function *llvm_func = func_map[stmt->func];
auto *new_ctx = call("allocate_runtime_context", get_runtime());
Expand Down Expand Up @@ -2749,14 +2741,50 @@ void TaskCodeGenLLVM::visit(GetElementStmt *stmt) {
llvm_val[stmt] = val;
}

llvm::Type *TaskCodeGenLLVM::get_real_func_ret_type(Function *real_func) {
std::vector<llvm::Type *> tps;
for (auto &ret : real_func->rets) {
tps.push_back(tlctx->get_data_type(ret.dt));
void TaskCodeGenLLVM::create_return(llvm::Value *buffer,
llvm::Type *buffer_type,
const std::vector<Stmt *> &elements,
const Type *current_type,
int &current_element,
std::vector<llvm::Value *> &current_index) {
if (auto primitive_type = current_type->cast<PrimitiveType>()) {
TI_ASSERT((Type *)elements[current_element]->ret_type == current_type);
auto *gep = builder->CreateGEP(buffer_type, buffer, current_index);
builder->CreateStore(llvm_val[elements[current_element]], gep);
current_element++;
} else if (auto struct_type = current_type->cast<StructType>()) {
int i = 0;
for (const auto &element_type : struct_type->elements()) {
current_index.push_back(tlctx->get_constant(i++));
create_return(buffer, buffer_type, elements, element_type,
current_element, current_index);
current_index.pop_back();
}
} else {
auto tensor_type = current_type->as<TensorType>();
int num_elements = tensor_type->get_num_elements();
Type *element_type = tensor_type->get_element_type();
for (int i = 0; i < num_elements; i++) {
current_index.push_back(tlctx->get_constant(i));
create_return(buffer, buffer_type, elements, element_type,
current_element, current_index);
current_index.pop_back();
}
}
return llvm::StructType::get(*llvm_context, tps);
}

void TaskCodeGenLLVM::create_return(const std::vector<Stmt *> &elements) {
auto buffer = call("RuntimeContext_get_result_buffer", get_context());
auto ret_type = current_real_func->ret_type;
auto buffer_type = tlctx->get_data_type(ret_type);
buffer = builder->CreatePointerCast(buffer,
llvm::PointerType::get(buffer_type, 0));
int current_element = 0;
std::vector<llvm::Value *> current_index = {tlctx->get_constant(0)};
create_return(buffer, buffer_type, elements, ret_type, current_element,
current_index);
};

LLVMCompiledTask LLVMCompiledTask::clone() const {
return {tasks, llvm::CloneModule(*module), used_tree_ids,
struct_for_tls_sizes};
Expand Down
14 changes: 11 additions & 3 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
bool returned{false};
std::unordered_set<int> used_tree_ids;
std::unordered_set<int> struct_for_tls_sizes;
Function *now_real_func{nullptr};
Function *current_real_func{nullptr};

std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

Expand Down Expand Up @@ -95,8 +95,6 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Type *get_mesh_xlogue_function_type();

llvm::Type *get_real_func_ret_type(Function *real_func);

llvm::Value *get_root(int snode_tree_id);

llvm::Value *get_runtime();
Expand Down Expand Up @@ -138,6 +136,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *create_print(std::string tag, llvm::Value *value);

void create_return(const std::vector<Stmt *> &elements);

llvm::Value *cast_pointer(llvm::Value *val,
std::string dest_ty_name,
int addr_space = 0);
Expand Down Expand Up @@ -402,6 +402,14 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);

~TaskCodeGenLLVM() override = default;

private:
void create_return(llvm::Value *buffer,
llvm::Type *buffer_type,
const std::vector<Stmt *> &elements,
const Type *current_type,
int &current_element,
std::vector<llvm::Value *> &current_index);
};

} // namespace taichi::lang
Expand Down

0 comments on commit 6d30f12

Please sign in to comment.