Skip to content

Commit

Permalink
add fatal
Browse files Browse the repository at this point in the history
lint

lint

lint

do

make completeness check an error

lint

drop changes

address review comment

lint

address review

add asf header

fix
  • Loading branch information
MarisaKirisame committed Sep 26, 2019
1 parent 01e5393 commit 116bdf7
Show file tree
Hide file tree
Showing 28 changed files with 258 additions and 79 deletions.
26 changes: 26 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,32 @@ class RefWriteNode : public ExprNode {

RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);

/*! \brief A fatal error has occurred. Stop all execution and report with a message. */
class Fatal;
class FatalNode : public ExprNode {
public:
/*! \brief The Message. */
std::string msg;
Type type_annotation;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("msg", &msg);
v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Fatal make(std::string msg, Type type_annotation);

static constexpr const char* _type_key = "relay.Fatal";
TVM_DECLARE_NODE_TYPE_INFO(FatalNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Fatal, FatalNode, Expr);

/*! \brief the fatal message for case unhandled in match. */
TVM_DLL std::string NoMatchMsg();

/*!
* \brief Base class of the temporary expression.
*
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FatalNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -140,6 +141,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FatalNode);
return vtable;
}
};
Expand Down Expand Up @@ -170,6 +172,7 @@ class ExprVisitor
void VisitExpr_(const RefWriteNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
void VisitExpr_(const FatalNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);
Expand Down Expand Up @@ -212,6 +215,7 @@ class ExprMutator
Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;
Expr VisitExpr_(const FatalNode* op) override;

/*!
* \brief Used to visit the types inside of expressions.
Expand Down
5 changes: 3 additions & 2 deletions include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ enum Feature : int {
fRefWrite = 12,
fConstructor = 13,
fMatch = 14,
fFatal = 15,
/*! \brief Whether any non-atom fragment of the program is shared, making the program a graph. */
fGraph = 15,
fGraph = 16,
/*! \brief Whether there is local fixpoint in the program. */
fLetRec = 16
fLetRec = 17
};

constexpr size_t feature_count = 17;
Expand Down
18 changes: 18 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
[pytest]
xfail_strict=true
4 changes: 4 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
RefCreate = expr.RefCreate
RefRead = expr.RefRead
RefWrite = expr.RefWrite
Fatal = expr.Fatal

# ADT
PatternWildcard = adt.PatternWildcard
Expand Down Expand Up @@ -142,3 +143,6 @@

# Feature
Feature = feature.Feature

# Fatal Messages
NO_MATCH_MSG = expr.NO_MATCH_MSG
21 changes: 21 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def __init__(self, tuple_value, index):
@register_relay_node
class RefCreate(Expr):
"""Create a new reference from initial value.
Parameters
----------
value: tvm.relay.Expr
Expand All @@ -407,6 +408,7 @@ def __init__(self, value):
@register_relay_node
class RefRead(Expr):
"""Get the value inside the reference.
Parameters
----------
ref: tvm.relay.Expr
Expand All @@ -421,6 +423,7 @@ class RefWrite(Expr):
"""
Update the value inside the reference.
The whole expression will evaluate to an empty tuple.
Parameters
----------
ref: tvm.relay.Expr
Expand All @@ -432,6 +435,24 @@ def __init__(self, ref, value):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value)


@register_relay_node
class Fatal(Expr):
"""
Abort the execution with a fatal error message.
Parameters
----------
msg: String
The message
type_annotation: Optional[tvm.relay.Type]
The type of Fatal. Leave none to be inferred.
"""
def __init__(self, msg, ty=None):
self.__init_handle_by_constructor__(_make.Fatal, msg, ty)

NO_MATCH_MSG = _expr.NoMatchMsg()

class TempExpr(Expr):
"""Baseclass of all TempExpr.
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""The expression functor of Relay."""

from .expr import Function, Call, Let, Var, GlobalVar
from .expr import If, Tuple, TupleGetItem, Constant
from .expr import If, Tuple, TupleGetItem, Constant, Fatal
from .expr import RefCreate, RefRead, RefWrite
from .adt import Constructor, Match, Clause
from .op import Op
Expand Down Expand Up @@ -69,6 +69,8 @@ def visit(self, expr):
res = self.visit_constructor(expr)
elif isinstance(expr, Match):
res = self.visit_match(expr)
elif isinstance(expr, Fatal):
res = self.visit_fatal(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))

Expand Down Expand Up @@ -124,6 +126,9 @@ def visit_constructor(self, _):
def visit_match(self, _):
raise NotImplementedError()

def visit_fatal(self, _):
raise NotImplementedError()


class ExprVisitor(ExprFunctor):
"""
Expand Down Expand Up @@ -186,6 +191,9 @@ def visit_match(self, m):
for c in m.clauses:
self.visit(c.rhs)

def visit_fatal(self, r):
pass


class ExprMutator(ExprFunctor):
"""
Expand Down Expand Up @@ -249,7 +257,9 @@ def visit_constructor(self, con):
return con

def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses])
return Match(self.visit(m.data),
[Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
m.complete)

def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))
Expand All @@ -259,3 +269,6 @@ def visit_ref_write(self, r):

def visit_ref_read(self, r):
return RefRead(self.visit(r.ref))

def visit_fatal(self, r):
return Fatal(r.msg, r.type_annotation)
5 changes: 3 additions & 2 deletions python/tvm/relay/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class Feature(IntEnum):
fRefWrite = 12
fConstructor = 13
fMatch = 14
fFatal = 15
""" Whether any non-atom fragment of the program is shared, making the program a graph. """
fGraph = 15
fGraph = 16
""" Whether there is local fixpoint in the program. """
fLetRec = 16
fLetRec = 17
7 changes: 7 additions & 0 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ def visit_global_var(self, gvar: Expr):
return (Name(gvar.name_hint, Load()), [])


def visit_fatal(self, fatal: Expr):
thunk_name = self.generate_function_name('_fatal_thunk')
thunk = self.create_def(thunk_name, [], [
ast.Raise(ast.Call(Name("Exception", Load()), [ast.Str(fatal.msg)], []), None)])
return (self.create_call(thunk_name, []), [thunk])


def visit_let(self, letexp: Expr):
# To properly account for scoping and ensure that the entire node produces an expression,
# we translate the let binding as a function that we call with the value we intend to bind.
Expand Down
5 changes: 5 additions & 0 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,11 @@ class Interpreter :
return Value();
}

Value VisitExpr_(const FatalNode* op) final {
LOG(FATAL) << "fatal message recieved: " << op->msg;
return Value();
}

bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final {
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
CHECK(cvn) << "need to be a constructor for match";
Expand Down
8 changes: 8 additions & 0 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,14 @@ class AlphaEqualHandler:
return false;
}

bool VisitExpr_(const FatalNode* lhs, const Expr& other) final {
if (const FatalNode* rhs = other.as<FatalNode>()) {
return lhs->msg == rhs->msg;
} else {
return false;
}
}

bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs);
}
Expand Down
25 changes: 25 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,31 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});


Fatal FatalNode::make(std::string msg, Type type_annotation) {
NodePtr<FatalNode> n = make_node<FatalNode>();
n->msg = std::move(msg);
n->type_annotation = std::move(type_annotation);
return Fatal(n);
}

TVM_REGISTER_NODE_TYPE(FatalNode);

TVM_REGISTER_API("relay._make.Fatal")
.set_body_typed(FatalNode::make);

std::string NoMatchMsg() {
return "No case Match";
}

TVM_REGISTER_API("relay._expr.NoMatchMsg")
.set_body_typed(NoMatchMsg);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FatalNode>([](const FatalNode* node, tvm::IRPrinter* p) {
p->stream << "FatalNode(" << node->msg << ")";
});

TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
}

Expr ExprMutator::VisitExpr_(const FatalNode* f) {
return GetRef<Expr>(f);
}

Clause ExprMutator::VisitClause(const Clause& c) {
Pattern p = VisitPattern(c->lhs);
return ClauseNode::make(p, VisitExpr(c->rhs));
Expand Down Expand Up @@ -318,6 +322,8 @@ void ExprVisitor::VisitExpr_(const MatchNode* op) {
}
}

void ExprVisitor::VisitExpr_(const FatalNode* op) { }

void ExprVisitor::VisitClause(const Clause& op) {
this->VisitPattern(op->lhs);
this->VisitExpr(op->rhs);
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ class RelayHashHandler:
return hash;
}

size_t VisitExpr_(const FatalNode* fn) final {
size_t hash = std::hash<std::string>()(FatalNode::_type_key);
hash = Combine(hash, std::hash<std::string>()(fn->msg));
return hash;
}

size_t VisitType_(const TypeCallNode* tcn) final {
size_t hash = std::hash<std::string>()(TypeCallNode::_type_key);
hash = Combine(hash, TypeHash(tcn->func));
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,12 @@ class PrettyPrinter :
return printed_pattern;
}

Doc VisitExpr_(const FatalNode* op) final {
Doc doc;
doc << "Fatal(" << PrintString(op->msg) << ")";
return doc;
}

Doc VisitPattern_(const PatternConstructorNode* p) final {
Doc doc;
doc << p->constructor->name_hint;
Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/dependency_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
void VisitExpr_(const OpNode* o) final { }

void VisitExpr_(const ConstructorNode* c) final { }

void VisitExpr_(const FatalNode* c) final { }
};

DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) {
Expand Down
1 change: 1 addition & 0 deletions src/relay/pass/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ FeatureSet DetectFeature(const Expr& expr) {
DETECT_DEFAULT_CONSTRUCT(RefWrite)
DETECT_DEFAULT_CONSTRUCT(Constructor)
DETECT_DEFAULT_CONSTRUCT(Match)
DETECT_DEFAULT_CONSTRUCT(Fatal)
#undef DETECT_DEFAULT_CONSTRUCT
} fd;
fd(expr);
Expand Down
12 changes: 11 additions & 1 deletion src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
namespace tvm {
namespace relay {

struct EmitFatal : dmlc::Error {
explicit EmitFatal(const std::string& msg) : dmlc::Error(msg) { }
};

/*!
* \brief LetList allow you to transform expression into variables, so you can copy them around.
* one can insert into the LetList by calling Push, and wrap an expression with bindings with Get.
Expand Down Expand Up @@ -134,7 +138,13 @@ class LetList {
template<typename F>
static Expr With(F&& f) {
LetList ll;
return ll.Get(f(&ll));
Expr ret;
try {
ret = f(&ll);
} catch (const EmitFatal& ef) {
ret = FatalNode::make(ef.what(), Type());
}
return ll.Get(ret);
}

static Expr Let(const Expr& e, const std::function<Expr(const Var&)>& f) {
Expand Down
Loading

0 comments on commit 116bdf7

Please sign in to comment.