diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index b9866577e9b6..39ecfd9e13a7 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -296,15 +296,6 @@ TVM_DLL tvm::Array FreeVars(const Expr& expr); */ TVM_DLL tvm::Array AllVars(const Expr& expr); -/*! - * \brief Get all global variables used in calls in expression expr. - * - * \param expr the expression. - * - * \return List of all global variables called in expr. - */ -TVM_DLL tvm::Array CalledGlobalVars(const Expr& expr); - /*! * \brief Get all global variables from expression expr. * diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 72aab6684ebf..42aa591a95b7 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -110,6 +110,13 @@ TVM_DLL tvm::relax::Var Emit( TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, const tvm::relax::StructInfo& struct_info); +/*! + * \brief Emit a binding to the last binding block frame. + * \param binding The binding to be emitted. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding); + ///////////////////////////// If Then Else ///////////////////////////// /*! diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 43918ce7ec83..63efea135c15 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -25,7 +25,7 @@ import tvm from tvm import DataType, relax from tvm.ir import PrimExpr -from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const +from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, VarBinding, const ############################### Operators ############################### from tvm.relax.op import ( @@ -342,6 +342,20 @@ def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var: return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore +def emit_var_binding(value: VarBinding) -> Var: + """Emit a binding to the last binding block frame. + Parameters + ---------- + value: VarBinding + The binding to be emitted. + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.EmitVarBinding(value) # type: ignore + + ############################# If Then Else ############################# @@ -497,6 +511,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "divide", "dtype", "emit", + "emit_var_binding", "emit_match_cast", "equal", "ewise_fma", diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index e5e5bb2743e1..e1af1c1df346 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -96,8 +96,7 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy: annotation = annotation() if isinstance(annotation, StructInfoProxy): return annotation - else: - raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") + raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") except Exception as err: self.report_error(node, str(err)) raise err @@ -112,6 +111,38 @@ def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> St raise err +def is_called(node: Any, func_name: str) -> bool: + # Check if it calls into a func + if isinstance(node, doc.Call): + # Recursive call was found + if isinstance(node.func, doc.Name) and node.func.id == func_name: + return True + elif isinstance(node, (list, tuple)): + for stmt in node: + if is_called(stmt, func_name): + return True + elif isinstance(node, (doc.AnnAssign, doc.Assign, doc.Return, doc.Expr)): + return is_called(node.value, func_name) + elif isinstance(node, doc.With): + return is_called(node.body, func_name) + elif isinstance(node, doc.If): + smts = [] + if node.body is not None: + smts = smts + list(node.body) + if node.orelse is not None: + smts = smts + list(node.orelse) + return is_called(smts, func_name) + return False + + +def is_recursive(node: doc.FunctionDef) -> bool: + # Check if it is a recursive function + for stmt in node.body: + if is_called(stmt, node.name): + return True + return False + + def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None: # Collect symbolic vars from parameters symbolic_vars = set() @@ -128,6 +159,24 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non @dispatch.register(token="relax", type_name="FunctionDef") def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + # reserve a var for local function + func_val = self.var_table.get().get(node.name) + if not func_val and is_recursive(node): + collect_symbolic_var_from_params(self, node) + if node.returns is None: + ret_sinfo = relax.TupleStructInfo([]) + else: + ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) + params_sinfo = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + params_sinfo.append(param_sinfo) + # created a var for the local function, the same var could be used for recursive call + local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo)) + self.var_table.add(node.name, local_func_var) + with self.var_table.with_frame(): with self.with_dispatch_token("relax"): with R.function(): @@ -164,12 +213,10 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: else: ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) params = [] - params_sinfo = [] for arg in node.args.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) - params_sinfo.append(param_sinfo) params.append(relax.Var(arg.arg, param_sinfo)) func_signature = relax.Function.create_empty(params, ret_sinfo) @@ -188,7 +235,12 @@ def post_token_switch(self: Parser, node: doc.Expr) -> None: ir_builder = IRBuilder.current() result = ir_builder.get() ir_builder.__exit__(None, None, None) - var = R.emit(result) + # reuse var if it is reserved + reserved_var = self.var_table.get().get(node.name) + if reserved_var: + var = R.emit_var_binding(relax.VarBinding(reserved_var, result)) + else: + var = R.emit(result) IRBuilder.name(node.name, var) self.var_table.add(node.name, var, allow_shadowing=False) diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index 33197308fa1b..4132039a5e34 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -87,15 +87,6 @@ class VarVisitor : protected ExprVisitor { return ret; } - Array CalledGlobalVars(const Expr& expr) { - this->VisitExpr(expr); - Array ret; - for (const auto& v : called_global_vars_.data) { - ret.push_back(v); - } - return ret; - } - void MarkBounded(const Var& v) { bound_vars_.Insert(v); vars_.Insert(v); @@ -123,10 +114,6 @@ class VarVisitor : protected ExprVisitor { for (Expr arg : call_node->args) { VisitExpr(arg); } - - if (const GlobalVarNode* global_var_node = call_node->op.as()) { - called_global_vars_.Insert(GetRef(global_var_node)); - } } void VisitBinding_(const VarBindingNode* binding) final { @@ -144,7 +131,6 @@ class VarVisitor : protected ExprVisitor { InsertionSet vars_; InsertionSet bound_vars_; InsertionSet global_vars_; - InsertionSet called_global_vars_; }; tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } @@ -155,10 +141,6 @@ tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } -tvm::Array CalledGlobalVars(const Expr& expr) { - return VarVisitor().CalledGlobalVars(expr); -} - TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); @@ -167,7 +149,5 @@ TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); -TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars); - } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 05ad0954bbfc..25b9155d7740 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -177,7 +177,7 @@ class WellFormedChecker : public relax::ExprVisitor, void VisitExpr_(const VarNode* op) final { Var var = GetRef(op); - if (var_set_.count(var) == 0) { + if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) { Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is not defined."); } CheckStructInfo(op); @@ -316,12 +316,20 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitBinding_(const VarBindingNode* binding) final { + bool is_lambda = false; + if (binding->value->IsInstance()) { + is_lambda = true; + recur_vars_.insert(binding->var); + } if (binding->value->IsInstance()) { Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is disallowed in Relax IR."); } else { this->VisitExpr(binding->value); } this->VisitVarDef(binding->var); + if (is_lambda) { + recur_vars_.erase(binding->var); + } } void VisitBinding_(const MatchCastNode* binding) final { @@ -451,6 +459,7 @@ class WellFormedChecker : public relax::ExprVisitor, VisitMode mode_ = VisitMode::kDefault; // set of context variables. std::unordered_set var_set_; + std::unordered_set recur_vars_; std::unordered_set dataflow_var_set_; std::unordered_set symbolic_var_set_; std::unordered_map param_var_func_map_; diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index f08499036b1c..74920823100a 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -46,35 +46,72 @@ class LambdaLifter : public ExprMutator { using ExprMutator::VisitExpr_; + void VisitBinding_(const VarBindingNode* binding) final { + bool is_lambda = false; + if (binding->value->IsInstance()) { + is_lambda = true; + recur_vars_.push_back(binding->var); + } + Expr new_value = this->VisitExpr(binding->value); + if (new_value->struct_info_.defined() && + !new_value->struct_info_.same_as(binding->var->struct_info_)) { + binding->var->struct_info_ = GetStructInfo(new_value); + binding->var->checked_type_ = new_value->checked_type_; + } + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } + if (is_lambda) { + recur_vars_.pop_back(); + } + } + Expr VisitExpr_(const CallNode* call_node) final { auto call = Downcast(ExprMutator::VisitExpr_(call_node)); - if (auto const* var = call_node->op.as()) { - bool has_closure = HasClosure(GetRef(var)); - auto val = builder_->LookupBinding(GetRef(var)); + if (const auto* var_node = call_node->op.as()) { + auto var = GetRef(var_node); + bool has_closure = HasClosure(var); + auto val = builder_->LookupBinding(var); + if (const auto* fsinfo_node = GetStructInfo(var).as()) { + auto fsinfo = GetRef(fsinfo_node); + if (!GetStructInfo(call).same_as(fsinfo)) { + call->struct_info_ = fsinfo->ret; + call->checked_type_ = GetStaticType(fsinfo->ret); + } + } // Call "relax.invoke_closure" to invoke closure - if (has_closure && val.as()) { - Var clo_arg = GetRef(var); + Var clo_arg = var; + if (has_closure && val->IsInstance()) { if (this->var_remap_.find(var->vid) != this->var_remap_.end()) { clo_arg = this->var_remap_.at(var->vid); } return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, {GetStructInfo(GetRef(call_node))}); } - } - if (auto global_var_node = call_node->op.as()) { - String rec_name = global_var_node->name_hint; - auto global_var = GetRef(global_var_node); - auto it = lambda_map_.find(global_var); + auto it = lambda_map_.find(var); if (it != lambda_map_.end()) { // flatten nested call, e.g. call(y)(x) -> call(x, y)) Array new_args; + Array params; for (const auto arg : call->args) { new_args.push_back(arg); + params.push_back(StructInfoFromType(arg->checked_type())); } if (const auto* nest_call = it->second.as()) { + // Update the StructInfo accordingly for (const auto arg : nest_call->args) { new_args.push_back(arg); + params.push_back(StructInfoFromType(arg->checked_type())); } + StructInfo new_func_sinfo; + if (const auto* fsinfo = GetStructInfo(nest_call->op).as()) { + auto func_sinfo = GetRef(fsinfo); + new_func_sinfo = FuncStructInfo(params, func_sinfo->ret); + } + nest_call->op->struct_info_ = new_func_sinfo; + nest_call->op->checked_type_ = GetStaticType(new_func_sinfo); return Call(nest_call->op, new_args, call_node->attrs, call_node->sinfo_args); } return Call(it->second, call->args, call_node->attrs, call_node->sinfo_args); @@ -89,11 +126,19 @@ class LambdaLifter : public ExprMutator { // TODO(@yongwww): consider appending inner func name into the lifted func name String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); auto global = GlobalVar(lift_func_name); - Array captured_vars = FreeVars(func); - recur_vars_ = CalledGlobalVars(func); - auto all_global_vars = AllGlobalVars(func); + Array free_vars = FreeVars(func); + Array captured_vars; Array typed_captured_vars; + bool recursive = false; + for (const auto& var : free_vars) { + if (!recur_vars_.empty() && var == recur_vars_.back()) { + recursive = true; + } else { + captured_vars.push_back(var); + } + } + Map rebinding_map; for (auto free_var : captured_vars) { Var var = Var(free_var->name_hint(), GetStructInfo(free_var), free_var->span); @@ -102,12 +147,14 @@ class LambdaLifter : public ExprMutator { } // recursive call - if (!recur_vars_.empty()) { + if (recursive) { if (!captured_vars.empty()) { Array fvs; for (auto fv : captured_vars) { fvs.push_back(fv); } + // it is required by block_blocker, will be updated later + UpdateStructInfo(global, GetStructInfo(recur_vars_.back())); lambda_map_.emplace(recur_vars_.back(), Call(global, fvs)); } else { if (recur_vars_.size() > 0) { @@ -162,18 +209,17 @@ class LambdaLifter : public ExprMutator { /*attrs=*/new_func->attrs, /*span=*/func->span); - Array param_types; for (Var param : closure_params) { CHECK(param->checked_type_.defined()) << "relax.Function requires params to contain checked_type_"; - param_types.push_back(param->checked_type_); } } ICHECK(lifted_func.defined()); // Add the lifted function to the module. - UpdateStructInfo(global, GetStructInfo(lifted_func)); + global->struct_info_ = GetStructInfo(lifted_func); + global->checked_type_ = lifted_func->checked_type_; builder_->UpdateFunction(global, lifted_func); if (!is_closure) { @@ -242,8 +288,8 @@ class LambdaLifter : public ExprMutator { } private: - std::unordered_map lambda_map_; - Array recur_vars_; + std::unordered_map lambda_map_; + Array recur_vars_; IRModule mod_; size_t lift_func_num_ = 0; /*! \brief Cache ops that would be used later to reduce lookup overhead. */ diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index ece645243c82..ddfb1ddfa35f 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -203,8 +203,17 @@ tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, return var; } +tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + block_builder->EmitNormalized(binding); + block_frame->emitted_vars.push_back(binding->var); + return binding->var; +} + TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); ///////////////////////////// If Then Else ///////////////////////////// diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 67da77274188..ee5814eb7bfc 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -173,6 +173,33 @@ def test_seq_expr(): assert not rx.analysis.well_formed(mod, check_struct_info=False) +def test_recursive(): + scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32") + gv0 = rx.Var("gv0", scalar_struct_info) + f = rx.Var("f", rx.FuncStructInfo([scalar_struct_info], scalar_struct_info)) + ipt = rx.Var("ipt", scalar_struct_info) + x0 = rx.Var("x0", scalar_struct_info) + x1 = rx.Var("x1", scalar_struct_info) + x2 = rx.Var("x2", scalar_struct_info) + y = rx.Var("y", scalar_struct_info) + inner_block = rx.BindingBlock( + [rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, rx.Call(f, [x0]))] + ) + inner_func = rx.Function([ipt], rx.SeqExpr([inner_block], y), scalar_struct_info) + outer_block = rx.BindingBlock( + [ + rx.VarBinding(f, inner_func), + rx.VarBinding(x1, rx.const(1, "int32")), + rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, [x1]))), + rx.VarBinding(gv0, x2), + ] + ) + func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info) + mod = tvm.IRModule.from_expr(func) + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) + + def test_if(): # Error: Var defined in true/false branch is invisible in the outer scope # except the return Var, i.e the var in the last stmt diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index c9bbc0fb91e7..5a137f22cb5f 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -114,7 +114,9 @@ def main( x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): @R.function - def outer_func(c1: R.Tensor((2, 3), "float32")): + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor((2, 3), "float32"),), R.Tensor((2, 3), "float32")): @R.function def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): s: R.Tensor((2, 3), "float32") = R.add(x1, c1) @@ -133,7 +135,6 @@ def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): _check_save_roundtrip(after) -@pytest.mark.skip(reason="Need fix after parser switch over") def test_recursive(): # the expected IRModule @tvm.script.ir_module @@ -149,18 +150,19 @@ def lifted_func_0( if cond: new_i: R.Tensor((), "int32") = R.add(i, c) new_s: R.Tensor((2, 3), "float32") = R.add(s, x) - r = lifted_func_0(new_i, new_s, x) + new_r = lifted_func_0(new_i, new_s, x) + r = new_r else: r = s return r @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), dtype="float32"): while_loop = R.make_closure(lifted_func_0, (x,)) - gv = R.invoke_closure( + gv: R.Tensor((2, 3), dtype="float32") = R.invoke_closure( while_loop, - (relax.const(0), x), - sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + (R.const(0), x), + sinfo_args=(R.Tensor((2, 3), dtype="float32")), ) return gv @@ -185,11 +187,14 @@ def while_loop( r: R.Tensor((2, 3), "float32") = s return r - gv: R.Tensor((2, 3), "float32") = while_loop(relax.const(0), x) + gv: R.Tensor((2, 3), "float32") = while_loop(R.const(0), x) return gv before = Before expected = Expected + # check well-formness of recursive call + assert relax.analysis.well_formed(before) + # Perform Lambda Lifting after = transform.LambdaLift()(before) assert len(after.functions) == 2 @@ -198,7 +203,6 @@ def while_loop( _check_save_roundtrip(after) -@pytest.mark.skip(reason="Need fix after parser switch over") def test_multi_func(): # expected IRModule @tvm.script.ir_module @@ -207,29 +211,29 @@ class Expected: def glob_func_1( x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") ) -> R.Tensor(None, "float32", ndim=2): - inner = lifted_func_1 - gv1 = inner(x1, y1) + inner = lifted_func_0 + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) return gv1 @R.function def glob_func_2( x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), "float32") ) -> R.Tensor(None, "float32", ndim=2): - inner1 = lifted_func_0 - gv11 = inner1(x11, y11) + inner = lifted_func_1 + gv11: R.Tensor((10, 5), "float32") = inner(x11, y11) return gv11 @R.function def lifted_func_0( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") - ) -> R.Tensor(None, "float32", ndim=2): + ) -> R.Tensor((10, 5), "float32"): s: R.Tensor((10, 5), "float32") = R.add(x2, y2) return s @R.function def lifted_func_1( x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") - ) -> R.Tensor(None, "float32", ndim=2): + ) -> R.Tensor((10, 5), "float32"): s1: R.Tensor((10, 5), "float32") = R.add(x21, y21) return s1 diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index fbeb57564fb5..15122dab3771 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -69,7 +69,7 @@ class Actual: @R.function def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): @R.function - def inner(x: R.Tensor((3,), "float32")): + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): gv = R.add(x, x) return gv @@ -81,7 +81,7 @@ class Expected: @R.function def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): @R.function - def inner(x: R.Tensor((3,), "float32")): + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): gv = R.add(x, x) return gv @@ -91,7 +91,7 @@ def inner(x: R.Tensor((3,), "float32")): @R.function def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): @R.function - def inner(x: R.Tensor((3,), "float32")): + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): gv = R.add(x, x) return gv