Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Relax IR Parser #6

Merged
merged 20 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 46 additions & 49 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace relax {

using Expr = RelayExpr;
using ExprNode = RelayExprNode;
using relay::Id;
using relay::Call;
using relay::Id;
using relay::Tuple;
using relay::TupleGetItem;

Expand All @@ -53,8 +53,7 @@ class ShapeExprNode : public ExprNode {
}

bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const {
return equal(values, other->values) &&
equal(checked_type_, other->checked_type_) &&
return equal(values, other->values) && equal(checked_type_, other->checked_type_) &&
equal(shape_, other->shape_);
}

Expand All @@ -72,15 +71,15 @@ class ShapeExprNode : public ExprNode {

class ShapeExpr : public Expr {
public:
TVM_DLL ShapeExpr(Array<PrimExpr> values);
TVM_DLL explicit ShapeExpr(Array<PrimExpr> values, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode);
};


/*! \brief The variable class for all Relax bindings. */
class VarNode : public ExprNode {
public:
/*! \brief The identifier of the variable, is used for comparing stable equality across transformations. */
/*! \brief The identifier of the variable, which is used for comparing stable equality across
* transformations. */
Id vid;
/*! \brief The type annotation, used by binding sites and parameter declarations. */
runtime::Optional<Type> type_annotation;
Expand All @@ -97,11 +96,9 @@ class VarNode : public ExprNode {
}

bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
return equal(vid, other->vid) &&
equal(type_annotation, other->type_annotation) &&
return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) &&
// Do we use the analysis information in equality?
equal(checked_type_, other->checked_type_) &&
equal(shape_, other->shape_);
equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -120,16 +117,12 @@ class VarNode : public ExprNode {

class Var : public Expr {
public:
TVM_DLL Var(String name_hint,
runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation,
Span span = Span())
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}

TVM_DLL Var(Id vid,
runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation,
Span span = Span());
TVM_DLL explicit Var(String name_hint, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span())
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}

TVM_DLL explicit Var(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
};

Expand All @@ -147,10 +140,8 @@ class DataflowVarNode : public VarNode {
}

bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const {
return equal(vid, other->vid) &&
equal(type_annotation, other->type_annotation) &&
equal(shape_, other->shape_) &&
equal(checked_type_, other->checked_type_);
return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) &&
equal(shape_, other->shape_) && equal(checked_type_, other->checked_type_);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -168,15 +159,22 @@ class DataflowVarNode : public VarNode {

class DataflowVar : public Var {
public:
using Var::Var; // inherit constructors from Var
TVM_DLL explicit DataflowVar(String name_hint, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span())
: DataflowVar(Id(name_hint), shape_annotation, type_annotation, span) {}

TVM_DLL explicit DataflowVar(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode);
};


/*! \brief The base class of a variable binding in Relax. */
class BindingNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {}
mutable Span span;

void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; }
void SHashReduce(SHashReducer hash_reduce) const {}

Expand All @@ -188,10 +186,10 @@ class BindingNode : public Object {

class Binding : public ObjectRef {
public:
TVM_DLL explicit Binding(Span span);
TVM_DEFINE_OBJECT_REF_METHODS(Binding, ObjectRef, BindingNode);
};


/*! \brief Symbolic shape match, binds the variables of the LHS with the rhs. */
class MatchShape;
class MatchShapeNode : public BindingNode {
Expand All @@ -202,6 +200,7 @@ class MatchShapeNode : public BindingNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
Expand All @@ -221,7 +220,7 @@ class MatchShapeNode : public BindingNode {

class MatchShape : public Binding {
public:
TVM_DLL MatchShape(Array<PrimExpr> pattern, Expr value);
TVM_DLL explicit MatchShape(Array<PrimExpr> pattern, Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
};

Expand All @@ -234,6 +233,7 @@ class VarBindingNode : public BindingNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const {
Expand All @@ -251,23 +251,28 @@ class VarBindingNode : public BindingNode {

class VarBinding : public Binding {
public:
TVM_DLL VarBinding(Var var, Expr value);
TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode);
};


class BindingBlock;

class BindingBlockNode : public Object {
public:
mutable Span span;
Array<Binding> bindings;

void VisitAttrs(AttrVisitor* v) {
v->Visit("span", &span);
v->Visit("bindings", &bindings);
}

bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const {
return equal(bindings, other->bindings);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }

static constexpr const char* _type_key = "relax.expr.BindingBlock";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand All @@ -276,21 +281,17 @@ class BindingBlockNode : public Object {

class BindingBlock : public ObjectRef {
public:
TVM_DLL BindingBlock(Array<Binding> bindings);
TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode);
};


class DataflowBlock;
class DataflowBlockNode : public BindingBlockNode {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("bindings", &bindings);
}
bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const {
return equal(bindings, other->bindings);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }

static constexpr const char* _type_key = "relax.expr.DataflowBlock";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand All @@ -299,7 +300,7 @@ class DataflowBlockNode : public BindingBlockNode {

class DataflowBlock : public BindingBlock {
public:
TVM_DLL DataflowBlock(Array<Binding> bindings);
TVM_DLL explicit DataflowBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode);
};

Expand Down Expand Up @@ -340,11 +341,10 @@ class SeqExprNode : public ExprNode {

class SeqExpr : public Expr {
public:
TVM_DLL SeqExpr(Array<BindingBlock> blocks, Expr body);
TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode);
};


/*! \brief A Relax function, eventually to replace the current Relay function definition. */
class FunctionNode : public BaseFuncNode {
public:
Expand Down Expand Up @@ -372,8 +372,7 @@ class FunctionNode : public BaseFuncNode {

bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal.DefEqual(params, other->params) &&
equal(body, other->body) &&
return equal.DefEqual(params, other->params) && equal(body, other->body) &&
equal(ret_type, other->ret_type) && equal(checked_type_, other->checked_type_) &&
equal(shape_, other->shape_);
}
Expand All @@ -396,12 +395,11 @@ class FunctionNode : public BaseFuncNode {

class Function : public Expr {
public:
TVM_DLL Function(runtime::Optional<GlobalVar> name, Array<Var> params,
Expr body, Type ret_type);
TVM_DLL explicit Function(runtime::Optional<GlobalVar> name, Array<Var> params, Expr body,
Type ret_type, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
};


/*! \brief The extern function, which can represent packed function. */
class ExternFuncNode : public BaseFuncNode {
public:
Expand All @@ -410,15 +408,14 @@ class ExternFuncNode : public BaseFuncNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("global_symbol", &global_symbol);
v->Visit("span", &span);
}

bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const {
return equal(global_symbol, other->global_symbol);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(global_symbol);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(global_symbol); }

static constexpr const char* _type_key = "relax.expr.ExternFunc";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand All @@ -428,7 +425,7 @@ class ExternFuncNode : public BaseFuncNode {

class ExternFunc : public Expr {
public:
TVM_DLL ExternFunc(String global_symbol);
TVM_DLL ExternFunc(String global_symbol, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode);
};

Expand Down
43 changes: 31 additions & 12 deletions include/tvm/relax/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ namespace relax {

class ShapeTypeNode : public TypeNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}

bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
return true;
Expand All @@ -53,16 +56,9 @@ class ShapeTypeNode : public TypeNode {

class ShapeType : public Type {
public:
explicit ShapeType();
explicit ShapeType(runtime::ObjectPtr<runtime::Object> n) : Type(n) {}
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(ShapeType);
const ShapeTypeNode* operator->() const {
return static_cast<const ShapeTypeNode*>(data_.get());
}
const ShapeTypeNode* get() const {
return operator->();
}
using ContainerType = ShapeTypeNode;
TVM_DLL ShapeType(Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode);
};

class DynTensorTypeNode : public BaseTensorTypeNode {
Expand Down Expand Up @@ -108,11 +104,34 @@ class DynTensorType : public Type {
* \param shape The shape of the tensor.
* \param dtype The runtime dtype of the tensor's elements.
*/
TVM_DLL DynTensorType(int rank, DataType dtype);
TVM_DLL DynTensorType(int rank, DataType dtype, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode);
};

class DimTypeNode : public TypeNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}

bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const {
return true;
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }

static constexpr const char* _type_key = "relax.DimType";
TVM_DECLARE_FINAL_OBJECT_INFO(DimTypeNode, TypeNode);
};

class DimType : public Type {
public:
TVM_DLL DimType(Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DimType, Type, DimTypeNode);
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_TYPE_H_
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class TupleNode : public ExprNode {
v->Visit("fields", &fields);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
v->Visit("shape_", &shape_);
}

bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm
import tvm._ffi

from . import Span
from .base import Node
from . import _ffi_api

Expand Down Expand Up @@ -166,8 +167,8 @@ class TupleType(Type):
The fields in the tuple
"""

def __init__(self, fields):
self.__init_handle_by_constructor__(_ffi_api.TupleType, fields)
def __init__(self, fields, span: Span = None):
self.__init_handle_by_constructor__(_ffi_api.TupleType, fields, span)


@tvm._ffi.register_object("TypeConstraint")
Expand Down
Loading