Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Let real function return nested StructType #7059

Merged
merged 7 commits into from
Jan 6, 2023
4 changes: 1 addition & 3 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def decl_rw_texture_arg(num_dimensions, num_channels, channel_format, lod):

def decl_ret(dtype, real_func=False):
if isinstance(dtype, StructType):
for member in dtype.members.values():
decl_ret(member, real_func)
return
dtype = dtype.dtype
if isinstance(dtype, MatrixType):
if real_func:
for i in range(dtype.n * dtype.m):
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def func_call_rvalue(self, key, args):
if id(self.return_type) in primitive_types.type_ids:
return Expr(_ti_core.make_get_element_expr(func_call.ptr, (0, )))
if isinstance(self.return_type, StructType):
return self.return_type.from_real_func_ret(func_call)[0]
return self.return_type.from_real_func_ret(func_call, (0, ))
raise TaichiTypeError(f"Unsupported return type: {self.return_type}")

def do_compile(self, key, args):
Expand Down
11 changes: 6 additions & 5 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,12 +1509,13 @@ def __call__(self, *args):
# type cast
return self.cast(Matrix(entries, dt=self.dtype, ndim=self.ndim))

def from_real_func_ret(self, func_ret, ret_index=0):
def from_real_func_ret(self, func_ret, ret_index=()):
return self([
expr.Expr(ti_python_core.make_get_element_expr(
func_ret.ptr, (i, )))
for i in range(ret_index, ret_index + self.m * self.n)
]), ret_index + self.m * self.n
expr.Expr(
ti_python_core.make_get_element_expr(func_ret.ptr,
ret_index + (i, )))
for i in range(self.m * self.n)
])

def cast(self, mat):
if in_python_scope():
Expand Down
11 changes: 5 additions & 6 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,21 +704,20 @@ def __call__(self, *args, **kwargs):
struct = self.cast(entries)
return struct

def from_real_func_ret(self, func_ret, ret_index=0):
def from_real_func_ret(self, func_ret, ret_index=()):
d = {}
items = self.members.items()
for index, pair in enumerate(items):
name, dtype = pair
if isinstance(dtype, CompoundType):
d[name], ret_index = dtype.from_real_func_ret(
func_ret, ret_index)
d[name] = dtype.from_real_func_ret(func_ret,
ret_index + (index, ))
else:
d[name] = expr.Expr(
_ti_core.make_get_element_expr(func_ret.ptr,
(ret_index, )))
ret_index += 1
ret_index + (index, )))

return Struct(d), ret_index
return Struct(d)

def cast(self, struct):
# sanity check members
Expand Down
68 changes: 48 additions & 20 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,18 +1271,10 @@ void TaskCodeGenLLVM::visit(ReturnStmt *stmt) {
if (std::any_of(types.begin(), types.end(),
[](const DataType &t) { return t.is_pointer(); })) {
TI_NOT_IMPLEMENTED
} else if (now_real_func) {
TI_ASSERT(stmt->values.size() == now_real_func->rets.size());
auto *result_buf = call("RuntimeContext_get_result_buffer", get_context());
auto *ret_type = get_real_func_ret_type(now_real_func);
result_buf = builder->CreatePointerCast(
result_buf, llvm::PointerType::get(ret_type, 0));
for (int i = 0; i < stmt->values.size(); i++) {
auto *gep =
builder->CreateGEP(ret_type, result_buf,
{tlctx->get_constant(0), tlctx->get_constant(i)});
builder->CreateStore(llvm_val[stmt->values[i]], gep);
}
} else if (current_real_func) {
TI_ASSERT(stmt->values.size() ==
current_real_func->ret_type->get_num_elements());
create_return(stmt->values);
} else {
TI_ASSERT(stmt->values.size() <= taichi_max_num_ret_value);
int idx{0};
Expand Down Expand Up @@ -2707,11 +2699,11 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) {
auto guard = get_function_creation_guard(
{llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0)},
stmt->func->get_name());
Function *old_real_func = now_real_func;
now_real_func = stmt->func;
Function *old_real_func = current_real_func;
current_real_func = stmt->func;
func_map.insert({stmt->func, guard.body});
stmt->func->ir->accept(this);
now_real_func = old_real_func;
current_real_func = old_real_func;
}
llvm::Function *llvm_func = func_map[stmt->func];
auto *new_ctx = call("allocate_runtime_context", get_runtime());
Expand Down Expand Up @@ -2749,14 +2741,50 @@ void TaskCodeGenLLVM::visit(GetElementStmt *stmt) {
llvm_val[stmt] = val;
}

llvm::Type *TaskCodeGenLLVM::get_real_func_ret_type(Function *real_func) {
std::vector<llvm::Type *> tps;
for (auto &ret : real_func->rets) {
tps.push_back(tlctx->get_data_type(ret.dt));
void TaskCodeGenLLVM::create_return(llvm::Value *buffer,
llvm::Type *buffer_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we obtain buffer_type from buffer->getType()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No... While it works fine on linux machines, it doesn't work on Windows because the pointers are opaque pointers, and buffer->getType() only returns ptr in LLVM 15. I don't know why it works on linux machines... maybe the LLVM binaries are different....

const std::vector<Stmt *> &elements,
const Type *current_type,
int &current_element,
std::vector<llvm::Value *> &current_index) {
if (auto primitive_type = current_type->cast<PrimitiveType>()) {
TI_ASSERT((Type *)elements[current_element]->ret_type == current_type);
auto *gep = builder->CreateGEP(buffer_type, buffer, current_index);
builder->CreateStore(llvm_val[elements[current_element]], gep);
current_element++;
} else if (auto struct_type = current_type->cast<StructType>()) {
int i = 0;
for (const auto &element_type : struct_type->elements()) {
current_index.push_back(tlctx->get_constant(i++));
create_return(buffer, buffer_type, elements, element_type,
current_element, current_index);
current_index.pop_back();
}
} else {
auto tensor_type = current_type->as<TensorType>();
int num_elements = tensor_type->get_num_elements();
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
Type *element_type = tensor_type->get_element_type();
for (int i = 0; i < num_elements; i++) {
current_index.push_back(tlctx->get_constant(i));
create_return(buffer, buffer_type, elements, element_type,
current_element, current_index);
current_index.pop_back();
}
}
return llvm::StructType::get(*llvm_context, tps);
}

void TaskCodeGenLLVM::create_return(const std::vector<Stmt *> &elements) {
auto buffer = call("RuntimeContext_get_result_buffer", get_context());
auto ret_type = current_real_func->ret_type;
auto buffer_type = tlctx->get_data_type(ret_type);
buffer = builder->CreatePointerCast(buffer,
llvm::PointerType::get(buffer_type, 0));
int current_element = 0;
std::vector<llvm::Value *> current_index = {tlctx->get_constant(0)};
create_return(buffer, buffer_type, elements, ret_type, current_element,
current_index);
};

LLVMCompiledTask LLVMCompiledTask::clone() const {
return {tasks, llvm::CloneModule(*module), used_tree_ids,
struct_for_tls_sizes};
Expand Down
14 changes: 11 additions & 3 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
bool returned{false};
std::unordered_set<int> used_tree_ids;
std::unordered_set<int> struct_for_tls_sizes;
Function *now_real_func{nullptr};
Function *current_real_func{nullptr};

std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

Expand Down Expand Up @@ -95,8 +95,6 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Type *get_mesh_xlogue_function_type();

llvm::Type *get_real_func_ret_type(Function *real_func);

llvm::Value *get_root(int snode_tree_id);

llvm::Value *get_runtime();
Expand Down Expand Up @@ -138,6 +136,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *create_print(std::string tag, llvm::Value *value);

void create_return(const std::vector<Stmt *> &elements);

llvm::Value *cast_pointer(llvm::Value *val,
std::string dest_ty_name,
int addr_space = 0);
Expand Down Expand Up @@ -402,6 +402,14 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);

~TaskCodeGenLLVM() override = default;

private:
void create_return(llvm::Value *buffer,
llvm::Type *buffer_type,
const std::vector<Stmt *> &elements,
const Type *current_type,
int &current_element,
std::vector<llvm::Value *> &current_index);
};

} // namespace taichi::lang
Expand Down