From bf883e023c97d4daa1f4bf1bb4b59f65ec37d0a0 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 30 Nov 2022 15:54:52 -0500 Subject: [PATCH 01/22] Factor out repeated leaf node check --- include/tvm/relax/utils.h | 17 +++++++++++++++++ src/relax/analysis/well_formed.cc | 5 ++--- src/relax/ir/block_builder.cc | 10 ++-------- src/relax/utils.cc | 7 +++++++ 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 6bfd0d0daa..1ddefe7118 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -122,6 +122,23 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, bool permit_unknown_dtype = true); + +/*! + * \brief Check if the given expression is a "leaf" node for normalization purposes. + * The following expressions are defined as leaf nodes: Var, Constant, ShapeExpr, + * GlobalVar, RuntimeDepShape, Op, ExternFunc, and Tuple. + * Tuples are included in this list mainly for convenience in grouping operator arguments. + * *Note*: Since tuples can contain nested expressions, it is necessary to ensure that + * values nested inside them are also leaves. + * + * \param expr The input expression + * + * \return True iff the input expression is a "leaf" node (a value allowed to appear + * inline without being bound to a var during normalization). + */ +TVM_DLL bool IsLeafExpr(const Expr& expr); + + } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 9773493131..ecd3b6c696 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -41,6 +41,7 @@ #include #include #include +#include #include #include @@ -170,9 +171,7 @@ class WellFormedChecker : public relax::ExprVisitor { void VisitExpr_(const CallNode* op) { for (size_t i = 0; i < op->args.size(); i++) { Expr arg = op->args[i]; - if (arg.as() || arg.as() || arg.as() || - arg.as() || arg.as() || arg.as() || - arg.as()) { + if (IsLeafExpr(arg)) { this->VisitExpr(arg); } else { Malformed(Diagnostic::Error(arg->span) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 82349e7ef2..66c6ddf885 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace tvm { @@ -631,13 +632,6 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { return NullOpt; } - static bool IsLeaf(const Expr& expr) { - // NB: tuples are treated as leaf nodes for ergonomics - return expr.as() || expr.as() || expr.as() || - expr.as() || expr.as() || - expr.as() || expr.as() || expr.as(); - } - Expr VisitWithNewScope(const Expr& expr) { builder_->BeginBindingBlock(); Expr post = this->VisitExpr(expr); @@ -650,7 +644,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { Expr Bind(const Expr& expr) { Expr post = this->VisitExpr(expr); - if (!IsLeaf(post)) { + if (!IsLeafExpr(post)) { post = builder_->Emit(post); expr_memo_.Set(expr, post); } diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 75a882de45..7e316a4d9d 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -77,5 +77,12 @@ bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unkn return correct_dtype && correct_rank; } +bool IsLeafExpr(const Expr& expr) { + // NB: tuples are treated as leaf nodes for ergonomics + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as() || + expr.as() || expr.as(); +} + } // namespace relax } // namespace tvm From 6c284183fedd2c5ffb0fa7f4fd391d2303b2c3f1 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 30 Nov 2022 17:07:36 -0500 Subject: [PATCH 02/22] Fix printing for if nodes (no shape_) --- python/tvm/relax/testing/ast_printer.py | 17 ++++++++++------- tests/python/relax/test_ast_printer.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index b7bdbc92d6..0807950192 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -208,13 +208,16 @@ def visit_seq_expr_(self, op: relax.SeqExpr) -> str: ) def visit_if_(self, op: relax.If) -> str: - return self.build_expr( - op, - "If", - cond=self.visit_expr(op.cond), - true_branch=self.visit_expr(op.true_branch), - false_branch=self.visit_expr(op.false_branch), - ) + # if is copied from Relay's AST, so it cannot have a Shape + # TODO(@relax-team): We should eventually assign a shape_ to If + fields = { + "cond": self.visit_expr(op.cond), + "true_branch": self.visit_expr(op.true_branch), + "false_branch": self.visit_expr(op.false_branch), + } + if op._checked_type_ and self.include_type_annotations: + fields["checked_type_"] = self.visit_type_(op.checked_type) + return self.build_ast_node("If", **fields) def visit_op_(self, op: tvm.ir.Op) -> str: # TODO: List other attributes? diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index f4582a1723..09a9b19d57 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -435,5 +435,23 @@ def f() -> R.Shape: assert type_str in call_str +def test_if(): + @R.function + def f(cond: R.Tensor((), dtype="bool")) -> R.Tensor((), dtype="int32"): + if cond: + x = R.const(1) + else: + x = R.const(2) + return x + + body = normalize(f).body + assert isinstance(body, rx.SeqExpr) + body_str = strip_whitespace(dump_ast(body)) + # we expect both branches to be seq exprs + assert "If" in body_str + assert "true_branch=SeqExpr(" in body_str + assert "false_branch=SeqExpr(" in body_str + + if __name__ == "__main__": pytest.main([__file__]) From 04bd8f0ab0044d8ff747aa6f1d3e73fd39d0c72b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 30 Nov 2022 17:08:43 -0500 Subject: [PATCH 03/22] Require scope bodies to be SeqExpr, also normalize body of SeqExpr --- src/relax/ir/block_builder.cc | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 66c6ddf885..0d1124d239 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -27,8 +27,8 @@ #include #include #include -#include #include +#include #include namespace tvm { @@ -214,7 +214,8 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } builder_->BeginBindingBlock(); - Expr new_body = this->VisitExpr(op->body); + // the body may not be a leaf expression, so check for that + Expr new_body = this->Bind(op->body); unchanged &= new_body.same_as(op->body); BindingBlock prologue = builder_->EndBlock(); @@ -228,8 +229,9 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { SeqExpr seq_expr; if (unchanged) { seq_expr = GetRef(op); + } else { + seq_expr = SeqExpr(new_blocks, new_body); } - seq_expr = SeqExpr(new_blocks, new_body); // only do shape/type inference if the SeqExpr does not have shape/type if (seq_expr->shape_ && seq_expr->checked_type_.defined()) { @@ -636,10 +638,18 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { builder_->BeginBindingBlock(); Expr post = this->VisitExpr(expr); BindingBlock prologue = builder_->EndBlock(); + // "New scopes" (function bodies, if/else clauses) must be wrapped in seq exprs. + // Don't wrap if it's already a seq and there are no bindings to add + if (post.as() && prologue->bindings.empty()) { + return post; + } + Array bindings; if (!prologue->bindings.empty()) { - post = SeqExpr({prologue}, post); + bindings.push_back(prologue); } - return post; + auto seq = SeqExpr(bindings, post); + // visit in case post is not a leaf and we need to bind it too + return this->VisitExpr(seq); } Expr Bind(const Expr& expr) { From c972ac77ccc5c05bbe93fc599d3686ecc516349a Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 17:47:13 -0500 Subject: [PATCH 04/22] Reset the memo to prevent false sharing during second visit --- include/tvm/relax/block_builder.h | 9 +++++++++ src/relax/ir/block_builder.cc | 9 +++++++++ src/script/ir_builder/relax/frame.cc | 3 +++ 3 files changed, 21 insertions(+) diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 409399713e..f2140cde93 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -133,6 +133,15 @@ class BlockBuilderNode : public Object { */ bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); + /*! + * \brief Resets the memo in the normalizer to prevent false hits when visiting + * the same expression more than once. + * Use if before visiting a given expression again. + */ + // TODO(@relax-team): Memoization should be tied to the scope tracking to prevent memo hits + // when the associated var is out of scope + void ResetMemo(); + /*! * \brief Convert an expression to A-normal form, and try to eagerly infer types and shapes. * \param expr The input expression. diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 0d1124d239..9a4c16b67b 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -392,6 +392,8 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { return new_block; } + void ResetMemo() { expr_memo_.Reset(); } + private: /*! * \brief Memoization map for expressions using Id for equality of variables. @@ -421,6 +423,11 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } } + void Reset() { + var_memo_ = std::unordered_map(); + expr_memo_ = std::unordered_map(); + } + private: std::unordered_map var_memo_; std::unordered_map expr_memo_; @@ -827,6 +834,8 @@ bool BlockBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { return false; } +void BlockBuilderNode::ResetMemo() { normalizer_->ResetMemo(); } + // TODO(@altanh, @yuchen): need an internal Emit_ that doesn't call normalize Expr BlockBuilderNode::Normalize(const Expr& expr) { Expr normalized = normalizer_->VisitExpr(expr); diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 451e01d317..f250e84781 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -54,6 +54,9 @@ void FunctionFrameNode::ExitWithScope() { // Step 1: Create the function. CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " "`return` to return an Expr"; + // Normalizing a second time could result in false hits to the memo + // TODO(relax-team): We should fix the memoization not to require this + this->block_builder->ResetMemo(); Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); Expr func_shape = ret_shape.value_or(tvm::relax::RuntimeDepShape()); if (func_shape->IsInstance()) { From b700e84a373e7d0df0ba28516c6ca5c2e191f8d1 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 17:48:14 -0500 Subject: [PATCH 05/22] Update parsing tests for new normalizer behavior --- tests/python/relax/test_parser.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 935cb4acc3..c593a430a7 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -238,13 +238,17 @@ def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor: body = ite.true_branch.body assert w_bind.var.name_hint == "w" check_call(w_bind.value, "relax.add", [x, x]) - check_call(body, "relax.multiply", [w_bind.var, w_bind.var]) + body_bind = ite.true_branch.blocks[1].bindings[0] + check_call(body_bind.value, "relax.multiply", [w_bind.var, w_bind.var]) + assert ite.true_branch.body == body_bind.var w_bind = ite.false_branch.blocks[0].bindings[0] body = ite.false_branch.body assert w_bind.var.name_hint == "w" check_call(w_bind.value, "relax.multiply", [x, x]) - check_call(body, "relax.add", [w_bind.var, w_bind.var]) + body_bind = ite.false_branch.blocks[1].bindings[0] + check_call(body_bind.value, "relax.add", [w_bind.var, w_bind.var]) + assert ite.false_branch.body == body_bind.var def test_func_type_annotation_fail(): @@ -766,9 +770,15 @@ def k(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.T func_j = my_module[var_j] func_k = my_module[var_k] - assert len(func_f.body.blocks) == 0 - assert func_f.body.body.op == var_g - assert func_g.body.body.args[0] == var_my_matmul + assert len(func_f.body.blocks) == 1 + assert len(func_f.body.blocks[0].bindings) == 1 + f_call_var = func_f.body.blocks[0].bindings[0].var + assert func_f.body.blocks[0].bindings[0].value.op == var_g + assert func_f.body.body == f_call_var + + g_call_var = func_g.body.blocks[0].bindings[-1].var + assert func_g.body.blocks[0].bindings[-1].value.args[0] == var_my_matmul + assert func_g.body.body == g_call_var gv_bind = func_j.body.blocks[0].bindings[0] assert gv_bind.value.checked_type.ndim == 2 From 726dbe63efdd41f5ff5f5c0f6e1f37b32e53d399 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 18:02:17 -0500 Subject: [PATCH 06/22] Require seq expr bodies to be leaves and require func/branch bodies to be seq exprs --- src/relax/analysis/well_formed.cc | 39 +++++++++++++++++-------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index ecd3b6c696..c823ca415a 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -163,9 +163,13 @@ class WellFormedChecker : public relax::ExprVisitor { this->VisitVarDef(param); } - this->VisitBody(op->body); - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_.clear(); + if (auto seq = op->body.as()) { + this->VisitSeqExpr(seq); + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_.clear(); + } else { + Malformed(Diagnostic::Error(op->span) << "Function bodies must be sequence expressions"); + } } void VisitExpr_(const CallNode* op) { @@ -189,12 +193,18 @@ class WellFormedChecker : public relax::ExprVisitor { std::unordered_set previous_var_set_ = var_set_; std::unordered_set previous_symbolic_var_set_ = prim_expr_visitor_.symbolic_var_set_; - this->VisitBody(op->true_branch); - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; - this->VisitBody(op->false_branch); - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + auto true_seq = op->true_branch.as(); + auto false_seq = op->false_branch.as(); + if (true_seq && false_seq) { + this->VisitSeqExpr(true_seq); + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + this->VisitSeqExpr(false_seq); + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + } else { + Malformed(Diagnostic::Error(op->span) << "If node branches must be seq exprs"); + } } void VisitExpr_(const ShapeExprNode* op) { @@ -220,15 +230,10 @@ class WellFormedChecker : public relax::ExprVisitor { for (BindingBlock block : op->blocks) { this->VisitBindingBlock(block); } - this->VisitExpr(op->body); - } - - void VisitBody(const Expr& expr) { - if (const SeqExprNode* seq_expr = expr.as()) { - this->VisitSeqExpr(seq_expr); - } else { - this->VisitExpr(expr); + if (!IsLeafExpr(op->body)) { + Malformed(Diagnostic::Error(op->span) << "SeqExpr bodies must be leaf expressions."); } + this->VisitExpr(op->body); } void VisitBinding_(const VarBindingNode* binding) { From 8443430d8dc15e524d41319dfe185268d01976bc Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 18:18:50 -0500 Subject: [PATCH 07/22] Add well formedness test cases, fix handling of if nodes --- src/relax/analysis/well_formed.cc | 6 +- .../python/relax/test_analysis_well_formed.py | 74 +++++++++++++++++++ 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index c823ca415a..ea1580b695 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -190,12 +190,12 @@ class WellFormedChecker : public relax::ExprVisitor { void VisitExpr_(const IfNode* op) { this->VisitExpr(op->cond); - std::unordered_set previous_var_set_ = var_set_; - std::unordered_set previous_symbolic_var_set_ = - prim_expr_visitor_.symbolic_var_set_; auto true_seq = op->true_branch.as(); auto false_seq = op->false_branch.as(); if (true_seq && false_seq) { + std::unordered_set previous_var_set_ = var_set_; + std::unordered_set previous_symbolic_var_set_ = + prim_expr_visitor_.symbolic_var_set_; this->VisitSeqExpr(true_seq); var_set_ = previous_var_set_; prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index a9b84ad970..2d809c083a 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -191,6 +191,80 @@ def test_if(): assert not rx.analysis.well_formed(mod) +def test_if_non_seq_body(): + # Error: If node has a body that is not a seq node + if_node = rx.If(cond=cond, true_branch=x, false_branch=x) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", [m, n], type_anno), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod) + + # on the other hand, if they're wrapped in a seq node, it's fine + seq = rx.SeqExpr([], x) + new_if_node = rx.If(cond=cond, true_branch=seq, false_branch=seq) + new_blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", [m, n], type_anno), + new_if_node, + ) + ] + ) + ] + new_func = build_function(new_blocks) + new_mod = tvm.IRModule.from_expr(new_func) + assert rx.analysis.well_formed(new_mod) + + +def test_complex_seq_body(): + # Error: seq expr with a body that is not a leaf expression is not permitted + x = rx.Var("x", [], rx.DynTensorType(ndim=0, dtype="int32")) + y = rx.Var("y", [], rx.DynTensorType(ndim=0, dtype="int32")) + ret_type = rx.DynTensorType(ndim=0, dtype="float32") + ret_shape = rx.RuntimeDepShape() + func = rx.Function( + [x, y], + rx.SeqExpr([], rx.op.add(x, y)), + ret_type, + ret_shape, + ).with_attr("global_symbol", "foo") + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod) + + # but if the result is bound, then it's okay + z = rx.Var("z", [], rx.DynTensorType(ndim=0, dtype="int32")) + new_func = rx.Function( + [x, y], + rx.SeqExpr( + [ + rx.BindingBlock( + [ + rx.VarBinding( + var=z, + value=rx.op.add(x, y), + ) + ] + ) + ], + z, + ), + ret_type, + ret_shape, + ).with_attr("global_symbol", "foo") + new_mod = tvm.IRModule.from_expr(new_func) + assert rx.analysis.well_formed(new_mod) + + def test_ANF(): # Error: Nested Call gv0 = rx.Var("gv0", [m, n], type_anno) From 95cdded6f5a2683658d18132e6b68a776bf66ea4 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 18:47:15 -0500 Subject: [PATCH 08/22] Add new normalization tests --- tests/python/relax/test_transform.py | 97 ++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e1318d2132..eaebfa3536 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -589,5 +589,102 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert_structural_equal(mod, mod_post) +def test_normalize_seq_body(): + # a seq expression with a non-leaf body should bind the body to a var as well + x = relax.Var("x", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + y = relax.Var("y", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + seq = relax.SeqExpr([], relax.op.add(x, y)) + f = relax.Function( + [x, y], + seq, + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") + ) -> R.Tensor(ndim=0, dtype="int32"): + # normalization inserts a binding like this + z = R.add(x, y) + return z + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_func_body(): + # a function with a body that is not a seq expr should have it wrapped in a seq expr + x = relax.Var("x", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + y = relax.Var("y", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + f = relax.Function( + [x, y], + relax.op.add(x, y), + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") + ) -> R.Tensor(ndim=0, dtype="int32"): + # result will be a seq expr where the body is a var + z = R.add(x, y) + return z + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_if_branches(): + # an if node's branches must be seq exprs + x = relax.Var("x", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + y = relax.Var("y", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + # TODO(@relax-team): z has a shape of () and type of DynTensorType(ndim=0), + # but normalization fails to infer these even though it should + z = relax.Var("z") + cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(ndim=0, dtype="bool")) + plus = relax.op.add(x, y) + mult = relax.op.multiply(x, y) + if_node = relax.If(cond, plus, mult) + seq = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(z, if_node)])], z) + f = relax.Function( + [cond, x, y], + seq, + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + cond: R.Tensor((), dtype="bool"), + x: R.Tensor((), dtype="int32"), + y: R.Tensor((), dtype="int32"), + ) -> R.Tensor(ndim=0, dtype="int32"): + # the bodies of the branches will be seq exprs with a binding + if cond: + w = R.add(x, y) + z = w + else: + w = R.multiply(x, y) + z = w + return z + + assert_structural_equal(after_mod, Expected) + + if __name__ == "__main__": pytest.main([__file__]) From cacd586ea8dca92bc23cf2c9381c86262b978a30 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 18:48:07 -0500 Subject: [PATCH 09/22] Factor out leaf expr in tuple case --- src/relax/analysis/well_formed.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index ea1580b695..99e2f0e064 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -108,8 +108,7 @@ class WellFormedChecker : public relax::ExprVisitor { void VisitExpr_(const TupleNode* op) { for (size_t i = 0; i < op->fields.size(); i++) { Expr expr = op->fields[i]; - if (expr.as() || expr.as() || expr.as() || - expr.as() || expr.as() || expr.as()) { + if (IsLeafExpr(expr)) { this->VisitExpr(expr); } else { Malformed(Diagnostic::Error(expr->span) From 83d03dfea78af59738e95148df0056db3404e0cd Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 19:47:15 -0500 Subject: [PATCH 10/22] Fix AST printer test for new normalization --- tests/python/relax/test_ast_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 09a9b19d57..a1c2579f8c 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -422,7 +422,7 @@ def f() -> R.Shape: body = normalize(f).body assert isinstance(body, rx.SeqExpr) - call = body.body + call = body.blocks[-1].bindings[-1].value assert isinstance(call, rx.Call) arg = call.args[0] arg_str = strip_whitespace(dump_ast(arg)) From be3274122185ea6e7376f0cc0db41adcc072f8b6 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 19:47:51 -0500 Subject: [PATCH 11/22] Do not attempt to print shape_ for TupleGetItem --- python/tvm/relax/testing/ast_printer.py | 15 +++++++++------ tests/python/relax/test_ast_printer.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 0807950192..34cb44dd49 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -230,12 +230,15 @@ def visit_prim_expr_(self, prim_expr: PrimExpr) -> str: return self.build_ast_node("PrimExpr", value=f"`{str(prim_expr)}`") def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> str: - return self.build_expr( - op, - "TupleGetItem", - tuple_value=self.visit_expr(op.tuple_value), - index=str(op.index), - ) + # TupleGetItem is copied from Relay's AST, so it cannot have a Shape + # TODO(@relax-team): We should eventually assign a shape_ to TupleGetItem + fields = { + "tuple_value": self.visit_expr(op.tuple_value), + "index": str(op.index), + } + if op._checked_type_ and self.include_type_annotations: + fields["checked_type_"] = self.visit_type_(op.checked_type) + return self.build_ast_node("TupleGetItem", **fields) def visit_type_(self, type_node: relax.Type) -> str: """ diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index a1c2579f8c..f96e9c4241 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -453,5 +453,19 @@ def f(cond: R.Tensor((), dtype="bool")) -> R.Tensor((), dtype="int32"): assert "false_branch=SeqExpr(" in body_str +def test_tuple_get_item(): + @R.function + def f(x: R.Tuple(R.Tensor((), dtype="int32"))) -> R.Tensor((), dtype="int32"): + return x[0] + + body = normalize(f).body + assert isinstance(body, rx.SeqExpr) + body_str = strip_whitespace(dump_ast(body)) + + assert "TupleGetItem" in body_str + assert 'tuple_value=Var(name_hint="x"' in body_str + assert "index=0" in body_str + + if __name__ == "__main__": pytest.main([__file__]) From 06dcab6fbb5f9bf33527cf434c136110e5c6b76a Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 19:52:08 -0500 Subject: [PATCH 12/22] Check for and normalize nesting in TupleGetItem and If conditions --- src/relax/analysis/well_formed.cc | 16 ++- src/relax/ir/block_builder.cc | 4 +- .../python/relax/test_analysis_well_formed.py | 85 ++++++++++++++ tests/python/relax/test_transform.py | 105 ++++++++++++++++++ 4 files changed, 207 insertions(+), 3 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 99e2f0e064..ebf205eb6b 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -121,6 +121,15 @@ class WellFormedChecker : public relax::ExprVisitor { } } + void VisitExpr_(const TupleGetItemNode* op) { + if (IsLeafExpr(op->tuple)) { + this->VisitExpr(op->tuple); + } else { + Malformed(Diagnostic::Error(op->span) + << "The tuple value in a TupleGetItem node must be a leaf expression."); + } + } + void VisitExpr_(const VarNode* op) { Var var = GetRef(op); if (var_set_.count(var) == 0) { @@ -188,7 +197,12 @@ class WellFormedChecker : public relax::ExprVisitor { } void VisitExpr_(const IfNode* op) { - this->VisitExpr(op->cond); + if (IsLeafExpr(op->cond)) { + this->VisitExpr(op->cond); + } else { + Malformed(Diagnostic::Error(op->span) + << "The condition for an if node must be a leaf expression."); + } auto true_seq = op->true_branch.as(); auto false_seq = op->false_branch.as(); if (true_seq && false_seq) { diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 9a4c16b67b..6b22de5f82 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -274,7 +274,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } Expr VisitExpr_(const IfNode* op) final { - Expr new_cond = this->VisitExpr(op->cond); + Expr new_cond = this->Bind(op->cond); Expr new_true = this->VisitWithNewScope(op->true_branch); Expr new_false = this->VisitWithNewScope(op->false_branch); @@ -306,7 +306,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } Expr VisitExpr_(const TupleGetItemNode* op) final { - Expr new_tuple = this->VisitExpr(op->tuple); + Expr new_tuple = this->Bind(op->tuple); TupleGetItem node; if (new_tuple.same_as(op->tuple)) { node = GetRef(op); diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 2d809c083a..f3f2297b39 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -226,6 +226,91 @@ def test_if_non_seq_body(): assert rx.analysis.well_formed(new_mod) +def test_if_complex_condition(): + # Error: If condition must be a leaf expression + cond_tuple = rx.Tuple([cond]) + cond_idx = rx.TupleGetItem(cond_tuple, 0) + if_node = rx.If(cond_idx, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", [m, n], type_anno), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod) + + cond_var = rx.Var("q", [], bool_type_anno) + new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding(cond_var, cond_idx), + rx.VarBinding( + rx.Var("gv1", [m, n], type_anno), + new_if, + ), + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert rx.analysis.well_formed(mod) + + +def test_tuple_get_item_nested(): + # Error: The tuple value in tuple get item must be a leaf expression + nested_tup = rx.Var( + "t", + type_annotation=rx.TupleType( + [ + rx.TupleType( + [ + rx.DynTensorType(ndim=0, dtype="int32"), + ] + ) + ] + ), + ) + double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0) + ret_var = rx.Var("r", [], rx.DynTensorType(ndim=0, dtype="int32")) + f = rx.Function( + [nested_tup], + rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], ret_var), + ret_type=rx.DynTensorType(ndim=0, dtype="int32"), + ret_shape=rx.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + mod = tvm.IRModule.from_expr(f) + assert not rx.analysis.well_formed(mod) + + # okay with an intermediate binding + first_idx = rx.TupleGetItem(nested_tup, 0) + idx_var = rx.Var("v", type_annotation=rx.TupleType([rx.DynTensorType(ndim=0, dtype="int32")])) + second_idx = rx.TupleGetItem(idx_var, 0) + new_f = rx.Function( + [nested_tup], + rx.SeqExpr( + [ + rx.BindingBlock( + [rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var, second_idx)] + ) + ], + ret_var, + ), + ret_type=rx.DynTensorType(ndim=0, dtype="int32"), + ret_shape=rx.RuntimeDepShape(), + ) + new_f = new_f.with_attr("global_symbol", "new_f") + mod = tvm.IRModule.from_expr(new_f) + assert rx.analysis.well_formed(mod) + + def test_complex_seq_body(): # Error: seq expr with a body that is not a leaf expression is not permitted x = rx.Var("x", [], rx.DynTensorType(ndim=0, dtype="int32")) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index eaebfa3536..61682b4405 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -25,6 +25,7 @@ import tvm.script from tvm.script import tir as T, relax as R +from tvm.relax.testing import dump_ast def test_fma_rewrite(): @tvm.script.ir_module @@ -686,5 +687,109 @@ def f( assert_structural_equal(after_mod, Expected) +def test_normalize_if_condition(): + cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(0, "bool")) + x = relax.Var("x", [tir.IntImm("int64", 1)], type_annotation=relax.DynTensorType(1, "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # The condition is wrapped in a tuple and then indexed + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + relax.TupleGetItem(relax.Tuple([cond]), 0), + relax.op.add(x, x), + relax.op.multiply(x, x), + ), + ) + ] + ) + ], + y, + ), + ret_type=relax.DynTensorType(1, "float32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor(dtype="float32", ndim=1): + c = R.TupleGetItem(R.Tuple(cond), 0) + if c: + gv = R.add(x, x) + y = gv + else: + gv = R.multiply(x, x) + y = gv + return y + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_tuple_get_item(): + x = relax.Var("x", [], relax.DynTensorType(ndim=0, dtype="int32")) + f = relax.Function( + [x], + relax.TupleGetItem( + relax.TupleGetItem( + relax.Tuple([relax.Tuple([x])]), + 0, + ), + 0, + ), + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + # TODO: Revisit once we canonicalize SeqExprs (part of normalization?) + # Not using the parser this time because writing it out correctly results in + # *one* binding block, whereas the normalized version has *two* + idx_var = relax.Var( + "idx_var", + shape_annotation=relax.Tuple([relax.ShapeExpr([])]), + type_annotation=relax.TupleType([relax.DynTensorType(ndim=0, dtype="int32")]) + ) + ret_var = relax.Var("ret", [], relax.DynTensorType(ndim=0, dtype="int32")) + expected_f = relax.Function( + [x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + idx_var, relax.TupleGetItem(relax.Tuple([relax.Tuple([x])]), 0) + ) + ] + ), + relax.BindingBlock([relax.VarBinding(ret_var, relax.TupleGetItem(idx_var, 0))]), + ], + ret_var, + ), + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + expected_f = expected_f.with_attr("global_symbol", "f") + expected_mod = tvm.IRModule.from_expr(expected_f) + # apply normalization to fill in type and shape annotations (tedious otherwise) + final_mod = relax.transform.Normalize()(expected_mod) + + assert_structural_equal(after_mod, final_mod) + + if __name__ == "__main__": pytest.main([__file__]) From a4d554354e784f528a40af03a30389b3326fc043 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 22:09:25 -0500 Subject: [PATCH 13/22] Fix new well-formed test cases to also have checked types --- .../python/relax/test_analysis_well_formed.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index f3f2297b39..8c45298f50 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -21,7 +21,7 @@ m = tir.Var("m", "int64") n = tir.Var("n", "int64") -type_anno = rx.DynTensorType(ndim=2, dtype="float16") +type_anno = rx.DynTensorType(ndim=2, dtype="float32") bool_type_anno = rx.DynTensorType(ndim=0, dtype="bool") x = rx.Var("x", [m, n], type_anno) cond = rx.Var("cond", [], bool_type_anno) @@ -223,7 +223,9 @@ def test_if_non_seq_body(): ] new_func = build_function(new_blocks) new_mod = tvm.IRModule.from_expr(new_func) - assert rx.analysis.well_formed(new_mod) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized) def test_if_complex_condition(): @@ -260,7 +262,9 @@ def test_if_complex_condition(): ] func = build_function(blocks) mod = tvm.IRModule.from_expr(func) - assert rx.analysis.well_formed(mod) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) def test_tuple_get_item_nested(): @@ -308,14 +312,16 @@ def test_tuple_get_item_nested(): ) new_f = new_f.with_attr("global_symbol", "new_f") mod = tvm.IRModule.from_expr(new_f) - assert rx.analysis.well_formed(mod) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) def test_complex_seq_body(): # Error: seq expr with a body that is not a leaf expression is not permitted x = rx.Var("x", [], rx.DynTensorType(ndim=0, dtype="int32")) y = rx.Var("y", [], rx.DynTensorType(ndim=0, dtype="int32")) - ret_type = rx.DynTensorType(ndim=0, dtype="float32") + ret_type = rx.DynTensorType(ndim=0, dtype="int32") ret_shape = rx.RuntimeDepShape() func = rx.Function( [x, y], @@ -347,7 +353,9 @@ def test_complex_seq_body(): ret_shape, ).with_attr("global_symbol", "foo") new_mod = tvm.IRModule.from_expr(new_func) - assert rx.analysis.well_formed(new_mod) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized) def test_ANF(): From b36899c11f0a5044480441af8fccdd42453bbebd Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 22:18:30 -0500 Subject: [PATCH 14/22] Lint --- tests/python/relax/test_transform.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 61682b4405..7063c6cfdf 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -27,6 +27,7 @@ from tvm.relax.testing import dump_ast + def test_fma_rewrite(): @tvm.script.ir_module class Before: @@ -760,9 +761,9 @@ def test_normalize_tuple_get_item(): # Not using the parser this time because writing it out correctly results in # *one* binding block, whereas the normalized version has *two* idx_var = relax.Var( - "idx_var", + "idx_var", shape_annotation=relax.Tuple([relax.ShapeExpr([])]), - type_annotation=relax.TupleType([relax.DynTensorType(ndim=0, dtype="int32")]) + type_annotation=relax.TupleType([relax.DynTensorType(ndim=0, dtype="int32")]), ) ret_var = relax.Var("ret", [], relax.DynTensorType(ndim=0, dtype="int32")) expected_f = relax.Function( From 175e7f4572607bc83a630a92dd466c93f8d7a3ec Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 22:24:00 -0500 Subject: [PATCH 15/22] Fix whitespace --- include/tvm/relax/block_builder.h | 2 +- include/tvm/relax/utils.h | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index f2140cde93..3c132c926c 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -134,7 +134,7 @@ class BlockBuilderNode : public Object { bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); /*! - * \brief Resets the memo in the normalizer to prevent false hits when visiting + * \brief Resets the memo in the normalizer to prevent false hits when visiting * the same expression more than once. * Use if before visiting a given expression again. */ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 1ddefe7118..f781372f23 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -122,7 +122,6 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, bool permit_unknown_dtype = true); - /*! * \brief Check if the given expression is a "leaf" node for normalization purposes. * The following expressions are defined as leaf nodes: Var, Constant, ShapeExpr, @@ -132,13 +131,12 @@ TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, * values nested inside them are also leaves. * * \param expr The input expression - * - * \return True iff the input expression is a "leaf" node (a value allowed to appear + * + * \return True iff the input expression is a "leaf" node (a value allowed to appear * inline without being bound to a var during normalization). */ TVM_DLL bool IsLeafExpr(const Expr& expr); - } // namespace relax } // namespace tvm From b7441f82ce5c0ae42562ab49652a59ed87a9f7cd Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 22:29:53 -0500 Subject: [PATCH 16/22] Fix expr functor tests for new normalizer --- tests/python/relax/test_expr_functor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index f4159c837c..6af89439f6 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -385,6 +385,9 @@ def visit(f, expr): # skip normalize GlobalVar since it requires context IRModule to get the checked_type_ if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): expr = bb.normalize(expr) + print(dump_ast(visit(basic_mutator, expr))) + print() + print(dump_ast(expr)) assert_structural_equal(visit(basic_mutator, expr), expr) # check the output log and return value @@ -464,7 +467,7 @@ def test_if(): basic_check( if_node, "\n".join(["If", "\tVar", "\tVar", "\tVar"]), - "\n".join(["Var", "Var", "Var", "If"]), + "\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]), ) @@ -565,7 +568,7 @@ def test_function(): bindings = [relax.VarBinding(x, relax.const(1))] blocks = [relax.BindingBlock(bindings)] seq_expr = relax.SeqExpr(blocks, x) - ret_type = relax.DynTensorType(-1, "float32") + ret_type = relax.DynTensorType(1, "float32") ret_shape = relax.RuntimeDepShape() func = relax.Function([x], seq_expr, ret_type, ret_shape) basic_check( From f3d6df65b62f368a44775c04c79687846c0e4654 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 22:41:03 -0500 Subject: [PATCH 17/22] Fix management of var sets in well-formed --- src/relax/analysis/well_formed.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index ebf205eb6b..25b0218930 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -173,11 +173,11 @@ class WellFormedChecker : public relax::ExprVisitor { } if (auto seq = op->body.as()) { this->VisitSeqExpr(seq); - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_.clear(); } else { Malformed(Diagnostic::Error(op->span) << "Function bodies must be sequence expressions"); } + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_.clear(); } void VisitExpr_(const CallNode* op) { From d5e5f49c310dc9a99bf2f7beb2e001326317feb4 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 23:20:32 -0500 Subject: [PATCH 18/22] Remove debug prints --- tests/python/relax/test_expr_functor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index 6af89439f6..b3a6e04347 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -385,9 +385,6 @@ def visit(f, expr): # skip normalize GlobalVar since it requires context IRModule to get the checked_type_ if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): expr = bb.normalize(expr) - print(dump_ast(visit(basic_mutator, expr))) - print() - print(dump_ast(expr)) assert_structural_equal(visit(basic_mutator, expr), expr) # check the output log and return value From 9b0b3a5a9a288a625d39b1181012e1ad6305b0f5 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 2 Dec 2022 15:12:55 -0500 Subject: [PATCH 19/22] Fix analysis tests --- tests/python/relax/test_analysis.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 1ec0adcf05..d0e94de9c6 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -304,6 +304,7 @@ def test_derive_func_ret_shape_free(): class VarExample: @R.function def func(a: R.Tensor) -> R.Tensor: + # normalized into assigning R.add(a, a) to a var and returning it return R.add(a, a) @R.function @@ -322,8 +323,10 @@ def main(x: R.Tensor, y: R.Tensor) -> R.Tensor: def test_all_vars(): vars = all_vars(VarExample["func"]) - assert len(vars) == 1 + assert len(vars) == 2 assert vars[0].name_hint == "a" + # the body of the seq expr in the func body is a var + assert vars[1] == VarExample["func"].body.body var_names = var_name_set(all_vars(VarExample["main"])) assert var_names == {"x", "y", "z", "p", "q", "r", "s"} @@ -331,8 +334,10 @@ def test_all_vars(): def test_bound_vars(): vars = bound_vars(VarExample["func"]) - assert len(vars) == 1 + assert len(vars) == 2 assert vars[0].name_hint == "a" + # the body of the seq expr in the func body is a bound var + assert vars[1] == VarExample["func"].body.body # all the vars are bound var_names = var_name_set(bound_vars(VarExample["main"])) @@ -342,8 +347,10 @@ def test_bound_vars(): body_names = var_name_set(bound_vars(VarExample["main"].body)) assert body_names == {"z", "p", "q", "r", "s"} - # if the argument isn't bound, then nothing is - assert len(bound_vars(VarExample["func"].body)) == 0 + # only binding is in the (normalized) body + simple_body_vars = bound_vars(VarExample["func"].body) + assert len(simple_body_vars) == 1 + assert simple_body_vars[0] == VarExample["func"].body.body def test_free_vars(): From 6c3a8284edb2ab03f645b6191105c60f27a3c3bb Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 2 Dec 2022 15:15:44 -0500 Subject: [PATCH 20/22] Fix TVMScript parser test --- tests/python/relax/test_tvmscript_parser.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 51df903afa..d54b6dff28 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -668,16 +668,18 @@ def check_call(call, op, args): tvm.ir.assert_structural_equal(call.args, args) w_bind = ite.true_branch.blocks[0].bindings[0] - body = ite.true_branch.body + # the seq exprts in the branches are normalized to bind any call + # in the seq expr "body" to a var + y_bind = ite.true_branch.blocks[-1].bindings[-1] assert w_bind.var.name_hint == "w" check_call(w_bind.value, "relax.add", [x, x]) - check_call(body, "relax.multiply", [w_bind.var, w_bind.var]) + check_call(y_bind.value, "relax.multiply", [w_bind.var, w_bind.var]) w_bind = ite.false_branch.blocks[0].bindings[0] - body = ite.false_branch.body + y_bind = ite.false_branch.blocks[-1].bindings[-1] assert w_bind.var.name_hint == "w" check_call(w_bind.value, "relax.multiply", [x, x]) - check_call(body, "relax.add", [w_bind.var, w_bind.var]) + check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var]) def test_if_inside_dataflow(): From b3a4eab0f7aa3aab64b001a8bcd953cf95248a64 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 2 Dec 2022 15:23:27 -0500 Subject: [PATCH 21/22] Update comment detailing what the well-formed pass checks --- src/relax/analysis/well_formed.cc | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 25b0218930..9a1a8d3b85 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -33,9 +33,20 @@ * 6. SeqExpr only serves as function body, or in the true and * false branches in IfNode. * 7. The IR is in ANF: - * (a) No nested call - * (b) The fields of the Tuple can only be Var/DataflowVar/Constant/ - * ShapeExpr/RuntimeDepShape/Tuple + * (a) Expressions cannot contain nested complex expressions. + * Here are the expressions that may be nested inside other expressions: + * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, RuntimeDepShape, + * Op, Tuple (we call these "leaf" expressions). + * (b) The right-hand side of a binding may contain a non-leaf expression + * (where all expressions nested in it are leaf expressions), + * other than SeqExprs (see rule 6) + * (c) Exceptions: The body of a Function node and the true branch + * and false branch of If nodes *must* be SeqExprs. + * (d) Places where non-leaf expressions cannot appear: + * * The tuple_value field of TupleGetItem nodes + * * The cond field of If nodes + * * The op or args fields of Call nodes + * * Inside the fields of Tuple nodes * 8. Expr always has checked_type_ (with the exception of Op). */ #include From 20da86641ab04dad5878f56e185df9d6648ab44b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 2 Dec 2022 15:31:51 -0500 Subject: [PATCH 22/22] Also check the op argument to call nodes --- src/relax/analysis/well_formed.cc | 9 +++++++-- src/relax/ir/block_builder.cc | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 9a1a8d3b85..0bbb207e96 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -37,7 +37,7 @@ * Here are the expressions that may be nested inside other expressions: * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, RuntimeDepShape, * Op, Tuple (we call these "leaf" expressions). - * (b) The right-hand side of a binding may contain a non-leaf expression + * (b) The right-hand side of a binding may contain a non-leaf expression * (where all expressions nested in it are leaf expressions), * other than SeqExprs (see rule 6) * (c) Exceptions: The body of a Function node and the true branch @@ -90,7 +90,7 @@ class WellFormedChecker : public relax::ExprVisitor { } void VisitExpr(const Expr& expr) override { - if (!expr->checked_type_.defined()) { + if (!expr.as() && !expr->checked_type_.defined()) { Malformed(Diagnostic::Error(expr->span) << "The checked_type_ of Expr " << expr << " is nullptr."); } @@ -192,6 +192,11 @@ class WellFormedChecker : public relax::ExprVisitor { } void VisitExpr_(const CallNode* op) { + if (IsLeafExpr(op->op)) { + this->VisitExpr(op->op); + } else { + Malformed(Diagnostic::Error(op->span) << "The called expression must be a leaf expression"); + } for (size_t i = 0; i < op->args.size(); i++) { Expr arg = op->args[i]; if (IsLeafExpr(arg)) { diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 6b22de5f82..a2c3b7eab1 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -158,7 +158,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } Expr VisitExpr_(const CallNode* op) final { - Expr new_op = this->VisitExpr(op->op); + Expr new_op = this->Bind(op->op); bool unchanged = new_op.same_as(op->op); Array new_args;