diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 4b993ad3c9..10b6be0991 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -157,12 +157,8 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: params_sinfo.append(param_sinfo) params.append(relax.Var(arg.arg, param_shape, param_type)) - # TODO(relax-team): remove the following line when fixing ret_shape issue in block builder - ret_shape = relax.RuntimeDepShape() - func_signature = relax.Function.create_unchecked(params, None, ret_type, ret_shape) global_var = I.decl_function(node.name, func_signature) - relax.expr._update_struct_info(global_var, relax.FuncStructInfo(params_sinfo, ret_sinfo)) self.var_table.add(node.name, global_var) diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index 5057d3dbf7..d9c4028dc1 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -581,7 +581,6 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& param << Print(var) << PrintVarAnnotation(var); params.push_back(param); } - print_symbolic_shape_as_str_ = false; if (is_global) { ICHECK(symbolic_vars_.empty()); @@ -591,9 +590,18 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& doc << "@R.function" << Doc::NewLine(); doc << "def " << name << "(" << Doc::Concat(params, Doc::Text(", ")) << ")"; if (func->ret_type.defined()) { - doc << " -> " << Print(func->ret_type); + doc << " -> "; + if (const relax::DynTensorTypeNode* tty = func->ret_type.as()) { + doc << PrintTensorAnnotation(GetRef(tty), func->ret_shape); + } else if (const TupleTypeNode* tty = func->ret_type.as()) { + doc << PrintTupleAnnotation(GetRef(tty), func->ret_shape); + } else { + doc << Print(func->ret_type); + } } doc << ":" << Doc::NewLine(4); + print_symbolic_shape_as_str_ = false; + // Step 3: print function attr Doc header_attr; diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 18d134cef5..81236b7640 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -92,6 +92,11 @@ class LegacyShapeDeriver : public StructInfoFunctor(const StructI } Optional VisitStructInfo_(const TupleStructInfoNode* op) final { + if (op->fields.empty()) { + // Corner case to prevent infinite recursion. + return Tuple(Array()); + } + bool valid = true; Array fields = op->fields.Map([this, &valid](const StructInfo& sinfo) { Optional shape = this->VisitStructInfo(sinfo); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 48064c7998..1045cb4e0e 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -428,7 +428,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorIsInstance()) { ICHECK(normalized->struct_info_.defined()) << "The struct_info_ of an Expr except OpNode after " - "normalization must not be nullptr. However, this Expr does not have checked_type_: " + "normalization must not be nullptr. However, this Expr does not have struct_info_: " << normalized; } diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index b50f220933..2af6a55dfc 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -454,12 +454,14 @@ Function::Function(Array params, Expr body, Type ret_type, Expr ret_shape, param_sinfo.push_back(GetStructInfo(param)); } - StructInfo ret_info; + StructInfo ret_info = GetStructInfo(body); if (ret_type.defined()) { - ret_info = StructInfoFromTypeLegacyShapeHint(ret_type, ret_shape); - } else { - ret_info = GetStructInfo(body); + StructInfo given_info = StructInfoFromTypeLegacyShapeHint(ret_type, ret_shape); + CHECK(IsBaseOf(given_info, ret_info)) + << "relax.Function requires the deduced body->struct_info to be a subtype of the " + "annotated struct_info but meet body->struct_info: " + << ret_info << ", ret_info: " << given_info; } FuncStructInfo func_sinfo(param_sinfo, ret_info); @@ -468,8 +470,8 @@ Function::Function(Array params, Expr body, Type ret_type, Expr ret_shape, ObjectPtr n = make_object(); n->params = std::move(params); n->body = std::move(body); - n->ret_type = std::move(ret_type); - n->ret_shape = std::move(ret_shape); + n->ret_type = GetStaticType(ret_info); + n->ret_shape = GetLegacyShapeHint(ret_info).value_or(ret_shape); n->struct_info_ = func_sinfo; n->checked_type_ = GetStaticType(func_sinfo); n->attrs = std::move(attrs); @@ -485,15 +487,29 @@ TVM_REGISTER_GLOBAL("relax.Function") Function Function::CreateUnchecked(Array params, Expr body, Type ret_type, Expr ret_shape, DictAttrs attrs, Span span) { + // TODO(@Hzfengsy): revisit `CreateUnchecked` after the parser_v1 removed + + Array param_sinfo; + for (Var param : params) { ICHECK(param->checked_type_.defined()) << "relax.Function requires params to contain checked_type_."; + param_sinfo.push_back(GetStructInfo(param)); + } + + StructInfo ret_info; + + if (ret_type.defined()) { + ret_info = StructInfoFromTypeLegacyShapeHint(ret_type, ret_shape); + } else { + ret_info = FuncStructInfo::OpaqueFunc(); } // set the fields ObjectPtr n = make_object(); n->params = std::move(params); n->body = std::move(body); + n->struct_info_ = FuncStructInfo(param_sinfo, ret_info); n->ret_type = std::move(ret_type); n->ret_shape = std::move(ret_shape); n->attrs = std::move(attrs); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index caefefb53f..06d0bc185c 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -55,7 +55,7 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { } StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) { - return TensorStructInfo({}); + return TupleStructInfo(Array()); } StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { @@ -142,21 +142,21 @@ TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); // can't actually name it assert or else Python will consider it a syntax error -Type InferAssertType(const Call& call, DiagnosticContext diag_ctx) { +StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { // Ensure that the condition argument is a boolean scalar. // Also permitted is a tensor with unknown shape and unknown dtype // (checked dynamically in that case). Returns void. if (call->args.size() < 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Assert must have at least one argument (the condition)."); + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Assert must have at least one argument (the condition)."); } Type arg_type = call->args[0]->checked_type(); if (!IsBoolScalarType(arg_type)) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "The argument to assert must be a boolean scalar type, but received " - << arg_type); + ctx->ReportFatal(Diagnostic::Error(call->span) + << "The argument to assert must be a boolean scalar type, but received " + << arg_type); } - return VoidType(); + return ReturnVoidStructInfo(call, ctx); } TVM_REGISTER_NODE_TYPE(AssertOpAttrs); @@ -167,7 +167,7 @@ RELAY_REGISTER_OP("relax.assert_op") .add_argument("vals", "Array", "The first value is used as the assertion condition. The others are used as " "format arguments if there is an error.") - .set_attr("FInferType", InferAssertType) + .set_attr("FInferStructInfo", InferAssertStructInfo) .set_attr("FCallPacked", "relax.run.assert_op"); Expr MakeAssertOp(Expr condition, Array vals, std::string format) { @@ -277,9 +277,9 @@ TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocSto StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "must be MemAllocTensorAttrs, but got " << call->attrs->GetTypeKey(); - ICHECK(call->args[0].as()) - << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); - return TensorStructInfo(call->args[0], attrs->dtype); + ICHECK(GetStructInfoAs(call->args[1])) + << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); + return TensorStructInfo(call->args[1], attrs->dtype); } RELAY_REGISTER_OP("relax.memory.alloc_tensor") diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index ddbddd4b1d..ded3909360 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -39,6 +40,22 @@ GlobalVar DeclFunction(const String& func_name, const Optional& func_s CHECK(!frame->global_var_map.count(func_name)) << "ValueError: function " << func_name << " already exists"; GlobalVar gv = GlobalVar(func_name); + if (func_signature.defined()) { + const BaseFunc& func = func_signature.value(); + if (func->struct_info_.defined()) { + gv->struct_info_ = tvm::relax::GetStructInfo(func); + } else if (const auto* prim_func = func.as()) { + // NOTE: use a slightly different struct info than checked type + // in PrimFunc so handle can turn into Tensor. + // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. + gv->struct_info_ = tvm::relax::FuncStructInfo::OpaqueFunc( + tvm::relax::StructInfoFromType(prim_func->ret_type)); + } else { + LOG(FATAL) << "Unsupported function: " << func; + } + } else { + gv->struct_info_ = tvm::relax::FuncStructInfo::OpaqueFunc(); + } CHECK(frame->functions.find(gv) == frame->functions.end()) << "ValueError: function " << func_name << " has already been defined."; frame->global_var_map.Set(func_name, gv); diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 161b5c3f3e..288e76e2f6 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -120,8 +120,6 @@ void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { frame->ret_sinfo = ret_sinfo; frame->ret_type = GetStaticType(ret_sinfo); frame->ret_shape = GetLegacyShapeHint(ret_sinfo); - // TODO(@Hzfengsy): remove it - frame->ret_shape = tvm::relax::RuntimeDepShape(); } void FuncRetValue(const tvm::relax::Expr& value) { diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index d725f11dfa..dba7e72dd6 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -16,11 +16,11 @@ # under the License. """Tests analysis functions of struct info""" -import pytest import tvm +import tvm.testing +from tvm import relax as rx from tvm import tir -from tvm import relax as rx, TVMError def test_get_static_type_basic(): @@ -40,9 +40,9 @@ def test_get_static_type_shape(): s2 = rx.ShapeStructInfo([1, n + 1, m]) s3 = rx.ShapeStructInfo(ndim=2) - tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType()) + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3)) - tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType()) + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType(ndim=2)) def test_get_static_type_tensor(): @@ -68,7 +68,7 @@ def test_get_static_type_tuple(): rx.TupleType( [ rx.TupleType([rx.DynTensorType(ndim=3, dtype="int64"), rx.ObjectType()]), - rx.ShapeType(), + rx.ShapeType(ndim=3), ] ), ) @@ -123,8 +123,7 @@ def test_erase_to_well_defined_shape(): def test_erase_to_well_defined_tensor(): n, m = tir.Var("n", "int64"), tir.Var("m", "int64") - rshape = rx.Var("shape", type_annotation=rx.ShapeType()) - rx.expr._update_struct_info(rshape, rx.ShapeStructInfo(ndim=2)) + rshape = rx.Var("shape", type_annotation=rx.ShapeType(ndim=2)) s0 = rx.TensorStructInfo(rshape, dtype="int32") # undefined @@ -382,3 +381,7 @@ def fn_info_erased(): _check_lca(fopaque0(), fopaque1(), fopaque0()) _check_lca(fopaque0(), fn_info_shape(1), fopaque0()) _check_lca(fopaque2(), fn_info_shape(1), fopaque2()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 7f36e84f73..0d23575ba3 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -201,8 +201,8 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): y1 = R.match_shape(y, (n,)) return (m, n * 2) - x = relax.Var("x", None, DynTensorType(-1, "float32")) - y = relax.Var("y", None, DynTensorType(-1, "float32")) + x = relax.Var("x", RuntimeDepShape(), DynTensorType(-1, "float32")) + y = relax.Var("y", RuntimeDepShape(), DynTensorType(-1, "float32")) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") bb = relax.BlockBuilder() @@ -237,7 +237,7 @@ def foo(x: R.Tensor("float32", ndim=2)): x0 = R.match_shape(x, (n, m)) return (x0, (n + 1, m, 1)) - x = relax.Var("x", None, DynTensorType(2, "float32")) + x = relax.Var("x", RuntimeDepShape(), DynTensorType(2, "float32")) n, m = tir.Var("n", "int64"), tir.Var("m", "int64") bb = relax.BlockBuilder() with bb.function("foo", (x,)): @@ -256,7 +256,7 @@ def foo(x: R.Tensor("float32", ndim=2)): t1 = (x, (n, m), t0) return t1 - x = relax.Var("x", None, DynTensorType(2, "float32")) + x = relax.Var("x", RuntimeDepShape(), DynTensorType(2, "float32")) n, m = tir.Var("n", "int64"), tir.Var("m", "int64") bb = relax.BlockBuilder() with bb.function("foo", (x,)): @@ -493,9 +493,9 @@ def _check_type_shape(binding, expected_type, expected_shape): relax.DynTensorType(ndim=2, dtype="float32"), relax.ShapeExpr([tvm.tir.IntImm("int64", 32), m]), ) - _check_type_shape(bindings[1], relax.DynTensorType(dtype=""), None) - _check_type_shape(bindings[2], relax.DynTensorType(ndim=2, dtype=""), None) - _check_type_shape(bindings[3], relax.DynTensorType(dtype=""), None) + _check_type_shape(bindings[1], relax.DynTensorType(dtype=""), RuntimeDepShape()) + _check_type_shape(bindings[2], relax.DynTensorType(ndim=2, dtype=""), RuntimeDepShape()) + _check_type_shape(bindings[3], relax.DynTensorType(dtype=""), RuntimeDepShape()) _check_type_shape(bindings[4], relax.ShapeType(), None) _check_type_shape( bindings[5],