diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index e9f60c614f112..cb8fa802123ed 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -216,6 +216,11 @@ def visit_block(self, list_stmt): for i, l in enumerate(list_stmt): list_stmt[i] = self.visit(l) + def visit_Return(self, node): + ret = self.parse_stmt('__retval.assign(0)') + ret.value.args[0] = node.value + return ret + def visit_If(self, node): self.generic_visit(node, ['body', 'orelse']) @@ -505,9 +510,11 @@ def visit_FunctionDef(self, node): arg_decls.append(arg_init) # remove original args node.args.args = [] + ret_stmt = None else: # Transform as func (all parameters passed by value) arg_decls = [] + arg_decls.append(self.parse_stmt('__retval = ti.expr_init(0)')) # TODO(archibate): init by ret type for i, arg in enumerate(args.args): if i == 0 and self.is_classfunc: continue @@ -517,9 +524,12 @@ def visit_FunctionDef(self, node): arg_init.value.args[0] = self.parse_expr(arg.arg + '_by_value__') args.args[i].arg += '_by_value__' arg_decls.append(arg_init) + ret_stmt = self.parse_stmt('return __retval') with self.variable_scope(): self.generic_visit(node) node.body = arg_decls + node.body + if ret_stmt is not None: + node.body.append(ret_stmt) return node def visit_UnaryOp(self, node):