diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 11bf7d4740d0..5e96f2de5d3f 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -89,6 +90,9 @@ class TypeFunctor { virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const relax::ShapeTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const relax::DynTensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const relax::DimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning @@ -112,6 +116,9 @@ class TypeFunctor { TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode); TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(relax::ShapeTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(relax::DynTensorTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(relax::DimTypeNode); return vtable; } }; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 5c05e7c9aaec..9868fa699e76 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -96,6 +96,7 @@ class VarNode : public ExprNode { } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) && // Do we use the analysis information in equality? equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); @@ -140,6 +141,7 @@ class DataflowVarNode : public VarNode { } bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) && equal(shape_, other->shape_) && equal(checked_type_, other->checked_type_); } diff --git a/include/tvm/relax/ir_functor.h b/include/tvm/relax/ir_functor.h new file mode 100644 index 000000000000..b9a17f19ef0e --- /dev/null +++ b/include/tvm/relax/ir_functor.h @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/ir_functor.h + * \brief A generic visitor for traversing Relax IR nodes. + */ +#ifndef TVM_RELAX_IR_FUNCTOR_H_ +#define TVM_RELAX_IR_FUNCTOR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +template +class IRFunctor; + +#define IR_FUNCTOR_DEFAULT \ + { return VisitNodeDefault_(op, std::forward(args)...); } + +#define RELAX_IR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitNode_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class IRFunctor { + private: + using TSelf = IRFunctor; + using FType = NodeFunctor; + + public: + using result_type = R; + virtual ~IRFunctor() {} + + R operator()(const ObjectRef& n, Args... args) { + return VisitNode(n, std::forward(args)...); + } + + virtual R VisitNode(const ObjectRef& n, Args... args) { + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + + // IR nodes inherited from Relay + virtual R VisitNode_(const relay::ConstantNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::TupleNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::GlobalVarNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::CallNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::IfNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const OpNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::TupleGetItemNode* op, Args... args) IR_FUNCTOR_DEFAULT; + + // IR nodes introduced by Relax + virtual R VisitNode_(const relax::VarNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::DataflowVarNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::ShapeExprNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::MatchShapeNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::VarBindingNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::BindingBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::DataflowBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::SeqExprNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::FunctionNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::ExternFuncNode* op, Args... args) IR_FUNCTOR_DEFAULT; + + virtual R VisitNodeDefault_(const Object* op, Args...) { + LOG(FATAL) << "no default visitor implemented for " << op->GetTypeKey(); + throw; + } + + private: + static FType InitVTable() { + FType vtable; + RELAX_IR_FUNCTOR_DISPATCH(relay::ConstantNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::TupleNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::GlobalVarNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::CallNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::IfNode); + RELAX_IR_FUNCTOR_DISPATCH(OpNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::TupleGetItemNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::VarNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::DataflowVarNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::ShapeExprNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::MatchShapeNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::VarBindingNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::BindingBlockNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::DataflowBlockNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::SeqExprNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::FunctionNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::ExternFuncNode); + return vtable; + } +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_FUNCTOR_H_ \ No newline at end of file diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index d7743c002626..061e78c6c05d 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -129,7 +129,8 @@ def __init__(self, bindings: List[Binding], span: Span = None) -> None: @tvm._ffi.register_object("relax.expr.DataflowBlock") class DataflowBlock(BindingBlock): - pass + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DataflowBlock, bindings, span) @tvm._ffi.register_object("relax.expr.SeqExpr") diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 62e587b70862..441da049c976 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -23,12 +23,49 @@ import tvm.relax as rx -# TODO(@altanh): Replace with a real pretty print method once we have the real AST -def pretty_print(f): - print(f) +def pretty_print(node): + """Prints the given Relax IR node in the Relax text format. + Parameters + ---------- + node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock] + The Relax IR node to print. + """ + print(tvm.script._ffi_api.AsRelaxScript(node)) -def is_registered(op_name: str, op_set=None) -> bool: + +def astext(node) -> str: + """Returns the Relax text format representation of the given Relax IR node. + + Parameters + ---------- + node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock] + The Relax IR node to print. + + Returns + ------- + str + The text format representation of the given Relax IR node. + """ + return tvm.script._ffi_api.AsRelaxScript(node) + + +def _is_registered(op_name: str, op_set=None) -> bool: + """Returns whether or not the given operator is registered. + + Parameters + ---------- + op_name : str + The name of the operator. + op_set : Union[Container, Iterable], optional + The collection of registered operator names to check against. If None, the global TVM + operator registry is queried. + + Returns + ------- + bool + True if the specified operator is registered, else False. + """ if op_set is None: op_set = tvm.ir._ffi_api.ListOpNames() return op_name in op_set @@ -37,12 +74,28 @@ def is_registered(op_name: str, op_set=None) -> bool: def _tir_from_synr( synr_ast: ast.Node, diag_ctx: tvm.script.diagnostics.TVMDiagnosticCtx ) -> tir.PrimFunc: + """Parses the given synr AST using the TVMScript parser to a PrimFunc. + + Parameters + ---------- + synr_ast : ast.Node + The synr AST to be parsed. + diag_ctx : tvm.script.diagnostics.TVMDiagnosticCtx + The diagnostic context for TVMScript parser error reporting. + + Returns + ------- + tir.PrimFunc + The parsed TIR PrimFunc. + """ parser = tvm.script.parser.TVMScriptParser(synr_ast.span.start_line) return parser.do_transform(synr_ast, diag_ctx) # NOTE: call_dps is an actual registered operator class SpecialOp(Enum): + """Relax operator calls that have special semantics handled by the parser.""" + MATCH_SHAPE = "relax.match_shape" CALL_PACKED = "relax.call_packed" DATAFLOW = "relax.dataflow" @@ -196,7 +249,7 @@ def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[rx.Type, r ) shape_annotation, dtype_annotation = ty.params - shape, dtype = None, None + shape, dtype, rank = None, None, -1 # parse the shape annotation if isinstance(shape_annotation, ast.TypeVar): @@ -210,13 +263,25 @@ def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[rx.Type, r # FIXME(@altanh): use a special node for unknown shape vs no shape? pass # shape = None elif isinstance(shape_annotation, ast.TypeTuple): - shape = rx.ShapeExpr( - self.parse_shape(shape_annotation, bind_free_vars), - span=self.to_tvm_span(shape_annotation.span), + # the syntax for fixed rank k but unknown/unmatched shape is a tuple of length + # k, where each element is "_" (e.g. "(_, _)" for rank 2) + is_unmatched = all( + map( + lambda v: isinstance(v, ast.Var) and v.id.name == "_", + shape_annotation.values, + ) ) + if len(shape_annotation.values) > 0 and is_unmatched: + rank = len(shape_annotation.values) + else: + shape = rx.ShapeExpr( + self.parse_shape(shape_annotation, bind_free_vars), + span=self.to_tvm_span(shape_annotation.span), + ) + rank = len(shape) else: self.report_error( - "unsupported shape annotation", + f"unsupported shape annotation {shape_annotation}", shape_annotation.span, ) @@ -231,7 +296,6 @@ def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[rx.Type, r dtype_annotation.span, ) - rank = len(shape) if shape is not None else -1 return (rx.DynTensorType(rank=rank, dtype=dtype, span=span), shape) elif ty.id.name == "Tuple": field_types = [] @@ -242,7 +306,7 @@ def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[rx.Type, r field_shapes.append(fsh) return (relay.TupleType(field_types, span), None) # TODO(@altanh): other types with args, e.g. Ref[T], func types - self.report_error("invalid type", span) + self.report_error("invalid type", ty.span) def parse_shape( self, @@ -331,6 +395,18 @@ def transform_module(self, mod: ast.Module) -> IRModule: self.module[func_name] = self.transform_function(func, is_global=True) return self.module + def _parse_attrs_to_str(self, expr: ast.Attr) -> str: + strs = [] + attr = expr + while isinstance(attr, ast.Attr): + strs.append(attr.field.name) + attr = attr.object + if not isinstance(attr, ast.Var): + self.report_error("unsupported attribute access", expr.span) + strs.append(attr.id.name) + result = ".".join(reversed(strs)) + return result + def transform_function(self, func: ast.Function, is_global: bool = False) -> rx.Function: """Transforms the given synr Function to a Relax Function. @@ -348,8 +424,7 @@ def transform_function(self, func: ast.Function, is_global: bool = False) -> rx. """ if ( len(func.decorators) == 1 - and isinstance(func.decorators[0], ast.Var) - and func.decorators[0].id.name == "tir" + and self._parse_attrs_to_str(func.decorators[0]) == "tvm.script.tir" ): return _tir_from_synr(func, self._diagnostic_context) @@ -664,18 +739,10 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: if isinstance(expr, ast.Attr): if expr.field.name == "shape": obj = self.transform_expr(expr.object) - return relay.op.shape_of(obj) + return relay.Call(relay.op.get("shape_of"), [obj], span=self.to_tvm_span(expr.span)) else: # assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps) - op_name = [] - attr = expr - while isinstance(attr, ast.Attr): - op_name.append(expr.field.name) - attr = attr.object - if not isinstance(attr, ast.Var): - self.report_error("unsupported field access", expr.span) - op_name.append(attr.id.name) - op_name = ".".join(reversed(op_name)) + op_name = self._parse_attrs_to_str(expr) # NOTE: at least for now, all special operators are namespaced try: return SpecialOp(op_name) @@ -684,6 +751,7 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: return relay.op.get(op_name) if isinstance(expr, ast.Call): + # TODO(@altanh): support parsing kwargs as attributes? op = self.transform_expr(expr.func_name) if op == SpecialOp.CALL_PACKED: if len(expr.params) != 2: @@ -719,7 +787,7 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: elif isinstance(expr, ast.Var): var_name = expr.id.name - if is_registered(var_name, op_set=self._registered_ops): + if _is_registered(var_name, op_set=self._registered_ops): return relay.op.get(var_name) if var_name not in self.scope: self.report_error("undefined variable", expr.span) @@ -839,7 +907,7 @@ def __call__(self, *args): def script(f) -> RelaxDecoratedFn: - """Parses the decorated Relax function (in Relax IR) to the Relax AST + """Parses the decorated Relax function (in Relax IR) to a Relax AST. Parameters ---------- @@ -858,3 +926,24 @@ def script(f) -> RelaxDecoratedFn: definition_scope = inspect.getmodule(f) module = RelaxTransformer(definition_scope).do_transform(ast, diag_ctx) return RelaxDecoratedFn(f.__name__, module, diag_ctx) + + +def fromtext(source: str, source_name: str = "from_string"): + """Parses the given input string (in the Relax text format) to a Relax AST. + + Parameters + ---------- + source : str + The input source string. + source_name : str, optional + A descriptive name for error reporting, by default "from_string". + + Returns + ------- + Relax AST + The parsed Relax AST. + """ + diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() + ast = synr.to_ast(source, diag_ctx) + module = RelaxTransformer(None).do_transform(ast, diag_ctx) + return module diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc new file mode 100644 index 000000000000..3181d38ba18c --- /dev/null +++ b/src/printer/relax_script_printer.cc @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file printer/relax_script_printer.cc + * \brief Printer class to print Relax IR to parsable Python + */ + +#include +#include + +#include +#include + +#include "doc.h" +#include "text_printer.h" + +namespace tvm { +namespace relax { + +class RelaxScriptPrinter : public relax::IRFunctor, + public TypeFunctor { + public: + TVM_DLL Doc Print(const ObjectRef& node); + + private: + std::unordered_map name_alloc_map_; + std::unordered_map var_id_map_; + std::unordered_map dim_var_map_; + + // IR nodes inherited from Relay + // Doc VisitNode_(const relay::ConstantNode* op) override; + Doc VisitNode_(const relay::TupleNode* op) override; + Doc VisitNode_(const relay::GlobalVarNode* op) override; + Doc VisitNode_(const relay::CallNode* op) override; + // Doc VisitNode_(const relay::IfNode* op) override; + Doc VisitNode_(const OpNode* op) override; + Doc VisitNode_(const relay::TupleGetItemNode* op) override; + + // IR nodes introduced by Relax + Doc VisitNode_(const relax::VarNode* op) override; + Doc VisitNode_(const relax::DataflowVarNode* op) override; + Doc VisitNode_(const relax::ShapeExprNode* op) override; + Doc VisitNode_(const relax::MatchShapeNode* op) override; + Doc VisitNode_(const relax::VarBindingNode* op) override; + Doc VisitNode_(const relax::BindingBlockNode* op) override; + Doc VisitNode_(const relax::DataflowBlockNode* op) override; + Doc VisitNode_(const relax::SeqExprNode* op) override; + Doc VisitNode_(const relax::FunctionNode* op) override; + Doc VisitNode_(const relax::ExternFuncNode* op) override; + + Doc PrintDimVar(const tir::Var& var); + Doc PrintIfStmt(const relax::Var& var, const relay::If& ite); + Doc PrintFunctionDef(const Doc& name, const relax::Function& func); + + Doc PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional& shape); + + Doc VisitType_(const relax::ShapeTypeNode* node) override; + Doc VisitType_(const relax::DynTensorTypeNode* node) override; + Doc VisitType_(const relay::TupleTypeNode* node) override; + + Doc GetUniqueName(std::string prefix, std::string fallback); +}; + +Doc RelaxScriptPrinter::Print(const ObjectRef& node) { + if (node->IsInstance()) { + return VisitType(Downcast(node)); + } else { + return VisitNode(node); + } +} + +Doc RelaxScriptPrinter::VisitNode_(const relay::TupleNode* op) { + size_t num_fields = op->fields.size(); + + if (num_fields == 0) { + return Doc::Text("tuple()"); + } + + Doc doc; + std::vector fields; + + for (size_t i = 0; i < num_fields; ++i) { + fields.push_back(Print(op->fields[i])); + } + doc << "(" << Doc::Concat(fields, Doc::Text(", ")); + if (num_fields == 1) { + doc << ","; + } + doc << ")"; + + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relay::GlobalVarNode* op) { + return Doc::Text(op->name_hint); +} + +Doc RelaxScriptPrinter::VisitNode_(const relay::CallNode* op) { + Doc doc; + + if (const relax::ExternFuncNode* ext = op->op.as()) { + ICHECK_EQ(op->args.size(), 1) << "extern calls should only have one argument"; + doc << "relax.call_packed(" << Print(op->op) << ", " << Print(op->args[0]) << ")"; + return doc; + } + + // TODO(@altanh): how to support when func cannot be printed as Python expr? + // e.g. Function or If + doc << Print(op->op); + if (op->args.empty()) { + doc << "()"; + return doc; + } + + std::vector args; + for (size_t i = 0; i < op->args.size(); ++i) { + args.push_back(Print(op->args[i])); + } + doc << "(" << Doc::Concat(args, Doc::Text(", ")) << ")"; + + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const OpNode* op) { return Doc::Text(op->name); } + +Doc RelaxScriptPrinter::VisitNode_(const relay::TupleGetItemNode* op) { + Doc doc; + doc << Print(op->tuple) << "[" << op->index << "]"; + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::VarNode* op) { + if (!var_id_map_.count(op->vid)) { + var_id_map_[op->vid] = GetUniqueName(op->name_hint(), "v"); + } + + return var_id_map_[op->vid]; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowVarNode* op) { + if (!var_id_map_.count(op->vid)) { + var_id_map_[op->vid] = GetUniqueName(op->name_hint(), "dv"); + } + + return var_id_map_[op->vid]; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::ShapeExprNode* op) { + // TODO(@altanh): support more PrimExpr printing, and check that empty tuple + // is never ambiguously printed as "()" + Doc doc; + + std::vector fields; + for (size_t i = 0; i < op->values.size(); ++i) { + auto val = op->values[i]; + if (const tir::VarNode* var = val.as()) { + fields.push_back(PrintDimVar(GetRef(var))); + } else if (const tir::IntImmNode* num = val.as()) { + fields.push_back(Doc::Text(std::to_string(num->value))); + } else { + LOG(FATAL) << "cannot print PrimExpr: " << val->GetTypeKey(); + } + } + doc << "(" << Doc::Concat(fields, Doc::Text(", ")); + if (fields.size() == 1) { + doc << ","; + } + doc << ")"; + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::MatchShapeNode* op) { + Doc doc; + doc << "relax.match_shape("; + // TODO(@altanh): maybe op->pattern should just be a ShapeExpr? + doc << Print(relax::ShapeExpr(op->pattern)) << ", " << Print(op->value); + doc << ")"; + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::VarBindingNode* op) { + // TODO(@altanh): think deeper about normal form (need to be strict about block exprs) + if (const relay::IfNode* ite = op->value.as()) { + return PrintIfStmt(op->var, GetRef(ite)); + } else if (const relax::FunctionNode* func = op->value.as()) { + return PrintFunctionDef(Print(op->var), GetRef(func)); + } else if (const tir::PrimFuncNode* prim_func = op->value.as()) { + // we need the mod for TVMScriptPrinter to properly print the function name - maybe it's worth + // refactoring to avoid this? + tir::PrimFunc prim_func_ref = GetRef(prim_func); + IRModule mod; + mod->Add(relay::GlobalVar(op->var->name_hint()), prim_func_ref); + return tir::AsTVMScriptDoc(mod, false, prim_func_ref); + } else { + Doc doc; + doc << Print(op->var); + if (op->var->type_annotation.defined()) { + doc << ": "; + if (const relax::DynTensorTypeNode* tty = + op->var->type_annotation.as()) { + doc << PrintTensorAnnotation(GetRef(tty), op->var->shape_); + } else { + doc << Print(op->var->type_annotation); + } + } + doc << " = " << Print(op->value); + return doc; + } +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::BindingBlockNode* op) { + Doc doc; + for (size_t i = 0; i < op->bindings.size(); ++i) { + doc << Print(op->bindings[i]) << Doc::NewLine(); + } + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) { + Doc block; + Doc body; + std::vector return_vars; + for (size_t i = 0; i < op->bindings.size(); ++i) { + body << Print(op->bindings[i]) << Doc::NewLine(); + if (const relax::VarBindingNode* binding = op->bindings[i].as()) { + if (!binding->var.as()) { + return_vars.push_back(Print(binding->var)); + } + } + } + ICHECK(!return_vars.empty()) << "dataflow blocks should have at least one output variable"; + body << "return " << Doc::Concat(return_vars, Doc::Text(", ")); + block << "with relax.dataflow():" << Doc::NewLine(4); + block << Doc::Indent(4, body) << Doc::NewLine(); + return block; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::SeqExprNode* op) { + Doc doc; + for (size_t i = 0; i < op->blocks.size(); ++i) { + doc << Print(op->blocks[i]); + } + // NOTE: the body expression is printed in the parent, since SeqExprs are used for both Function + // bodies and If expr bodies (which don't have a "return" statement but instead a binding) + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::FunctionNode* op) { + ICHECK(op->name.defined()); + return PrintFunctionDef(Doc::Text(op->name.value()->name_hint), GetRef(op)); +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::ExternFuncNode* op) { + return Doc::StrLiteral(op->global_symbol); +} + +Doc RelaxScriptPrinter::VisitType_(const relax::ShapeTypeNode* node) { return Doc::Text("Shape"); } + +Doc RelaxScriptPrinter::VisitType_(const relax::DynTensorTypeNode* node) { + // NOTE: to print shape information, use PrintTensorAnnotation + return PrintTensorAnnotation(GetRef(node), NullOpt); +} + +Doc RelaxScriptPrinter::VisitType_(const relay::TupleTypeNode* node) { + if (node->fields.empty()) { + return Doc::Text("Tuple[]"); + } + + Doc doc; + + std::vector fields; + for (size_t i = 0; i < node->fields.size(); ++i) { + fields.push_back(Print(node->fields[i])); + } + doc << "Tuple[" << Doc::Concat(fields, Doc::Text(", ")) << "]"; + + return doc; +} + +Doc RelaxScriptPrinter::PrintDimVar(const tir::Var& var) { + if (!dim_var_map_.count(var)) { + dim_var_map_[var] = GetUniqueName(var->name_hint, "dim"); + } + + return dim_var_map_[var]; +} + +Doc RelaxScriptPrinter::PrintIfStmt(const relax::Var& var, const relay::If& ite) { + const relax::SeqExprNode* true_branch = ite->true_branch.as(); + const relax::SeqExprNode* false_branch = ite->false_branch.as(); + // TODO(@altanh): this invariant must be maintained by the normal form + ICHECK(true_branch && false_branch) + << "in the Relax IR normal form, each branch of a If expression should be a SeqExpr"; + + Doc doc; + doc << "if " << Print(ite->cond) << ":" << Doc::NewLine(4); + doc << Doc::Indent(4, Print(GetRef(true_branch))); + doc << Doc::Indent(4, Print(relax::VarBinding(var, true_branch->body))); + doc << Doc::NewLine(); + doc << "else:" << Doc::NewLine(4); + doc << Doc::Indent(4, Print(GetRef(false_branch))); + doc << Doc::Indent(4, Print(relax::VarBinding(var, false_branch->body))); + return doc; +} + +Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& func) { + Doc doc; + + std::vector params; + for (size_t i = 0; i < func->params.size(); ++i) { + relax::Var var = func->params[i]; + Doc param; + param << Print(var); + if (var->type_annotation.defined()) { + param << ": "; + if (const relax::DynTensorTypeNode* tty = + var->type_annotation.as()) { + param << PrintTensorAnnotation(GetRef(tty), var->shape_); + } else { + param << Print(var->type_annotation); + } + } + params.push_back(param); + } + + doc << "def " << name << "(" << Doc::Concat(params, Doc::Text(", ")) << ")"; + if (func->ret_type.defined()) { + doc << " -> " << Print(func->ret_type); + } + doc << ":" << Doc::NewLine(4); + + const relax::SeqExprNode* body = func->body.as(); + ICHECK(body) << "in the Relax IR normal form, the body of a Function should be a SeqExpr"; + + doc << Doc::Indent(4, Print(func->body)); + doc << Doc::Indent(4, Doc::Text("return ") << Print(body->body)) << Doc::NewLine(); + return doc; +} + +Doc RelaxScriptPrinter::PrintTensorAnnotation(const relax::DynTensorType& ty, + const Optional& shape) { + Doc doc; + // doc << "Tensor[" + // << (shape.defined() ? Print(Downcast(shape.value())) : Doc::Text("_")) << ", "; + doc << "Tensor["; + if (shape.defined()) { + doc << Print(Downcast(shape.value())); + } else if (ty->rank != -1) { + ICHECK_GE(ty->rank, 0) << "DynTensor ranks must be -1 (unknown) or nonnegative"; + std::vector dims(ty->rank, Doc::Text("_")); + doc << "(" << Doc::Concat(dims, Doc::Text(", ")); + if (ty->rank == 1) { + doc << ","; + } + doc << ")"; + } else { + doc << "_"; + } + doc << ", "; + if (ty->dtype.is_void()) { + doc << "_"; + } else { + doc << Doc::StrLiteral(runtime::DLDataType2String(ty->dtype)); + } + doc << "]"; + return doc; +} + +Doc RelaxScriptPrinter::GetUniqueName(std::string prefix, std::string fallback = "x") { + if (prefix.empty()) { + prefix = fallback; + } + // TODO(@altanh): more robust name legalization + std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (name_alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { + } + } + name_alloc_map_[unique_prefix] = 0; + return Doc::Text(unique_prefix); +} + +String AsRelaxScript(const ObjectRef& mod) { + ICHECK(mod->IsInstance()); + return "@tvm.script.relax\n" + RelaxScriptPrinter().Print(mod).str() + "\n"; +} + +TVM_REGISTER_GLOBAL("script.AsRelaxScript").set_body_typed(AsRelaxScript); + +} // namespace relax +} // namespace tvm \ No newline at end of file diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 2dc0997f82ec..4f525cec6014 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -410,6 +410,8 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool sh String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate); +Doc AsTVMScriptDoc(const ObjectRef& mod, bool show_meta = false, const PrimFunc& func = PrimFunc()); + } // namespace tir } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 0dc6240bc6ca..8f54fe533779 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -2000,6 +2000,15 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_met return doc.str() + "\n"; } +Doc AsTVMScriptDoc(const ObjectRef& mod, bool show_meta, const PrimFunc& func) { + ICHECK(mod->IsInstance() || mod->IsInstance()); + TVMScriptPrinter printer = TVMScriptPrinter(show_meta); + Doc mod_doc = printer.Print(mod); + Doc doc = Doc::Text("@tvm.script.tir") << Doc::NewLine(); + doc << (func.defined() ? printer.PrintPrimFunc(func) : mod_doc) << Doc::NewLine(); + return doc; +} + TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, diff --git a/tests/python/relax/parser.py b/tests/python/relax/test_parser.py similarity index 96% rename from tests/python/relax/parser.py rename to tests/python/relax/test_parser.py index 12ba128e2595..62aab2336bab 100644 --- a/tests/python/relax/parser.py +++ b/tests/python/relax/test_parser.py @@ -32,11 +32,13 @@ def check_shape(e, s): assert edim.value == sdim -def check_tensor_var(v, s, d): +def check_tensor_var(v, s, d, rank=None): assert isinstance(v.type_annotation, rx.ty.DynTensorType) assert v.type_annotation.dtype == d if isinstance(s, (list, tuple)): assert v.type_annotation.rank == len(s) + if rank is not None: + assert v.type_annotation.rank == rank check_shape(v, s) @@ -53,15 +55,17 @@ def test_annotations(): def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: z: Tensor[(32, k), "float32"] = nn.matmul(x, y) w: Tensor[_, _] = multiply(z, z) + q: Tensor[(_, _), _] = add(w, w) t = subtract(w, z) sh: Shape = t.shape return t f = rx_func(foo) x, y = f.params - z_bind, w_bind, t_bind, sh_bind = f.body.blocks[0].bindings + z_bind, w_bind, q_bind, t_bind, sh_bind = f.body.blocks[0].bindings z, mm = z_bind.var, z_bind.value w, mul = w_bind.var, w_bind.value + q, add = q_bind.var, w_bind.value t, sub = t_bind.var, t_bind.value sh, shape_of = sh_bind.var, sh_bind.value @@ -69,6 +73,7 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: check_tensor_var(y, ("m", "k"), "float32") check_tensor_var(z, (32, "k"), "float32") check_tensor_var(w, None, "") + check_tensor_var(q, None, "", rank=2) assert t.type_annotation is None assert isinstance(sh.type_annotation, rx.ty.ShapeType) @@ -344,7 +349,7 @@ def foo(x: Tensor[_, _]): def test_inline_tir(): @rx.script def foo(x: Tensor[(B, 128), "float32"], y: Tensor[(128, 128), "float32"]): - @tir + @tvm.script.tir def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py new file mode 100644 index 000000000000..3e7ba528d524 --- /dev/null +++ b/tests/python/relax/test_printer.py @@ -0,0 +1,135 @@ +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import tvm +from tvm import relax as rx +from tvm import tir, relay +from tvm.ir import structural_equal, assert_structural_equal + + +def rx_func(func): + return func.module[func.fn_name] + + +def check_roundtrip(fn): + f_pre = rx_func(fn) + f_post = rx.parser.fromtext(rx.parser.astext(f_pre))[fn.fn_name] + assert_structural_equal(f_pre, f_post, map_free_vars=True) + + +def test_annotations(): + @rx.script + def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: + z: Tensor[(32, k), "float32"] = nn.matmul(x, y) + w: Tensor[_, _] = multiply(z, z) + t = subtract(w, z) + sh: Shape = t.shape + return t + + check_roundtrip(foo) + + +def test_match_shape(): + @rx.script + def foo(x: Tensor[_, "float32"]): + relax.match_shape((n, m), x.shape) + y: Tensor[(n, m), "float32"] = add(x, x) + return x + + check_roundtrip(foo) + + + +def test_if(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return y + + check_roundtrip(foo) + + +def test_tuple(): + @rx.script + def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]): + t: Tuple[Tensor[_, _], Tensor[(32,), "float32"]] = (x, y) + return t + + check_roundtrip(foo) + + +def test_local_func(): + @rx.script + def foo(x: Tensor[_, _]): + def bar(y: Tensor[_, _]): + return y + + y = bar(x) # tests local function variable scoping + return y + + check_roundtrip(foo) + + +def test_dataflow(): + @rx.script + def foo(x: Tensor[_, _]): + with relax.dataflow(): + # TODO: parse this + # nonlocal y, w + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + return y, w + t = divide(y, w) + return t + + check_roundtrip(foo) + + +def test_dataflow_match_shape(): + @rx.script + def foo(x: Tensor[_, _]): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + relax.match_shape((n, m), z.shape) + w: Tensor[(n, m), _] = subtract(z, x) + return y, w + t: Tensor[(n, m), _] = divide(y, w) + return t + + check_roundtrip(foo) + + +def test_inline_tir(): + @rx.script + def foo(x: Tensor[(B, 128), "float32"], y: Tensor[(128, 128), "float32"]): + @tvm.script.tir + def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = tir.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + z = relax.call_dps((B, 128), my_matmul, (x, y)) + return z + + check_roundtrip(foo) + + +def test_call_packed(): + @rx.script + def foo(x: Tensor[(3, 4), "float32"]): + # test that we can intro dim vars + z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x)) + return z + + check_roundtrip(foo)