From 678f07375caacdc66d74edc1643c23d121c159f3 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 23 Dec 2022 14:46:07 -0500 Subject: [PATCH] [Analysis] Optionally check structure info in well-formedness check (#321) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With the introduction of structure info in #314, the well-formedness check will report malformed whenever an Expr doesn’t have defined structure info. However, when writing tests for well-formedness check and normalizer, usually we will manually construct the Exprs, which means their structure info are not defined most of the time. As a consequence, the well-formedness check will always complain “the Expr xxx doesn’t have structure info populated.” Therefore, when the checker fails to complain about the original reason of malformed, which means the checker is not working, the tests will still pass and we won’t be able to realize there is something wrong with the checker. Thus, in this PR we add an optional flag to the well-formedness check. In well-formedness tests, we will turn off the structure info check so that the original reason of being malformed will be revealed correctly. --- This PR also cleans up the DiagnosticContext parameter in the WellFormed API - the diag_ctx has been unused since the merge of #99. --- include/tvm/relax/analysis.h | 21 ++- python/tvm/relax/analysis/analysis.py | 14 +- src/relax/analysis/well_formed.cc | 154 +++++++++--------- .../python/relax/test_analysis_well_formed.py | 44 ++--- 4 files changed, 127 insertions(+), 106 deletions(-) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 4467d51f8a..4e38c055e1 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -200,8 +200,8 @@ TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map diag_ctx = Optional()); +TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); /*! * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. @@ -346,7 +349,7 @@ TVM_DLL tvm::Array FreeVars(const Expr& expr); TVM_DLL tvm::Array AllVars(const Expr& expr); /*! - * \brief Get all glabal variables used in calls in expression expr. + * \brief Get all global variables used in calls in expression expr. * * \param expr the expression. * @@ -355,7 +358,7 @@ TVM_DLL tvm::Array AllVars(const Expr& expr); TVM_DLL tvm::Array CalledGlobalVars(const Expr& expr); /*! - * \brief Get all glabal variables from expression expr. + * \brief Get all global variables from expression expr. * * AllVars is a superset of BoundVars and FreeVars. * The union of BoundVars and FreeVars is Allvars. @@ -402,7 +405,7 @@ TVM_DLL Map> NameToBinding(const Function& fn); * \brief Get the use-def chain of variables inside a dataflow block. * * \param dfb The dataflow block to be analyzed. - * \return A map mapping variable definitoins to a set of uses. + * \return A map mapping variable definitions to a set of uses. */ TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); @@ -410,7 +413,7 @@ TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); * \brief Get the use-def chain of variables inside a function. * * \param fn The function to be analyzed. - * \return A map from variable definitoins to a set of uses and variables needed by return value. + * \return A map from variable definitions to a set of uses and variables needed by return value. */ std::pair>, Array> FunctionUseDef(const Function& fn); diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 59824c16ad..c5b9850e8b 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -197,7 +197,7 @@ def post_order_visit(expr, fvisit): return _ffi_api.post_order_visit(expr, fvisit) # type: ignore -def well_formed(mod: tvm.IRModule) -> bool: +def well_formed(mod: tvm.IRModule, check_struct_info: bool = True) -> bool: """Check if the IRModule is well formed. Parameters @@ -205,12 +205,22 @@ def well_formed(mod: tvm.IRModule) -> bool: mod : tvm.IRModule The input IRModule. + check_struct_info : bool + A boolean flag indicating if the property "every Expr must + have defined structure info" will be checked. + Returns ------- ret: bool True if the IRModule is well formed, False if not. + + Note + ---- + By default the structure info is always checked. It is only in test cases + where `check_struct_info` might be false, so that other well-formed requirements + will be well tested and will not be blocked by not having structure info. """ - return _ffi_api.well_formed(mod) # type: ignore + return _ffi_api.well_formed(mod, check_struct_info) # type: ignore def get_var2val(func: Function) -> Dict[Var, Expr]: diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index fcadfd8daa..f694727003 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -24,19 +24,21 @@ * This pass is supposed to be applied to normalized Relax AST. * If it's malformed, messages will be logged as Warning. * This pass will check: - * 1. GlobalVars are defined before use. - * 2. When a Function has a corresponding GlobalVar and a `global_symbol` + * 1. Each Expr should have `struct_info_` field already populated, when + * `check_struct_info` is true. + * 2. GlobalVars are defined before use. + * 3. When a Function has a corresponding GlobalVar and a `global_symbol` * attribute, the name of the GlobalVar must equal the value of the * `global_symbol` attribute value. - * 3. Vars are defined before use. - * 4. Vars are defined exactly once. - * 5. Symbolic Vars are defined before use. - * 6. DataflowVars cannot be defined inside BindingBlock. - * 7. Vars defined in IfNode, except the return Var, are invisible + * 4. Vars are defined before use. + * 5. Vars are defined exactly once. + * 6. Symbolic Vars are defined before use. + * 7. DataflowVars cannot be defined inside BindingBlock. + * 8. Vars defined in IfNode, except the return Var, are invisible * out of the If body.(May change for new AST designs) - * 8. SeqExpr only serves as function body, or in the true and + * 9. SeqExpr only serves as function body, or in the true and * false branches in IfNode. - * 9. The IR is in ANF: + * 10. The IR is in ANF: * (a) Expressions cannot contain nested complex expressions. * Here are the expressions that may be nested inside other expressions: * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, RuntimeDepShape, @@ -51,7 +53,7 @@ * * The cond field of If nodes * * The op or args fields of Call nodes * * Inside the fields of Tuple nodes - * 10. Expr always has checked_type_ (with the exception of Op). + * 11. Expr always has checked_type_ (with the exception of Op). */ #include #include @@ -73,33 +75,24 @@ class WellFormedChecker : public relax::ExprVisitor, public relax::StructInfoVisitor, public tir::ExprVisitor { public: - bool well_formed = true; - - void Malformed(Diagnostic diag) { - well_formed = false; - LOG(WARNING) << "This IR is not well formed: " << diag->message; - } - - void VisitExpr(const Expr& expr) override { - if (!expr.as() && !expr->checked_type_.defined()) { - Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " << expr << " is nullptr."); - } - relax::ExprVisitor::VisitExpr(expr); - } - - void RegisterGlobalVar(GlobalVar var) { global_var_set_.insert(var); } - - void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) { - // check name in global var and gsymbol - Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); - if (gsymbol.defined() && gsymbol != var->name_hint) { - Malformed(Diagnostic::Error(func->span) - << "Name in GlobalVar is not equal to name in gsymbol: " << var->name_hint - << " != " << gsymbol.value()); + static bool Check(IRModule mod, bool check_struct_info) { + WellFormedChecker well_formed_checker = WellFormedChecker(mod, check_struct_info); + + for (const auto& it : mod->functions) { + // visit relax.Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); + well_formed_checker.VisitExpr(func); + } } + return well_formed_checker.well_formed_; } private: + explicit WellFormedChecker(IRModule mod, bool check_struct_info) + : mod_(std::move(mod)), check_struct_info_(check_struct_info) {} + // Possible mode of visitor enum class VisitMode { /*! @@ -113,9 +106,32 @@ class WellFormedChecker : public relax::ExprVisitor, kMatchVarDef }; - void VisitExpr_(const GlobalVarNode* op) { + void Malformed(Diagnostic diag) { + well_formed_ = false; + LOG(WARNING) << "This IR is not well formed: " << diag->message; + } + + void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) { + // check name in global var and gsymbol + Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol != var->name_hint) { + Malformed(Diagnostic::Error(func->span) + << "Name in GlobalVar is not equal to name in gsymbol: " << var->name_hint + << " != " << gsymbol.value()); + } + } + + void VisitExpr(const Expr& expr) final { + if (!expr.as() && !expr->checked_type_.defined()) { + Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " << expr << " is nullptr."); + } + relax::ExprVisitor::VisitExpr(expr); + } + + void VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); - if (global_var_set_.count(var) == 0) { + if (!(mod_->ContainGlobalVar(var->name_hint) && + mod_->GetGlobalVar(var->name_hint).same_as(var))) { Malformed(Diagnostic::Error(var) << "GlobalVar " << op->name_hint << " is not defined."); } @@ -130,7 +146,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitExpr_(const TupleNode* op) { + void VisitExpr_(const TupleNode* op) final { for (size_t i = 0; i < op->fields.size(); i++) { Expr expr = op->fields[i]; if (IsLeafExpr(expr)) { @@ -144,7 +160,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitExpr_(const TupleGetItemNode* op) { + void VisitExpr_(const TupleGetItemNode* op) final { if (IsLeafExpr(op->tuple)) { this->VisitExpr(op->tuple); } else { @@ -154,7 +170,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitExpr_(const VarNode* op) { + void VisitExpr_(const VarNode* op) final { Var var = GetRef(op); if (var_set_.count(var) == 0) { Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is not defined."); @@ -162,7 +178,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitExpr_(const DataflowVarNode* op) { + void VisitExpr_(const DataflowVarNode* op) final { DataflowVar var = GetRef(op); if (!is_dataflow_) { Malformed(Diagnostic::Error(var) @@ -174,7 +190,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitExpr_(const FunctionNode* op) { + void VisitExpr_(const FunctionNode* op) final { // save the var_set_ for local function auto prev_var_set = var_set_; auto prev_symbolic_var_set = symbolic_var_set_; @@ -204,7 +220,7 @@ class WellFormedChecker : public relax::ExprVisitor, symbolic_var_set_ = prev_symbolic_var_set; } - void VisitExpr_(const CallNode* op) { + void VisitExpr_(const CallNode* op) final { if (IsLeafExpr(op->op)) { this->VisitExpr(op->op); } else { @@ -223,7 +239,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitExpr_(const IfNode* op) { + void VisitExpr_(const IfNode* op) final { if (IsLeafExpr(op->cond)) { this->VisitExpr(op->cond); } else { @@ -247,7 +263,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitExpr_(const ShapeExprNode* op) { + void VisitExpr_(const ShapeExprNode* op) final { for (PrimExpr expr : op->values) { // check if the symbolic vars in the expr are defined, e.g, 2 * m tir::ExprVisitor::VisitExpr(expr); @@ -259,7 +275,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitExpr_(const SeqExprNode* op) { + void VisitExpr_(const SeqExprNode* op) final { Malformed(Diagnostic::Error(op) << "SeqExpr only serves as the function body in FunctionNode, " "or the true/false branch body in IfNode."); } @@ -277,12 +293,12 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(op); } - void VisitBinding_(const VarBindingNode* binding) { + void VisitBinding_(const VarBindingNode* binding) final { this->VisitExpr(binding->value); this->VisitVarDef(binding->var); } - void VisitBinding_(const MatchShapeNode* binding) { + void VisitBinding_(const MatchShapeNode* binding) final { this->VisitExpr(binding->value); // define the vars WithMode(VisitMode::kMatchVarDef, [&]() { @@ -300,7 +316,7 @@ class WellFormedChecker : public relax::ExprVisitor, } } - void VisitBindingBlock_(const DataflowBlockNode* block) { + void VisitBindingBlock_(const DataflowBlockNode* block) final { is_dataflow_ = true; for (Binding binding : block->bindings) { this->VisitBinding(binding); @@ -309,7 +325,7 @@ class WellFormedChecker : public relax::ExprVisitor, dataflow_var_set_.clear(); } - void VisitVarDef_(const DataflowVarNode* var) { + void VisitVarDef_(const DataflowVarNode* var) final { if (!is_dataflow_) { Malformed(Diagnostic::Error(var) << "DataflowVar " << var->name_hint() << " is defined outside DataflowBlock."); @@ -324,7 +340,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(var); } - void VisitVarDef_(const VarNode* var) { + void VisitVarDef_(const VarNode* var) final { Var gv = GetRef(var); if (var_set_.count(gv) == 1) { Malformed(Diagnostic::Error(var) @@ -335,7 +351,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(var); } - void VisitVarDef(const Var& var) { + void VisitVarDef(const Var& var) final { if (const DataflowVarNode* lv_node = var.as()) { VisitVarDef_(lv_node); } else if (const VarNode* gv_node = var.as()) { @@ -356,7 +372,7 @@ class WellFormedChecker : public relax::ExprVisitor, void VisitStructInfoExprField(const Expr& expr) final { if (mode_ == VisitMode::kMatchVarDef) { - // populate symbolic var in first occurance + // populate symbolic var in first occurrence if (auto* op = expr.as()) { auto var = GetRef(op); if (var_set_.count(var) == 0) { @@ -375,7 +391,7 @@ class WellFormedChecker : public relax::ExprVisitor, void VisitStructInfoExprField(const PrimExpr& expr) final { if (mode_ == VisitMode::kMatchVarDef) { - // populate symbolic var in first occurance + // populate symbolic var in first occurrence if (auto* op = expr.as()) { auto var = GetRef(op); if (symbolic_var_set_.count(var) == 0) { @@ -388,6 +404,10 @@ class WellFormedChecker : public relax::ExprVisitor, } void CheckStructInfo(const ExprNode* op) { + if (!check_struct_info_) { + return; + } + auto* sinfo = op->struct_info_.as(); if (sinfo != nullptr) { this->VisitStructInfo(GetRef(sinfo)); @@ -405,38 +425,26 @@ class WellFormedChecker : public relax::ExprVisitor, std::swap(mode_, mode); } + IRModule mod_; + const bool check_struct_info_; + bool well_formed_ = true; bool is_dataflow_ = false; // Current visit mode. VisitMode mode_ = VisitMode::kDefault; // set of context variables. - std::unordered_set global_var_set_; std::unordered_set var_set_; std::unordered_set dataflow_var_set_; std::unordered_set symbolic_var_set_; }; -bool WellFormed(const IRModule& m, Optional diag_ctx) { - WellFormedChecker well_formed_checker = WellFormedChecker(); - for (const auto& it : m->functions) { - // register GlobalVar in the IRModule first - well_formed_checker.RegisterGlobalVar(it.first); - } - - for (const auto& it : m->functions) { - // visit relax.Function - if (auto* n = it.second.as()) { - Function func = GetRef(n); - well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); - well_formed_checker.VisitExpr(func); - } - } - - return well_formed_checker.well_formed; +bool WellFormed(IRModule m, bool check_struct_info) { + return WellFormedChecker::Check(std::move(m), check_struct_info); } -TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")).set_body_typed([](IRModule m) { - return WellFormed(m); -}); +TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")) + .set_body_typed([](IRModule m, bool check_struct_info) { + return WellFormed(m, check_struct_info); + }); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index b05ab64b44..900b28dbea 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -44,7 +44,7 @@ def test_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) # Error: Var gv0 is defined more than once gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) @@ -54,7 +54,7 @@ def test_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) def test_dataflow_var(): @@ -66,7 +66,7 @@ def test_dataflow_var(): blocks = [rx.DataflowBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) # Error: DataflowVar gv0 is defined more than once lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) @@ -76,7 +76,7 @@ def test_dataflow_var(): blocks = [rx.DataflowBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) # Error: DataflowVar lv0 is defined outside DataflowBlock lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) @@ -85,7 +85,7 @@ def test_dataflow_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) # Error: DataflowVar lv0 is used outside DataflowBlock lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) @@ -95,7 +95,7 @@ def test_dataflow_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) def test_global_var(): @@ -110,7 +110,7 @@ def test_global_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) def test_symbolic_var(): @@ -122,7 +122,7 @@ def test_symbolic_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) def test_symbolic_var_invalid_type(): @@ -137,7 +137,7 @@ def test_symbolic_var_invalid_type(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks, [y]) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) def test_seq_expr(): @@ -154,7 +154,7 @@ def test_seq_expr(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) def test_if(): @@ -186,7 +186,7 @@ def test_if(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=True) def test_if_non_seq_body(): @@ -204,7 +204,7 @@ def test_if_non_seq_body(): ] func = build_function(blocks) mod = tvm.IRModule.from_expr(func) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) # on the other hand, if they're wrapped in a seq node, it's fine seq = rx.SeqExpr([], x) @@ -223,7 +223,7 @@ def test_if_non_seq_body(): new_mod = tvm.IRModule.from_expr(new_func) # apply normalization to fill in checked_type_ normalized = rx.transform.Normalize()(new_mod) - assert rx.analysis.well_formed(normalized) + assert rx.analysis.well_formed(normalized, check_struct_info=True) def test_if_complex_condition(): @@ -243,7 +243,7 @@ def test_if_complex_condition(): ] func = build_function(blocks) mod = tvm.IRModule.from_expr(func) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) cond_var = rx.Var("q", R.Tensor([], "bool")) new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x)) @@ -262,7 +262,7 @@ def test_if_complex_condition(): mod = tvm.IRModule.from_expr(func) # apply normalization to fill in checked_type_ normalized = rx.transform.Normalize()(mod) - assert rx.analysis.well_formed(normalized) + assert rx.analysis.well_formed(normalized, check_struct_info=True) def test_tuple_get_item_nested(): @@ -279,7 +279,7 @@ def test_tuple_get_item_nested(): ) f = f.with_attr("global_symbol", "f") mod = tvm.IRModule.from_expr(f) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) # okay with an intermediate binding first_idx = rx.TupleGetItem(nested_tup, 0) @@ -301,7 +301,7 @@ def test_tuple_get_item_nested(): mod = tvm.IRModule.from_expr(new_f) # normalize in order to fill in checked type normalized = rx.transform.Normalize()(mod) - assert rx.analysis.well_formed(normalized) + assert rx.analysis.well_formed(normalized, check_struct_info=True) def test_complex_seq_body(): @@ -314,7 +314,7 @@ def test_complex_seq_body(): R.Tensor(ndim=0, dtype="int32"), ).with_attr("global_symbol", "foo") mod = tvm.IRModule.from_expr(func) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) # but if the result is bound, then it's okay z = rx.Var("z", R.Tensor([], "int32")) @@ -338,7 +338,7 @@ def test_complex_seq_body(): new_mod = tvm.IRModule.from_expr(new_func) # normalize in order to fill in checked type normalized = rx.transform.Normalize()(new_mod) - assert rx.analysis.well_formed(normalized) + assert rx.analysis.well_formed(normalized, check_struct_info=True) def test_ANF(): @@ -349,7 +349,7 @@ def test_ANF(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) # Error: Call Node in Tuple gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) @@ -357,7 +357,7 @@ def test_ANF(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) def test_global_var_vs_gsymbol(): @@ -371,7 +371,7 @@ def test_global_var_vs_gsymbol(): R.Tensor(ndim=2, dtype="float32"), ).with_attr("global_symbol", "main1") mod = tvm.IRModule({rx.GlobalVar("main"): func}) - assert not rx.analysis.well_formed(mod) + assert not rx.analysis.well_formed(mod, check_struct_info=False) if __name__ == "__main__":