diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 409399713e..3c132c926c 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -133,6 +133,15 @@ class BlockBuilderNode : public Object { */ bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); + /*! + * \brief Resets the memo in the normalizer to prevent false hits when visiting + * the same expression more than once. + * Use if before visiting a given expression again. + */ + // TODO(@relax-team): Memoization should be tied to the scope tracking to prevent memo hits + // when the associated var is out of scope + void ResetMemo(); + /*! * \brief Convert an expression to A-normal form, and try to eagerly infer types and shapes. * \param expr The input expression. diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 6bfd0d0daa..f781372f23 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -122,6 +122,21 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, bool permit_unknown_dtype = true); +/*! + * \brief Check if the given expression is a "leaf" node for normalization purposes. + * The following expressions are defined as leaf nodes: Var, Constant, ShapeExpr, + * GlobalVar, RuntimeDepShape, Op, ExternFunc, and Tuple. + * Tuples are included in this list mainly for convenience in grouping operator arguments. + * *Note*: Since tuples can contain nested expressions, it is necessary to ensure that + * values nested inside them are also leaves. + * + * \param expr The input expression + * + * \return True iff the input expression is a "leaf" node (a value allowed to appear + * inline without being bound to a var during normalization). + */ +TVM_DLL bool IsLeafExpr(const Expr& expr); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index b7bdbc92d6..34cb44dd49 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -208,13 +208,16 @@ def visit_seq_expr_(self, op: relax.SeqExpr) -> str: ) def visit_if_(self, op: relax.If) -> str: - return self.build_expr( - op, - "If", - cond=self.visit_expr(op.cond), - true_branch=self.visit_expr(op.true_branch), - false_branch=self.visit_expr(op.false_branch), - ) + # if is copied from Relay's AST, so it cannot have a Shape + # TODO(@relax-team): We should eventually assign a shape_ to If + fields = { + "cond": self.visit_expr(op.cond), + "true_branch": self.visit_expr(op.true_branch), + "false_branch": self.visit_expr(op.false_branch), + } + if op._checked_type_ and self.include_type_annotations: + fields["checked_type_"] = self.visit_type_(op.checked_type) + return self.build_ast_node("If", **fields) def visit_op_(self, op: tvm.ir.Op) -> str: # TODO: List other attributes? @@ -227,12 +230,15 @@ def visit_prim_expr_(self, prim_expr: PrimExpr) -> str: return self.build_ast_node("PrimExpr", value=f"`{str(prim_expr)}`") def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> str: - return self.build_expr( - op, - "TupleGetItem", - tuple_value=self.visit_expr(op.tuple_value), - index=str(op.index), - ) + # TupleGetItem is copied from Relay's AST, so it cannot have a Shape + # TODO(@relax-team): We should eventually assign a shape_ to TupleGetItem + fields = { + "tuple_value": self.visit_expr(op.tuple_value), + "index": str(op.index), + } + if op._checked_type_ and self.include_type_annotations: + fields["checked_type_"] = self.visit_type_(op.checked_type) + return self.build_ast_node("TupleGetItem", **fields) def visit_type_(self, type_node: relax.Type) -> str: """ diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 9773493131..0bbb207e96 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -33,14 +33,26 @@ * 6. SeqExpr only serves as function body, or in the true and * false branches in IfNode. * 7. The IR is in ANF: - * (a) No nested call - * (b) The fields of the Tuple can only be Var/DataflowVar/Constant/ - * ShapeExpr/RuntimeDepShape/Tuple + * (a) Expressions cannot contain nested complex expressions. + * Here are the expressions that may be nested inside other expressions: + * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, RuntimeDepShape, + * Op, Tuple (we call these "leaf" expressions). + * (b) The right-hand side of a binding may contain a non-leaf expression + * (where all expressions nested in it are leaf expressions), + * other than SeqExprs (see rule 6) + * (c) Exceptions: The body of a Function node and the true branch + * and false branch of If nodes *must* be SeqExprs. + * (d) Places where non-leaf expressions cannot appear: + * * The tuple_value field of TupleGetItem nodes + * * The cond field of If nodes + * * The op or args fields of Call nodes + * * Inside the fields of Tuple nodes * 8. Expr always has checked_type_ (with the exception of Op). */ #include #include #include +#include #include #include @@ -78,7 +90,7 @@ class WellFormedChecker : public relax::ExprVisitor { } void VisitExpr(const Expr& expr) override { - if (!expr->checked_type_.defined()) { + if (!expr.as() && !expr->checked_type_.defined()) { Malformed(Diagnostic::Error(expr->span) << "The checked_type_ of Expr " << expr << " is nullptr."); } @@ -107,8 +119,7 @@ class WellFormedChecker : public relax::ExprVisitor { void VisitExpr_(const TupleNode* op) { for (size_t i = 0; i < op->fields.size(); i++) { Expr expr = op->fields[i]; - if (expr.as() || expr.as() || expr.as() || - expr.as() || expr.as() || expr.as()) { + if (IsLeafExpr(expr)) { this->VisitExpr(expr); } else { Malformed(Diagnostic::Error(expr->span) @@ -121,6 +132,15 @@ class WellFormedChecker : public relax::ExprVisitor { } } + void VisitExpr_(const TupleGetItemNode* op) { + if (IsLeafExpr(op->tuple)) { + this->VisitExpr(op->tuple); + } else { + Malformed(Diagnostic::Error(op->span) + << "The tuple value in a TupleGetItem node must be a leaf expression."); + } + } + void VisitExpr_(const VarNode* op) { Var var = GetRef(op); if (var_set_.count(var) == 0) { @@ -162,17 +182,24 @@ class WellFormedChecker : public relax::ExprVisitor { this->VisitVarDef(param); } - this->VisitBody(op->body); + if (auto seq = op->body.as()) { + this->VisitSeqExpr(seq); + } else { + Malformed(Diagnostic::Error(op->span) << "Function bodies must be sequence expressions"); + } var_set_ = previous_var_set_; prim_expr_visitor_.symbolic_var_set_.clear(); } void VisitExpr_(const CallNode* op) { + if (IsLeafExpr(op->op)) { + this->VisitExpr(op->op); + } else { + Malformed(Diagnostic::Error(op->span) << "The called expression must be a leaf expression"); + } for (size_t i = 0; i < op->args.size(); i++) { Expr arg = op->args[i]; - if (arg.as() || arg.as() || arg.as() || - arg.as() || arg.as() || arg.as() || - arg.as()) { + if (IsLeafExpr(arg)) { this->VisitExpr(arg); } else { Malformed(Diagnostic::Error(arg->span) @@ -186,16 +213,27 @@ class WellFormedChecker : public relax::ExprVisitor { } void VisitExpr_(const IfNode* op) { - this->VisitExpr(op->cond); - std::unordered_set previous_var_set_ = var_set_; - std::unordered_set previous_symbolic_var_set_ = - prim_expr_visitor_.symbolic_var_set_; - this->VisitBody(op->true_branch); - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; - this->VisitBody(op->false_branch); - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + if (IsLeafExpr(op->cond)) { + this->VisitExpr(op->cond); + } else { + Malformed(Diagnostic::Error(op->span) + << "The condition for an if node must be a leaf expression."); + } + auto true_seq = op->true_branch.as(); + auto false_seq = op->false_branch.as(); + if (true_seq && false_seq) { + std::unordered_set previous_var_set_ = var_set_; + std::unordered_set previous_symbolic_var_set_ = + prim_expr_visitor_.symbolic_var_set_; + this->VisitSeqExpr(true_seq); + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + this->VisitSeqExpr(false_seq); + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + } else { + Malformed(Diagnostic::Error(op->span) << "If node branches must be seq exprs"); + } } void VisitExpr_(const ShapeExprNode* op) { @@ -221,15 +259,10 @@ class WellFormedChecker : public relax::ExprVisitor { for (BindingBlock block : op->blocks) { this->VisitBindingBlock(block); } - this->VisitExpr(op->body); - } - - void VisitBody(const Expr& expr) { - if (const SeqExprNode* seq_expr = expr.as()) { - this->VisitSeqExpr(seq_expr); - } else { - this->VisitExpr(expr); + if (!IsLeafExpr(op->body)) { + Malformed(Diagnostic::Error(op->span) << "SeqExpr bodies must be leaf expressions."); } + this->VisitExpr(op->body); } void VisitBinding_(const VarBindingNode* binding) { diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 82349e7ef2..a2c3b7eab1 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -157,7 +158,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } Expr VisitExpr_(const CallNode* op) final { - Expr new_op = this->VisitExpr(op->op); + Expr new_op = this->Bind(op->op); bool unchanged = new_op.same_as(op->op); Array new_args; @@ -213,7 +214,8 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } builder_->BeginBindingBlock(); - Expr new_body = this->VisitExpr(op->body); + // the body may not be a leaf expression, so check for that + Expr new_body = this->Bind(op->body); unchanged &= new_body.same_as(op->body); BindingBlock prologue = builder_->EndBlock(); @@ -227,8 +229,9 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { SeqExpr seq_expr; if (unchanged) { seq_expr = GetRef(op); + } else { + seq_expr = SeqExpr(new_blocks, new_body); } - seq_expr = SeqExpr(new_blocks, new_body); // only do shape/type inference if the SeqExpr does not have shape/type if (seq_expr->shape_ && seq_expr->checked_type_.defined()) { @@ -271,7 +274,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } Expr VisitExpr_(const IfNode* op) final { - Expr new_cond = this->VisitExpr(op->cond); + Expr new_cond = this->Bind(op->cond); Expr new_true = this->VisitWithNewScope(op->true_branch); Expr new_false = this->VisitWithNewScope(op->false_branch); @@ -303,7 +306,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } Expr VisitExpr_(const TupleGetItemNode* op) final { - Expr new_tuple = this->VisitExpr(op->tuple); + Expr new_tuple = this->Bind(op->tuple); TupleGetItem node; if (new_tuple.same_as(op->tuple)) { node = GetRef(op); @@ -389,6 +392,8 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { return new_block; } + void ResetMemo() { expr_memo_.Reset(); } + private: /*! * \brief Memoization map for expressions using Id for equality of variables. @@ -418,6 +423,11 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } } + void Reset() { + var_memo_ = std::unordered_map(); + expr_memo_ = std::unordered_map(); + } + private: std::unordered_map var_memo_; std::unordered_map expr_memo_; @@ -631,26 +641,27 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { return NullOpt; } - static bool IsLeaf(const Expr& expr) { - // NB: tuples are treated as leaf nodes for ergonomics - return expr.as() || expr.as() || expr.as() || - expr.as() || expr.as() || - expr.as() || expr.as() || expr.as(); - } - Expr VisitWithNewScope(const Expr& expr) { builder_->BeginBindingBlock(); Expr post = this->VisitExpr(expr); BindingBlock prologue = builder_->EndBlock(); + // "New scopes" (function bodies, if/else clauses) must be wrapped in seq exprs. + // Don't wrap if it's already a seq and there are no bindings to add + if (post.as() && prologue->bindings.empty()) { + return post; + } + Array bindings; if (!prologue->bindings.empty()) { - post = SeqExpr({prologue}, post); + bindings.push_back(prologue); } - return post; + auto seq = SeqExpr(bindings, post); + // visit in case post is not a leaf and we need to bind it too + return this->VisitExpr(seq); } Expr Bind(const Expr& expr) { Expr post = this->VisitExpr(expr); - if (!IsLeaf(post)) { + if (!IsLeafExpr(post)) { post = builder_->Emit(post); expr_memo_.Set(expr, post); } @@ -823,6 +834,8 @@ bool BlockBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { return false; } +void BlockBuilderNode::ResetMemo() { normalizer_->ResetMemo(); } + // TODO(@altanh, @yuchen): need an internal Emit_ that doesn't call normalize Expr BlockBuilderNode::Normalize(const Expr& expr) { Expr normalized = normalizer_->VisitExpr(expr); diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 75a882de45..7e316a4d9d 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -77,5 +77,12 @@ bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unkn return correct_dtype && correct_rank; } +bool IsLeafExpr(const Expr& expr) { + // NB: tuples are treated as leaf nodes for ergonomics + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as() || + expr.as() || expr.as(); +} + } // namespace relax } // namespace tvm diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 451e01d317..f250e84781 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -54,6 +54,9 @@ void FunctionFrameNode::ExitWithScope() { // Step 1: Create the function. CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " "`return` to return an Expr"; + // Normalizing a second time could result in false hits to the memo + // TODO(relax-team): We should fix the memoization not to require this + this->block_builder->ResetMemo(); Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); Expr func_shape = ret_shape.value_or(tvm::relax::RuntimeDepShape()); if (func_shape->IsInstance()) { diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 1ec0adcf05..d0e94de9c6 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -304,6 +304,7 @@ def test_derive_func_ret_shape_free(): class VarExample: @R.function def func(a: R.Tensor) -> R.Tensor: + # normalized into assigning R.add(a, a) to a var and returning it return R.add(a, a) @R.function @@ -322,8 +323,10 @@ def main(x: R.Tensor, y: R.Tensor) -> R.Tensor: def test_all_vars(): vars = all_vars(VarExample["func"]) - assert len(vars) == 1 + assert len(vars) == 2 assert vars[0].name_hint == "a" + # the body of the seq expr in the func body is a var + assert vars[1] == VarExample["func"].body.body var_names = var_name_set(all_vars(VarExample["main"])) assert var_names == {"x", "y", "z", "p", "q", "r", "s"} @@ -331,8 +334,10 @@ def test_all_vars(): def test_bound_vars(): vars = bound_vars(VarExample["func"]) - assert len(vars) == 1 + assert len(vars) == 2 assert vars[0].name_hint == "a" + # the body of the seq expr in the func body is a bound var + assert vars[1] == VarExample["func"].body.body # all the vars are bound var_names = var_name_set(bound_vars(VarExample["main"])) @@ -342,8 +347,10 @@ def test_bound_vars(): body_names = var_name_set(bound_vars(VarExample["main"].body)) assert body_names == {"z", "p", "q", "r", "s"} - # if the argument isn't bound, then nothing is - assert len(bound_vars(VarExample["func"].body)) == 0 + # only binding is in the (normalized) body + simple_body_vars = bound_vars(VarExample["func"].body) + assert len(simple_body_vars) == 1 + assert simple_body_vars[0] == VarExample["func"].body.body def test_free_vars(): diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index a9b84ad970..8c45298f50 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -21,7 +21,7 @@ m = tir.Var("m", "int64") n = tir.Var("n", "int64") -type_anno = rx.DynTensorType(ndim=2, dtype="float16") +type_anno = rx.DynTensorType(ndim=2, dtype="float32") bool_type_anno = rx.DynTensorType(ndim=0, dtype="bool") x = rx.Var("x", [m, n], type_anno) cond = rx.Var("cond", [], bool_type_anno) @@ -191,6 +191,173 @@ def test_if(): assert not rx.analysis.well_formed(mod) +def test_if_non_seq_body(): + # Error: If node has a body that is not a seq node + if_node = rx.If(cond=cond, true_branch=x, false_branch=x) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", [m, n], type_anno), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod) + + # on the other hand, if they're wrapped in a seq node, it's fine + seq = rx.SeqExpr([], x) + new_if_node = rx.If(cond=cond, true_branch=seq, false_branch=seq) + new_blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", [m, n], type_anno), + new_if_node, + ) + ] + ) + ] + new_func = build_function(new_blocks) + new_mod = tvm.IRModule.from_expr(new_func) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized) + + +def test_if_complex_condition(): + # Error: If condition must be a leaf expression + cond_tuple = rx.Tuple([cond]) + cond_idx = rx.TupleGetItem(cond_tuple, 0) + if_node = rx.If(cond_idx, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", [m, n], type_anno), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod) + + cond_var = rx.Var("q", [], bool_type_anno) + new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding(cond_var, cond_idx), + rx.VarBinding( + rx.Var("gv1", [m, n], type_anno), + new_if, + ), + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) + + +def test_tuple_get_item_nested(): + # Error: The tuple value in tuple get item must be a leaf expression + nested_tup = rx.Var( + "t", + type_annotation=rx.TupleType( + [ + rx.TupleType( + [ + rx.DynTensorType(ndim=0, dtype="int32"), + ] + ) + ] + ), + ) + double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0) + ret_var = rx.Var("r", [], rx.DynTensorType(ndim=0, dtype="int32")) + f = rx.Function( + [nested_tup], + rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], ret_var), + ret_type=rx.DynTensorType(ndim=0, dtype="int32"), + ret_shape=rx.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + mod = tvm.IRModule.from_expr(f) + assert not rx.analysis.well_formed(mod) + + # okay with an intermediate binding + first_idx = rx.TupleGetItem(nested_tup, 0) + idx_var = rx.Var("v", type_annotation=rx.TupleType([rx.DynTensorType(ndim=0, dtype="int32")])) + second_idx = rx.TupleGetItem(idx_var, 0) + new_f = rx.Function( + [nested_tup], + rx.SeqExpr( + [ + rx.BindingBlock( + [rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var, second_idx)] + ) + ], + ret_var, + ), + ret_type=rx.DynTensorType(ndim=0, dtype="int32"), + ret_shape=rx.RuntimeDepShape(), + ) + new_f = new_f.with_attr("global_symbol", "new_f") + mod = tvm.IRModule.from_expr(new_f) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) + + +def test_complex_seq_body(): + # Error: seq expr with a body that is not a leaf expression is not permitted + x = rx.Var("x", [], rx.DynTensorType(ndim=0, dtype="int32")) + y = rx.Var("y", [], rx.DynTensorType(ndim=0, dtype="int32")) + ret_type = rx.DynTensorType(ndim=0, dtype="int32") + ret_shape = rx.RuntimeDepShape() + func = rx.Function( + [x, y], + rx.SeqExpr([], rx.op.add(x, y)), + ret_type, + ret_shape, + ).with_attr("global_symbol", "foo") + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod) + + # but if the result is bound, then it's okay + z = rx.Var("z", [], rx.DynTensorType(ndim=0, dtype="int32")) + new_func = rx.Function( + [x, y], + rx.SeqExpr( + [ + rx.BindingBlock( + [ + rx.VarBinding( + var=z, + value=rx.op.add(x, y), + ) + ] + ) + ], + z, + ), + ret_type, + ret_shape, + ).with_attr("global_symbol", "foo") + new_mod = tvm.IRModule.from_expr(new_func) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized) + + def test_ANF(): # Error: Nested Call gv0 = rx.Var("gv0", [m, n], type_anno) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index f4582a1723..f96e9c4241 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -422,7 +422,7 @@ def f() -> R.Shape: body = normalize(f).body assert isinstance(body, rx.SeqExpr) - call = body.body + call = body.blocks[-1].bindings[-1].value assert isinstance(call, rx.Call) arg = call.args[0] arg_str = strip_whitespace(dump_ast(arg)) @@ -435,5 +435,37 @@ def f() -> R.Shape: assert type_str in call_str +def test_if(): + @R.function + def f(cond: R.Tensor((), dtype="bool")) -> R.Tensor((), dtype="int32"): + if cond: + x = R.const(1) + else: + x = R.const(2) + return x + + body = normalize(f).body + assert isinstance(body, rx.SeqExpr) + body_str = strip_whitespace(dump_ast(body)) + # we expect both branches to be seq exprs + assert "If" in body_str + assert "true_branch=SeqExpr(" in body_str + assert "false_branch=SeqExpr(" in body_str + + +def test_tuple_get_item(): + @R.function + def f(x: R.Tuple(R.Tensor((), dtype="int32"))) -> R.Tensor((), dtype="int32"): + return x[0] + + body = normalize(f).body + assert isinstance(body, rx.SeqExpr) + body_str = strip_whitespace(dump_ast(body)) + + assert "TupleGetItem" in body_str + assert 'tuple_value=Var(name_hint="x"' in body_str + assert "index=0" in body_str + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index f4159c837c..b3a6e04347 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -464,7 +464,7 @@ def test_if(): basic_check( if_node, "\n".join(["If", "\tVar", "\tVar", "\tVar"]), - "\n".join(["Var", "Var", "Var", "If"]), + "\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]), ) @@ -565,7 +565,7 @@ def test_function(): bindings = [relax.VarBinding(x, relax.const(1))] blocks = [relax.BindingBlock(bindings)] seq_expr = relax.SeqExpr(blocks, x) - ret_type = relax.DynTensorType(-1, "float32") + ret_type = relax.DynTensorType(1, "float32") ret_shape = relax.RuntimeDepShape() func = relax.Function([x], seq_expr, ret_type, ret_shape) basic_check( diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 935cb4acc3..c593a430a7 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -238,13 +238,17 @@ def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor: body = ite.true_branch.body assert w_bind.var.name_hint == "w" check_call(w_bind.value, "relax.add", [x, x]) - check_call(body, "relax.multiply", [w_bind.var, w_bind.var]) + body_bind = ite.true_branch.blocks[1].bindings[0] + check_call(body_bind.value, "relax.multiply", [w_bind.var, w_bind.var]) + assert ite.true_branch.body == body_bind.var w_bind = ite.false_branch.blocks[0].bindings[0] body = ite.false_branch.body assert w_bind.var.name_hint == "w" check_call(w_bind.value, "relax.multiply", [x, x]) - check_call(body, "relax.add", [w_bind.var, w_bind.var]) + body_bind = ite.false_branch.blocks[1].bindings[0] + check_call(body_bind.value, "relax.add", [w_bind.var, w_bind.var]) + assert ite.false_branch.body == body_bind.var def test_func_type_annotation_fail(): @@ -766,9 +770,15 @@ def k(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.T func_j = my_module[var_j] func_k = my_module[var_k] - assert len(func_f.body.blocks) == 0 - assert func_f.body.body.op == var_g - assert func_g.body.body.args[0] == var_my_matmul + assert len(func_f.body.blocks) == 1 + assert len(func_f.body.blocks[0].bindings) == 1 + f_call_var = func_f.body.blocks[0].bindings[0].var + assert func_f.body.blocks[0].bindings[0].value.op == var_g + assert func_f.body.body == f_call_var + + g_call_var = func_g.body.blocks[0].bindings[-1].var + assert func_g.body.blocks[0].bindings[-1].value.args[0] == var_my_matmul + assert func_g.body.body == g_call_var gv_bind = func_j.body.blocks[0].bindings[0] assert gv_bind.value.checked_type.ndim == 2 diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e1318d2132..7063c6cfdf 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -25,6 +25,8 @@ import tvm.script from tvm.script import tir as T, relax as R +from tvm.relax.testing import dump_ast + def test_fma_rewrite(): @tvm.script.ir_module @@ -589,5 +591,206 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert_structural_equal(mod, mod_post) +def test_normalize_seq_body(): + # a seq expression with a non-leaf body should bind the body to a var as well + x = relax.Var("x", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + y = relax.Var("y", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + seq = relax.SeqExpr([], relax.op.add(x, y)) + f = relax.Function( + [x, y], + seq, + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") + ) -> R.Tensor(ndim=0, dtype="int32"): + # normalization inserts a binding like this + z = R.add(x, y) + return z + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_func_body(): + # a function with a body that is not a seq expr should have it wrapped in a seq expr + x = relax.Var("x", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + y = relax.Var("y", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + f = relax.Function( + [x, y], + relax.op.add(x, y), + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") + ) -> R.Tensor(ndim=0, dtype="int32"): + # result will be a seq expr where the body is a var + z = R.add(x, y) + return z + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_if_branches(): + # an if node's branches must be seq exprs + x = relax.Var("x", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + y = relax.Var("y", [], type_annotation=relax.DynTensorType(ndim=0, dtype="int32")) + # TODO(@relax-team): z has a shape of () and type of DynTensorType(ndim=0), + # but normalization fails to infer these even though it should + z = relax.Var("z") + cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(ndim=0, dtype="bool")) + plus = relax.op.add(x, y) + mult = relax.op.multiply(x, y) + if_node = relax.If(cond, plus, mult) + seq = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(z, if_node)])], z) + f = relax.Function( + [cond, x, y], + seq, + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + cond: R.Tensor((), dtype="bool"), + x: R.Tensor((), dtype="int32"), + y: R.Tensor((), dtype="int32"), + ) -> R.Tensor(ndim=0, dtype="int32"): + # the bodies of the branches will be seq exprs with a binding + if cond: + w = R.add(x, y) + z = w + else: + w = R.multiply(x, y) + z = w + return z + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_if_condition(): + cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(0, "bool")) + x = relax.Var("x", [tir.IntImm("int64", 1)], type_annotation=relax.DynTensorType(1, "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # The condition is wrapped in a tuple and then indexed + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + relax.TupleGetItem(relax.Tuple([cond]), 0), + relax.op.add(x, x), + relax.op.multiply(x, x), + ), + ) + ] + ) + ], + y, + ), + ret_type=relax.DynTensorType(1, "float32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor(dtype="float32", ndim=1): + c = R.TupleGetItem(R.Tuple(cond), 0) + if c: + gv = R.add(x, x) + y = gv + else: + gv = R.multiply(x, x) + y = gv + return y + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_tuple_get_item(): + x = relax.Var("x", [], relax.DynTensorType(ndim=0, dtype="int32")) + f = relax.Function( + [x], + relax.TupleGetItem( + relax.TupleGetItem( + relax.Tuple([relax.Tuple([x])]), + 0, + ), + 0, + ), + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + # TODO: Revisit once we canonicalize SeqExprs (part of normalization?) + # Not using the parser this time because writing it out correctly results in + # *one* binding block, whereas the normalized version has *two* + idx_var = relax.Var( + "idx_var", + shape_annotation=relax.Tuple([relax.ShapeExpr([])]), + type_annotation=relax.TupleType([relax.DynTensorType(ndim=0, dtype="int32")]), + ) + ret_var = relax.Var("ret", [], relax.DynTensorType(ndim=0, dtype="int32")) + expected_f = relax.Function( + [x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + idx_var, relax.TupleGetItem(relax.Tuple([relax.Tuple([x])]), 0) + ) + ] + ), + relax.BindingBlock([relax.VarBinding(ret_var, relax.TupleGetItem(idx_var, 0))]), + ], + ret_var, + ), + ret_type=relax.DynTensorType(ndim=0, dtype="int32"), + ret_shape=relax.RuntimeDepShape(), + ) + expected_f = expected_f.with_attr("global_symbol", "f") + expected_mod = tvm.IRModule.from_expr(expected_f) + # apply normalization to fill in type and shape annotations (tedious otherwise) + final_mod = relax.transform.Normalize()(expected_mod) + + assert_structural_equal(after_mod, final_mod) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 51df903afa..d54b6dff28 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -668,16 +668,18 @@ def check_call(call, op, args): tvm.ir.assert_structural_equal(call.args, args) w_bind = ite.true_branch.blocks[0].bindings[0] - body = ite.true_branch.body + # the seq exprts in the branches are normalized to bind any call + # in the seq expr "body" to a var + y_bind = ite.true_branch.blocks[-1].bindings[-1] assert w_bind.var.name_hint == "w" check_call(w_bind.value, "relax.add", [x, x]) - check_call(body, "relax.multiply", [w_bind.var, w_bind.var]) + check_call(y_bind.value, "relax.multiply", [w_bind.var, w_bind.var]) w_bind = ite.false_branch.blocks[0].bindings[0] - body = ite.false_branch.body + y_bind = ite.false_branch.blocks[-1].bindings[-1] assert w_bind.var.name_hint == "w" check_call(w_bind.value, "relax.multiply", [x, x]) - check_call(body, "relax.add", [w_bind.var, w_bind.var]) + check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var]) def test_if_inside_dataflow():