Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[IR][Bugfix] Improvements to the normalizer and well-formed checker #288

Merged
merged 22 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bf883e0
Factor out repeated leaf node check
slyubomirsky Nov 30, 2022
6c28418
Fix printing for if nodes (no shape_)
slyubomirsky Nov 30, 2022
04bd8f0
Require scope bodies to be SeqExpr, also normalize body of SeqExpr
slyubomirsky Nov 30, 2022
c972ac7
Reset the memo to prevent false sharing during second visit
slyubomirsky Dec 1, 2022
b700e84
Update parsing tests for new normalizer behavior
slyubomirsky Dec 1, 2022
726dbe6
Require seq expr bodies to be leaves and require func/branch bodies t…
slyubomirsky Dec 1, 2022
8443430
Add well formedness test cases, fix handling of if nodes
slyubomirsky Dec 1, 2022
95cdded
Add new normalization tests
slyubomirsky Dec 1, 2022
cacd586
Factor out leaf expr in tuple case
slyubomirsky Dec 1, 2022
83d03df
Fix AST printer test for new normalization
slyubomirsky Dec 2, 2022
be32741
Do not attempt to print shape_ for TupleGetItem
slyubomirsky Dec 2, 2022
06dcab6
Check for and normalize nesting in TupleGetItem and If conditions
slyubomirsky Dec 2, 2022
a4d5543
Fix new well-formed test cases to also have checked types
slyubomirsky Dec 2, 2022
b36899c
Lint
slyubomirsky Dec 2, 2022
175e7f4
Fix whitespace
slyubomirsky Dec 2, 2022
b7441f8
Fix expr functor tests for new normalizer
slyubomirsky Dec 2, 2022
f3d6df6
Fix management of var sets in well-formed
slyubomirsky Dec 2, 2022
d5e5f49
Remove debug prints
slyubomirsky Dec 2, 2022
9b0b3a5
Fix analysis tests
slyubomirsky Dec 2, 2022
6c3a828
Fix TVMScript parser test
slyubomirsky Dec 2, 2022
b3a4eab
Update comment detailing what the well-formed pass checks
slyubomirsky Dec 2, 2022
20da866
Also check the op argument to call nodes
slyubomirsky Dec 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ class BlockBuilderNode : public Object {
*/
bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs);

/*!
* \brief Resets the memo in the normalizer to prevent false hits when visiting
* the same expression more than once.
* Use if before visiting a given expression again.
*/
// TODO(@relax-team): Memoization should be tied to the scope tracking to prevent memo hits
// when the associated var is out of scope
void ResetMemo();

/*!
* \brief Convert an expression to A-normal form, and try to eagerly infer types and shapes.
* \param expr The input expression.
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
bool permit_unknown_dtype = true);

/*!
* \brief Check if the given expression is a "leaf" node for normalization purposes.
* The following expressions are defined as leaf nodes: Var, Constant, ShapeExpr,
* GlobalVar, RuntimeDepShape, Op, ExternFunc, and Tuple.
* Tuples are included in this list mainly for convenience in grouping operator arguments.
* *Note*: Since tuples can contain nested expressions, it is necessary to ensure that
* values nested inside them are also leaves.
*
* \param expr The input expression
*
* \return True iff the input expression is a "leaf" node (a value allowed to appear
* inline without being bound to a var during normalization).
*/
TVM_DLL bool IsLeafExpr(const Expr& expr);

} // namespace relax
} // namespace tvm

Expand Down
32 changes: 19 additions & 13 deletions python/tvm/relax/testing/ast_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,16 @@ def visit_seq_expr_(self, op: relax.SeqExpr) -> str:
)

def visit_if_(self, op: relax.If) -> str:
return self.build_expr(
op,
"If",
cond=self.visit_expr(op.cond),
true_branch=self.visit_expr(op.true_branch),
false_branch=self.visit_expr(op.false_branch),
)
# if is copied from Relay's AST, so it cannot have a Shape
# TODO(@relax-team): We should eventually assign a shape_ to If
fields = {
"cond": self.visit_expr(op.cond),
"true_branch": self.visit_expr(op.true_branch),
"false_branch": self.visit_expr(op.false_branch),
}
if op._checked_type_ and self.include_type_annotations:
fields["checked_type_"] = self.visit_type_(op.checked_type)
return self.build_ast_node("If", **fields)
Comment on lines -211 to +220
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junrushao @YuchenJin This is a good reason to decouple the Relax AST from the Relay one 😅 (see also the change I needed for TupleGetItem).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree we should decouple the AST from Relay someday!

Currently we directly added a shape_ field to RelayExpr: https://github.com/tlc-pack/relax/blob/relax/include/tvm/ir/expr.h#L377, so before the AST decoupling, we can still rely on that the AST nodes borrowed from Relay have a shape_ field.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Odd, not sure why I was getting failures then. It would throw an exception saying that there was no shape_ defined.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example:

>>> import tvm
>>> from tvm import relax
>>> v = relax.Var("v")
>>> tgi = relax.TupleGetItem(v, 0)
>>> tgi.shape_
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/slyubomirsky/code/relax/python/tvm/runtime/object.py", line 67, in __getattr__
    raise AttributeError("%s has no attribute %s" % (str(type(self)), name)) from None
AttributeError: <class 'tvm.relay.expr.TupleGetItem'> has no attribute shape_. Did you mean: 'shape'?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because we forgot to add shape_ to its VisitAttrs (just noticed now 😄 ): https://github.com/tlc-pack/relax/blob/relax/include/tvm/relay/expr.h#L565, so the shape_ attribute is not reflected in the Python class side. See CallNode has its shape_ visited: https://github.com/tlc-pack/relax/blob/relax/include/tvm/relay/expr.h#L334.

Given we will decouple the Relax AST from Relay AST soon (#292), we can leave it as is, and we fix it in a batch.

Copy link
Collaborator Author

@slyubomirsky slyubomirsky Dec 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, great. Happy to update the printer alongside those changes, just let me know


def visit_op_(self, op: tvm.ir.Op) -> str:
# TODO: List other attributes?
Expand All @@ -227,12 +230,15 @@ def visit_prim_expr_(self, prim_expr: PrimExpr) -> str:
return self.build_ast_node("PrimExpr", value=f"`{str(prim_expr)}`")

def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> str:
return self.build_expr(
op,
"TupleGetItem",
tuple_value=self.visit_expr(op.tuple_value),
index=str(op.index),
)
# TupleGetItem is copied from Relay's AST, so it cannot have a Shape
# TODO(@relax-team): We should eventually assign a shape_ to TupleGetItem
fields = {
"tuple_value": self.visit_expr(op.tuple_value),
"index": str(op.index),
}
if op._checked_type_ and self.include_type_annotations:
fields["checked_type_"] = self.visit_type_(op.checked_type)
return self.build_ast_node("TupleGetItem", **fields)

def visit_type_(self, type_node: relax.Type) -> str:
"""
Expand Down
89 changes: 61 additions & 28 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,26 @@
* 6. SeqExpr only serves as function body, or in the true and
* false branches in IfNode.
* 7. The IR is in ANF:
* (a) No nested call
* (b) The fields of the Tuple can only be Var/DataflowVar/Constant/
* ShapeExpr/RuntimeDepShape/Tuple
* (a) Expressions cannot contain nested complex expressions.
* Here are the expressions that may be nested inside other expressions:
* Var, DataflowVar, GlobalVar, Constant, ShapeExpr, RuntimeDepShape,
* Op, Tuple (we call these "leaf" expressions).
* (b) The right-hand side of a binding may contain a non-leaf expression
* (where all expressions nested in it are leaf expressions),
* other than SeqExprs (see rule 6)
* (c) Exceptions: The body of a Function node and the true branch
* and false branch of If nodes *must* be SeqExprs.
* (d) Places where non-leaf expressions cannot appear:
* * The tuple_value field of TupleGetItem nodes
* * The cond field of If nodes
* * The op or args fields of Call nodes
* * Inside the fields of Tuple nodes
* 8. Expr always has checked_type_ (with the exception of Op).
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/utils.h>
#include <tvm/tir/expr_functor.h>

#include <unordered_set>
Expand Down Expand Up @@ -78,7 +90,7 @@ class WellFormedChecker : public relax::ExprVisitor {
}

void VisitExpr(const Expr& expr) override {
if (!expr->checked_type_.defined()) {
if (!expr.as<OpNode>() && !expr->checked_type_.defined()) {
Malformed(Diagnostic::Error(expr->span)
<< "The checked_type_ of Expr " << expr << " is nullptr.");
}
Expand Down Expand Up @@ -107,8 +119,7 @@ class WellFormedChecker : public relax::ExprVisitor {
void VisitExpr_(const TupleNode* op) {
for (size_t i = 0; i < op->fields.size(); i++) {
Expr expr = op->fields[i];
if (expr.as<VarNode>() || expr.as<DataflowVarNode>() || expr.as<ShapeExprNode>() ||
expr.as<RuntimeDepShapeNode>() || expr.as<ConstantNode>() || expr.as<TupleNode>()) {
if (IsLeafExpr(expr)) {
this->VisitExpr(expr);
} else {
Malformed(Diagnostic::Error(expr->span)
Expand All @@ -121,6 +132,15 @@ class WellFormedChecker : public relax::ExprVisitor {
}
}

void VisitExpr_(const TupleGetItemNode* op) {
if (IsLeafExpr(op->tuple)) {
this->VisitExpr(op->tuple);
} else {
Malformed(Diagnostic::Error(op->span)
<< "The tuple value in a TupleGetItem node must be a leaf expression.");
}
}

void VisitExpr_(const VarNode* op) {
Var var = GetRef<Var>(op);
if (var_set_.count(var) == 0) {
Expand Down Expand Up @@ -162,17 +182,24 @@ class WellFormedChecker : public relax::ExprVisitor {

this->VisitVarDef(param);
}
this->VisitBody(op->body);
if (auto seq = op->body.as<SeqExprNode>()) {
this->VisitSeqExpr(seq);
} else {
Malformed(Diagnostic::Error(op->span) << "Function bodies must be sequence expressions");
}
var_set_ = previous_var_set_;
prim_expr_visitor_.symbolic_var_set_.clear();
}

void VisitExpr_(const CallNode* op) {
if (IsLeafExpr(op->op)) {
this->VisitExpr(op->op);
} else {
Malformed(Diagnostic::Error(op->span) << "The called expression must be a leaf expression");
}
for (size_t i = 0; i < op->args.size(); i++) {
Expr arg = op->args[i];
if (arg.as<GlobalVarNode>() || arg.as<ExternFuncNode>() || arg.as<TupleNode>() ||
arg.as<ShapeExprNode>() || arg.as<VarNode>() || arg.as<DataflowVarNode>() ||
arg.as<ConstantNode>()) {
if (IsLeafExpr(arg)) {
this->VisitExpr(arg);
} else {
Malformed(Diagnostic::Error(arg->span)
Expand All @@ -186,16 +213,27 @@ class WellFormedChecker : public relax::ExprVisitor {
}

void VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set_ = var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> previous_symbolic_var_set_ =
prim_expr_visitor_.symbolic_var_set_;
this->VisitBody(op->true_branch);
var_set_ = previous_var_set_;
prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_;
this->VisitBody(op->false_branch);
var_set_ = previous_var_set_;
prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_;
if (IsLeafExpr(op->cond)) {
this->VisitExpr(op->cond);
} else {
Malformed(Diagnostic::Error(op->span)
<< "The condition for an if node must be a leaf expression.");
}
auto true_seq = op->true_branch.as<SeqExprNode>();
auto false_seq = op->false_branch.as<SeqExprNode>();
if (true_seq && false_seq) {
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set_ = var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> previous_symbolic_var_set_ =
prim_expr_visitor_.symbolic_var_set_;
this->VisitSeqExpr(true_seq);
var_set_ = previous_var_set_;
prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_;
this->VisitSeqExpr(false_seq);
var_set_ = previous_var_set_;
prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_;
} else {
Malformed(Diagnostic::Error(op->span) << "If node branches must be seq exprs");
}
}

void VisitExpr_(const ShapeExprNode* op) {
Expand All @@ -221,15 +259,10 @@ class WellFormedChecker : public relax::ExprVisitor {
for (BindingBlock block : op->blocks) {
this->VisitBindingBlock(block);
}
this->VisitExpr(op->body);
}

void VisitBody(const Expr& expr) {
if (const SeqExprNode* seq_expr = expr.as<SeqExprNode>()) {
this->VisitSeqExpr(seq_expr);
} else {
this->VisitExpr(expr);
if (!IsLeafExpr(op->body)) {
Malformed(Diagnostic::Error(op->span) << "SeqExpr bodies must be leaf expressions.");
}
this->VisitExpr(op->body);
}

void VisitBinding_(const VarBindingNode* binding) {
Expand Down
43 changes: 28 additions & 15 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
#include <tvm/relax/type_analysis.h>
#include <tvm/relax/utils.h>
#include <tvm/relay/op.h>
#include <tvm/tir/function.h>

Expand Down Expand Up @@ -157,7 +158,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
}

Expr VisitExpr_(const CallNode* op) final {
Expr new_op = this->VisitExpr(op->op);
Expr new_op = this->Bind(op->op);
bool unchanged = new_op.same_as(op->op);

Array<Expr> new_args;
Expand Down Expand Up @@ -213,7 +214,8 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
}

builder_->BeginBindingBlock();
Expr new_body = this->VisitExpr(op->body);
// the body may not be a leaf expression, so check for that
Expr new_body = this->Bind(op->body);
unchanged &= new_body.same_as(op->body);
BindingBlock prologue = builder_->EndBlock();

Expand All @@ -227,8 +229,9 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
SeqExpr seq_expr;
if (unchanged) {
seq_expr = GetRef<SeqExpr>(op);
} else {
seq_expr = SeqExpr(new_blocks, new_body);
}
seq_expr = SeqExpr(new_blocks, new_body);

// only do shape/type inference if the SeqExpr does not have shape/type
if (seq_expr->shape_ && seq_expr->checked_type_.defined()) {
Expand Down Expand Up @@ -271,7 +274,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
}

Expr VisitExpr_(const IfNode* op) final {
Expr new_cond = this->VisitExpr(op->cond);
Expr new_cond = this->Bind(op->cond);
Expr new_true = this->VisitWithNewScope(op->true_branch);
Expr new_false = this->VisitWithNewScope(op->false_branch);

Expand Down Expand Up @@ -303,7 +306,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
}

Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr new_tuple = this->VisitExpr(op->tuple);
Expr new_tuple = this->Bind(op->tuple);
TupleGetItem node;
if (new_tuple.same_as(op->tuple)) {
node = GetRef<TupleGetItem>(op);
Expand Down Expand Up @@ -389,6 +392,8 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
return new_block;
}

void ResetMemo() { expr_memo_.Reset(); }

private:
/*!
* \brief Memoization map for expressions using Id for equality of variables.
Expand Down Expand Up @@ -418,6 +423,11 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
}
}

void Reset() {
var_memo_ = std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual>();
expr_memo_ = std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>();
}

private:
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;
Expand Down Expand Up @@ -631,26 +641,27 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
return NullOpt;
}

static bool IsLeaf(const Expr& expr) {
// NB: tuples are treated as leaf nodes for ergonomics
return expr.as<VarNode>() || expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<ShapeExprNode>() || expr.as<RuntimeDepShapeNode>() ||
expr.as<ExternFuncNode>() || expr.as<OpNode>() || expr.as<TupleNode>();
}

Expr VisitWithNewScope(const Expr& expr) {
builder_->BeginBindingBlock();
Expr post = this->VisitExpr(expr);
BindingBlock prologue = builder_->EndBlock();
// "New scopes" (function bodies, if/else clauses) must be wrapped in seq exprs.
// Don't wrap if it's already a seq and there are no bindings to add
if (post.as<SeqExprNode>() && prologue->bindings.empty()) {
return post;
}
Array<BindingBlock> bindings;
if (!prologue->bindings.empty()) {
post = SeqExpr({prologue}, post);
bindings.push_back(prologue);
}
return post;
auto seq = SeqExpr(bindings, post);
// visit in case post is not a leaf and we need to bind it too
return this->VisitExpr(seq);
}

Expr Bind(const Expr& expr) {
Expr post = this->VisitExpr(expr);
if (!IsLeaf(post)) {
if (!IsLeafExpr(post)) {
post = builder_->Emit(post);
expr_memo_.Set(expr, post);
}
Expand Down Expand Up @@ -823,6 +834,8 @@ bool BlockBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) {
return false;
}

void BlockBuilderNode::ResetMemo() { normalizer_->ResetMemo(); }

// TODO(@altanh, @yuchen): need an internal Emit_ that doesn't call normalize
Expr BlockBuilderNode::Normalize(const Expr& expr) {
Expr normalized = normalizer_->VisitExpr(expr);
Expand Down
7 changes: 7 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,12 @@ bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unkn
return correct_dtype && correct_rank;
}

bool IsLeafExpr(const Expr& expr) {
// NB: tuples are treated as leaf nodes for ergonomics
return expr.as<VarNode>() || expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<ShapeExprNode>() || expr.as<RuntimeDepShapeNode>() || expr.as<ExternFuncNode>() ||
expr.as<OpNode>() || expr.as<TupleNode>();
}

} // namespace relax
} // namespace tvm
3 changes: 3 additions & 0 deletions src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ void FunctionFrameNode::ExitWithScope() {
// Step 1: Create the function.
CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use "
"`return` to return an Expr";
// Normalizing a second time could result in false hits to the memo
// TODO(relax-team): We should fix the memoization not to require this
this->block_builder->ResetMemo();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I had to include this seemingly strange step after getting some bizarre errors when normalizing SeqExprs once I changed the normalizer to use Bind on the body. This was apparently because the parser was normalizing the same expression more than once. On the first visit, Bind would register the SeqExpr body in the memo and insert a call. On the second visit, the call to VisitExpr would check the memo and retrieve the var, but no binding would be inserted, resulting in an AST that had a var with no binding. It took a lot of debugging to figure out that this was due to a false hit in the memo.

Manually resetting the memo like this is not a good solution, IMO. I think it might make more sense to restructure the memo in the normalizer to be tied to the scope, since the block builder is already tracking the scope. It should not be possible for the memo to return a var that is not currently in scope.

(Debugged thanks to the help of @yongwww and @YuchenJin)

Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value()));
Expr func_shape = ret_shape.value_or(tvm::relax::RuntimeDepShape());
if (func_shape->IsInstance<tvm::relax::RuntimeDepShapeNode>()) {
Expand Down
Loading