diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 8cb74a595c..5c05e7c9aa 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -33,8 +33,8 @@ namespace relax { using Expr = RelayExpr; using ExprNode = RelayExprNode; -using relay::Id; using relay::Call; +using relay::Id; using relay::Tuple; using relay::TupleGetItem; @@ -53,8 +53,7 @@ class ShapeExprNode : public ExprNode { } bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { - return equal(values, other->values) && - equal(checked_type_, other->checked_type_) && + return equal(values, other->values) && equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); } @@ -72,15 +71,15 @@ class ShapeExprNode : public ExprNode { class ShapeExpr : public Expr { public: - TVM_DLL ShapeExpr(Array values); + TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode); }; - /*! \brief The variable class for all Relax bindings. */ class VarNode : public ExprNode { public: - /*! \brief The identifier of the variable, is used for comparing stable equality across transformations. */ + /*! \brief The identifier of the variable, which is used for comparing stable equality across + * transformations. */ Id vid; /*! \brief The type annotation, used by binding sites and parameter declarations. */ runtime::Optional type_annotation; @@ -97,11 +96,9 @@ class VarNode : public ExprNode { } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { - return equal(vid, other->vid) && - equal(type_annotation, other->type_annotation) && + return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) && // Do we use the analysis information in equality? - equal(checked_type_, other->checked_type_) && - equal(shape_, other->shape_); + equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -120,16 +117,12 @@ class VarNode : public ExprNode { class Var : public Expr { public: - TVM_DLL Var(String name_hint, - runtime::Optional shape_annotation, - runtime::Optional type_annotation, - Span span = Span()) - : Var(Id(name_hint), shape_annotation, type_annotation, span) {} - - TVM_DLL Var(Id vid, - runtime::Optional shape_annotation, - runtime::Optional type_annotation, - Span span = Span()); + TVM_DLL explicit Var(String name_hint, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()) + : Var(Id(name_hint), shape_annotation, type_annotation, span) {} + + TVM_DLL explicit Var(Id vid, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); }; @@ -147,10 +140,8 @@ class DataflowVarNode : public VarNode { } bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { - return equal(vid, other->vid) && - equal(type_annotation, other->type_annotation) && - equal(shape_, other->shape_) && - equal(checked_type_, other->checked_type_); + return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) && + equal(shape_, other->shape_) && equal(checked_type_, other->checked_type_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -168,15 +159,22 @@ class DataflowVarNode : public VarNode { class DataflowVar : public Var { public: - using Var::Var; // inherit constructors from Var + TVM_DLL explicit DataflowVar(String name_hint, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()) + : DataflowVar(Id(name_hint), shape_annotation, type_annotation, span) {} + + TVM_DLL explicit DataflowVar(Id vid, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); }; - /*! \brief The base class of a variable binding in Relax. */ class BindingNode : public Object { public: - void VisitAttrs(AttrVisitor* v) {} + mutable Span span; + + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; } void SHashReduce(SHashReducer hash_reduce) const {} @@ -188,10 +186,10 @@ class BindingNode : public Object { class Binding : public ObjectRef { public: + TVM_DLL explicit Binding(Span span); TVM_DEFINE_OBJECT_REF_METHODS(Binding, ObjectRef, BindingNode); }; - /*! \brief Symbolic shape match, binds the variables of the LHS with the rhs. */ class MatchShape; class MatchShapeNode : public BindingNode { @@ -202,6 +200,7 @@ class MatchShapeNode : public BindingNode { void VisitAttrs(AttrVisitor* v) { v->Visit("pattern", &pattern); v->Visit("value", &value); + v->Visit("span", &span); } bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const { @@ -221,7 +220,7 @@ class MatchShapeNode : public BindingNode { class MatchShape : public Binding { public: - TVM_DLL MatchShape(Array pattern, Expr value); + TVM_DLL explicit MatchShape(Array pattern, Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode); }; @@ -234,6 +233,7 @@ class VarBindingNode : public BindingNode { void VisitAttrs(AttrVisitor* v) { v->Visit("var", &var); v->Visit("value", &value); + v->Visit("span", &span); } bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { @@ -251,23 +251,28 @@ class VarBindingNode : public BindingNode { class VarBinding : public Binding { public: - TVM_DLL VarBinding(Var var, Expr value); + TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); }; - class BindingBlock; class BindingBlockNode : public Object { public: + mutable Span span; Array bindings; + void VisitAttrs(AttrVisitor* v) { + v->Visit("span", &span); v->Visit("bindings", &bindings); } + bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { return equal(bindings, other->bindings); } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + static constexpr const char* _type_key = "relax.expr.BindingBlock"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -276,21 +281,17 @@ class BindingBlockNode : public Object { class BindingBlock : public ObjectRef { public: - TVM_DLL BindingBlock(Array bindings); + TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); }; - class DataflowBlock; class DataflowBlockNode : public BindingBlockNode { public: - void VisitAttrs(AttrVisitor* v) { - v->Visit("bindings", &bindings); - } bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { return equal(bindings, other->bindings); } - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + static constexpr const char* _type_key = "relax.expr.DataflowBlock"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -299,7 +300,7 @@ class DataflowBlockNode : public BindingBlockNode { class DataflowBlock : public BindingBlock { public: - TVM_DLL DataflowBlock(Array bindings); + TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); }; @@ -340,11 +341,10 @@ class SeqExprNode : public ExprNode { class SeqExpr : public Expr { public: - TVM_DLL SeqExpr(Array blocks, Expr body); + TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); }; - /*! \brief A Relax function, eventually to replace the current Relay function definition. */ class FunctionNode : public BaseFuncNode { public: @@ -372,8 +372,7 @@ class FunctionNode : public BaseFuncNode { bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return equal.DefEqual(params, other->params) && - equal(body, other->body) && + return equal.DefEqual(params, other->params) && equal(body, other->body) && equal(ret_type, other->ret_type) && equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); } @@ -396,12 +395,11 @@ class FunctionNode : public BaseFuncNode { class Function : public Expr { public: - TVM_DLL Function(runtime::Optional name, Array params, - Expr body, Type ret_type); + TVM_DLL explicit Function(runtime::Optional name, Array params, Expr body, + Type ret_type, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); }; - /*! \brief The extern function, which can represent packed function. */ class ExternFuncNode : public BaseFuncNode { public: @@ -410,15 +408,14 @@ class ExternFuncNode : public BaseFuncNode { void VisitAttrs(AttrVisitor* v) { v->Visit("global_symbol", &global_symbol); + v->Visit("span", &span); } bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { return equal(global_symbol, other->global_symbol); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(global_symbol); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(global_symbol); } static constexpr const char* _type_key = "relax.expr.ExternFunc"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -428,7 +425,7 @@ class ExternFuncNode : public BaseFuncNode { class ExternFunc : public Expr { public: - TVM_DLL ExternFunc(String global_symbol); + TVM_DLL ExternFunc(String global_symbol, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode); }; diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 9967c29a5f..4b5b1640f3 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -39,7 +39,10 @@ namespace relax { class ShapeTypeNode : public TypeNode { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("span", &span); + } bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { return true; @@ -53,16 +56,9 @@ class ShapeTypeNode : public TypeNode { class ShapeType : public Type { public: - explicit ShapeType(); - explicit ShapeType(runtime::ObjectPtr n) : Type(n) {} - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(ShapeType); - const ShapeTypeNode* operator->() const { - return static_cast(data_.get()); - } - const ShapeTypeNode* get() const { - return operator->(); - } - using ContainerType = ShapeTypeNode; + TVM_DLL ShapeType(Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); }; class DynTensorTypeNode : public BaseTensorTypeNode { @@ -108,11 +104,34 @@ class DynTensorType : public Type { * \param shape The shape of the tensor. * \param dtype The runtime dtype of the tensor's elements. */ - TVM_DLL DynTensorType(int rank, DataType dtype); + TVM_DLL DynTensorType(int rank, DataType dtype, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); }; +class DimTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("span", &span); + } + + bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const { + return true; + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.DimType"; + TVM_DECLARE_FINAL_OBJECT_INFO(DimTypeNode, TypeNode); +}; + +class DimType : public Type { + public: + TVM_DLL DimType(Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DimType, Type, DimTypeNode); +}; + } // namespace relax } // namespace tvm #endif // TVM_RELAX_TYPE_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index bd094a7f69..00f45f2591 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -131,6 +131,7 @@ class TupleNode : public ExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("shape_", &shape_); } bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 4fe28f1d72..3748b7dcec 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -19,6 +19,7 @@ import tvm import tvm._ffi +from . import Span from .base import Node from . import _ffi_api @@ -166,8 +167,8 @@ class TupleType(Type): The fields in the tuple """ - def __init__(self, fields): - self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) + def __init__(self, fields, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.TupleType, fields, span) @tvm._ffi.register_object("TypeConstraint") diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index b81df5c883..e3b3c61c5e 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -19,9 +19,9 @@ from . import expr from . import ty from . import vm -from . import op from . import ir_builder from . import op +from . import parser # Expr @@ -42,14 +42,18 @@ Tuple = expr.Tuple Function = expr.Function ExternFunc = expr.ExternFunc +Call = expr.Call +If = expr.If # helper functions const = expr.const extern = expr.extern # Type +Type = ty.Type ShapeType = ty.ShapeType DynTensorType = ty.DynTensorType +DimType = ty.DimType # VM ExecBuilder = exec_builder.ExecBuilder @@ -61,3 +65,6 @@ # IRBuilder IRBuilder = ir_builder.IRBuilder + +# Parser +from .parser import script diff --git a/python/tvm/relax/base.py b/python/tvm/relax/base.py new file mode 100644 index 0000000000..b85ac77c6e --- /dev/null +++ b/python/tvm/relax/base.py @@ -0,0 +1,4 @@ +# Skeleton AST so we can get prototype working before this PR is merged +class rxNode: + def __init__(self, span): + self.span = span diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 172cf6dee4..d4ed3cc797 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -27,6 +27,7 @@ Type = relay.Type GlobalVar = relay.GlobalVar Call = relay.Call +If = relay.If const = relay.const @@ -34,8 +35,8 @@ class ShapeExpr(Expr): values: List[PrimExpr] - def __init__(self, values: List[PrimExpr]) -> None: - self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values) + def __init__(self, values: List[PrimExpr], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) def __getitem__(self, index): if index >= len(self): @@ -45,6 +46,7 @@ def __getitem__(self, index): def __len__(self): return len(self.values) + def make_shape(shape: List[PrimExpr]) -> ShapeExpr: if isinstance(shape, (list, tuple)): return ShapeExpr(shape) @@ -54,17 +56,19 @@ def make_shape(shape: List[PrimExpr]) -> ShapeExpr: @tvm._ffi.register_object("relax.expr.Var") class Var(Expr): - id: Id + vid: Id type_annotation: Optional[Type] - def __init__(self, name_hint: str, - shape_annotation: Optional[Expr] = None, - type_annotation: Optional[Type] = None) -> None: - if shape_annotation is not None: - shape_annotation = make_shape(shape_annotation) - self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, - shape_annotation, - type_annotation) + def __init__( + self, + name_hint: str, + shape_annotation: Optional[Expr] = None, + type_annotation: Optional[Type] = None, + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.Var, name_hint, shape_annotation, type_annotation, span + ) @property def name_hint(self): @@ -75,12 +79,22 @@ def name_hint(self): @tvm._ffi.register_object("relax.expr.DataflowVar") class DataflowVar(Var): - pass + def __init__( + self, + name_hint: str, + shape_annotation: Optional[Expr] = None, + type_annotation: Optional[Type] = None, + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DataflowVar, name_hint, shape_annotation, type_annotation, span + ) @tvm._ffi.register_object("relax.expr.Binding") class Binding(Node): - pass + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.Binding, span) @tvm._ffi.register_object("relax.expr.MatchShape") @@ -88,8 +102,8 @@ class MatchShape(Binding): pattern: List[PrimExpr] value: Expr - def __init__(self, pattern: List[PrimExpr], value: Expr) -> None: - self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value) + def __init__(self, pattern: List[PrimExpr], value: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value, span) @tvm._ffi.register_object("relax.expr.VarBinding") @@ -97,16 +111,16 @@ class VarBinding(Binding): var: Var value: Expr - def __init__(self, var: Var, value: Expr) -> None: - self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value) + def __init__(self, var: Var, value: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value, span) @tvm._ffi.register_object("relax.expr.BindingBlock") class BindingBlock(Node): bindings: List[Binding] - def __init__(self, bindings: List[Binding]) -> None: - self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings) + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings, span) @tvm._ffi.register_object("relax.expr.DataflowBlock") @@ -119,8 +133,8 @@ class SeqExpr(Expr): blocks: List[BindingBlock] body: Expr - def __init__(self, blocks: List[BindingBlock], body: Expr) -> None: - self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body) + def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body, span) @tvm._ffi.register_object("relax.expr.Function") @@ -130,18 +144,24 @@ class Function(BaseFunc): body: Expr ret_type: Type - def __init__(self, params: List[Var], body: Expr, - ret_type: Type, name: Optional[GlobalVar] = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.Function, name, params, - body, ret_type) + def __init__( + self, + params: List[Var], + body: Expr, + ret_type: Type, + name: Optional[GlobalVar] = None, + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__(_ffi_api.Function, name, params, body, ret_type, span) @tvm._ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc): global_symbol: String - def __init__(self, global_symbol: String) -> None: - self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol) + def __init__(self, global_symbol: String, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol, span) + -def extern(name): - return ExternFunc(name) +def extern(name, span: Span = None): + return ExternFunc(name, span) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 767dbdea15..2fa88cb954 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -18,9 +18,10 @@ from . import _ffi_api from typing import Union, List -def call_dps(shape: Union[ShapeExpr, List[int]], - func: Expr, - args: Union[Tuple, List[Expr]]) -> Call: + +def call_dps( + shape: Union[ShapeExpr, List[int]], func: Expr, args: Union[Tuple, List[Expr]] +) -> Call: """ Call a destination-passing-style function and return the output. diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py new file mode 100644 index 0000000000..62e587b708 --- /dev/null +++ b/python/tvm/relax/parser.py @@ -0,0 +1,860 @@ +from __future__ import annotations + +import inspect +from typing import TypeVar, Generic, Union, Dict, List, Tuple, Optional +from io import StringIO +from enum import Enum + +import tvm +import tvm.script +from tvm.ir.module import IRModule +from tvm.relay.base import Id +from tvm.ir import diagnostics +from tvm import tir + +import numpy as np + +import synr +from synr import ast, Transformer +from synr.diagnostic_context import DiagnosticContext +from tvm.relay.op.strategy.generic import conv1d_strategy + +import tvm.relay as relay +import tvm.relax as rx + + +# TODO(@altanh): Replace with a real pretty print method once we have the real AST +def pretty_print(f): + print(f) + + +def is_registered(op_name: str, op_set=None) -> bool: + if op_set is None: + op_set = tvm.ir._ffi_api.ListOpNames() + return op_name in op_set + + +def _tir_from_synr( + synr_ast: ast.Node, diag_ctx: tvm.script.diagnostics.TVMDiagnosticCtx +) -> tir.PrimFunc: + parser = tvm.script.parser.TVMScriptParser(synr_ast.span.start_line) + return parser.do_transform(synr_ast, diag_ctx) + + +# NOTE: call_dps is an actual registered operator +class SpecialOp(Enum): + MATCH_SHAPE = "relax.match_shape" + CALL_PACKED = "relax.call_packed" + DATAFLOW = "relax.dataflow" + + +class RelaxTransformer(Transformer): + def __init__(self, definition_scope): + super().__init__() + self.definition_scope = definition_scope + self.module = {} + self._scopes = [{}] # str -> Var + self._registered_ops = set(tvm.ir._ffi_api.ListOpNames()) # cached + + def to_tvm_span(self, span: ast.Span) -> tvm.ir.Span: + """Helper method for converting synr span to TVM span. + + Parameters + ---------- + span : ast.Span + The synr span + + Returns + ------- + tvm.ir.Span + The corresponding TVM span + """ + return self._diagnostic_context.to_tvm_span(self._diagnostic_context.source_name, span) + + def report_error(self, msg: str, span: ast.Span): + """Helper method for emitting and immediately rendering an error. + + Parameters + ---------- + msg : str + The error message + span : ast.Span + The span to report the error at + """ + self._diagnostic_context.emit("error", msg, span) + self._diagnostic_context.render() + + def new_scope(self): + """Helper method for creating a new scope context object + + Returns + ------- + _Scope + An internal scope context object used in a with block to create a new scope + """ + + class _Scope: + def __init__(self, transformer: "RelaxTransformer"): + self.transformer = transformer + + def __enter__(self): + self.transformer._scopes.append(self.transformer._scopes[-1].copy()) + + def __exit__(self, *exc): + assert len(self.transformer._scopes) > 1, "cannot pop root scope" + self.transformer._scopes.pop() + + return _Scope(self) + + @property + def scope(self): + """Returns the current definition scope. + + Returns + ------- + Dict[str, Union[rx.Var, tir.Var]] + The scope of all currently defined variables (Relax and TIR). + """ + return self._scopes[-1] + + def decl_var( + self, + name: str, + type_annotation: Optional[rx.Type], + shape: Optional[rx.Expr], + span: ast.Span, + is_dataflow: bool = False, + ) -> rx.Var: + """Introduces a variable with the given name and annotations to the current scope. + + Parameters + ---------- + name : str + The name of the variable + type_annotation : Optional[rx.Type] + The type annotation + shape : Optional[rx.Expr] + The shape annotation + span : ast.Span + The span where the variable is declared + + Returns + ------- + Union[rx.Var, rx.DataflowVar] + The declared variable + """ + if name in self.scope: + # TODO(@altanh): maybe emit an error at the declaration site and report it together + self.report_error("variable has already been declared in the current scope", span) + if is_dataflow: + var = rx.DataflowVar(name, shape, type_annotation, self.to_tvm_span(span)) + else: + var = rx.Var(name, shape, type_annotation, self.to_tvm_span(span)) + self.scope[name] = var + return var + + def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[rx.Type, rx.Expr]: + """Transforms the given synr type annotation to a Relax type and shape expression. + + Parameters + ---------- + ty : ast.Type + The synr type + bind_free_vars : bool + Whether or not the shape annotation can introduce new dimension variables + + Returns + ------- + Tuple[rx.Type, rx.Expr]: + The corresponding Relax type and shape expression + """ + if ty is None: + return (None, None) + + span = self.to_tvm_span(ty.span) + + # simple annotation with no type arguments + if isinstance(ty, ast.TypeVar): + if ty.id.name == "Tensor": + return (rx.DynTensorType(rank=-1, dtype=None, span=span), None) + elif ty.id.name == "Shape": + return (rx.ShapeType(span), None) + elif ty.id.name == "Dim": + return (rx.DimType(span), None) + else: + self.report_error("unknown type in annotation", span) + + # annotation with type arguments/shape annotation + if isinstance(ty, ast.TypeApply): + if ty.id.name == "Tensor": + # TODO(@altanh): forgetting dtype like "Tensor[(n, m)]" ends up getting parsed as + # Tensor[n, m] which makes correct errors difficult here... + if len(ty.params) != 2: + self.report_error( + "Tensor type annotations must have 2 fields (shape and dtype)", + span, + ) + + shape_annotation, dtype_annotation = ty.params + shape, dtype = None, None + + # parse the shape annotation + if isinstance(shape_annotation, ast.TypeVar): + if shape_annotation.id.name != "_": + # TODO(@altanh): handle variable annotations, e.g. x: Tensor[my_shape, _] + self.report_error( + "variable Tensor shape annotations not yet supported", + shape_annotation.span, + ) + else: + # FIXME(@altanh): use a special node for unknown shape vs no shape? + pass # shape = None + elif isinstance(shape_annotation, ast.TypeTuple): + shape = rx.ShapeExpr( + self.parse_shape(shape_annotation, bind_free_vars), + span=self.to_tvm_span(shape_annotation.span), + ) + else: + self.report_error( + "unsupported shape annotation", + shape_annotation.span, + ) + + # parse the dtype annotation + if isinstance(dtype_annotation, ast.TypeVar) and dtype_annotation.id.name == "_": + pass # dtype = None + elif isinstance(dtype_annotation, ast.TypeConstant): + dtype = dtype_annotation.value + else: + self.report_error( + "Tensor dtype annotations must be concrete or erased", + dtype_annotation.span, + ) + + rank = len(shape) if shape is not None else -1 + return (rx.DynTensorType(rank=rank, dtype=dtype, span=span), shape) + elif ty.id.name == "Tuple": + field_types = [] + field_shapes = [] + for field in ty.params: + fty, fsh = self.transform_type(field, bind_free_vars=False) + field_types.append(fty) + field_shapes.append(fsh) + return (relay.TupleType(field_types, span), None) + # TODO(@altanh): other types with args, e.g. Ref[T], func types + self.report_error("invalid type", span) + + def parse_shape( + self, + shape_annotation: Union[ast.TypeTuple, ast.Tuple], + bind_free_vars: bool, + ) -> List[tir.PrimExpr]: + """Parses the given shape annotation to a list of PrimExprs. + + Parameters + ---------- + shape_annotation : Union[ast.TypeTuple, ast.Tuple] + The shape annotation in synr + bind_free_vars : bool + Whether or not the annotation can bind previously free variables + + Returns + ------- + List[tir.PrimExpr] + The parsed shape as a list of PrimExprs + """ + return [self.parse_primexpr(field, bind_free_vars) for field in shape_annotation.values] + + def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr: + """Parses the given expression to a PrimExpr. + + Parameters + ---------- + expr : ast.Expr + The input expression + bind_free_vars : bool + Whether or not the expression can bind previously free variables + + Returns + ------- + tir.PrimExpr + The result PrimExpr + """ + if isinstance(expr, ast.Var): + var_name = expr.id.name + if var_name in self.scope: + var = self.scope[var_name] + if not isinstance(var, tir.Var): + # TODO(@altanh): we may wish to relax this in the future to support constructing + # shapes from Dim-typed Relax expressions + self.report_error( + "non-dimension variables cannot appear in dimension expressions", + expr.span, + ) + return var + elif bind_free_vars: + # introduce TIR variable to scope, e.g. for func params or rx.call_packed + var = tir.Var(var_name, "int32", self.to_tvm_span(expr.span)) + self.scope[var_name] = var + return var + else: + self.report_error( + "cannot introduce new dimension variables in this expression", + expr.span, + ) + elif isinstance(expr, ast.Constant): + if not isinstance(expr.value, int): + self.report_error("only integer constants are supported", expr.span) + return tir.const(expr.value, "int32", self.to_tvm_span(expr.span)) + else: + # TODO(@altanh): parse (simple) PrimExprs + self.report_error( + "only dimension variable expressions are currently supported", + expr.span, + ) + + def transform_module(self, mod: ast.Module) -> IRModule: + """Transforms the given synr Module to a Relax IRModule. + + Parameters + ---------- + mod : ast.Module + The input synr Module + + Returns + ------- + IRModule + The parsed Relax IRModule + """ + for func_name in mod.funcs: + func = mod.funcs[func_name] + self.module[func_name] = self.transform_function(func, is_global=True) + return self.module + + def transform_function(self, func: ast.Function, is_global: bool = False) -> rx.Function: + """Transforms the given synr Function to a Relax Function. + + Parameters + ---------- + func : ast.Function + The input synr Function + is_global : bool, optional + Whether or not the input function is global/module-level, by default False + + Returns + ------- + rx.Function + The parsed Relax Function + """ + if ( + len(func.decorators) == 1 + and isinstance(func.decorators[0], ast.Var) + and func.decorators[0].id.name == "tir" + ): + return _tir_from_synr(func, self._diagnostic_context) + + with self.new_scope(): + params = [] + for param in func.params: + ty, shape = self.transform_type(param.ty, bind_free_vars=True) + param = self.decl_var(param.name, ty, shape, param.span) + params.append(param) + new_body = self.transform_block(func.body) + ret_type, _ = self.transform_type(func.ret_type, bind_free_vars=False) + + func_name = rx.GlobalVar(func.name) if is_global else None + return rx.Function( + params, new_body, ret_type, name=func_name, span=self.to_tvm_span(func.span) + ) + + def parse_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> rx.Binding: + """Parses the input synr statement to the corresponding Relax binding. + + Parameters + ---------- + stmt : ast.Stmt + The input synr statement (either an assignment or a unassigned call) + is_dataflow : bool, optional + Whether or not the binding is in a dataflow block, by default False + + Returns + ------- + rx.Binding + The parsed Relax binding + """ + assert isinstance(stmt, (ast.Assign, ast.UnassignedCall)) + if isinstance(stmt, ast.Assign): + return self.parse_var_binding(stmt, is_dataflow=is_dataflow) + else: + return self.parse_shape_binding(stmt) + + def parse_shape_binding(self, stmt: ast.UnassignedCall) -> rx.MatchShape: + """Parses the input synr statement to a Relax shape binding. + + Parameters + ---------- + stmt : ast.UnassignedCall + The input synr statement + + Returns + ------- + rx.MatchShape + The parsed Relax shape binding + """ + call: synr.ast.Call = stmt.call + op = self.transform_expr(call.func_name) + if op != SpecialOp.MATCH_SHAPE: + self.report_error("the results of calls must be bound or used", stmt.span) + if len(stmt.call.params) != 2: + self.report_error(op.value + " takes exactly two arguments", stmt.span) + + lhs = stmt.call.params[0] + rhs = stmt.call.params[1] + + rhs_expr = self.transform_expr(rhs) + if not isinstance(lhs, ast.Tuple): + self.report_error( + "the pattern (lhs) of " + op.value + " must be a tuple", + lhs.span, + ) + lhs_expr = self.parse_shape(lhs, bind_free_vars=True) + return rx.MatchShape(lhs_expr, rhs_expr, self.to_tvm_span(stmt.span)) + + def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> rx.VarBinding: + """Parses the input synr assignment to a Relax variable binding. + + Parameters + ---------- + stmt : ast.Assign + The input synr assignment + is_dataflow : bool, optional + Whether or not the binding is in a dataflow block, by default False + + Returns + ------- + rx.VarBinding + The prased Relax variable binding + """ + if not isinstance(stmt.lhs, ast.Var): + self.report_error( + "the left hand side of a binding must be a variable", + stmt.lhs.span, + ) + rhs = self.transform_expr(stmt.rhs) + # an ExternFunc call comes from call_packed + if isinstance(rhs, relay.Call) and isinstance(rhs.op, rx.ExternFunc): + bind_free_vars = True + else: + bind_free_vars = False + ty, shape = self.transform_type(stmt.ty, bind_free_vars) + lhs = self.decl_var(stmt.lhs.id.name, ty, shape, stmt.lhs.span, is_dataflow=is_dataflow) + return rx.VarBinding(lhs, rhs, self.to_tvm_span(stmt.span)) + + # Stmts: + # - Assert: probably unsupported for now + # - Assign: VarBinding + # - For: ?? + # - If: IfThenElse, must check no empty false branch + # - Return: just the returned expression, must terminate blocks? (special case if-else) + # - UnassignedCall: match_shape + # - With: rx.dataflow + def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.DataflowBlock]: + """Transforms the given synr statement to the corresponding Relax node. + + Parameters + ---------- + stmt : ast.Stmt + The input synr statement + + Returns + ------- + Union[rx.Expr, rx.Binding, rx.DataflowBlock] + The parsed Relax node + """ + if isinstance(stmt, ast.Assign): + # dataflow bindings are handled separately in parse_dataflow + return self.parse_binding(stmt) + + elif isinstance(stmt, ast.If): + # check branches are non-empty + if len(stmt.true.stmts) == 0 or len(stmt.false.stmts) == 0: + self.report_error("both branches of an if-else block must be non-empty", stmt.span) + true_assign = stmt.true.stmts[-1] + false_assign = stmt.false.stmts[-1] + + # check last statement in each branch lines up + if not isinstance(true_assign, ast.Assign) or not isinstance(true_assign.lhs, ast.Var): + self.report_error( + "each branch of an if-else statement must end in a variable assignment", + true_assign.span, + ) + if not isinstance(false_assign, ast.Assign) or not isinstance( + false_assign.lhs, ast.Var + ): + self.report_error( + "each branch of an if-else statement must end in a variable assignment", + false_assign.span, + ) + union_span = ast.Span.union([true_assign.span, false_assign.span]) + if true_assign.lhs.id.name != false_assign.lhs.id.name: + self.report_error( + "the final assignment of both branches must have the same variable", + union_span, + ) + + var_name = true_assign.lhs.id.name + + # rewrite branches to have a return statement so the blocks properly parse to SeqExprs + true_block = synr.ast.Block( + span=stmt.true.span, + stmts=stmt.true.stmts[:-1] + + [synr.ast.Return(span=true_assign.span, value=true_assign.rhs)], + ) + false_block = synr.ast.Block( + span=stmt.false.span, + stmts=stmt.false.stmts[:-1] + + [synr.ast.Return(span=false_assign.span, value=false_assign.rhs)], + ) + + # parse the branches, build the final expression and binding + cond = self.transform_expr(stmt.condition) + with self.new_scope(): + true_branch = self.transform_block(true_block) + with self.new_scope(): + false_branch = self.transform_block(false_block) + # TODO(@altanh): the spans here are all sorts of messed up, not sure how to fix + ite_expr = relay.If(cond, true_branch, false_branch, self.to_tvm_span(stmt.span)) + # TODO(@altanh): type and shape of return var + var = self.decl_var(var_name, None, None, union_span) + return rx.VarBinding(var, ite_expr, self.to_tvm_span(union_span)) + + elif isinstance(stmt, ast.Return): + return self.transform_expr(stmt.value) + + elif isinstance(stmt, ast.UnassignedCall): + # FIXME: when we add ref support, ref_write can be unassigned + return self.parse_shape_binding(stmt) + + elif isinstance(stmt, ast.With): + if not isinstance(stmt.rhs, ast.Call): + self.report_error("unsupported with block", stmt.span) + + call = stmt.rhs + op = self.transform_expr(call.func_name) + + # TODO(@altanh): perhaps this ought to be more general + + if op == SpecialOp.DATAFLOW: + if len(call.params) > 0: + self.report_error( + "dataflow block constructor takes no arguments", + call.params[0].span, + ) + if len(stmt.lhs) > 0: + self.report_error( + "dataflow blocks don't bind any patterns", + stmt.lhs[0].span, + ) + return self.parse_dataflow(stmt.body) + else: + self.report_error("unsupported with block type", call.span) + + elif isinstance(stmt, ast.Function): + func = self.transform_function(stmt) + func_var = self.decl_var(stmt.name, None, None, stmt.span) + return rx.VarBinding(func_var, func, self.to_tvm_span(stmt.span)) + + else: + self.report_error( + "unsupported statement", + stmt.span, + ) + + def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: + """Parses the input synr block to a Relax dataflow block. + + Parameters + ---------- + block : ast.Block + The input synr block + + Returns + ------- + rx.DataflowBlock + The parsed Relax dataflow block + """ + assert len(block.stmts) > 0, "should never have an empty dataflow block" + bindings = [] + output_vars = [] + + with self.new_scope(): + # parse the return statement first to figure out which bindings assign normal Vars + output_stmt = block.stmts[-1] + if not isinstance(output_stmt, ast.Return): + self.report_error( + "dataflow blocks must end with returning the output variables", + output_stmt.span, + ) + + ret_val = output_stmt.value + if isinstance(ret_val, ast.Var): + ret_val = ast.Tuple(values=[ret_val], span=ret_val.span) + + if not isinstance(ret_val, ast.Tuple) or any( + [not isinstance(f, ast.Var) for f in ret_val.values] + ): + self.report_error( + "the returned values must be variables", + ret_val.span, + ) + + # output variables are bound to normal (not data flow) Vars + output_var_names = {var.id.name for var in ret_val.values} + + for binding_stmt in block.stmts[:-1]: + if not isinstance(binding_stmt, (ast.Assign, ast.UnassignedCall)): + self.report_error( + "only bindings are supported in dataflow blocks", + binding_stmt.span, + ) + is_match_shape = isinstance(binding_stmt, ast.UnassignedCall) + is_dataflow = not is_match_shape and ( + binding_stmt.lhs.id.name not in output_var_names + ) + binding = self.parse_binding(binding_stmt, is_dataflow=is_dataflow) + bindings.append(binding) + if not is_dataflow: + if is_match_shape: + for var in binding.pattern: + output_vars.append(var) + else: + output_vars.append(binding.var) + + # make output variables visible in parent scope + for v in output_vars: + # v could already be in scope if it was a previously bound dimension variable + v_name = v.name if isinstance(v, tir.Var) else v.name_hint + if v not in self.scope: + self.scope[v_name] = v + + return rx.DataflowBlock(bindings, self.to_tvm_span(block.span)) + + # Exprs: + # - ArrayLiteral: unsupported for now? + # - Attr: use for .shape, and intrinsic/special operator namespace + # - Call + # - Constant + # - DictLiteral: unsupported for now + # - Slice: unsupported for now, could desugar to slice op + # - Tuple + # - Var + def transform_expr(self, expr: ast.Expr) -> rx.Expr: + """Transforms the input synr expression to a Relax expression. + + Parameters + ---------- + expr : ast.Expr + The input synr + + Returns + ------- + rx.Expr + The corresponding Relax expression + """ + if isinstance(expr, ast.Attr): + if expr.field.name == "shape": + obj = self.transform_expr(expr.object) + return relay.op.shape_of(obj) + else: + # assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps) + op_name = [] + attr = expr + while isinstance(attr, ast.Attr): + op_name.append(expr.field.name) + attr = attr.object + if not isinstance(attr, ast.Var): + self.report_error("unsupported field access", expr.span) + op_name.append(attr.id.name) + op_name = ".".join(reversed(op_name)) + # NOTE: at least for now, all special operators are namespaced + try: + return SpecialOp(op_name) + except ValueError: + # TODO(@altanh): maybe diagnostics here in case this fails? + return relay.op.get(op_name) + + if isinstance(expr, ast.Call): + op = self.transform_expr(expr.func_name) + if op == SpecialOp.CALL_PACKED: + if len(expr.params) != 2: + self.report_error( + op.value + " takes an extern function name and a tuple of arguments", + expr.span, + ) + extern_func = expr.params[0] + if not ( + isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str) + ): + self.report_error( + "the first argument of " + op.value + " must be the extern function name", + extern_func.span, + ) + op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span)) + args = [self.transform_expr(expr.params[1])] + else: + args = [self.transform_expr(arg) for arg in expr.params] + # TODO(@altanh): should we check for correct arity here eagerly, or defer to a pass? + return relay.Call(op, args, span=self.to_tvm_span(expr.span)) + + elif isinstance(expr, ast.Tuple): + fields = [self.transform_expr(field) for field in expr.values] + + # TODO(@altanh): this check might be too weak; we really only accept integral PrimExprs + # (e.g. int constants, dim vars, and integer operations on these) + + # coerce to ShapeExpr when fields are all PrimExprs + if all([isinstance(f, tir.PrimExpr) for f in fields]): + return rx.ShapeExpr(fields, span=self.to_tvm_span(expr.span)) + return relay.Tuple(fields, span=self.to_tvm_span(expr.span)) + + elif isinstance(expr, ast.Var): + var_name = expr.id.name + if is_registered(var_name, op_set=self._registered_ops): + return relay.op.get(var_name) + if var_name not in self.scope: + self.report_error("undefined variable", expr.span) + return self.scope[var_name] + + elif isinstance(expr, ast.Constant): + # FIXME(@altanh): use internal representation that doesn't have precision limits here + if isinstance(expr.value, int): + return tir.IntImm("int32", expr.value, self.to_tvm_span(expr.span)) + elif isinstance(expr.value, float): + return tir.FloatImm("float32", expr.value, self.to_tvm_span(expr.span)) + else: + self.report_error( + "unsupported constant expression (we currently only support int and float)", + expr.span, + ) + + else: + self.report_error("unsupported expression", expr.span) + + def transform_block(self, block: ast.Block) -> rx.SeqExpr: + """Transforms the given synr block to a Relax SeqExpr (sequence of Blocks with a final + expression). + + Parameters + ---------- + block : ast.Block + The input synr block + + Returns + ------- + rx.SeqExpr + The parsed SeqExpr + """ + # a block of statements needs to be converted to a SeqExpr of binding blocks + blocks = [] + current_block = [] + for stmt in block.stmts[:-1]: + parsed_stmt = self.transform_stmt(stmt) + if isinstance(parsed_stmt, rx.DataflowBlock): + if current_block: + # FIXME: span + blocks.append(rx.BindingBlock(current_block, self.to_tvm_span(stmt.span))) + current_block = [] + blocks.append(parsed_stmt) + else: + assert isinstance(parsed_stmt, rx.Binding) + current_block.append(parsed_stmt) + if len(current_block) > 0: + blocks.append(rx.BindingBlock(current_block, self.to_tvm_span(stmt.span))) + + ret_stmt = block.stmts[-1] + if not isinstance(ret_stmt, ast.Return): + self.report_error( + "block must end with a returned expression", + ret_stmt.span, + ) + ret_expr = self.transform_stmt(ret_stmt) + + return rx.SeqExpr(blocks, ret_expr, self.to_tvm_span(block.span)) + + +# class TVMDiagnosticContext(synr.DiagnosticContext): +# def __init__(self, tvm_diag_ctx): +# self.tvm_diag_ctx = tvm_diag_ctx +# self.str_to_source_name = {} + +# def add_source(self, name: str, source: str) -> None: +# """Add a file with source code to the context. This will be called +# before any call to :py:func:`emit` that contains a span in this +# file. +# """ +# src_name = self.tvm_diag_ctx.module.source_map.add(name, source) +# self.str_to_source_name[name] = src_name + +# def emit(self, level: str, message: str, span: tvm.ir.Span) -> None: +# """Called when an error has occured.""" + +# if level == "error": +# level = diagnostics.DiagnosticLevel.ERROR +# elif level == "bug": +# level = diagnostics.DiagnosticLevel.BUG +# elif level == "warning": +# level = diagnostics.DiagnosticLevel.WARNING +# else: +# level = "error" + +# assert span, "Span must not be null" +# assert isinstance(span, tvm.ir.Span), "Expected tvm.ir.Span, but got " + str(type(span)) + +# diag = diagnostics.Diagnostic(level, span, message) + +# self.tvm_diag_ctx.emit(diag) + +# def render(self) -> Optional[Any]: +# """Render out all error messages. Can either return a value or raise +# and execption. +# """ +# self.tvm_diag_ctx.render() + + +# TODO(@altanh, @jroesch): revisit this? +class RelaxDecoratedFn: + def __init__(self, fn_name, relax_module, diag_ctx): + self.fn_name = fn_name + self.module = relax_module + self.diag_ctx = diag_ctx + + def __call__(self, *args): + pretty_print(self.module[self.fn_name]) + # compiler = Compiler(self.diag_ctx, self.module, self.fn_name) + # compiled_f = compiler.compile(execute=True) + # # Actually compute needed buffer sizes. + # out = tvm.nd.array(np.random.rand(10).astype('float32')) + # compiled_f(*(list(args) + [out])) + # return out + + +def script(f) -> RelaxDecoratedFn: + """Parses the decorated Relax function (in Relax IR) to the Relax AST + + Parameters + ---------- + f : function + The function to be parsed, written in the Relax IR + + Returns + ------- + RelaxDecoratedFn + The parsed Relax function + """ + # ir_module = tvm.IRModule({}) + # diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) + diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() + ast = synr.to_ast(f, diag_ctx) + definition_scope = inspect.getmodule(f) + module = RelaxTransformer(definition_scope).do_transform(ast, diag_ctx) + return RelaxDecoratedFn(f.__name__, module, diag_ctx) diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index 0c34d2797d..e8f96fc935 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -17,15 +17,15 @@ # pylint: disable=invalid-name, unused-import """The type nodes of the Relax language.""" import tvm._ffi -from tvm.ir import Type, TensorType +from tvm.ir import Type, TensorType, Span from . import _ffi_api @tvm._ffi.register_object("relax.ShapeType") class ShapeType(Type): - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.ShapeType) + def __init__(self, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.ShapeType, span) @tvm._ffi.register_object("relax.DynTensorType") @@ -43,5 +43,12 @@ class DynTensorType(TensorType): The content data type. """ - def __init__(self, rank=-1, dtype="float32"): - self.__init_handle_by_constructor__(_ffi_api.DynTensorType, rank, dtype) + def __init__(self, rank=-1, dtype="float32", span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.DynTensorType, rank, dtype, span) + + +@tvm._ffi.register_object("relax.DimType") +class DimType(Type): + """The type of indices/shape dimensions in Relax.""" + def __init__(self, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.DimType, span) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 88b84bbe7e..b046e98239 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -23,7 +23,7 @@ import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar, Node +from tvm.ir import RelayExpr, GlobalVar, Node, Span from .base import RelayNode from . import _ffi_api diff --git a/src/ir/type.cc b/src/ir/type.cc index fe8e00329b..86dda2a274 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -144,8 +144,8 @@ TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { - return TupleType(fields); +TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields, Span span) { + return TupleType(fields, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/relax/expr.cc b/src/relax/expr.cc index 42f84adf6c..08e1a01f1d 100644 --- a/src/relax/expr.cc +++ b/src/relax/expr.cc @@ -1,21 +1,21 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ #include namespace tvm { @@ -29,8 +29,7 @@ RelayExpr RelayExprNode::shape() const { return relay::Call(op, {self}, {}, {}); } -TVM_REGISTER_GLOBAL("ir.RelayExprShape") -.set_body_typed([](RelayExpr expr) { +TVM_REGISTER_GLOBAL("ir.RelayExprShape").set_body_typed([](RelayExpr expr) { return expr->shape(); }); @@ -40,24 +39,20 @@ using tvm::runtime::Optional; TVM_REGISTER_NODE_TYPE(ShapeExprNode); -ShapeExpr::ShapeExpr(Array values) { +ShapeExpr::ShapeExpr(Array values, Span span) { ObjectPtr n = make_object(); n->values = std::move(values); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeExpr") -.set_body_typed([](Array values) { - return ShapeExpr(values); +TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { + return ShapeExpr(values, span); }); - TVM_REGISTER_NODE_TYPE(VarNode); -Var::Var(Id vid, - Optional shape_annotation, - Optional type_annotation, - Span span) { +Var::Var(Id vid, Optional shape_annotation, Optional type_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); n->shape_ = std::move(shape_annotation); @@ -70,141 +65,139 @@ Var::Var(Id vid, } TVM_REGISTER_GLOBAL("relax.Var") -.set_body_typed([](String name_hint, - Optional shape_annotation, - Optional type_annotation) { - return Var(name_hint, shape_annotation, type_annotation); -}); - + .set_body_typed([](String name_hint, Optional shape_annotation, + Optional type_annotation, Span span) { + return Var(name_hint, shape_annotation, type_annotation, span); + }); TVM_REGISTER_NODE_TYPE(DataflowVarNode); -TVM_REGISTER_GLOBAL("relax.DataflowVar") -.set_body_typed([](String name_hint, - Optional shape_annotation, - Optional type_annotation) { - return DataflowVar(name_hint, shape_annotation, type_annotation); -}); +DataflowVar::DataflowVar(Id vid, Optional shape_annotation, Optional type_annotation, + Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + n->shape_ = std::move(shape_annotation); + n->type_annotation = std::move(type_annotation); + n->span = std::move(span); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.DataflowVar") + .set_body_typed([](String name_hint, Optional shape_annotation, + Optional type_annotation, Span span) { + return DataflowVar(name_hint, shape_annotation, type_annotation, span); + }); + +Binding::Binding(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} TVM_REGISTER_NODE_TYPE(BindingNode); -TVM_REGISTER_GLOBAL("relax.Binding") -.set_body_typed([]() { - return Binding(); -}); - +TVM_REGISTER_GLOBAL("relax.Binding").set_body_typed([](Span span) { return Binding(span); }); TVM_REGISTER_NODE_TYPE(MatchShapeNode); -MatchShape::MatchShape(Array pattern, - Expr value) { +MatchShape::MatchShape(Array pattern, Expr value, Span span) { ObjectPtr n = make_object(); n->pattern = std::move(pattern); n->value = std::move(value); + n->span = span; data_ = std::move(n); } TVM_REGISTER_GLOBAL("relax.MatchShape") -.set_body_typed([](Array pattern, Expr value) { - return MatchShape(pattern, value); -}); - + .set_body_typed([](Array pattern, Expr value, Span span) { + return MatchShape(pattern, value, span); + }); TVM_REGISTER_NODE_TYPE(VarBindingNode); -VarBinding::VarBinding(Var var, - Expr value) { +VarBinding::VarBinding(Var var, Expr value, Span span) { ObjectPtr n = make_object(); n->var = std::move(var); n->value = std::move(value); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.VarBinding") -.set_body_typed([](Var var,Expr value) { - return VarBinding(var,value); +TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { + return VarBinding(var, value, span); }); - TVM_REGISTER_NODE_TYPE(BindingBlockNode); -BindingBlock::BindingBlock(Array bindings) { +BindingBlock::BindingBlock(Array bindings, Span span) { ObjectPtr n = make_object(); n->bindings = std::move(bindings); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.BindingBlock") -.set_body_typed([](Array bindings) { - return BindingBlock(bindings); +TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { + return BindingBlock(bindings, span); }); - TVM_REGISTER_NODE_TYPE(DataflowBlockNode); -DataflowBlock::DataflowBlock(Array bindings) { +DataflowBlock::DataflowBlock(Array bindings, Span span) { ObjectPtr n = make_object(); n->bindings = std::move(bindings); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.DataflowBlock") -.set_body_typed([](Array bindings) { - return DataflowBlock(bindings); +TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bindings, Span span) { + return DataflowBlock(bindings, span); }); - TVM_REGISTER_NODE_TYPE(SeqExprNode); -SeqExpr::SeqExpr(Array blocks, - Expr body) { +SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { ObjectPtr n = make_object(); n->blocks = std::move(blocks); n->body = std::move(body); + n->span = span; data_ = std::move(n); } TVM_REGISTER_GLOBAL("relax.SeqExpr") -.set_body_typed([](Array blocks, Expr body) { - return SeqExpr(blocks, body); -}); - + .set_body_typed([](Array blocks, Expr body, Span span) { + return SeqExpr(blocks, body, span); + }); TVM_REGISTER_NODE_TYPE(FunctionNode); -Function::Function(runtime::Optional name, - Array params, - Expr body, - Type ret_type) { +Function::Function(runtime::Optional name, Array params, Expr body, Type ret_type, + Span span) { ObjectPtr n = make_object(); n->name = std::move(name); n->params = std::move(params); n->body = std::move(body); n->ret_type = std::move(ret_type); + n->span = span; data_ = std::move(n); } TVM_REGISTER_GLOBAL("relax.Function") -.set_body_typed([](runtime::Optional name, - Array params, - Expr body, - Type ret_type) { - return Function(name, params, body, ret_type); -}); + .set_body_typed([](runtime::Optional name, Array params, Expr body, + Type ret_type, + Span span) { return Function(name, params, body, ret_type, span); }); TVM_REGISTER_NODE_TYPE(ExternFuncNode); -ExternFunc::ExternFunc(String global_symbol) { +ExternFunc::ExternFunc(String global_symbol, Span span) { ObjectPtr n = make_object(); n->global_symbol = std::move(global_symbol); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ExternFunc") -.set_body_typed([](String global_symbol) { - return ExternFunc(global_symbol); +TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) { + return ExternFunc(global_symbol, span); }); -} // namespace relax -} // namespace tvm +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 676dfb22a9..242bb3249e 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1,23 +1,23 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ -#include + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ #include +#include namespace tvm { namespace relax { @@ -35,8 +35,7 @@ Expr MakeCallDPS(ShapeExpr shape, Expr func, Tuple args) { return Call(op, {shape, func, args}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.call_dps") -.set_body_typed(MakeCallDPS); +TVM_REGISTER_GLOBAL("relax.op.call_dps").set_body_typed(MakeCallDPS); // shape_of @@ -51,5 +50,6 @@ Expr MakeShapeOf(Expr expr) { TVM_REGISTER_GLOBAL("relax.op.shape_of") .set_body_typed(MakeShapeOf); -} // namespace relax -} // namespace tvm + +} // namespace relax +} // namespace tvm diff --git a/src/relax/type.cc b/src/relax/type.cc index 498d6082de..9c398c23ea 100644 --- a/src/relax/type.cc +++ b/src/relax/type.cc @@ -29,30 +29,37 @@ namespace relax { TVM_REGISTER_NODE_TYPE(ShapeTypeNode); -ShapeType::ShapeType() { +ShapeType::ShapeType(Span span) { ObjectPtr n = make_object(); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeType") -.set_body_typed([]() { - return ShapeType(); -}); +TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](Span span) { return ShapeType(span); }); -DynTensorType::DynTensorType(int rank, DataType dtype) { +DynTensorType::DynTensorType(int rank, DataType dtype, Span span) { ObjectPtr n = make_object(); n->rank = std::move(rank); n->dtype = std::move(dtype); + n->span = span; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DynTensorTypeNode); -TVM_REGISTER_GLOBAL("relax.DynTensorType") -.set_body_typed([](int rank, DataType dtype) { - return DynTensorType(rank, dtype); +TVM_REGISTER_GLOBAL("relax.DynTensorType").set_body_typed([](int rank, DataType dtype, Span span) { + return DynTensorType(rank, dtype, span); }); +DimType::DimType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DimTypeNode); + +TVM_REGISTER_GLOBAL("relax.DimType").set_body_typed([](Span span) { return DimType(span); }); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/parser.py b/tests/python/relax/parser.py new file mode 100644 index 0000000000..12ba128e25 --- /dev/null +++ b/tests/python/relax/parser.py @@ -0,0 +1,390 @@ +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import tvm +from tvm import relax as rx +from tvm import tir, relay +from tvm.ir import structural_equal + +# TODO: replace xfails with proper diagnostics checking. +# c.f. tests/python/unittest/test_tvmscript_error_report.py + + +def rx_func(func): + return func.module[func.fn_name] + + +def check_shape(e, s): + if isinstance(e, rx.Expr): + e = e.shape_ + + if s is None: + assert e is None + return + + assert len(e) == len(s) + + for edim, sdim in zip(e, s): + if isinstance(sdim, str): + assert isinstance(edim, tir.Var) + assert edim.name == sdim + else: + assert isinstance(edim, tir.IntImm) + assert edim.value == sdim + + +def check_tensor_var(v, s, d): + assert isinstance(v.type_annotation, rx.ty.DynTensorType) + assert v.type_annotation.dtype == d + if isinstance(s, (list, tuple)): + assert v.type_annotation.rank == len(s) + check_shape(v, s) + + +def check_call(call, op, args): + assert isinstance(call, rx.Call) + if isinstance(op, str): + op = relay.op.get(op) + assert call.op == op + assert structural_equal(call.args, args) + + +def test_annotations(): + @rx.script + def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: + z: Tensor[(32, k), "float32"] = nn.matmul(x, y) + w: Tensor[_, _] = multiply(z, z) + t = subtract(w, z) + sh: Shape = t.shape + return t + + f = rx_func(foo) + x, y = f.params + z_bind, w_bind, t_bind, sh_bind = f.body.blocks[0].bindings + z, mm = z_bind.var, z_bind.value + w, mul = w_bind.var, w_bind.value + t, sub = t_bind.var, t_bind.value + sh, shape_of = sh_bind.var, sh_bind.value + + check_tensor_var(x, (32, "m"), "float32") + check_tensor_var(y, ("m", "k"), "float32") + check_tensor_var(z, (32, "k"), "float32") + check_tensor_var(w, None, "") + assert t.type_annotation is None + assert isinstance(sh.type_annotation, rx.ty.ShapeType) + + check_call(mm, "nn.matmul", [x, y]) + check_call(mul, "multiply", [z, z]) + check_call(sub, "subtract", [w, z]) + check_call(shape_of, "shape_of", [t]) + + assert f.body.body == t + + assert isinstance(f.ret_type, rx.ty.DynTensorType) + + +def test_match_shape(): + @rx.script + def foo(x: Tensor[_, "float32"]): + relax.match_shape((n, m), x.shape) + y: Tensor[(n, m), "float32"] = add(x, x) + return x + + f = rx_func(foo) + match_sh = f.body.blocks[0].bindings[0] + pattern, value = match_sh.pattern, match_sh.value + + check_shape(pattern, ("n", "m")) + check_call(value, "shape_of", [f.params[0]]) + + +@pytest.mark.xfail +def test_dim_var_intro_fail(): + @rx.script + def foo(x: Tensor[_, _]): + y: Tensor[(n, m), "float32"] = x + return y + + +def test_if(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return y + + f = rx_func(foo) + cond, x = f.params + y_bind = f.body.blocks[0].bindings[0] + y, ite = y_bind.var, y_bind.value + + check_tensor_var(cond, tuple(), "bool") + check_tensor_var(x, (1,), "float32") + + assert isinstance(y, rx.Var) + assert y.name_hint == "y" + + assert isinstance(ite, rx.If) + assert isinstance(ite.true_branch, rx.SeqExpr) + assert isinstance(ite.false_branch, rx.SeqExpr) + + w_bind = ite.true_branch.blocks[0].bindings[0] + body = ite.true_branch.body + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "add", [x, x]) + check_call(body, "multiply", [w_bind.var, w_bind.var]) + + w_bind = ite.false_branch.blocks[0].bindings[0] + body = ite.false_branch.body + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "multiply", [x, x]) + check_call(body, "add", [w_bind.var, w_bind.var]) + + +# TODO: figure out if-else binding type and shape + + +@pytest.mark.xfail +def test_var_redefine_fail(): + @rx.script + def foo(x, y): + z = add(x, y) + y = z + return y + + +@pytest.mark.xfail +def test_var_redefine_fail_if(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + y = x + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return y + + +@pytest.mark.xfail +def test_var_if_scoping_fail(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return w + + +@pytest.mark.xfail +def test_if_mismatch_var_fail(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + z = add(w, w) + return z + + +@pytest.mark.xfail +def test_unassigned_call_fail(): + @rx.script + def foo(x: Tensor[_, _]): + add(x, x) + return x + + +def test_tuple(): + @rx.script + def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]): + t: Tuple[Tensor[_, _], Tensor[(32,), "float32"]] = (x, y) + return t + + f = rx_func(foo) + x, y = f.params + t_bind = f.body.blocks[0].bindings[0] + t, tup = t_bind.var, t_bind.value + + assert isinstance(t.type_annotation, relay.TupleType) + annot = t.type_annotation + assert isinstance(annot.fields[0], rx.ty.DynTensorType) and annot.fields[0].dtype == "" + assert isinstance(annot.fields[1], rx.ty.DynTensorType) and annot.fields[1].dtype == "float32" + + assert t.shape_ is None + + assert isinstance(tup, rx.Tuple) + assert structural_equal(tup.fields, [x, y]) + assert tup.shape_ is None + check_shape(tup.fields[0], None) + check_shape(tup.fields[1], (32,)) + + +def test_local_func(): + @rx.script + def foo(x: Tensor[_, _]): + def bar(y: Tensor[_, _]): + return y + + y = bar(x) # tests local function variable scoping + return y + + f = rx_func(foo) + bar_bind, y_bind = f.body.blocks[0].bindings + bar, bar_fn = bar_bind.var, bar_bind.value + bar_x = y_bind.value + + assert isinstance(bar_fn, rx.Function) + assert bar_fn.body.body == bar_fn.params[0] + + assert bar_x.op == bar + + +def test_dataflow(): + @rx.script + def foo(x: Tensor[_, _]): + with relax.dataflow(): + # TODO: parse this + # nonlocal y, w + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + return y, w + t = divide(y, w) + return t + + f = rx_func(foo) + assert len(f.body.blocks) == 2 + df_block = f.body.blocks[0] + y_bind, z_bind, w_bind = df_block.bindings + (t_bind,) = f.body.blocks[1].bindings + x = f.params[0] + y, z, w, t = map(lambda b: b.var, [y_bind, z_bind, w_bind, t_bind]) + + assert isinstance(y, rx.Var) + assert isinstance(z, rx.DataflowVar) + assert isinstance(w, rx.Var) + + check_call(y_bind.value, "add", [x, x]) + check_call(z_bind.value, "multiply", [y, x]) + check_call(w_bind.value, "subtract", [z, x]) + check_call(t_bind.value, "divide", [y, w]) + + assert f.body.body == t + + +def test_dataflow_match_shape(): + @rx.script + def foo(x: Tensor[_, _]): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + relax.match_shape((n, m), z.shape) + w: Tensor[(n, m), _] = subtract(z, x) + return y, w + t: Tensor[(n, m), _] = divide(y, w) + return t + + +@pytest.mark.xfail +def test_dataflow_scope_fail(): + @rx.script + def foo(x: Tensor[_, _]): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + return y, w + t = divide(y, z) + return t + + +@pytest.mark.xfail +def test_dataflow_syntax_fail_pattern(): + @rx.script + def foo(x: Tensor[_, _]): + with relax.dataflow() as df: + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + return y, w + t = divide(y, z) + return t + + +@pytest.mark.xfail +def test_dataflow_syntax_fail_params(): + @rx.script + def foo(x: Tensor[_, _]): + with relax.dataflow(x) as df: + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + return y, w + t = divide(y, z) + return t + + +@pytest.mark.xfail +def test_func_no_return_fail(): + @rx.script + def foo(x: Tensor[_, _]): + y = add(x, x) + + +def test_inline_tir(): + @rx.script + def foo(x: Tensor[(B, 128), "float32"], y: Tensor[(128, 128), "float32"]): + @tir + def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = tir.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + z = relax.call_dps((B, 128), my_matmul, (x, y)) + return z + + f = rx_func(foo) + x, y = f.params + B = x.shape_[0] + mm_bind, z_bind = f.body.blocks[0].bindings + + assert mm_bind.var.name_hint == "my_matmul" + assert isinstance(mm_bind.value, tir.PrimFunc) + + check_call( + z_bind.value, + "relax.call_dps", + [rx.ShapeExpr([B, tir.IntImm("int32", 128)]), mm_bind.var, rx.Tuple([x, y])], + ) + + +def test_call_packed(): + @rx.script + def foo(x: Tensor[(3, 4), "float32"]): + # test that we can intro dim vars + z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x)) + return z + + f = rx_func(foo) + x = f.params[0] + (z_bind,) = f.body.blocks[0].bindings + check_tensor_var(z_bind.var, ("n", "m"), "float32") + + assert isinstance(z_bind.value.op, rx.ExternFunc) + assert z_bind.value.op.global_symbol == "contrib.my_matmul" + assert structural_equal(z_bind.value.args, [rx.Tuple([x, x])])