From 208732699974996027eb718c7cac81d2a60312c4 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 29 Jul 2021 16:57:56 -0700 Subject: [PATCH 01/20] Copy jared's frontend --- python/tvm/relax/parser.py | 277 +++++++++++++++++++++++++++++++++++++ 1 file changed, 277 insertions(+) create mode 100644 python/tvm/relax/parser.py diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py new file mode 100644 index 0000000000..69f22deee0 --- /dev/null +++ b/python/tvm/relax/parser.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import inspect +from typing import TypeVar, Generic, Union, Dict +from io import StringIO + +import tvm +from tvm.relay.base import Id +from tvm.relax import expr, op +from tvm.relax.pprint import pretty_print +from tvm.ir import diagnostics +from tvm import tir, relax + +import numpy as np + +import synr +from synr import ast, Transformer +from synr.diagnostic_context import DiagnosticContext + +def print_ty(ty): + if isinstance(ty, expr.Dim): + return "Dim" + elif isinstance(ty, expr.Tensor): + return "Tensor" + else: + return "UNKNOWN" + +def print_fn(func): + buffer = StringIO("") + param_str = "" + for param in func.params: + param_str += f"{param.id.name_hint}: {print_ty(param.ty)}, " + + buffer.write(f"fn {func.name}({param_str}) {{\n") + buffer.write(f"{func.body}\n") + buffer.write("}") + return buffer.getvalue() + + +expr.Function.__str__ = print_fn # type: ignore + +# Module = Dict[str, relax.Function] +# Transformer[Module, relax.Function, relax.Expr, relax.Expr, relax.Expr, relax.Expr, relax.Type]): +class R2Transformer(Transformer): # Transformer[Module, relax.Function, relax.Expr, relax.Expr, relax.Expr, relax.Expr, relax.Type]): + def __init__(self, definition_scope, diag_ctx): + self.definition_scope = definition_scope + self.diag_ctx = diag_ctx + self.str_to_var = {} + self.blocks = [] + self.module = {} + super().__init__() + + def span_to_span(self, span: synr.Span) -> tvm.ir.Span: + src_name = self.diag_ctx.str_to_source_name[span.filename] + tvm_span = tvm.ir.Span(src_name, span.start_line, span.end_line, span.start_column, span.end_column) + return tvm_span + + + def decl_var(self, name, ty, span=None): + identifier = Id(name) + var = expr.Var(identifier, ty, span) + self.str_to_var[name] = var + return var + + def to_type(self, ty): + if ty is None: + return None + + if isinstance(ty, ast.TypeVar): + if ty.id.name == "Tensor": + span = self.span_to_span(ty.span) + return expr.Tensor(None, None, span) + + if isinstance(ty, ast.TypeApply): + if ty.id.name == "Tensor": + dims = [] + # TODO(@jroesch): add support for dtype + for param in ty.params: + if isinstance(param, ast.TypeConstant): + dim = expr.TIRExpr(tir.IntImm("int32", param.value), None) + dims.append(dim) + + return expr.Tensor(expr.Tuple(dims, span=None), None, None) + + # import pdb; pdb.set_trace() + + self._diagnostic_context.emit('error', "invalid type", self.span_to_span(ty.span)) + self._diagnostic_context.render() + + def transform_module(self, mod: ast.Module) -> Dict[str, relax.Function]: + print(mod) + for func_name in mod.funcs: + func = mod.funcs[func_name] + self.module[func_name] = self.transform_function(func) + return self.module + + def transform_function(self, func: ast.Function) -> relax.Function: + params = [] + for param in func.params: + ty = self.to_type(param.ty) + param = self.decl_var(param.name, ty, None) + params.append(param) + new_body = self.transform_block(func.body) + ret_type = self.to_type(func.ret_type) + print(new_body) + return expr.Function(func.name, params, new_body, ret_type, None) + + def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: + if isinstance(stmt, ast.Assign): + assert isinstance(stmt.lhs, ast.Var) + lhs = self.decl_var(stmt.lhs.id.name, None, None) + rhs = self.transform_expr(stmt.rhs) + self.blocks[-1].append(expr.Binding(lhs, rhs)) + return None + elif isinstance(stmt, ast.Return): + return self.transform_expr(stmt.value) + else: + self._diagnostic_context.emit('error', "only variable left-hand sides are supported in Relay", stmt.span) + self._diagnostic_context.render() + + def transform_expr(self, exp: ast.Expr) -> relax.Expr: + if isinstance(exp, ast.Call): + if isinstance(exp.func_name, ast.Var): + params = [] + for arg in exp.params: + params.append(self.transform_expr(arg)) + + if exp.func_name.id.name == "broadcast_shape": + if len(params) != 2: + self._diagnostic_context.emit('error', f"broadcast_shape only takes 2 arguments {params.len()}", exp.span) + self._diagnostic_context.render() + return expr.BroadcastShape(params[0], params[1], span=None) + elif exp.func_name.id.name == "compute": + if len(params) != 2: + self._diagnostic_context.emit('error', f"compute only takes 2 arguments {params.len()}", exp.span) + self._diagnostic_context.render() + return expr.Compute(params[0], params[1], span=None) + else: + if exp.func_name.id.name in self.str_to_var: + return self.str_to_var[exp.func_name.id.name] + else: + name = exp.func_name.id.name + relax_fn = getattr(self.definition_scope, name, None) + # builtin operator + if relax_fn is None: + return expr.Call(op.Op.get(name), params, None) + else: + self.module[name] = relax_fn.module[name] + # todo: globalvar equality? use global str -> id map? + ident = Id(exp.func_name.id.name) + return expr.Call(expr.GlobalVar(ident, None, None), params, None) + + self._diagnostic_context.emit('error', f"unknown functionc all {len(params)}", exp.span) + self._diagnostic_context.render() + elif isinstance(exp.func_name, ast.Op): + if exp.func_name.name == ast.BuiltinOp.Subscript: + tensor = self.transform_expr(exp.params[0]) + indicies = [] + for index in exp.params[1].values: + indicies.append(self.transform_expr(index)) + return expr.TensorSlice(tensor, indicies, None) + elif exp.func_name.name == ast.BuiltinOp.Add: + params = [] + for arg in exp.params: + params.append(self.transform_expr(arg)) + return expr.Add(params[0], params[1], None) + + self._diagnostic_context.emit('error', "unsupported function", exp.span) + self._diagnostic_context.render() + elif isinstance(exp, ast.Attr): + field_name = exp.field.name + tensor = self.transform_expr(exp.object) + + if field_name == "shape": + return expr.ShapeOf(tensor, None) + else: + self._diagnostic_context.emit('error', "unsupported function", exp.span) + self._diagnostic_context.render() + elif isinstance(exp, ast.Function): + print(exp) + return self.transform_function(exp) + elif isinstance(exp, ast.Tuple): + assert False + elif isinstance(exp, ast.Var): + return self.str_to_var[exp.id.name] + else: + self._diagnostic_context.emit('error', f"don't support this construct {type(exp)}", exp.span) + self._diagnostic_context.render() + + def enter_block(self): + self.blocks.append([]) + + def exit_block(self): + back = self.blocks[-1] + self.blocks.pop() + return back + + def transform_block(self, block: ast.Block) -> relax.expr: + self.enter_block() + + for stmt in block.stmts[:-1]: + assert self.transform_stmt(stmt) is None + + ret_expr = self.transform_stmt(block.stmts[-1]) + # assert ret_expr is not None + + bindings = self.exit_block() + return expr.Let(bindings, ret_expr, span=None) + + def transform_parameter(self, expr: ast.Parameter) -> relax.Expr: + pass + + def transform_type(self, ty: ast.Type) -> relax.Type: + pass + +class TVMDiagnosticContext(synr.DiagnosticContext): + def __init__(self, tvm_diag_ctx): + self.tvm_diag_ctx = tvm_diag_ctx + self.str_to_source_name = {} + + def add_source(self, name: str, source: str) -> None: + """Add a file with source code to the context. This will be called + before any call to :py:func:`emit` that contains a span in this + file. + """ + src_name = self.tvm_diag_ctx.module.source_map.add(name, source) + self.str_to_source_name[name] = src_name + + def emit(self, level: str, message: str, span: tvm.ir.Span) -> None: + """Called when an error has occured.""" + + if level == "error": + level = diagnostics.DiagnosticLevel.ERROR + elif level == "bug": + level = diagnostics.DiagnosticLevel.BUG + elif level == "warning": + level = diagnostics.DiagnosticLevel.WARNING + else: + level = "error" + + assert span, "Span must not be null" + assert isinstance(span, tvm.ir.span), "Expected tvm.ir.span, but got " + str(type(span)) + + diag = diagnostics.Diagnostic(level, span, message) + + self.tvm_diag_ctx.emit(diag) + + def render(self) -> Optional[Any]: + """Render out all error messages. Can either return a value or raise + and execption. + """ + self.tvm_diag_ctx.render() + +class RelaxDecoratedFn: + def __init__(self, fn_name, relax_module, diag_ctx): + self.fn_name = fn_name + self.module = relax_module + self.diag_ctx = diag_ctx + + def __call__(self, *args): + pretty_print(self.module[self.fn_name]) + # compiler = Compiler(self.diag_ctx, self.module, self.fn_name) + # compiled_f = compiler.compile(execute=True) + # # Actually compute needed buffer sizes. + # out = tvm.nd.array(np.random.rand(10).astype('float32')) + # compiled_f(*(list(args) + [out])) + # return out + +def r2(f): + ir_module = tvm.IRModule({}) + diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) + diag_ctx = TVMDiagnosticContext(diag_ctx) + ast = synr.to_ast(f, diag_ctx) + definition_scope = inspect.getmodule(f) + # Why have diag context at transform time? TK? + module = R2Transformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) + return RelaxDecoratedFn(f.__name__, module, diag_ctx) \ No newline at end of file From a8664e70fb50d6b100b491949a2f71840e9c1154 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 29 Jul 2021 17:15:33 -0700 Subject: [PATCH 02/20] Remove some extraneous code + add TODOs --- python/tvm/relax/parser.py | 63 +++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 69f22deee0..a2f2ab3b7c 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -5,9 +5,8 @@ from io import StringIO import tvm +from tvm.ir.module import IRModule from tvm.relay.base import Id -from tvm.relax import expr, op -from tvm.relax.pprint import pretty_print from tvm.ir import diagnostics from tvm import tir, relax @@ -17,31 +16,14 @@ from synr import ast, Transformer from synr.diagnostic_context import DiagnosticContext -def print_ty(ty): - if isinstance(ty, expr.Dim): - return "Dim" - elif isinstance(ty, expr.Tensor): - return "Tensor" - else: - return "UNKNOWN" - -def print_fn(func): - buffer = StringIO("") - param_str = "" - for param in func.params: - param_str += f"{param.id.name_hint}: {print_ty(param.ty)}, " - - buffer.write(f"fn {func.name}({param_str}) {{\n") - buffer.write(f"{func.body}\n") - buffer.write("}") - return buffer.getvalue() - - +# TODO: What is this doing? expr.Function.__str__ = print_fn # type: ignore -# Module = Dict[str, relax.Function] -# Transformer[Module, relax.Function, relax.Expr, relax.Expr, relax.Expr, relax.Expr, relax.Type]): -class R2Transformer(Transformer): # Transformer[Module, relax.Function, relax.Expr, relax.Expr, relax.Expr, relax.Expr, relax.Type]): +# TODO: Replace with a real pretty print method once we have the real AST +def pretty_print(f): + print(f) + +class RelaxTransformer(Transformer): def __init__(self, definition_scope, diag_ctx): self.definition_scope = definition_scope self.diag_ctx = diag_ctx @@ -58,6 +40,7 @@ def span_to_span(self, span: synr.Span) -> tvm.ir.Span: def decl_var(self, name, ty, span=None): identifier = Id(name) + # TODO: Replace with relax node var = expr.Var(identifier, ty, span) self.str_to_var[name] = var return var @@ -69,6 +52,7 @@ def to_type(self, ty): if isinstance(ty, ast.TypeVar): if ty.id.name == "Tensor": span = self.span_to_span(ty.span) + # TODO: Replace with relax node return expr.Tensor(None, None, span) if isinstance(ty, ast.TypeApply): @@ -77,24 +61,22 @@ def to_type(self, ty): # TODO(@jroesch): add support for dtype for param in ty.params: if isinstance(param, ast.TypeConstant): + # TODO: Replace with relax node dim = expr.TIRExpr(tir.IntImm("int32", param.value), None) dims.append(dim) - + # TODO: Replace with relax node return expr.Tensor(expr.Tuple(dims, span=None), None, None) - # import pdb; pdb.set_trace() - self._diagnostic_context.emit('error', "invalid type", self.span_to_span(ty.span)) self._diagnostic_context.render() - def transform_module(self, mod: ast.Module) -> Dict[str, relax.Function]: - print(mod) + def transform_module(self, mod: ast.Module) -> IRModule: for func_name in mod.funcs: func = mod.funcs[func_name] self.module[func_name] = self.transform_function(func) return self.module - def transform_function(self, func: ast.Function) -> relax.Function: + def transform_function(self, func: ast.Function) -> relax.Function: # TODO: update once relax ast finalized params = [] for param in func.params: ty = self.to_type(param.ty) @@ -103,6 +85,7 @@ def transform_function(self, func: ast.Function) -> relax.Function: new_body = self.transform_block(func.body) ret_type = self.to_type(func.ret_type) print(new_body) + # TODO: Replace with relax node return expr.Function(func.name, params, new_body, ret_type, None) def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: @@ -110,6 +93,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: assert isinstance(stmt.lhs, ast.Var) lhs = self.decl_var(stmt.lhs.id.name, None, None) rhs = self.transform_expr(stmt.rhs) + # TODO: Replace with relax node self.blocks[-1].append(expr.Binding(lhs, rhs)) return None elif isinstance(stmt, ast.Return): @@ -118,7 +102,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: self._diagnostic_context.emit('error', "only variable left-hand sides are supported in Relay", stmt.span) self._diagnostic_context.render() - def transform_expr(self, exp: ast.Expr) -> relax.Expr: + def transform_expr(self, exp: ast.Expr) -> relax.Expr: # TODO: update once we have real relax AST if isinstance(exp, ast.Call): if isinstance(exp.func_name, ast.Var): params = [] @@ -129,11 +113,13 @@ def transform_expr(self, exp: ast.Expr) -> relax.Expr: if len(params) != 2: self._diagnostic_context.emit('error', f"broadcast_shape only takes 2 arguments {params.len()}", exp.span) self._diagnostic_context.render() + # TODO: Replace with relax node return expr.BroadcastShape(params[0], params[1], span=None) elif exp.func_name.id.name == "compute": if len(params) != 2: self._diagnostic_context.emit('error', f"compute only takes 2 arguments {params.len()}", exp.span) self._diagnostic_context.render() + # TODO: Replace with relax node return expr.Compute(params[0], params[1], span=None) else: if exp.func_name.id.name in self.str_to_var: @@ -143,13 +129,15 @@ def transform_expr(self, exp: ast.Expr) -> relax.Expr: relax_fn = getattr(self.definition_scope, name, None) # builtin operator if relax_fn is None: + # TODO: Replace with relax node return expr.Call(op.Op.get(name), params, None) else: self.module[name] = relax_fn.module[name] # todo: globalvar equality? use global str -> id map? ident = Id(exp.func_name.id.name) + # TODO: Replace with relax node return expr.Call(expr.GlobalVar(ident, None, None), params, None) - + # TODO: Where is this supposed to be?? self._diagnostic_context.emit('error', f"unknown functionc all {len(params)}", exp.span) self._diagnostic_context.render() elif isinstance(exp.func_name, ast.Op): @@ -158,11 +146,13 @@ def transform_expr(self, exp: ast.Expr) -> relax.Expr: indicies = [] for index in exp.params[1].values: indicies.append(self.transform_expr(index)) + # TODO: Replace with relax node return expr.TensorSlice(tensor, indicies, None) elif exp.func_name.name == ast.BuiltinOp.Add: params = [] for arg in exp.params: params.append(self.transform_expr(arg)) + # TODO: Replace with relax node return expr.Add(params[0], params[1], None) self._diagnostic_context.emit('error', "unsupported function", exp.span) @@ -172,6 +162,7 @@ def transform_expr(self, exp: ast.Expr) -> relax.Expr: tensor = self.transform_expr(exp.object) if field_name == "shape": + # TODO: Replace with relax node return expr.ShapeOf(tensor, None) else: self._diagnostic_context.emit('error', "unsupported function", exp.span) @@ -205,6 +196,7 @@ def transform_block(self, block: ast.Block) -> relax.expr: # assert ret_expr is not None bindings = self.exit_block() + # TODO: Replace with relax node return expr.Let(bindings, ret_expr, span=None) def transform_parameter(self, expr: ast.Parameter) -> relax.Expr: @@ -266,12 +258,13 @@ def __call__(self, *args): # compiled_f(*(list(args) + [out])) # return out -def r2(f): +def relax(f): ir_module = tvm.IRModule({}) diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) diag_ctx = TVMDiagnosticContext(diag_ctx) ast = synr.to_ast(f, diag_ctx) definition_scope = inspect.getmodule(f) # Why have diag context at transform time? TK? - module = R2Transformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) + # TODO: Replace RelaxTransformer with new transformation + module = RelaxTransformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) return RelaxDecoratedFn(f.__name__, module, diag_ctx) \ No newline at end of file From b8fc32949d301ee545eac0af00a63bcb21ae1d9a Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 4 Aug 2021 17:44:42 -0700 Subject: [PATCH 03/20] Skeleton AST --- python/tvm/relax/parser.py | 77 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index a2f2ab3b7c..ea6a252848 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -15,6 +15,83 @@ import synr from synr import ast, Transformer from synr.diagnostic_context import DiagnosticContext +from tvm.relay.op.strategy.generic import conv1d_strategy + + +# TODO: make this better +var_table = {} + +# Skeleton AST so we can get prototype working before this PR is merged +class rxNode: + pass + +class rxExpr(rxNode): + def __init__(self): + self.shape = None + self.checked_type = None + +class rxVar(rxExpr): + + def __init__(self, name): + super.__init__(self) + self.shape_annotation = None + self.type_annotation = None + if name not in var_table: + self.id = name + var_table.add(name) + else: + assert False, "All variable names must be unique, name is: " + name + +class rxDataflowVar(rxVar): + pass + +class rxBinding(rxNode): + + def __init__(self, var, rhs): + self.var = var + self.rhs = rhs + +class rxMatchShape(rxNode): + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + +# TODO: is dim a tir var or any algebraic PrimExpr? +class Dim: + def __init__(self, name): + self.name = name + +class ShapeTuple(rxExpr): + def __init__(self, dims): + self.dims = dims + +class rxFunction(rxExpr): + def __init__(self, args, body): + self.args = args + self.body = body + +class rxBlock(rxExpr): + def __init__(self, body): + self.body = body + +class rxDataflowBlock(rxBlock): + def __init__(self, body): + super.__init__(self, body) + +class rxBasicBlock(rxBlock): + def __init__(self, body): + super.__init__() + +class rxIfThenElse(rxExpr): + def __init__(self, cond, if_true, then_else): + self.cond = cond + self.if_true = if_true + self.then_else = then_else + + + + # TODO: What is this doing? expr.Function.__str__ = print_fn # type: ignore From b313ed9a4ebeba3636d3536fbf0f3d382827aefd Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 5 Aug 2021 16:38:25 -0700 Subject: [PATCH 04/20] Added more skeleton AST, worked on parsing shape annotations. Something is wrong with span_to_span --- python/tvm/relax/parser.py | 196 ++++++++++++++++++++++++------------- 1 file changed, 130 insertions(+), 66 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index ea6a252848..a182eb2a02 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -26,48 +26,60 @@ class rxNode: pass class rxExpr(rxNode): - def __init__(self): + def __init__(self, span): self.shape = None self.checked_type = None + self.span = span class rxVar(rxExpr): - def __init__(self, name): - super.__init__(self) - self.shape_annotation = None - self.type_annotation = None - if name not in var_table: - self.id = name - var_table.add(name) - else: - assert False, "All variable names must be unique, name is: " + name + def __init__(self, id, ty, shape_annotation, span): + super().__init__(span) + self.shape_annotation = shape_annotation + self.type_annotation = ty + self.id = id + +class rxGlobalVar(rxVar): + def __init__(self, id, span): + super().__init__(self, id, span) class rxDataflowVar(rxVar): pass class rxBinding(rxNode): - def __init__(self, var, rhs): + def __init__(self, var, rhs, span): self.var = var self.rhs = rhs + super().__init__(self, span) class rxMatchShape(rxNode): - def __init__(self, lhs, rhs): + def __init__(self, lhs, rhs, span): self.lhs = lhs self.rhs = rhs + super().__init__(self, span) # TODO: is dim a tir var or any algebraic PrimExpr? -class Dim: - def __init__(self, name): +class rxDim(rxExpr): + def __init__(self, value): + self.value = value + +class rxTIRVar(rxNode): + def __init__(self, name, span): self.name = name + super().__init__(self, span) + class ShapeTuple(rxExpr): - def __init__(self, dims): + def __init__(self, dims, span): self.dims = dims + super().__init__(self, span) + class rxFunction(rxExpr): - def __init__(self, args, body): + def __init__(self, name, args, body): + self.name = name self.args = args self.body = body @@ -77,11 +89,8 @@ def __init__(self, body): class rxDataflowBlock(rxBlock): def __init__(self, body): - super.__init__(self, body) + super().__init__(self, body) -class rxBasicBlock(rxBlock): - def __init__(self, body): - super.__init__() class rxIfThenElse(rxExpr): def __init__(self, cond, if_true, then_else): @@ -89,12 +98,40 @@ def __init__(self, cond, if_true, then_else): self.if_true = if_true self.then_else = then_else +class rxType: + def __init__(self, span): + self.span = span + +class rxTensor(rxType): + def __init__(self, dtype, span): + self.dtype = dtype + super().__init__(span) +class rxShapeOf(rxExpr): + def __init__(self, expr): + self.expr = expr + +class rxLet(rxExpr): + def __init__(self, bindings, body): + self.bindings = bindings + self.body = body +class rxCall(rxExpr): + def __init__(self, function, arguments): + self.function = function + self.arguments = arguments +class rxGetBuiltin(rxExpr): + def __init__(self, builtin_name): + self.builtin_name = builtin_name + +class rxTensorSlice(rxExpr): + def __init__(self, tensor, indices): + self.tensor = tensor + self.indices = indices # TODO: What is this doing? -expr.Function.__str__ = print_fn # type: ignore +#expr.Function.__str__ = print_fn # type: ignore # TODO: Replace with a real pretty print method once we have the real AST def pretty_print(f): @@ -115,13 +152,12 @@ def span_to_span(self, span: synr.Span) -> tvm.ir.Span: return tvm_span - def decl_var(self, name, ty, span=None): - identifier = Id(name) - # TODO: Replace with relax node - var = expr.Var(identifier, ty, span) + def decl_var(self, name, ty, shape_annotation, span=None): + var = rxVar(name, ty, shape_annotation, span) self.str_to_var[name] = var return var + # Returns type, shape_annotation def to_type(self, ty): if ty is None: return None @@ -129,23 +165,48 @@ def to_type(self, ty): if isinstance(ty, ast.TypeVar): if ty.id.name == "Tensor": span = self.span_to_span(ty.span) - # TODO: Replace with relax node - return expr.Tensor(None, None, span) - + return rxTensor(None, span), None + if isinstance(ty, ast.TypeApply): if ty.id.name == "Tensor": + # TODO: add more dtypes + allowed_dtypes = ["int32", "float32", "int8", "fp16"] dims = [] - # TODO(@jroesch): add support for dtype - for param in ty.params: - if isinstance(param, ast.TypeConstant): - # TODO: Replace with relax node - dim = expr.TIRExpr(tir.IntImm("int32", param.value), None) + dtype = "" + assert len(ty.params) == 2 + shape_param = ty.params[0] + dtype_param = ty.params[1] + # Check whether the Tensor is shape / dtype erased + shape_erased = False + dtype_erased = False + + if (isinstance(shape_param, ast.TypeVar)) and shape_param.id.name == "_": + shape_erased = True + if (isinstance(dtype_param, ast.TypeVar)) and dtype_param.id.name == "_": + dtype_erased = True + if not shape_erased: + if not isinstance(shape_param, ast.Tuple): + self._diagnostic_context.emit('error', "Tensor shape must be erased or be a tuple", self.span_to_span(ty.span)) + self._diagnostic_context.render() + for shape_dim in shape_param: + # TODO: Turn this into a helper fn that only allows algebraic expressions + if isinstance(shape_dim, ast.Var): + dim = rxDim(tir.Var(shape_dim.name)) + elif isinstance(shape_dim, ast.Constant) and isinstance(shape_dim.value, int): + dim = rxDim(tir.IntImm("int32", shape_dim.value)) + else: + self._diagnostic_context.emit('error', "shape annotation must be only vars or consts for now", self.span_to_span(ty.span)) + self._diagnostic_context.render() dims.append(dim) - # TODO: Replace with relax node - return expr.Tensor(expr.Tuple(dims, span=None), None, None) + if not dtype_erased: + if not isinstance(shape_param, ast.TypeVar) and not shape_param.id.name in allowed_dtypes: + self._diagnostic_context.emit('error', "dtype must be erased or one of " + allowed_dtypes, self.span_to_span(ty.span)) + self._diagnostic_context.render() + dtype = shape_param.id.name + + return rxTensor(dtype, None), dims self._diagnostic_context.emit('error', "invalid type", self.span_to_span(ty.span)) - self._diagnostic_context.render() def transform_module(self, mod: ast.Module) -> IRModule: for func_name in mod.funcs: @@ -153,17 +214,16 @@ def transform_module(self, mod: ast.Module) -> IRModule: self.module[func_name] = self.transform_function(func) return self.module - def transform_function(self, func: ast.Function) -> relax.Function: # TODO: update once relax ast finalized + def transform_function(self, func: ast.Function) -> rxFunction: params = [] for param in func.params: - ty = self.to_type(param.ty) + ty, shape_dims = self.to_type(param.ty) param = self.decl_var(param.name, ty, None) params.append(param) new_body = self.transform_block(func.body) ret_type = self.to_type(func.ret_type) print(new_body) - # TODO: Replace with relax node - return expr.Function(func.name, params, new_body, ret_type, None) + return rxFunction(func.name, params, new_body, ret_type, None) def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: if isinstance(stmt, ast.Assign): @@ -171,7 +231,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: lhs = self.decl_var(stmt.lhs.id.name, None, None) rhs = self.transform_expr(stmt.rhs) # TODO: Replace with relax node - self.blocks[-1].append(expr.Binding(lhs, rhs)) + self.blocks[-1].append(rxBinding(lhs, rhs)) return None elif isinstance(stmt, ast.Return): return self.transform_expr(stmt.value) @@ -179,25 +239,20 @@ def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: self._diagnostic_context.emit('error', "only variable left-hand sides are supported in Relay", stmt.span) self._diagnostic_context.render() - def transform_expr(self, exp: ast.Expr) -> relax.Expr: # TODO: update once we have real relax AST + def transform_expr(self, exp: ast.Expr) -> rxExpr: if isinstance(exp, ast.Call): if isinstance(exp.func_name, ast.Var): params = [] for arg in exp.params: params.append(self.transform_expr(arg)) - if exp.func_name.id.name == "broadcast_shape": - if len(params) != 2: - self._diagnostic_context.emit('error', f"broadcast_shape only takes 2 arguments {params.len()}", exp.span) - self._diagnostic_context.render() - # TODO: Replace with relax node - return expr.BroadcastShape(params[0], params[1], span=None) - elif exp.func_name.id.name == "compute": - if len(params) != 2: - self._diagnostic_context.emit('error', f"compute only takes 2 arguments {params.len()}", exp.span) - self._diagnostic_context.render() - # TODO: Replace with relax node - return expr.Compute(params[0], params[1], span=None) + + if exp.func_name.id.name == "call_packed": + # TODO: deal with call_packed + pass + elif exp.func_name.id.name == "call_tir": + #TODO: deal with call_tir + pass else: if exp.func_name.id.name in self.str_to_var: return self.str_to_var[exp.func_name.id.name] @@ -206,14 +261,12 @@ def transform_expr(self, exp: ast.Expr) -> relax.Expr: # TODO: update once we ha relax_fn = getattr(self.definition_scope, name, None) # builtin operator if relax_fn is None: - # TODO: Replace with relax node - return expr.Call(op.Op.get(name), params, None) + return rxCall(rxGetBuiltin(name), params, None) else: self.module[name] = relax_fn.module[name] # todo: globalvar equality? use global str -> id map? ident = Id(exp.func_name.id.name) - # TODO: Replace with relax node - return expr.Call(expr.GlobalVar(ident, None, None), params, None) + return rxCall(rxGlobalVar(ident, None, None), params, None) # TODO: Where is this supposed to be?? self._diagnostic_context.emit('error', f"unknown functionc all {len(params)}", exp.span) self._diagnostic_context.render() @@ -224,13 +277,13 @@ def transform_expr(self, exp: ast.Expr) -> relax.Expr: # TODO: update once we ha for index in exp.params[1].values: indicies.append(self.transform_expr(index)) # TODO: Replace with relax node - return expr.TensorSlice(tensor, indicies, None) + return rxTensorSlice(tensor, indicies, None) elif exp.func_name.name == ast.BuiltinOp.Add: params = [] for arg in exp.params: params.append(self.transform_expr(arg)) # TODO: Replace with relax node - return expr.Add(params[0], params[1], None) + return rxCall("add", [params[0], params[1]], None) self._diagnostic_context.emit('error', "unsupported function", exp.span) self._diagnostic_context.render() @@ -239,8 +292,7 @@ def transform_expr(self, exp: ast.Expr) -> relax.Expr: # TODO: update once we ha tensor = self.transform_expr(exp.object) if field_name == "shape": - # TODO: Replace with relax node - return expr.ShapeOf(tensor, None) + return rxShapeOf(tensor, None) else: self._diagnostic_context.emit('error', "unsupported function", exp.span) self._diagnostic_context.render() @@ -273,8 +325,7 @@ def transform_block(self, block: ast.Block) -> relax.expr: # assert ret_expr is not None bindings = self.exit_block() - # TODO: Replace with relax node - return expr.Let(bindings, ret_expr, span=None) + return rxLet(bindings, ret_expr, span=None) def transform_parameter(self, expr: ast.Parameter) -> relax.Expr: pass @@ -308,7 +359,7 @@ def emit(self, level: str, message: str, span: tvm.ir.Span) -> None: level = "error" assert span, "Span must not be null" - assert isinstance(span, tvm.ir.span), "Expected tvm.ir.span, but got " + str(type(span)) + assert isinstance(span, tvm.ir.Span), "Expected tvm.ir.Span, but got " + str(type(span)) diag = diagnostics.Diagnostic(level, span, message) @@ -344,4 +395,17 @@ def relax(f): # Why have diag context at transform time? TK? # TODO: Replace RelaxTransformer with new transformation module = RelaxTransformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) - return RelaxDecoratedFn(f.__name__, module, diag_ctx) \ No newline at end of file + return RelaxDecoratedFn(f.__name__, module, diag_ctx) + +@relax +def my_test(x : Tensor[_, _]): + return None + +""" +@relax +def my_test(x : Tensor[(1, 2, 3), "int32"], y: Tensor[_, _]): + return call_packed("my_func", x, y) +""" +ones = np.ones((1, 2, 3)) +y = ones +my_test(ones, y) \ No newline at end of file From b3357d0813e70fe6f0857c98d550f184f4470379 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 5 Aug 2021 16:59:24 -0700 Subject: [PATCH 05/20] Fix spans --- python/tvm/relax/parser.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index a182eb2a02..c6c3872bca 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -198,10 +198,10 @@ def to_type(self, ty): self._diagnostic_context.emit('error', "shape annotation must be only vars or consts for now", self.span_to_span(ty.span)) self._diagnostic_context.render() dims.append(dim) - if not dtype_erased: - if not isinstance(shape_param, ast.TypeVar) and not shape_param.id.name in allowed_dtypes: - self._diagnostic_context.emit('error', "dtype must be erased or one of " + allowed_dtypes, self.span_to_span(ty.span)) + if not dtype_erased and not shape_param.id.name in allowed_dtypes: + self._diagnostic_context.emit('error', "dtype must be erased or one of " + str(allowed_dtypes), self.span_to_span(ty.span)) self._diagnostic_context.render() + else: dtype = shape_param.id.name return rxTensor(dtype, None), dims @@ -268,7 +268,7 @@ def transform_expr(self, exp: ast.Expr) -> rxExpr: ident = Id(exp.func_name.id.name) return rxCall(rxGlobalVar(ident, None, None), params, None) # TODO: Where is this supposed to be?? - self._diagnostic_context.emit('error', f"unknown functionc all {len(params)}", exp.span) + self._diagnostic_context.emit('error', f"unknown functionc all {len(params)}", self.span_to_span(exp.span)) self._diagnostic_context.render() elif isinstance(exp.func_name, ast.Op): if exp.func_name.name == ast.BuiltinOp.Subscript: @@ -285,7 +285,7 @@ def transform_expr(self, exp: ast.Expr) -> rxExpr: # TODO: Replace with relax node return rxCall("add", [params[0], params[1]], None) - self._diagnostic_context.emit('error', "unsupported function", exp.span) + self._diagnostic_context.emit('error', "unsupported function", self.span_to_span(exp.span)) self._diagnostic_context.render() elif isinstance(exp, ast.Attr): field_name = exp.field.name @@ -294,7 +294,7 @@ def transform_expr(self, exp: ast.Expr) -> rxExpr: if field_name == "shape": return rxShapeOf(tensor, None) else: - self._diagnostic_context.emit('error', "unsupported function", exp.span) + self._diagnostic_context.emit('error', "unsupported function", self.span_to_span(exp.span)) self._diagnostic_context.render() elif isinstance(exp, ast.Function): print(exp) @@ -304,7 +304,7 @@ def transform_expr(self, exp: ast.Expr) -> rxExpr: elif isinstance(exp, ast.Var): return self.str_to_var[exp.id.name] else: - self._diagnostic_context.emit('error', f"don't support this construct {type(exp)}", exp.span) + self._diagnostic_context.emit('error', f"don't support this construct {type(exp)}", self.span_to_span(exp.span)) self._diagnostic_context.render() def enter_block(self): From 74e9543d312f4ec52095c5964a1ccc2472e8870a Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 6 Aug 2021 12:01:54 -0700 Subject: [PATCH 06/20] Type annotations parsing correctly --- python/tvm/relax/parser.py | 40 ++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index c6c3872bca..04232ebf97 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -60,7 +60,7 @@ def __init__(self, lhs, rhs, span): self.rhs = rhs super().__init__(self, span) -# TODO: is dim a tir var or any algebraic PrimExpr? +# TODO: is dim a tir var or any algebraic PrimExpr? or just a shape tuple with index? class rxDim(rxExpr): def __init__(self, value): self.value = value @@ -184,25 +184,28 @@ def to_type(self, ty): shape_erased = True if (isinstance(dtype_param, ast.TypeVar)) and dtype_param.id.name == "_": dtype_erased = True + if not shape_erased: - if not isinstance(shape_param, ast.Tuple): + if isinstance(shape_param, ast.TypeTuple): + for shape_dim in shape_param.values: + # TODO: Turn this into a helper fn that only allows algebraic expressions + if isinstance(shape_dim, ast.Var): + dim = rxDim(tir.Var(shape_dim.name)) + elif isinstance(shape_dim, ast.Constant) and isinstance(shape_dim.value, int): + dim = rxDim(tir.IntImm("int32", shape_dim.value)) + else: + self._diagnostic_context.emit('error', "shape annotation must be only vars or consts for now", self.span_to_span(ty.span)) + self._diagnostic_context.render() + dims.append(dim) + else: self._diagnostic_context.emit('error', "Tensor shape must be erased or be a tuple", self.span_to_span(ty.span)) self._diagnostic_context.render() - for shape_dim in shape_param: - # TODO: Turn this into a helper fn that only allows algebraic expressions - if isinstance(shape_dim, ast.Var): - dim = rxDim(tir.Var(shape_dim.name)) - elif isinstance(shape_dim, ast.Constant) and isinstance(shape_dim.value, int): - dim = rxDim(tir.IntImm("int32", shape_dim.value)) - else: - self._diagnostic_context.emit('error', "shape annotation must be only vars or consts for now", self.span_to_span(ty.span)) - self._diagnostic_context.render() - dims.append(dim) - if not dtype_erased and not shape_param.id.name in allowed_dtypes: + if not dtype_erased: + if dtype_param.value in allowed_dtypes: + dtype = dtype_param.value + else: self._diagnostic_context.emit('error', "dtype must be erased or one of " + str(allowed_dtypes), self.span_to_span(ty.span)) self._diagnostic_context.render() - else: - dtype = shape_param.id.name return rxTensor(dtype, None), dims @@ -397,15 +400,10 @@ def relax(f): module = RelaxTransformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) return RelaxDecoratedFn(f.__name__, module, diag_ctx) -@relax -def my_test(x : Tensor[_, _]): - return None - -""" @relax def my_test(x : Tensor[(1, 2, 3), "int32"], y: Tensor[_, _]): return call_packed("my_func", x, y) -""" + ones = np.ones((1, 2, 3)) y = ones my_test(ones, y) \ No newline at end of file From 80970398c15731e35245233a24f67797e2e6070b Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 6 Aug 2021 14:07:58 -0700 Subject: [PATCH 07/20] some match_shape support --- python/tvm/relax/parser.py | 132 +++++++++++++++++------ python/tvm/relax/parser_tests/failing.py | 4 + python/tvm/relax/parser_tests/passing.py | 22 ++++ 3 files changed, 124 insertions(+), 34 deletions(-) create mode 100644 python/tvm/relax/parser_tests/failing.py create mode 100644 python/tvm/relax/parser_tests/passing.py diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 04232ebf97..95225e4e9e 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -25,6 +25,10 @@ class rxNode: pass +# A node that will desugar into a different AST node in a subsequent pass +class rxFrontendNode: + pass + class rxExpr(rxNode): def __init__(self, span): self.shape = None @@ -53,13 +57,23 @@ def __init__(self, var, rhs, span): self.rhs = rhs super().__init__(self, span) -class rxMatchShape(rxNode): +# Allows arbitrary exprs on the left and the right +# Desugars into two rxMatchShapeBinding +# TODO: might be worth not parsing this into its own node.. +class rxFrontendMatchShapeExprs(rxFrontendNode): def __init__(self, lhs, rhs, span): self.lhs = lhs self.rhs = rhs super().__init__(self, span) +# +class rxMatchShapeBinding(rxNode): + def __init__(self, binding, shape, span): + self.binding = binding # Array[PrimExpr] + self.shape = shape # Expr (type is shape tuple) + super().__init__(self, span) + # TODO: is dim a tir var or any algebraic PrimExpr? or just a shape tuple with index? class rxDim(rxExpr): def __init__(self, value): @@ -71,7 +85,7 @@ def __init__(self, name, span): super().__init__(self, span) -class ShapeTuple(rxExpr): +class rxShapeTuple(rxExpr): def __init__(self, dims, span): self.dims = dims super().__init__(self, span) @@ -169,7 +183,7 @@ def to_type(self, ty): if isinstance(ty, ast.TypeApply): if ty.id.name == "Tensor": - # TODO: add more dtypes + # TODO: add more dtypes, maybe define elsewhere allowed_dtypes = ["int32", "float32", "int8", "fp16"] dims = [] dtype = "" @@ -188,7 +202,8 @@ def to_type(self, ty): if not shape_erased: if isinstance(shape_param, ast.TypeTuple): for shape_dim in shape_param.values: - # TODO: Turn this into a helper fn that only allows algebraic expressions + + # TODO: use to_primexpr or whatever if isinstance(shape_dim, ast.Var): dim = rxDim(tir.Var(shape_dim.name)) elif isinstance(shape_dim, ast.Constant) and isinstance(shape_dim.value, int): @@ -211,6 +226,37 @@ def to_type(self, ty): self._diagnostic_context.emit('error', "invalid type", self.span_to_span(ty.span)) + # Turns a tuple into an array of PrimExprs + # Allow arithmetic indicates whether we are letting there be + def expr_to_primexpr(self, expr: ast.Expr, allow_arithmetic=False) -> PrimExpr: + if not allow_arithmetic and not isinstance(expr, ast.Var): + #TODO: improve error message + self._diagnostic_context.emit('error', "You can only use single variables here, not an expression", self.span_to_span(expr.span)) + self._diagnostic_context.render() + else: + if isinstance(expr, ast.Var): + return tir.Var(expr.id.name, "int32") + + # TODO: do all the ops here + elif isinstance(expr, ast.Constant) and isinstance(expr.value, int): + assert False + elif isinstance(expr, ast.Call): + if exp.func_name.name == ast.BuiltinOp.Add: + # TODO: call this fn on args and return primexpr containing result + assert False + if exp.func_name.name == ast.BuiltinOp.Sub: + assert False + if exp.func_name.name == ast.BuiltinOp.Mul: + assert False + if exp.func_name.name == ast.BuiltinOp.Div: + assert False + if exp.func_name.name == ast.BuiltinOp.Mod: + assert False + else: + self._diagnostic_context.emit('error', "The shape expression can only contain arithmetic operators, integer constants and variables", self.span_to_span(expr.span)) + self._diagnostic_context.render() + + def transform_module(self, mod: ast.Module) -> IRModule: for func_name in mod.funcs: func = mod.funcs[func_name] @@ -238,8 +284,38 @@ def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: return None elif isinstance(stmt, ast.Return): return self.transform_expr(stmt.value) + # match_shape is the ONLY node that doesn't have to be bound to an LHS variable! + elif (isinstance(stmt, ast.UnassignedCall) and isinstance(stmt.call.func_name, ast.Var) + and stmt.call.func_name.id.name == "match_shape"): + if len(stmt.call.params) != 2: + self._diagnostic_context.emit('error', "match_shape takes exactly two arguments", self.span_to_span(stmt.span)) + self._diagnostic_context.render() + + lhs = stmt.call.params[0] + rhs = stmt.call.params[1] + + # If RHS is a tuple, turn it into a ShapeTuple, otherwise, process normally + if isinstance(rhs, ast.Tuple): + arithmetic_primexprs = [] + for elem in rhs.values: + arithmetic_primexprs.append(self.expr_to_primexpr(elem, allow_arithmetic=True)) + rhs_expr = rxShapeTuple(arithmetic_primexprs) + else: + rhs_expr = self.transform_expr(rhs) + + # If LHS is a tuple of variables, then we use the binding match shape + # If it is an Expr, we use the sugared match_shape (and will insert bindings later) + if isinstance(lhs, ast.Tuple): + binding_tir_vars = [] + for elem in lhs.values: + # Here we only are defining variables so we don't allow arithmetic expressions + binding_tir_vars.append(self.expr_to_primexpr(elem)) + self.blocks[-1].append(rxMatchShapeBinding(binding_tir_vars, rhs_expr)) + else: + lhs_expr = self.transform_expr(lhs) + self.blocks[-1].append(rxFrontendMatchShapeExprs(lhs_expr, rhs_expr)) else: - self._diagnostic_context.emit('error', "only variable left-hand sides are supported in Relay", stmt.span) + self._diagnostic_context.emit('error', "only variable left-hand sides are supported in Relay", self.span_to_span(stmt.span)) self._diagnostic_context.render() def transform_expr(self, exp: ast.Expr) -> rxExpr: @@ -249,30 +325,20 @@ def transform_expr(self, exp: ast.Expr) -> rxExpr: for arg in exp.params: params.append(self.transform_expr(arg)) - - if exp.func_name.id.name == "call_packed": - # TODO: deal with call_packed - pass - elif exp.func_name.id.name == "call_tir": - #TODO: deal with call_tir - pass + if exp.func_name.id.name in self.str_to_var: + return self.str_to_var[exp.func_name.id.name] else: - if exp.func_name.id.name in self.str_to_var: - return self.str_to_var[exp.func_name.id.name] + name = exp.func_name.id.name + relax_fn = getattr(self.definition_scope, name, None) + # builtin operator + if relax_fn is None: + return rxCall(rxGetBuiltin(name), params, None) else: - name = exp.func_name.id.name - relax_fn = getattr(self.definition_scope, name, None) - # builtin operator - if relax_fn is None: - return rxCall(rxGetBuiltin(name), params, None) - else: - self.module[name] = relax_fn.module[name] - # todo: globalvar equality? use global str -> id map? - ident = Id(exp.func_name.id.name) - return rxCall(rxGlobalVar(ident, None, None), params, None) - # TODO: Where is this supposed to be?? - self._diagnostic_context.emit('error', f"unknown functionc all {len(params)}", self.span_to_span(exp.span)) - self._diagnostic_context.render() + self.module[name] = relax_fn.module[name] + # todo: globalvar equality? use global str -> id map? + ident = Id(exp.func_name.id.name) + return rxCall(rxGlobalVar(ident, None, None), params, None) + elif isinstance(exp.func_name, ast.Op): if exp.func_name.name == ast.BuiltinOp.Subscript: tensor = self.transform_expr(exp.params[0]) @@ -295,7 +361,7 @@ def transform_expr(self, exp: ast.Expr) -> rxExpr: tensor = self.transform_expr(exp.object) if field_name == "shape": - return rxShapeOf(tensor, None) + return rxShapeOf(tensor) else: self._diagnostic_context.emit('error', "unsupported function", self.span_to_span(exp.span)) self._diagnostic_context.render() @@ -395,15 +461,13 @@ def relax(f): diag_ctx = TVMDiagnosticContext(diag_ctx) ast = synr.to_ast(f, diag_ctx) definition_scope = inspect.getmodule(f) - # Why have diag context at transform time? TK? - # TODO: Replace RelaxTransformer with new transformation module = RelaxTransformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) return RelaxDecoratedFn(f.__name__, module, diag_ctx) +@relax +def my_test(x: Tensor[_, "float32"]): + match_shape(x.shape, (1, 2, 3)) + @relax def my_test(x : Tensor[(1, 2, 3), "int32"], y: Tensor[_, _]): return call_packed("my_func", x, y) - -ones = np.ones((1, 2, 3)) -y = ones -my_test(ones, y) \ No newline at end of file diff --git a/python/tvm/relax/parser_tests/failing.py b/python/tvm/relax/parser_tests/failing.py new file mode 100644 index 0000000000..68ca0347f6 --- /dev/null +++ b/python/tvm/relax/parser_tests/failing.py @@ -0,0 +1,4 @@ +from .. import relax +@relax +def my_test(x : Tensor[_, _]): + return None \ No newline at end of file diff --git a/python/tvm/relax/parser_tests/passing.py b/python/tvm/relax/parser_tests/passing.py new file mode 100644 index 0000000000..966342b8c1 --- /dev/null +++ b/python/tvm/relax/parser_tests/passing.py @@ -0,0 +1,22 @@ + +# Type annotation tests +@relax +def my_test(x : Tensor[_, _]): + return None + +@relax +def my_test(x: Tensor[(1, 2, 3), "int32"]): + return None + +@relax +def my_test(x: Tensor[(1, 2, 3), _]): + return None + +@relax +def my_test(x: Tensor[_, "int32"]): + return None + +# Builtin functions +@relax +def my_test(x: Tensor[_, "float32"]): + match_shape(x, (1, 2, 3)) \ No newline at end of file From 4b6394ccae329c8511d60d790c5131185cdf40dd Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 6 Aug 2021 17:07:58 -0700 Subject: [PATCH 08/20] More bug fixes! Some stuff parses. Importing into tests is messed up. We probably need to restructure this code as well. --- python/tvm/relax/parser.py | 40 +++++++++-------------- python/tvm/relax/parser_tests/__init__.py | 0 python/tvm/relax/parser_tests/passing.py | 22 +++++++++++-- 3 files changed, 35 insertions(+), 27 deletions(-) create mode 100644 python/tvm/relax/parser_tests/__init__.py diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 95225e4e9e..19706bd8b3 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -19,6 +19,7 @@ # TODO: make this better +relax_scope = [] # A stack of dictionaries representing the scope var_table = {} # Skeleton AST so we can get prototype working before this PR is merged @@ -27,7 +28,8 @@ class rxNode: # A node that will desugar into a different AST node in a subsequent pass class rxFrontendNode: - pass + def __init__(self, span): + self.span = span class rxExpr(rxNode): def __init__(self, span): @@ -65,7 +67,7 @@ class rxFrontendMatchShapeExprs(rxFrontendNode): def __init__(self, lhs, rhs, span): self.lhs = lhs self.rhs = rhs - super().__init__(self, span) + super().__init__(span) # class rxMatchShapeBinding(rxNode): @@ -74,28 +76,19 @@ def __init__(self, binding, shape, span): self.shape = shape # Expr (type is shape tuple) super().__init__(self, span) -# TODO: is dim a tir var or any algebraic PrimExpr? or just a shape tuple with index? -class rxDim(rxExpr): - def __init__(self, value): - self.value = value - -class rxTIRVar(rxNode): - def __init__(self, name, span): - self.name = name - super().__init__(self, span) - - class rxShapeTuple(rxExpr): def __init__(self, dims, span): self.dims = dims - super().__init__(self, span) + super().__init__(span) class rxFunction(rxExpr): - def __init__(self, name, args, body): + def __init__(self, name, args, body, ret_type, span): self.name = name self.args = args self.body = body + self.ret_type = ret_type + super().__init__(span) class rxBlock(rxExpr): def __init__(self, body): @@ -116,6 +109,9 @@ class rxType: def __init__(self, span): self.span = span +class rxDim(rxType): + pass + class rxTensor(rxType): def __init__(self, dtype, span): self.dtype = dtype @@ -239,7 +235,7 @@ def expr_to_primexpr(self, expr: ast.Expr, allow_arithmetic=False) -> PrimExpr: # TODO: do all the ops here elif isinstance(expr, ast.Constant) and isinstance(expr.value, int): - assert False + return tir.IntImm("int32", expr.value) elif isinstance(expr, ast.Call): if exp.func_name.name == ast.BuiltinOp.Add: # TODO: call this fn on args and return primexpr containing result @@ -271,7 +267,6 @@ def transform_function(self, func: ast.Function) -> rxFunction: params.append(param) new_body = self.transform_block(func.body) ret_type = self.to_type(func.ret_type) - print(new_body) return rxFunction(func.name, params, new_body, ret_type, None) def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: @@ -299,7 +294,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: arithmetic_primexprs = [] for elem in rhs.values: arithmetic_primexprs.append(self.expr_to_primexpr(elem, allow_arithmetic=True)) - rhs_expr = rxShapeTuple(arithmetic_primexprs) + rhs_expr = rxShapeTuple(arithmetic_primexprs, self.span_to_span(rhs.span)) else: rhs_expr = self.transform_expr(rhs) @@ -313,7 +308,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: self.blocks[-1].append(rxMatchShapeBinding(binding_tir_vars, rhs_expr)) else: lhs_expr = self.transform_expr(lhs) - self.blocks[-1].append(rxFrontendMatchShapeExprs(lhs_expr, rhs_expr)) + self.blocks[-1].append(rxFrontendMatchShapeExprs(lhs_expr, rhs_expr, stmt.span)) else: self._diagnostic_context.emit('error', "only variable left-hand sides are supported in Relay", self.span_to_span(stmt.span)) self._diagnostic_context.render() @@ -366,7 +361,6 @@ def transform_expr(self, exp: ast.Expr) -> rxExpr: self._diagnostic_context.emit('error', "unsupported function", self.span_to_span(exp.span)) self._diagnostic_context.render() elif isinstance(exp, ast.Function): - print(exp) return self.transform_function(exp) elif isinstance(exp, ast.Tuple): assert False @@ -394,7 +388,7 @@ def transform_block(self, block: ast.Block) -> relax.expr: # assert ret_expr is not None bindings = self.exit_block() - return rxLet(bindings, ret_expr, span=None) + return rxLet(bindings, ret_expr) def transform_parameter(self, expr: ast.Parameter) -> relax.Expr: pass @@ -468,6 +462,4 @@ def relax(f): def my_test(x: Tensor[_, "float32"]): match_shape(x.shape, (1, 2, 3)) -@relax -def my_test(x : Tensor[(1, 2, 3), "int32"], y: Tensor[_, _]): - return call_packed("my_func", x, y) +print(my_test.module['my_test']) \ No newline at end of file diff --git a/python/tvm/relax/parser_tests/__init__.py b/python/tvm/relax/parser_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/tvm/relax/parser_tests/passing.py b/python/tvm/relax/parser_tests/passing.py index 966342b8c1..f5ac0a41e6 100644 --- a/python/tvm/relax/parser_tests/passing.py +++ b/python/tvm/relax/parser_tests/passing.py @@ -1,11 +1,13 @@ +from tvm.relax.parser import relax # Type annotation tests +""" @relax def my_test(x : Tensor[_, _]): return None @relax -def my_test(x: Tensor[(1, 2, 3), "int32"]): +def my_test(x: Tensor[(a, b, c), "int32"]): return None @relax @@ -15,8 +17,22 @@ def my_test(x: Tensor[(1, 2, 3), _]): @relax def my_test(x: Tensor[_, "int32"]): return None - +""" # Builtin functions + @relax def my_test(x: Tensor[_, "float32"]): - match_shape(x, (1, 2, 3)) \ No newline at end of file + match_shape(x.shape, (1, 2, 3)) + + +# These should pass in the future but don't right now +""" +@relax +def my_test(x: Tensor[_, "float32"]): + match_shape(x.shape, (1, 2, 3)) + + +@relax +def my_test(x : Tensor[(1, 2, 3), "int32"], y: Tensor[_, _]): + return call_packed("my_func", x, y) +""" \ No newline at end of file From 5f5893491613cc1dd8dd3f16428b1f51c6bcbf78 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Mon, 16 Aug 2021 08:57:43 -0700 Subject: [PATCH 09/20] refactor parser and fill out more stubs --- python/tvm/relax/base.py | 4 + python/tvm/relax/parser.py | 798 ++++++++++++++++++++++--------------- 2 files changed, 489 insertions(+), 313 deletions(-) create mode 100644 python/tvm/relax/base.py diff --git a/python/tvm/relax/base.py b/python/tvm/relax/base.py new file mode 100644 index 0000000000..b85ac77c6e --- /dev/null +++ b/python/tvm/relax/base.py @@ -0,0 +1,4 @@ +# Skeleton AST so we can get prototype working before this PR is merged +class rxNode: + def __init__(self, span): + self.span = span diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 19706bd8b3..20c657a89a 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import TypeVar, Generic, Union, Dict +from typing import TypeVar, Generic, Union, Dict, List, Tuple from io import StringIO import tvm @@ -17,241 +17,309 @@ from synr.diagnostic_context import DiagnosticContext from tvm.relay.op.strategy.generic import conv1d_strategy +from tvm.relax.expr import * +from tvm.relax.ty import * -# TODO: make this better -relax_scope = [] # A stack of dictionaries representing the scope -var_table = {} -# Skeleton AST so we can get prototype working before this PR is merged -class rxNode: - pass +# TODO: make this better +# relax_scope = [] # A stack of dictionaries representing the scope +# var_table = {} # A node that will desugar into a different AST node in a subsequent pass -class rxFrontendNode: - def __init__(self, span): - self.span = span - -class rxExpr(rxNode): - def __init__(self, span): - self.shape = None - self.checked_type = None - self.span = span - -class rxVar(rxExpr): - - def __init__(self, id, ty, shape_annotation, span): - super().__init__(span) - self.shape_annotation = shape_annotation - self.type_annotation = ty - self.id = id +# class rxFrontendNode: +# def __init__(self, span): +# self.span = span -class rxGlobalVar(rxVar): - def __init__(self, id, span): - super().__init__(self, id, span) - -class rxDataflowVar(rxVar): - pass - -class rxBinding(rxNode): - - def __init__(self, var, rhs, span): - self.var = var - self.rhs = rhs - super().__init__(self, span) # Allows arbitrary exprs on the left and the right # Desugars into two rxMatchShapeBinding # TODO: might be worth not parsing this into its own node.. -class rxFrontendMatchShapeExprs(rxFrontendNode): - - def __init__(self, lhs, rhs, span): - self.lhs = lhs - self.rhs = rhs - super().__init__(span) - -# -class rxMatchShapeBinding(rxNode): - def __init__(self, binding, shape, span): - self.binding = binding # Array[PrimExpr] - self.shape = shape # Expr (type is shape tuple) - super().__init__(self, span) - -class rxShapeTuple(rxExpr): - def __init__(self, dims, span): - self.dims = dims - super().__init__(span) - - -class rxFunction(rxExpr): - def __init__(self, name, args, body, ret_type, span): - self.name = name - self.args = args - self.body = body - self.ret_type = ret_type - super().__init__(span) - -class rxBlock(rxExpr): - def __init__(self, body): - self.body = body - -class rxDataflowBlock(rxBlock): - def __init__(self, body): - super().__init__(self, body) - - -class rxIfThenElse(rxExpr): - def __init__(self, cond, if_true, then_else): - self.cond = cond - self.if_true = if_true - self.then_else = then_else - -class rxType: - def __init__(self, span): - self.span = span - -class rxDim(rxType): - pass - -class rxTensor(rxType): - def __init__(self, dtype, span): - self.dtype = dtype - super().__init__(span) - -class rxShapeOf(rxExpr): - def __init__(self, expr): - self.expr = expr - -class rxLet(rxExpr): - def __init__(self, bindings, body): - self.bindings = bindings - self.body = body - -class rxCall(rxExpr): - def __init__(self, function, arguments): - self.function = function - self.arguments = arguments - -class rxGetBuiltin(rxExpr): - def __init__(self, builtin_name): - self.builtin_name = builtin_name - -class rxTensorSlice(rxExpr): - def __init__(self, tensor, indices): - self.tensor = tensor - self.indices = indices +# class rxFrontendMatchShapeExprs(rxFrontendNode): +# def __init__(self, lhs, rhs, span): +# self.lhs = lhs +# self.rhs = rhs +# super().__init__(span) + +# class rxShapeTuple(rxExpr): +# def __init__(self, dims, span): +# self.dims = dims +# super().__init__(span) + # TODO: What is this doing? -#expr.Function.__str__ = print_fn # type: ignore +# expr.Function.__str__ = print_fn # type: ignore # TODO: Replace with a real pretty print method once we have the real AST def pretty_print(f): print(f) + class RelaxTransformer(Transformer): def __init__(self, definition_scope, diag_ctx): + super().__init__() self.definition_scope = definition_scope self.diag_ctx = diag_ctx - self.str_to_var = {} - self.blocks = [] self.module = {} - super().__init__() + self._scopes = [{}] # str -> Var + + def new_scope(self): + class _Scope: + def __init__(self, transformer: "RelaxTransformer"): + self.transformer = transformer + + def __enter__(self): + self.transformer._scopes.append(self.transformer._scopes[-1].copy()) + + def __exit__(self, *exc): + assert len(self.transformer._scopes) > 1, "cannot pop root scope" + self.transformer._scopes.pop() + + return _Scope(self) + + @property + def scope(self): + return self._scopes[-1] - def span_to_span(self, span: synr.Span) -> tvm.ir.Span: + def tvm_span(self, span: synr.Span) -> tvm.ir.Span: + """Converts the synr span to a TVM span + + Parameters + ---------- + span : synr.Span + The synr span + + Returns + ------- + tvm.ir.Span + The corresponding TVM span + """ src_name = self.diag_ctx.str_to_source_name[span.filename] - tvm_span = tvm.ir.Span(src_name, span.start_line, span.end_line, span.start_column, span.end_column) + tvm_span = tvm.ir.Span( + src_name, span.start_line, span.end_line, span.start_column, span.end_column + ) return tvm_span - - def decl_var(self, name, ty, shape_annotation, span=None): - var = rxVar(name, ty, shape_annotation, span) - self.str_to_var[name] = var + def decl_var( + self, + name: str, + type_annotation: Optional[rxType], + shape: Optional[rxExpr], + span: tvm.ir.Span, + ) -> rxVar: + """Introduces a variable with the given name and annotations to the current scope. + + Parameters + ---------- + name : str + The name of the variable + type_annotation : Optional[rxType] + The type annotation + shape : Optional[rxExpr] + The shape annotation + span : tvm.ir.Span + The span where the variable is declared + + Returns + ------- + rxVar + The declared variable + """ + if name in self.scope: + self._diagnostic_context.emit( + "error", "variable has already been declared in the current scope", span + ) + self._diagnostic_context.render() + var = rxVar(name, type_annotation, shape, span) + self.scope[name] = var return var - # Returns type, shape_annotation - def to_type(self, ty): + def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rxType, rxShape]: + """Transforms the given synr type annotation to a Relax type and shape expression. + + Parameters + ---------- + ty : ast.Type + The synr type + allow_intro : bool + Whether or not the shape annotation can introduce new dimension variables + + Returns + ------- + Tuple[rxType, rxExpr]: + The corresponding Relax type and shape expression + """ if ty is None: - return None + return (None, None) + span = self.tvm_span(ty.span) + + # simple annotation with no type arguments if isinstance(ty, ast.TypeVar): if ty.id.name == "Tensor": - span = self.span_to_span(ty.span) - return rxTensor(None, span), None - + return (rxTensor(None, span), None) + elif ty.id.name == "Shape": + return (rxShape(span), None) + elif ty.id.name == "Dim": + return (rxDim(span), None) + else: + self._diagnostic_context.emit("error", "unknown type in annotation", span) + self._diagnostic_context.render() + + # annotation with type arguments/shape annotation if isinstance(ty, ast.TypeApply): if ty.id.name == "Tensor": - # TODO: add more dtypes, maybe define elsewhere - allowed_dtypes = ["int32", "float32", "int8", "fp16"] - dims = [] - dtype = "" - assert len(ty.params) == 2 - shape_param = ty.params[0] - dtype_param = ty.params[1] - # Check whether the Tensor is shape / dtype erased - shape_erased = False - dtype_erased = False - - if (isinstance(shape_param, ast.TypeVar)) and shape_param.id.name == "_": - shape_erased = True - if (isinstance(dtype_param, ast.TypeVar)) and dtype_param.id.name == "_": - dtype_erased = True - - if not shape_erased: - if isinstance(shape_param, ast.TypeTuple): - for shape_dim in shape_param.values: - - # TODO: use to_primexpr or whatever - if isinstance(shape_dim, ast.Var): - dim = rxDim(tir.Var(shape_dim.name)) - elif isinstance(shape_dim, ast.Constant) and isinstance(shape_dim.value, int): - dim = rxDim(tir.IntImm("int32", shape_dim.value)) - else: - self._diagnostic_context.emit('error', "shape annotation must be only vars or consts for now", self.span_to_span(ty.span)) - self._diagnostic_context.render() - dims.append(dim) - else: - self._diagnostic_context.emit('error', "Tensor shape must be erased or be a tuple", self.span_to_span(ty.span)) - self._diagnostic_context.render() - if not dtype_erased: - if dtype_param.value in allowed_dtypes: - dtype = dtype_param.value - else: - self._diagnostic_context.emit('error', "dtype must be erased or one of " + str(allowed_dtypes), self.span_to_span(ty.span)) + if len(ty.params) != 2: + self._diagnostic_context.emit( + "error", + "Tensor type annotations must have 2 fields (shape and dtype)", + span, + ) + self._diagnostic_context.render() + + shape_annotation, dtype_annotation = ty.params + shape, dtype = None, None + + # parse the shape annotation + if isinstance(shape_annotation, ast.TypeVar): + if shape_annotation.id.name != "_": + # TODO: handle variable annotations, e.g. x: Tensor[my_shape, _] + self._diagnostic_context.emit( + "error", + "variable Tensor shape annotations not yet supported", + self.tvm_span(shape_annotation.span), + ) self._diagnostic_context.render() - - return rxTensor(dtype, None), dims - - self._diagnostic_context.emit('error', "invalid type", self.span_to_span(ty.span)) - - # Turns a tuple into an array of PrimExprs - # Allow arithmetic indicates whether we are letting there be - def expr_to_primexpr(self, expr: ast.Expr, allow_arithmetic=False) -> PrimExpr: - if not allow_arithmetic and not isinstance(expr, ast.Var): - #TODO: improve error message - self._diagnostic_context.emit('error', "You can only use single variables here, not an expression", self.span_to_span(expr.span)) - self._diagnostic_context.render() - else: - if isinstance(expr, ast.Var): - return tir.Var(expr.id.name, "int32") - - # TODO: do all the ops here - elif isinstance(expr, ast.Constant) and isinstance(expr.value, int): - return tir.IntImm("int32", expr.value) - elif isinstance(expr, ast.Call): - if exp.func_name.name == ast.BuiltinOp.Add: - # TODO: call this fn on args and return primexpr containing result - assert False - if exp.func_name.name == ast.BuiltinOp.Sub: - assert False - if exp.func_name.name == ast.BuiltinOp.Mul: - assert False - if exp.func_name.name == ast.BuiltinOp.Div: - assert False - if exp.func_name.name == ast.BuiltinOp.Mod: - assert False + else: + pass # shape = None + elif isinstance(shape_annotation, ast.TypeTuple): + shape = self.parse_shape(shape_annotation, allow_intro) + else: + self._diagnostic_context.emit( + "error", + "unsupported shape annotation", + self.tvm_span(shape_annotation.span), + ) + self._diagnostic_context.render() + + # parse the dtype annotation + if isinstance(dtype_annotation, ast.TypeVar) and dtype_annotation.id.name == "_": + pass # dtype = None + elif isinstance(dtype_annotation, ast.TypeConstant): + dtype = dtype_annotation.value # TODO: parse to TVM DType? + else: + self._diagnostic_context.emit( + "error", + "Tensor dtype annotations must be concrete or erased", + self.tvm_span(dtype_annotation.span), + ) + self._diagnostic_context.render() + + return (rxTensor(dtype, span), shape) + # TODO: other types with args, e.g. Ref[T], Tuple[Ts...], func types + self._diagnostic_context.emit("error", "invalid type", span) + + def parse_shape( + self, shape_annotation: Union[ast.TypeTuple, ast.Tuple], allow_intro: bool + ) -> List[tir.PrimExpr]: + """Parses the given shape annotation to a list of PrimExprs + + Parameters + ---------- + shape_annotation : ast.TypeTuple + The shape annotation in synr + allow_intro : bool + Whether or not the annotation can bind previously free variables + + Returns + ------- + List[tir.PrimExpr] + The parsed shape as a list of PrimExprs + """ + return [self.parse_primexpr(field, allow_intro) for field in shape_annotation.values] + + def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: + """Parses the given expression to a PrimExpr + + Parameters + ---------- + expr : ast.Expr + The input expression + allow_intro : bool + Whether or not the expression can bind previously free variables + + Returns + ------- + tir.PrimExpr + The result PrimExpr + """ + if isinstance(expr, ast.Var): + var_name = expr.id.name + if var_name in self.scope: + var = self.scope[var_name] + if not isinstance(var, tir.Var): + self._diagnostic_context.emit( + "error", + "non-dimension variables cannot appear in dimension expressions", + self.tvm_span(expr.span), + ) + self._diagnostic_context.render() + return var + elif allow_intro: + # introduce TIR variable to scope, e.g. for func params or rx.call_packed + var = tir.Var(var_name, "int32", self.tvm_span(expr.span)) + self.scope[var_name] = var + return var else: - self._diagnostic_context.emit('error', "The shape expression can only contain arithmetic operators, integer constants and variables", self.span_to_span(expr.span)) + self._diagnostic_context.emit( + "error", + "cannot introduce new dimension variables in this expression", + self.tvm_span(expr.span), + ) self._diagnostic_context.render() + else: + # TODO: parse (simple) PrimExprs + self._diagnostic_context.emit( + "error", "only dimension variable expressions are currently supported" + ) + self._diagnostic_context.render() + # Turns a tuple into an array of PrimExprs + # Allow arithmetic indicates whether we are letting there be + # def expr_to_primexpr(self, expr: ast.Expr, allow_arithmetic=False) -> PrimExpr: + # if not allow_arithmetic and not isinstance(expr, ast.Var): + # # TODO: improve error message + # self._diagnostic_context.emit( + # "error", + # "You can only use single variables here, not an expression", + # self.span_to_span(expr.span), + # ) + # self._diagnostic_context.render() + # else: + # if isinstance(expr, ast.Var): + # return tir.Var(expr.id.name, "int32") + + # # TODO: do all the ops here + # elif isinstance(expr, ast.Constant) and isinstance(expr.value, int): + # return tir.IntImm("int32", expr.value) + # elif isinstance(expr, ast.Call): + # if exp.func_name.name == ast.BuiltinOp.Add: + # # TODO: call this fn on args and return primexpr containing result + # assert False + # if exp.func_name.name == ast.BuiltinOp.Sub: + # assert False + # if exp.func_name.name == ast.BuiltinOp.Mul: + # assert False + # if exp.func_name.name == ast.BuiltinOp.Div: + # assert False + # if exp.func_name.name == ast.BuiltinOp.Mod: + # assert False + # else: + # self._diagnostic_context.emit( + # "error", + # "The shape expression can only contain arithmetic operators, integer constants and variables", + # self.tvm_span(expr.span), + # ) + # self._diagnostic_context.render() def transform_module(self, mod: ast.Module) -> IRModule: for func_name in mod.funcs: @@ -260,141 +328,238 @@ def transform_module(self, mod: ast.Module) -> IRModule: return self.module def transform_function(self, func: ast.Function) -> rxFunction: - params = [] - for param in func.params: - ty, shape_dims = self.to_type(param.ty) - param = self.decl_var(param.name, ty, None) - params.append(param) - new_body = self.transform_block(func.body) - ret_type = self.to_type(func.ret_type) - return rxFunction(func.name, params, new_body, ret_type, None) - - def transform_stmt(self, stmt: ast.Stmt) -> relax.Expr: + with self.new_scope(): + params = [] + for param in func.params: + ty, shape = self.transform_type(param.ty, allow_intro=True) + param = self.decl_var(param.name, ty, shape, self.tvm_span(param.span)) + params.append(param) + new_body = self.transform_block(func.body) + ret_type, _ = self.transform_type(func.ret_type, allow_intro=False) + return rxFunction(func.name, params, new_body, ret_type, self.tvm_span(func.span)) + + # Stmts: + # - Assert: probably unsupported for now + # - Assign: VarBinding + # - For: ?? + # - If: IfThenElse, must check no empty false branch + # - Return: just the returned expression, must terminate blocks? (special case if-else) + # - UnassignedCall: match_shape + # - With: rx.dataflow + def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: if isinstance(stmt, ast.Assign): - assert isinstance(stmt.lhs, ast.Var) - lhs = self.decl_var(stmt.lhs.id.name, None, None) + if not isinstance(stmt.lhs, ast.Var): + self._diagnostic_context.emit( + "error", + "the left hand side of a binding must be a variable", + self.tvm_span(stmt.lhs.span), + ) + self._diagnostic_context.render() + # TODO: figure out proper way of doing this rhs = self.transform_expr(stmt.rhs) - # TODO: Replace with relax node - self.blocks[-1].append(rxBinding(lhs, rhs)) - return None + if isinstance(rhs, rxCall) and rhs.op == "rx.call_packed": + allow_intro = True + else: + allow_intro = False + ty, shape = self.transform_type(stmt.ty, allow_intro) + lhs = self.decl_var(stmt.lhs.id.name, ty, shape, self.tvm_span(stmt.lhs.span)) + return rxVarBinding(lhs, rhs, self.tvm_span(stmt.span)) + elif isinstance(stmt, ast.If): + # TODO: proper diagnostics + + # check branches are non-empty + assert stmt.true.stmts + assert stmt.false.stmts + true_assign = stmt.true.stmts[-1] + false_assign = stmt.false.stmts[-1] + + # check last statement in each branch lines up + assert isinstance(true_assign, ast.Assign) and isinstance(true_assign.lhs, ast.Var) + assert isinstance(false_assign, ast.Assign) and isinstance(false_assign.lhs, ast.Var) + assert true_assign.lhs.id.name == false_assign.lhs.id.name + var_name = true_assign.lhs.id.name + + # rewrite branches to have a return statement so the blocks properly parse to SeqExprs + true_block = synr.ast.Block( + span=stmt.true.span, + stmts=stmt.true.stmts[:-1] + + synr.ast.Return(span=true_assign.span, value=true_assign.rhs), + ) + false_block = synr.ast.Block( + span=stmt.false.span, + stmts=stmt.false.stmts[:-1] + + synr.ast.Return(span=false_assign.span, value=false_assign.rhs), + ) + + # parse the branches, build the final expression and binding + cond = self.transform_expr(stmt.condition) + with self.new_scope(): + true_branch = self.transform_block(true_block) + with self.new_scope(): + false_branch = self.transform_block(false_block) + # TODO: the spans here are all sorts of messed up, not sure how to fix + ite_expr = rxIfThenElse(cond, true_branch, false_branch, self.tvm_span(stmt.span)) + var = self.decl_var(var_name, None, None, self.tvm_span(false_assign.span)) + return rxVarBinding(var, ite_expr, self.tvm_span(stmt.span)) elif isinstance(stmt, ast.Return): return self.transform_expr(stmt.value) # match_shape is the ONLY node that doesn't have to be bound to an LHS variable! - elif (isinstance(stmt, ast.UnassignedCall) and isinstance(stmt.call.func_name, ast.Var) - and stmt.call.func_name.id.name == "match_shape"): + elif isinstance(stmt, ast.UnassignedCall): + call: synr.ast.Call = stmt.call + op = self.transform_expr(call.func_name) + if op != "rx.match_shape": + self._diagnostic_context.emit( + "error", "the results of operator calls must be bound", self.tvm_span(stmt.span) + ) + self._diagnostic_context.render() if len(stmt.call.params) != 2: - self._diagnostic_context.emit('error', "match_shape takes exactly two arguments", self.span_to_span(stmt.span)) + self._diagnostic_context.emit( + "error", "rx.match_shape takes exactly two arguments", self.tvm_span(stmt.span) + ) self._diagnostic_context.render() lhs = stmt.call.params[0] rhs = stmt.call.params[1] - - # If RHS is a tuple, turn it into a ShapeTuple, otherwise, process normally - if isinstance(rhs, ast.Tuple): - arithmetic_primexprs = [] - for elem in rhs.values: - arithmetic_primexprs.append(self.expr_to_primexpr(elem, allow_arithmetic=True)) - rhs_expr = rxShapeTuple(arithmetic_primexprs, self.span_to_span(rhs.span)) - else: - rhs_expr = self.transform_expr(rhs) - - # If LHS is a tuple of variables, then we use the binding match shape - # If it is an Expr, we use the sugared match_shape (and will insert bindings later) - if isinstance(lhs, ast.Tuple): - binding_tir_vars = [] - for elem in lhs.values: - # Here we only are defining variables so we don't allow arithmetic expressions - binding_tir_vars.append(self.expr_to_primexpr(elem)) - self.blocks[-1].append(rxMatchShapeBinding(binding_tir_vars, rhs_expr)) - else: - lhs_expr = self.transform_expr(lhs) - self.blocks[-1].append(rxFrontendMatchShapeExprs(lhs_expr, rhs_expr, stmt.span)) + + rhs_expr = self.transform_expr(rhs) + if not isinstance(lhs, ast.Tuple): + self._diagnostic_context.emit( + "error", + "the pattern (lhs) of rx.match_shape must be a tuple", + self.tvm_span(lhs.span), + ) + self._diagnostic_context.render() + lhs_expr = self.parse_shape(lhs, allow_intro=True) + return rxMatchShape(lhs_expr, rhs_expr, self.tvm_span(stmt.span)) + elif isinstance(stmt, ast.With): + assert False, "todo with/dataflow" else: - self._diagnostic_context.emit('error', "only variable left-hand sides are supported in Relay", self.span_to_span(stmt.span)) + self._diagnostic_context.emit( + "error", + "unsupported statement", + self.tvm_span(stmt.span), + ) self._diagnostic_context.render() - def transform_expr(self, exp: ast.Expr) -> rxExpr: - if isinstance(exp, ast.Call): - if isinstance(exp.func_name, ast.Var): - params = [] - for arg in exp.params: - params.append(self.transform_expr(arg)) - - if exp.func_name.id.name in self.str_to_var: - return self.str_to_var[exp.func_name.id.name] + # Exprs: + # - ArrayLiteral + # - Attr + # - Call + # - Constant + # - DictLiteral + # - Slice + # - Tuple + # - Var + def transform_expr(self, expr: ast.Expr) -> rxExpr: + if isinstance(expr, ast.Attr): + obj = self.transform_expr(expr.object) + field_name = expr.field.name + # TODO: use some kind of proper identifier? str bad + if isinstance(obj, str): + return obj + "." + field_name + elif field_name == "shape": + return rxCall("rx.shape_of", obj, self.tvm_span(expr.span)) + else: + self._diagnostic_context.emit( + "error", "unsupported attribute", self.tvm_span(expr.span) + ) + self._diagnostic_context.render() + if isinstance(expr, ast.Call): + op = expr.func_name + if isinstance(op, ast.Var): + args = [] + for arg in expr.params: + args.append(self.transform_expr(arg)) + if op in self.scope: + op = self.scope[op] else: - name = exp.func_name.id.name - relax_fn = getattr(self.definition_scope, name, None) - # builtin operator - if relax_fn is None: - return rxCall(rxGetBuiltin(name), params, None) - else: - self.module[name] = relax_fn.module[name] - # todo: globalvar equality? use global str -> id map? - ident = Id(exp.func_name.id.name) - return rxCall(rxGlobalVar(ident, None, None), params, None) - - elif isinstance(exp.func_name, ast.Op): - if exp.func_name.name == ast.BuiltinOp.Subscript: - tensor = self.transform_expr(exp.params[0]) - indicies = [] - for index in exp.params[1].values: - indicies.append(self.transform_expr(index)) - # TODO: Replace with relax node - return rxTensorSlice(tensor, indicies, None) - elif exp.func_name.name == ast.BuiltinOp.Add: - params = [] - for arg in exp.params: - params.append(self.transform_expr(arg)) - # TODO: Replace with relax node - return rxCall("add", [params[0], params[1]], None) - - self._diagnostic_context.emit('error', "unsupported function", self.span_to_span(exp.span)) - self._diagnostic_context.render() - elif isinstance(exp, ast.Attr): - field_name = exp.field.name - tensor = self.transform_expr(exp.object) - - if field_name == "shape": - return rxShapeOf(tensor) + # TODO: fix + op = op.id.name + return rxCall(op, args, self.tvm_span(expr.span)) + # if exp.func_name.id.name in self.str_to_var: + # return self.str_to_var[exp.func_name.id.name] + # else: + # name = exp.func_name.id.name + # relax_fn = getattr(self.definition_scope, name, None) + # # builtin operator + # if relax_fn is None: + # return rxCall(rxGetBuiltin(name), params, None) + # else: + # self.module[name] = relax_fn.module[name] + # # todo: globalvar equality? use global str -> id map? + # ident = Id(exp.func_name.id.name) + # return rxCall(rxGlobalVar(ident, None, None), params, None) + elif isinstance(op, ast.Op): + assert False, "TODO: sugar for python built in operators" + # if exp.func_name.name == ast.BuiltinOp.Subscript: + # tensor = self.transform_expr(exp.params[0]) + # indicies = [] + # for index in exp.params[1].values: + # indicies.append(self.transform_expr(index)) + # # TODO: Replace with relax node + # return rxTensorSlice(tensor, indicies, None) + # elif exp.func_name.name == ast.BuiltinOp.Add: + # params = [] + # for arg in exp.params: + # params.append(self.transform_expr(arg)) + # # TODO: Replace with relax node + # return rxCall("add", [params[0], params[1]], None) else: - self._diagnostic_context.emit('error', "unsupported function", self.span_to_span(exp.span)) + self._diagnostic_context.emit( + "error", "unsupported function", self.tvm_span(expr.span) + ) self._diagnostic_context.render() - elif isinstance(exp, ast.Function): - return self.transform_function(exp) - elif isinstance(exp, ast.Tuple): - assert False - elif isinstance(exp, ast.Var): - return self.str_to_var[exp.id.name] + elif isinstance(expr, ast.Tuple): + fields = [self.transform_expr(field) for field in expr.values] + return rxTuple(fields, self.tvm_span(expr.span)) + elif isinstance(expr, ast.Var): + var_name = expr.id.name + if var_name == "rx": + return "rx" + if var_name not in self.scope: + self._diagnostic_context.emit( + "error", "undefined variable", self.tvm_span(expr.span) + ) + self._diagnostic_context.render() + return self.scope[var_name] else: - self._diagnostic_context.emit('error', f"don't support this construct {type(exp)}", self.span_to_span(exp.span)) + self._diagnostic_context.emit( + "error", "unsupported expression", self.tvm_span(expr.span) + ) self._diagnostic_context.render() - def enter_block(self): - self.blocks.append([]) - - def exit_block(self): - back = self.blocks[-1] - self.blocks.pop() - return back - - def transform_block(self, block: ast.Block) -> relax.expr: - self.enter_block() - + def transform_block(self, block: ast.Block) -> rxSeqExpr: + # a block of statements needs to be converted to a SeqExpr of binding blocks + blocks = [] + current_block = [] for stmt in block.stmts[:-1]: - assert self.transform_stmt(stmt) is None - - ret_expr = self.transform_stmt(block.stmts[-1]) - # assert ret_expr is not None - - bindings = self.exit_block() - return rxLet(bindings, ret_expr) + parsed_stmt = self.transform_stmt(stmt) + if isinstance(parsed_stmt, rxDataflowBlock): + assert len(current_block) > 0, "should never have an empty block" + blocks.append(current_block) + blocks.append(parsed_stmt) + current_block = [] + else: + assert isinstance(parsed_stmt, rxBinding) + current_block.append(parsed_stmt) + if len(current_block) > 0: + blocks.append(current_block) + + ret_stmt = block.stmts[-1] + if not isinstance(ret_stmt, ast.Return): + self._diagnostic_context.emit( + "error", + "block must end with a returned expression", + self.tvm_span(ret_stmt.span), + ) + self._diagnostic_context.render() + ret_expr = self.transform_stmt(ret_stmt) + + return rxSeqExpr(blocks, ret_expr, self.tvm_span(block.span)) - def transform_parameter(self, expr: ast.Parameter) -> relax.Expr: + def transform_parameter(self, expr: ast.Parameter) -> rxExpr: pass - def transform_type(self, ty: ast.Type) -> relax.Type: - pass class TVMDiagnosticContext(synr.DiagnosticContext): def __init__(self, tvm_diag_ctx): @@ -434,6 +599,7 @@ def render(self) -> Optional[Any]: """ self.tvm_diag_ctx.render() + class RelaxDecoratedFn: def __init__(self, fn_name, relax_module, diag_ctx): self.fn_name = fn_name @@ -449,6 +615,7 @@ def __call__(self, *args): # compiled_f(*(list(args) + [out])) # return out + def relax(f): ir_module = tvm.IRModule({}) diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) @@ -458,8 +625,13 @@ def relax(f): module = RelaxTransformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) return RelaxDecoratedFn(f.__name__, module, diag_ctx) + @relax def my_test(x: Tensor[_, "float32"]): - match_shape(x.shape, (1, 2, 3)) + rx.match_shape((n, m), x.shape) + y = mul(x, x) + return y + -print(my_test.module['my_test']) \ No newline at end of file +f = my_test.module["my_test"] +import pdb; pdb.set_trace() From 753d2ee7f84f587d907afd8e627e619ac6b1a8bc Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 18 Aug 2021 10:01:44 -0700 Subject: [PATCH 10/20] some parser tests --- python/tvm/relax/parser.py | 56 +++++---- tests/python/relax/parser.py | 218 +++++++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 23 deletions(-) create mode 100644 tests/python/relax/parser.py diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 20c657a89a..e3d352c45a 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -8,7 +8,7 @@ from tvm.ir.module import IRModule from tvm.relay.base import Id from tvm.ir import diagnostics -from tvm import tir, relax +from tvm import tir import numpy as np @@ -190,6 +190,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rxType, rxSha ) self._diagnostic_context.render() else: + # FIXME: use a special node for unknown shape vs no shape? pass # shape = None elif isinstance(shape_annotation, ast.TypeTuple): shape = self.parse_shape(shape_annotation, allow_intro) @@ -215,8 +216,17 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rxType, rxSha self._diagnostic_context.render() return (rxTensor(dtype, span), shape) - # TODO: other types with args, e.g. Ref[T], Tuple[Ts...], func types + elif ty.id.name == "Tuple": + field_types = [] + field_shapes = [] + for field in ty.params: + fty, fsh = self.transform_type(field, allow_intro=False) + field_types.append(fty) + field_shapes.append(fsh) + return rxTupleType(field_types, self.tvm_span(ty.span)), field_shapes + # TODO: other types with args, e.g. Ref[T], func types self._diagnostic_context.emit("error", "invalid type", span) + self._diagnostic_context.render() def parse_shape( self, shape_annotation: Union[ast.TypeTuple, ast.Tuple], allow_intro: bool @@ -276,10 +286,19 @@ def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: self.tvm_span(expr.span), ) self._diagnostic_context.render() + elif isinstance(expr, ast.Constant): + if not isinstance(expr.value, int): + self._diagnostic_context.emit( + "error", "only integer constants are supported", self.tvm_span(expr.span) + ) + self._diagnostic_context.render() + return tir.const(expr.value, "int32", self.tvm_span(expr.span)) else: # TODO: parse (simple) PrimExprs self._diagnostic_context.emit( - "error", "only dimension variable expressions are currently supported" + "error", + "only dimension variable expressions are currently supported", + self.tvm_span(expr.span), ) self._diagnostic_context.render() @@ -383,12 +402,12 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: true_block = synr.ast.Block( span=stmt.true.span, stmts=stmt.true.stmts[:-1] - + synr.ast.Return(span=true_assign.span, value=true_assign.rhs), + + [synr.ast.Return(span=true_assign.span, value=true_assign.rhs)], ) false_block = synr.ast.Block( span=stmt.false.span, stmts=stmt.false.stmts[:-1] - + synr.ast.Return(span=false_assign.span, value=false_assign.rhs), + + [synr.ast.Return(span=false_assign.span, value=false_assign.rhs)], ) # parse the branches, build the final expression and binding @@ -407,6 +426,8 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: elif isinstance(stmt, ast.UnassignedCall): call: synr.ast.Call = stmt.call op = self.transform_expr(call.func_name) + # FIXME: this check is unreachable since transform_expr tries looking up func_name as a + # variable and fails if op != "rx.match_shape": self._diagnostic_context.emit( "error", "the results of operator calls must be bound", self.tvm_span(stmt.span) @@ -442,12 +463,12 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: self._diagnostic_context.render() # Exprs: - # - ArrayLiteral - # - Attr + # - ArrayLiteral: unsupported for now? + # - Attr: use for .shape, and intrinsic/special operator namespace # - Call # - Constant - # - DictLiteral - # - Slice + # - DictLiteral: unsupported for now + # - Slice: unsupported for now, could desugar to slice op # - Tuple # - Var def transform_expr(self, expr: ast.Expr) -> rxExpr: @@ -458,7 +479,7 @@ def transform_expr(self, expr: ast.Expr) -> rxExpr: if isinstance(obj, str): return obj + "." + field_name elif field_name == "shape": - return rxCall("rx.shape_of", obj, self.tvm_span(expr.span)) + return rxCall("rx.shape_of", [obj], self.tvm_span(expr.span)) else: self._diagnostic_context.emit( "error", "unsupported attribute", self.tvm_span(expr.span) @@ -554,7 +575,7 @@ def transform_block(self, block: ast.Block) -> rxSeqExpr: ) self._diagnostic_context.render() ret_expr = self.transform_stmt(ret_stmt) - + return rxSeqExpr(blocks, ret_expr, self.tvm_span(block.span)) def transform_parameter(self, expr: ast.Parameter) -> rxExpr: @@ -616,7 +637,7 @@ def __call__(self, *args): # return out -def relax(f): +def script(f): ir_module = tvm.IRModule({}) diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) diag_ctx = TVMDiagnosticContext(diag_ctx) @@ -624,14 +645,3 @@ def relax(f): definition_scope = inspect.getmodule(f) module = RelaxTransformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) return RelaxDecoratedFn(f.__name__, module, diag_ctx) - - -@relax -def my_test(x: Tensor[_, "float32"]): - rx.match_shape((n, m), x.shape) - y = mul(x, x) - return y - - -f = my_test.module["my_test"] -import pdb; pdb.set_trace() diff --git a/tests/python/relax/parser.py b/tests/python/relax/parser.py new file mode 100644 index 0000000000..ecfed28819 --- /dev/null +++ b/tests/python/relax/parser.py @@ -0,0 +1,218 @@ +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import tvm +from tvm import relax as rx +from tvm import tir + + +def rx_func(func): + return func.module[func.fn_name] + + +def check_shape(e, s): + if not isinstance(e, (list, tuple)) and e is not None: + e = e._shape + + if s is None: + assert e is None + return + + assert isinstance(e, (list, tuple)) + assert len(e) == len(s) + + for edim, sdim in zip(e, s): + if isinstance(sdim, str): + assert isinstance(edim, tir.Var) + assert edim.name == sdim + else: + assert isinstance(edim, tir.IntImm) + assert edim.value == sdim + + +def check_tensor_var(v, s, d): + assert isinstance(v.type_annotation, rx.ty.rxTensor) + assert v.type_annotation.dtype == d + check_shape(v, s) + + +def test_annotations(): + @rx.script + def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: + z: Tensor[(32, k), "float32"] = matmul(x, y) + w: Tensor[_, _] = mul(z, z) + t = sub(w, z) + sh: Shape = t.shape + return t + + f = rx_func(foo) + x, y = f.params + z_bind, w_bind, t_bind, sh_bind = f.body.blocks[0] + z, mm = z_bind.var, z_bind.value + w, mul = w_bind.var, w_bind.value + t, sub = t_bind.var, t_bind.value + sh, shape_of = sh_bind.var, sh_bind.value + + check_tensor_var(x, (32, "m"), "float32") + check_tensor_var(y, ("m", "k"), "float32") + check_tensor_var(z, (32, "k"), "float32") + check_tensor_var(w, None, None) + assert t.type_annotation is None + assert isinstance(sh.type_annotation, rx.ty.rxShape) + + assert mm.op == "matmul" + assert mm.args == [x, y] + + assert mul.op == "mul" + assert mul.args == [z, z] + + assert sub.op == "sub" + assert sub.args == [w, z] + + assert shape_of.op == "rx.shape_of" + assert shape_of.args == [t] + + assert f.body.body == t + + assert isinstance(f.ret_type, rx.ty.rxTensor) + + +def test_match_shape(): + @rx.script + def foo(x: Tensor[_, "float32"]): + rx.match_shape((n, m), x.shape) + y: Tensor[(n, m), "float32"] = refine(x) + return x + + f = rx_func(foo) + match_sh = f.body.blocks[0][0] + pattern, value = match_sh.pattern, match_sh.value + + check_shape(pattern, ("n", "m")) + assert isinstance(value, rx.expr.rxCall) + assert value.op == "rx.shape_of" + assert value.args == [f.params[0]] + + +@pytest.mark.xfail +def test_dim_var_intro_fail(): + @rx.script + def foo(x: Tensor[_, _]): + y: Tensor[(n, m), "float32"] = x + return y + + +def test_if(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = mul(w, w) + else: + w = mul(x, x) + y = add(w, w) + return y + + f = rx_func(foo) + cond, x = f.params + y_bind = f.body.blocks[0][0] + y, ite = y_bind.var, y_bind.value + + check_tensor_var(cond, tuple(), "bool") + check_tensor_var(x, (1,), "float32") + + assert isinstance(y, rx.expr.rxVar) + assert y.id == "y" + + assert isinstance(ite, rx.expr.rxIfThenElse) + assert isinstance(ite.true_branch, rx.expr.rxSeqExpr) + assert isinstance(ite.false_branch, rx.expr.rxSeqExpr) + + w_bind = ite.true_branch.blocks[0][0] + body = ite.true_branch.body + assert w_bind.var.id == "w" + assert isinstance(w_bind.value, rx.expr.rxCall) + assert w_bind.value.op == "add" and w_bind.value.args == [x, x] + assert isinstance(body, rx.expr.rxCall) + assert body.op == "mul" and body.args == [w_bind.var, w_bind.var] + + w_bind = ite.false_branch.blocks[0][0] + body = ite.false_branch.body + assert w_bind.var.id == "w" + assert isinstance(w_bind.value, rx.expr.rxCall) + assert w_bind.value.op == "mul" and w_bind.value.args == [x, x] + assert isinstance(body, rx.expr.rxCall) + assert body.op == "add" and body.args == [w_bind.var, w_bind.var] + + +# TODO: figure out if-else binding type and shape + + +@pytest.mark.xfail +def test_var_redefine_fail(): + @rx.script + def foo(x, y): + z = add(x, y) + y = z + return y + + +@pytest.mark.xfail +def test_var_redefine_fail_if(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + y = x + if cond: + w = add(x, x) + y = mul(w, w) + else: + w = mul(x, x) + y = add(w, w) + return y + + +@pytest.mark.xfail +def test_var_if_scoping_fail(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = mul(w, w) + else: + w = mul(x, x) + y = add(w, w) + return w + + +@pytest.mark.xfail +def test_unassigned_call_fail(): + @rx.script + def foo(x: Tensor[_, _]): + add(x, x) + return x + + +def test_tuple(): + @rx.script + def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]): + t: Tuple[Tensor[_, _], Tensor[(32,), "float32"]] = (x, y) + return t + + f = rx_func(foo) + x, y = f.params + t_bind = f.body.blocks[0][0] + t, tup = t_bind.var, t_bind.value + + assert isinstance(t.type_annotation, rx.ty.rxTupleType) + annot = t.type_annotation + assert isinstance(annot.fields[0], rx.ty.rxTensor) and annot.fields[0].dtype is None + assert isinstance(annot.fields[1], rx.ty.rxTensor) and annot.fields[1].dtype == "float32" + + assert isinstance(t._shape, list) and len(t._shape) == 2 + check_shape(t._shape[0], None) + check_shape(t._shape[1], (32,)) + + assert isinstance(tup, rx.expr.rxTuple) + assert tup.fields == [x, y] + assert isinstance(tup._shape, list) and len(tup._shape) == 2 + check_shape(tup._shape[0], None) + check_shape(tup._shape[1], (32,)) From 3fd7e5ddc3d1dfac6cb673a8090c435c78e3b6b2 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 18 Aug 2021 15:47:43 -0700 Subject: [PATCH 11/20] yolo dataflow --- python/tvm/relax/parser.py | 96 +++++++++++++++++++++++++++++++++--- tests/python/relax/parser.py | 82 ++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index e3d352c45a..9bce251447 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -383,6 +383,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: ty, shape = self.transform_type(stmt.ty, allow_intro) lhs = self.decl_var(stmt.lhs.id.name, ty, shape, self.tvm_span(stmt.lhs.span)) return rxVarBinding(lhs, rhs, self.tvm_span(stmt.span)) + elif isinstance(stmt, ast.If): # TODO: proper diagnostics @@ -420,8 +421,10 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: ite_expr = rxIfThenElse(cond, true_branch, false_branch, self.tvm_span(stmt.span)) var = self.decl_var(var_name, None, None, self.tvm_span(false_assign.span)) return rxVarBinding(var, ite_expr, self.tvm_span(stmt.span)) + elif isinstance(stmt, ast.Return): return self.transform_expr(stmt.value) + # match_shape is the ONLY node that doesn't have to be bound to an LHS variable! elif isinstance(stmt, ast.UnassignedCall): call: synr.ast.Call = stmt.call @@ -452,8 +455,44 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: self._diagnostic_context.render() lhs_expr = self.parse_shape(lhs, allow_intro=True) return rxMatchShape(lhs_expr, rhs_expr, self.tvm_span(stmt.span)) + elif isinstance(stmt, ast.With): - assert False, "todo with/dataflow" + if not isinstance(stmt.rhs, ast.Call): + self._diagnostic_context.emit( + "error", "unsupported with block", self.tvm_span(stmt.span) + ) + self._diagnostic_context.render() + + call = stmt.rhs + op = self.transform_expr(call.func_name) + + # TODO: perhaps this ought to be more general + + if op != "rx.dataflow": + self._diagnostic_context.emit( + "error", "unsupported with block type", self.tvm_span(call.span) + ) + self._diagnostic_context.render() + if len(call.params) > 0: + self._diagnostic_context.emit( + "error", + "dataflow block constructor takes no arguments", + self.tvm_span(call.params[0].span), + ) + self._diagnostic_context.render() + if len(stmt.lhs) > 0: + self._diagnostic_context.emit( + "error", "dataflow blocks don't bind any patterns", self.tvm_span(stmt.lhs[0].span) + ) + self._diagnostic_context.render() + + return self.parse_dataflow(stmt.body) + + elif isinstance(stmt, ast.Function): + func = self.transform_function(stmt) + func_var = self.decl_var(func.name, None, None, self.tvm_span(stmt.span)) + return rxVarBinding(func_var, func, self.tvm_span(stmt.span)) + else: self._diagnostic_context.emit( "error", @@ -462,6 +501,51 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: ) self._diagnostic_context.render() + def parse_dataflow(self, block: ast.Block): + assert len(block.stmts) > 0, "should never have an empty dataflow block" + bindings = [] + + with self.new_scope(): + for binding_stmt in block.stmts[:-1]: + if not isinstance(binding_stmt, ast.Assign): + self._diagnostic_context.emit( + "error", + "only bindings are supported in dataflow blocks", + self.tvm_span(binding_stmt.span), + ) + self._diagnostic_context.render() + binding = self.transform_stmt(binding_stmt) + bindings.append(binding) + + output_stmt = block.stmts[-1] + if not isinstance(output_stmt, ast.Return): + self._diagnostic_context.emit( + "error", + "dataflow blocks must end with returning the output variables", + self.tvm_span(output_stmt.span), + ) + self._diagnostic_context.render() + + ret_val = output_stmt.value + if isinstance(ret_val, ast.Var): + ret_val = ast.Tuple(values=[ret_val], span=ret_val.span) + + if not isinstance(ret_val, ast.Tuple) or any([not isinstance(f, ast.Var) for f in ret_val.values]): + self._diagnostic_context.emit( + "error", + "the returned values must be variables", + self.tvm_span(ret_val.span), + ) + + ret_vars = [self.transform_expr(v) for v in ret_val.values] + + # parent scope + for v in ret_vars: + self.scope[v.id] = v + + return rxDataflowBlock(bindings, ret_vars, self.tvm_span(block.span)) + + # Exprs: # - ArrayLiteral: unsupported for now? # - Attr: use for .shape, and intrinsic/special operator namespace @@ -491,8 +575,8 @@ def transform_expr(self, expr: ast.Expr) -> rxExpr: args = [] for arg in expr.params: args.append(self.transform_expr(arg)) - if op in self.scope: - op = self.scope[op] + if op.id.name in self.scope: + op = self.transform_expr(op) else: # TODO: fix op = op.id.name @@ -556,10 +640,10 @@ def transform_block(self, block: ast.Block) -> rxSeqExpr: for stmt in block.stmts[:-1]: parsed_stmt = self.transform_stmt(stmt) if isinstance(parsed_stmt, rxDataflowBlock): - assert len(current_block) > 0, "should never have an empty block" - blocks.append(current_block) + if current_block: + blocks.append(current_block) + current_block = [] blocks.append(parsed_stmt) - current_block = [] else: assert isinstance(parsed_stmt, rxBinding) current_block.append(parsed_stmt) diff --git a/tests/python/relax/parser.py b/tests/python/relax/parser.py index ecfed28819..f95a2216f3 100644 --- a/tests/python/relax/parser.py +++ b/tests/python/relax/parser.py @@ -5,6 +5,10 @@ from tvm import tir +# TODO: replace xfails with proper diagnostics checking. +# c.f. tests/python/unittest/test_tvmscript_error_report.py + + def rx_func(func): return func.module[func.fn_name] @@ -216,3 +220,81 @@ def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]): assert isinstance(tup._shape, list) and len(tup._shape) == 2 check_shape(tup._shape[0], None) check_shape(tup._shape[1], (32,)) + + +# NOTE: this test requires patching synr to support local function definitions. +# it's an easy change (just two lines), but may break other users of synr +# (e.g. tvmscript). should investigate. +def test_local_func(): + @rx.script + def foo(x: Tensor[_, _]): + def bar(y: Tensor[_, _]): + return y + z = bar(x) + return z + + f = rx_func(foo) + bar_bind, z_bind = f.body.blocks[0] + bar, bar_fn = bar_bind.var, bar_bind.value + bar_x = z_bind.value + + assert isinstance(bar_fn, rx.expr.rxFunction) + assert bar_fn.body.body == bar_fn.params[0] + + assert bar_x.op == bar + + +def test_dataflow(): + @rx.script + def foo(x: Tensor[_, _]): + with rx.dataflow(): + y = add(x, x) + z = mul(y, x) + w = sub(z, x) + return y, w + t = div(y, w) + return t + + f = rx_func(foo) + df_block = f.body.blocks[0] + + # TODO: check correctness + + +@pytest.mark.xfail +def test_dataflow_scope_fail(): + @rx.script + def foo(x: Tensor[_, _]): + with rx.dataflow(): + y = add(x, x) + z = mul(y, x) + w = sub(z, x) + return y, w + t = div(y, z) + return t + + +@pytest.mark.xfail +def test_dataflow_syntax_fail_pattern(): + @rx.script + def foo(x: Tensor[_, _]): + with rx.dataflow() as df: + y = add(x, x) + z = mul(y, x) + w = sub(z, x) + return y, w + t = div(y, z) + return t + + +@pytest.mark.xfail +def test_dataflow_syntax_fail_params(): + @rx.script + def foo(x: Tensor[_, _]): + with rx.dataflow(x) as df: + y = add(x, x) + z = mul(y, x) + w = sub(z, x) + return y, w + t = div(y, z) + return t From 57edf1ec2de9bf3dc2149e0fe2cdacbf05d87c43 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Mon, 23 Aug 2021 11:52:08 -0700 Subject: [PATCH 12/20] checkpoint for rebase --- include/tvm/relax/expr.h | 78 ++++++++---------- include/tvm/relax/type.h | 43 +++++++--- python/tvm/ir/type.py | 5 +- python/tvm/relax/__init__.py | 4 +- python/tvm/relax/expr.py | 55 ++++++------ python/tvm/relax/parser.py | 75 ++++++----------- python/tvm/relax/ty.py | 17 ++-- python/tvm/relay/expr.py | 6 +- src/ir/type.cc | 4 +- src/relax/expr.cc | 156 ++++++++++++++++------------------- src/relax/type.cc | 25 ++++-- src/relay/ir/expr.cc | 4 +- 12 files changed, 232 insertions(+), 240 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 8cb74a595c..73d14b393a 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -53,8 +53,7 @@ class ShapeExprNode : public ExprNode { } bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { - return equal(values, other->values) && - equal(checked_type_, other->checked_type_) && + return equal(values, other->values) && equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); } @@ -72,15 +71,15 @@ class ShapeExprNode : public ExprNode { class ShapeExpr : public Expr { public: - TVM_DLL ShapeExpr(Array values); + TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode); }; - /*! \brief The variable class for all Relax bindings. */ class VarNode : public ExprNode { public: - /*! \brief The identifier of the variable, is used for comparing stable equality across transformations. */ + /*! \brief The identifier of the variable, is used for comparing stable equality across + * transformations. */ Id vid; /*! \brief The type annotation, used by binding sites and parameter declarations. */ runtime::Optional type_annotation; @@ -97,11 +96,9 @@ class VarNode : public ExprNode { } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { - return equal(vid, other->vid) && - equal(type_annotation, other->type_annotation) && + return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) && // Do we use the analysis information in equality? - equal(checked_type_, other->checked_type_) && - equal(shape_, other->shape_); + equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -120,16 +117,12 @@ class VarNode : public ExprNode { class Var : public Expr { public: - TVM_DLL Var(String name_hint, - runtime::Optional shape_annotation, - runtime::Optional type_annotation, - Span span = Span()) - : Var(Id(name_hint), shape_annotation, type_annotation, span) {} - - TVM_DLL Var(Id vid, - runtime::Optional shape_annotation, - runtime::Optional type_annotation, - Span span = Span()); + TVM_DLL explicit Var(String name_hint, runtime::Optional> shape_annotation, + runtime::Optional type_annotation, Span span = Span()) + : Var(Id(name_hint), shape_annotation, type_annotation, span) {} + + TVM_DLL explicit Var(Id vid, runtime::Optional> shape_annotation, + runtime::Optional type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); }; @@ -147,10 +140,8 @@ class DataflowVarNode : public VarNode { } bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { - return equal(vid, other->vid) && - equal(type_annotation, other->type_annotation) && - equal(shape_, other->shape_) && - equal(checked_type_, other->checked_type_); + return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) && + equal(shape_, other->shape_) && equal(checked_type_, other->checked_type_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -168,15 +159,16 @@ class DataflowVarNode : public VarNode { class DataflowVar : public Var { public: - using Var::Var; // inherit constructors from Var + using Var::Var; // inherit constructors from Var TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); }; - /*! \brief The base class of a variable binding in Relax. */ class BindingNode : public Object { public: - void VisitAttrs(AttrVisitor* v) {} + mutable Span span; + + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; } void SHashReduce(SHashReducer hash_reduce) const {} @@ -188,10 +180,10 @@ class BindingNode : public Object { class Binding : public ObjectRef { public: + TVM_DLL explicit Binding(Span span); TVM_DEFINE_OBJECT_REF_METHODS(Binding, ObjectRef, BindingNode); }; - /*! \brief Symbolic shape match, binds the variables of the LHS with the rhs. */ class MatchShape; class MatchShapeNode : public BindingNode { @@ -202,6 +194,7 @@ class MatchShapeNode : public BindingNode { void VisitAttrs(AttrVisitor* v) { v->Visit("pattern", &pattern); v->Visit("value", &value); + v->Visit("span", &span); } bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const { @@ -221,7 +214,7 @@ class MatchShapeNode : public BindingNode { class MatchShape : public Binding { public: - TVM_DLL MatchShape(Array pattern, Expr value); + TVM_DLL explicit MatchShape(Array pattern, Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode); }; @@ -234,6 +227,7 @@ class VarBindingNode : public BindingNode { void VisitAttrs(AttrVisitor* v) { v->Visit("var", &var); v->Visit("value", &value); + v->Visit("span", &span); } bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { @@ -251,23 +245,28 @@ class VarBindingNode : public BindingNode { class VarBinding : public Binding { public: - TVM_DLL VarBinding(Var var, Expr value); + TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); }; - class BindingBlock; class BindingBlockNode : public Object { public: + mutable Span span; Array bindings; + void VisitAttrs(AttrVisitor* v) { + v->Visit("span", &span); v->Visit("bindings", &bindings); } + bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { return equal(bindings, other->bindings); } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + static constexpr const char* _type_key = "relax.expr.BindingBlock"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -276,21 +275,17 @@ class BindingBlockNode : public Object { class BindingBlock : public ObjectRef { public: - TVM_DLL BindingBlock(Array bindings); + TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); }; - class DataflowBlock; class DataflowBlockNode : public BindingBlockNode { public: - void VisitAttrs(AttrVisitor* v) { - v->Visit("bindings", &bindings); - } bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { return equal(bindings, other->bindings); } - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + static constexpr const char* _type_key = "relax.expr.DataflowBlock"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -299,7 +294,7 @@ class DataflowBlockNode : public BindingBlockNode { class DataflowBlock : public BindingBlock { public: - TVM_DLL DataflowBlock(Array bindings); + TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); }; @@ -340,11 +335,10 @@ class SeqExprNode : public ExprNode { class SeqExpr : public Expr { public: - TVM_DLL SeqExpr(Array blocks, Expr body); + TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); }; - /*! \brief A Relax function, eventually to replace the current Relay function definition. */ class FunctionNode : public BaseFuncNode { public: @@ -372,8 +366,7 @@ class FunctionNode : public BaseFuncNode { bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return equal.DefEqual(params, other->params) && - equal(body, other->body) && + return equal.DefEqual(params, other->params) && equal(body, other->body) && equal(ret_type, other->ret_type) && equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); } @@ -396,8 +389,7 @@ class FunctionNode : public BaseFuncNode { class Function : public Expr { public: - TVM_DLL Function(runtime::Optional name, Array params, - Expr body, Type ret_type); + TVM_DLL explicit Function(runtime::Optional name, Array params, Expr body, Type ret_type, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); }; diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 9967c29a5f..4b5b1640f3 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -39,7 +39,10 @@ namespace relax { class ShapeTypeNode : public TypeNode { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("span", &span); + } bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { return true; @@ -53,16 +56,9 @@ class ShapeTypeNode : public TypeNode { class ShapeType : public Type { public: - explicit ShapeType(); - explicit ShapeType(runtime::ObjectPtr n) : Type(n) {} - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(ShapeType); - const ShapeTypeNode* operator->() const { - return static_cast(data_.get()); - } - const ShapeTypeNode* get() const { - return operator->(); - } - using ContainerType = ShapeTypeNode; + TVM_DLL ShapeType(Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); }; class DynTensorTypeNode : public BaseTensorTypeNode { @@ -108,11 +104,34 @@ class DynTensorType : public Type { * \param shape The shape of the tensor. * \param dtype The runtime dtype of the tensor's elements. */ - TVM_DLL DynTensorType(int rank, DataType dtype); + TVM_DLL DynTensorType(int rank, DataType dtype, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); }; +class DimTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("span", &span); + } + + bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const { + return true; + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.DimType"; + TVM_DECLARE_FINAL_OBJECT_INFO(DimTypeNode, TypeNode); +}; + +class DimType : public Type { + public: + TVM_DLL DimType(Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DimType, Type, DimTypeNode); +}; + } // namespace relax } // namespace tvm #endif // TVM_RELAX_TYPE_H_ diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 4fe28f1d72..3748b7dcec 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -19,6 +19,7 @@ import tvm import tvm._ffi +from . import Span from .base import Node from . import _ffi_api @@ -166,8 +167,8 @@ class TupleType(Type): The fields in the tuple """ - def __init__(self, fields): - self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) + def __init__(self, fields, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.TupleType, fields, span) @tvm._ffi.register_object("TypeConstraint") diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index b81df5c883..bcb24b3b52 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -19,9 +19,9 @@ from . import expr from . import ty from . import vm -from . import op from . import ir_builder from . import op +from . import parser # Expr @@ -48,8 +48,10 @@ extern = expr.extern # Type +Type = ty.Type ShapeType = ty.ShapeType DynTensorType = ty.DynTensorType +DimType = ty.DimType # VM ExecBuilder = exec_builder.ExecBuilder diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 172cf6dee4..bf8158d43e 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -34,8 +34,8 @@ class ShapeExpr(Expr): values: List[PrimExpr] - def __init__(self, values: List[PrimExpr]) -> None: - self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values) + def __init__(self, values: List[PrimExpr], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) def __getitem__(self, index): if index >= len(self): @@ -57,14 +57,16 @@ class Var(Expr): id: Id type_annotation: Optional[Type] - def __init__(self, name_hint: str, - shape_annotation: Optional[Expr] = None, - type_annotation: Optional[Type] = None) -> None: - if shape_annotation is not None: - shape_annotation = make_shape(shape_annotation) - self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, - shape_annotation, - type_annotation) + def __init__( + self, + name_hint: str, + shape_annotation: Optional[List[Type]] = None, + type_annotation: Optional[Type] = None, + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.Var, name_hint, shape_annotation, type_annotation, span + ) @property def name_hint(self): @@ -80,7 +82,8 @@ class DataflowVar(Var): @tvm._ffi.register_object("relax.expr.Binding") class Binding(Node): - pass + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.Binding, span) @tvm._ffi.register_object("relax.expr.MatchShape") @@ -88,8 +91,8 @@ class MatchShape(Binding): pattern: List[PrimExpr] value: Expr - def __init__(self, pattern: List[PrimExpr], value: Expr) -> None: - self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value) + def __init__(self, pattern: List[PrimExpr], value: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value, span) @tvm._ffi.register_object("relax.expr.VarBinding") @@ -97,16 +100,16 @@ class VarBinding(Binding): var: Var value: Expr - def __init__(self, var: Var, value: Expr) -> None: - self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value) + def __init__(self, var: Var, value: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value, span) @tvm._ffi.register_object("relax.expr.BindingBlock") class BindingBlock(Node): bindings: List[Binding] - def __init__(self, bindings: List[Binding]) -> None: - self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings) + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings, span) @tvm._ffi.register_object("relax.expr.DataflowBlock") @@ -119,8 +122,8 @@ class SeqExpr(Expr): blocks: List[BindingBlock] body: Expr - def __init__(self, blocks: List[BindingBlock], body: Expr) -> None: - self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body) + def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body, span) @tvm._ffi.register_object("relax.expr.Function") @@ -130,10 +133,15 @@ class Function(BaseFunc): body: Expr ret_type: Type - def __init__(self, params: List[Var], body: Expr, - ret_type: Type, name: Optional[GlobalVar] = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.Function, name, params, - body, ret_type) + def __init__( + self, + params: List[Var], + body: Expr, + ret_type: Type, + name: Optional[GlobalVar] = None, + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__(_ffi_api.Function, name, params, body, ret_type, span) @tvm._ffi.register_object("relax.expr.ExternFunc") @@ -143,5 +151,6 @@ class ExternFunc(BaseFunc): def __init__(self, global_symbol: String) -> None: self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol) + def extern(name): return ExternFunc(name) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 9bce251447..f4b59ed3ee 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import TypeVar, Generic, Union, Dict, List, Tuple +from typing import TypeVar, Generic, Union, Dict, List, Tuple, Optional from io import StringIO import tvm @@ -17,38 +17,10 @@ from synr.diagnostic_context import DiagnosticContext from tvm.relay.op.strategy.generic import conv1d_strategy -from tvm.relax.expr import * -from tvm.relax.ty import * +import tvm.relay as relay +import tvm.relax as rx -# TODO: make this better -# relax_scope = [] # A stack of dictionaries representing the scope -# var_table = {} - -# A node that will desugar into a different AST node in a subsequent pass -# class rxFrontendNode: -# def __init__(self, span): -# self.span = span - - -# Allows arbitrary exprs on the left and the right -# Desugars into two rxMatchShapeBinding -# TODO: might be worth not parsing this into its own node.. -# class rxFrontendMatchShapeExprs(rxFrontendNode): -# def __init__(self, lhs, rhs, span): -# self.lhs = lhs -# self.rhs = rhs -# super().__init__(span) - -# class rxShapeTuple(rxExpr): -# def __init__(self, dims, span): -# self.dims = dims -# super().__init__(span) - - -# TODO: What is this doing? -# expr.Function.__str__ = print_fn # type: ignore - # TODO: Replace with a real pretty print method once we have the real AST def pretty_print(f): print(f) @@ -102,10 +74,10 @@ def tvm_span(self, span: synr.Span) -> tvm.ir.Span: def decl_var( self, name: str, - type_annotation: Optional[rxType], - shape: Optional[rxExpr], + type_annotation: Optional[rx.Type], + shape: Optional[rx.Expr], span: tvm.ir.Span, - ) -> rxVar: + ) -> Var: """Introduces a variable with the given name and annotations to the current scope. Parameters @@ -129,11 +101,11 @@ def decl_var( "error", "variable has already been declared in the current scope", span ) self._diagnostic_context.render() - var = rxVar(name, type_annotation, shape, span) + var = rx.Var(name, type_annotation, shape, span) self.scope[name] = var return var - def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rxType, rxShape]: + def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.Expr]: """Transforms the given synr type annotation to a Relax type and shape expression. Parameters @@ -156,11 +128,11 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rxType, rxSha # simple annotation with no type arguments if isinstance(ty, ast.TypeVar): if ty.id.name == "Tensor": - return (rxTensor(None, span), None) + return (rx.DynTensorType(rank=-1, dtype=None, span=span), None) elif ty.id.name == "Shape": - return (rxShape(span), None) + return (rx.ShapeType(span), None) elif ty.id.name == "Dim": - return (rxDim(span), None) + return (rx.DimType(span), None) else: self._diagnostic_context.emit("error", "unknown type in annotation", span) self._diagnostic_context.render() @@ -215,7 +187,8 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rxType, rxSha ) self._diagnostic_context.render() - return (rxTensor(dtype, span), shape) + rank = len(shape) if shape is not None else -1 + return (rx.DynTensorType(rank=rank, dtype=dtype, span=span), shape) elif ty.id.name == "Tuple": field_types = [] field_shapes = [] @@ -223,7 +196,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rxType, rxSha fty, fsh = self.transform_type(field, allow_intro=False) field_types.append(fty) field_shapes.append(fsh) - return rxTupleType(field_types, self.tvm_span(ty.span)), field_shapes + return relay.TupleType(field_types, self.tvm_span(ty.span)), field_shapes # TODO: other types with args, e.g. Ref[T], func types self._diagnostic_context.emit("error", "invalid type", span) self._diagnostic_context.render() @@ -346,7 +319,7 @@ def transform_module(self, mod: ast.Module) -> IRModule: self.module[func_name] = self.transform_function(func) return self.module - def transform_function(self, func: ast.Function) -> rxFunction: + def transform_function(self, func: ast.Function) -> rx.Function: with self.new_scope(): params = [] for param in func.params: @@ -355,7 +328,7 @@ def transform_function(self, func: ast.Function) -> rxFunction: params.append(param) new_body = self.transform_block(func.body) ret_type, _ = self.transform_type(func.ret_type, allow_intro=False) - return rxFunction(func.name, params, new_body, ret_type, self.tvm_span(func.span)) + return rx.Function(func.name, params, new_body, ret_type, self.tvm_span(func.span)) # Stmts: # - Assert: probably unsupported for now @@ -365,7 +338,7 @@ def transform_function(self, func: ast.Function) -> rxFunction: # - Return: just the returned expression, must terminate blocks? (special case if-else) # - UnassignedCall: match_shape # - With: rx.dataflow - def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: + def transform_stmt(self, stmt: ast.Stmt) -> rx.Expr: if isinstance(stmt, ast.Assign): if not isinstance(stmt.lhs, ast.Var): self._diagnostic_context.emit( @@ -376,13 +349,13 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: self._diagnostic_context.render() # TODO: figure out proper way of doing this rhs = self.transform_expr(stmt.rhs) - if isinstance(rhs, rxCall) and rhs.op == "rx.call_packed": + if isinstance(rhs, relay.Call) and rhs.op == tvm.ir.Op.get("rx.call_packed"): allow_intro = True else: allow_intro = False ty, shape = self.transform_type(stmt.ty, allow_intro) lhs = self.decl_var(stmt.lhs.id.name, ty, shape, self.tvm_span(stmt.lhs.span)) - return rxVarBinding(lhs, rhs, self.tvm_span(stmt.span)) + return rx.VarBinding(lhs, rhs, self.tvm_span(stmt.span)) elif isinstance(stmt, ast.If): # TODO: proper diagnostics @@ -418,9 +391,9 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: with self.new_scope(): false_branch = self.transform_block(false_block) # TODO: the spans here are all sorts of messed up, not sure how to fix - ite_expr = rxIfThenElse(cond, true_branch, false_branch, self.tvm_span(stmt.span)) + ite_expr = relay.If(cond, true_branch, false_branch, self.tvm_span(stmt.span)) var = self.decl_var(var_name, None, None, self.tvm_span(false_assign.span)) - return rxVarBinding(var, ite_expr, self.tvm_span(stmt.span)) + return rx.VarBinding(var, ite_expr, self.tvm_span(stmt.span)) elif isinstance(stmt, ast.Return): return self.transform_expr(stmt.value) @@ -431,9 +404,9 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: op = self.transform_expr(call.func_name) # FIXME: this check is unreachable since transform_expr tries looking up func_name as a # variable and fails - if op != "rx.match_shape": + if op != tvm.ir.Op.get("rx.match_shape"): self._diagnostic_context.emit( - "error", "the results of operator calls must be bound", self.tvm_span(stmt.span) + "error", "the results of calls must be bound or used", self.tvm_span(stmt.span) ) self._diagnostic_context.render() if len(stmt.call.params) != 2: @@ -454,7 +427,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> rxExpr: ) self._diagnostic_context.render() lhs_expr = self.parse_shape(lhs, allow_intro=True) - return rxMatchShape(lhs_expr, rhs_expr, self.tvm_span(stmt.span)) + return rx.MatchShape(lhs_expr, rhs_expr, self.tvm_span(stmt.span)) elif isinstance(stmt, ast.With): if not isinstance(stmt.rhs, ast.Call): diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index 0c34d2797d..e8f96fc935 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -17,15 +17,15 @@ # pylint: disable=invalid-name, unused-import """The type nodes of the Relax language.""" import tvm._ffi -from tvm.ir import Type, TensorType +from tvm.ir import Type, TensorType, Span from . import _ffi_api @tvm._ffi.register_object("relax.ShapeType") class ShapeType(Type): - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.ShapeType) + def __init__(self, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.ShapeType, span) @tvm._ffi.register_object("relax.DynTensorType") @@ -43,5 +43,12 @@ class DynTensorType(TensorType): The content data type. """ - def __init__(self, rank=-1, dtype="float32"): - self.__init_handle_by_constructor__(_ffi_api.DynTensorType, rank, dtype) + def __init__(self, rank=-1, dtype="float32", span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.DynTensorType, rank, dtype, span) + + +@tvm._ffi.register_object("relax.DimType") +class DimType(Type): + """The type of indices/shape dimensions in Relax.""" + def __init__(self, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.DimType, span) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 8461885b38..61fb7623ea 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -23,7 +23,7 @@ import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar, Node +from tvm.ir import RelayExpr, GlobalVar, Node, Span from .base import RelayNode from . import _ffi_api @@ -301,8 +301,8 @@ class If(ExprWithOp): The expression evaluated when condition is false. """ - def __init__(self, cond, true_branch, false_branch): - self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch) + def __init__(self, cond, true_branch, false_branch, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span) @tvm._ffi.register_object("relay.TupleGetItem") diff --git a/src/ir/type.cc b/src/ir/type.cc index fe8e00329b..86dda2a274 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -144,8 +144,8 @@ TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { - return TupleType(fields); +TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields, Span span) { + return TupleType(fields, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/relax/expr.cc b/src/relax/expr.cc index 42f84adf6c..83892a0296 100644 --- a/src/relax/expr.cc +++ b/src/relax/expr.cc @@ -1,21 +1,21 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ #include namespace tvm { @@ -29,8 +29,7 @@ RelayExpr RelayExprNode::shape() const { return relay::Call(op, {self}, {}, {}); } -TVM_REGISTER_GLOBAL("ir.RelayExprShape") -.set_body_typed([](RelayExpr expr) { +TVM_REGISTER_GLOBAL("ir.RelayExprShape").set_body_typed([](RelayExpr expr) { return expr->shape(); }); @@ -40,24 +39,20 @@ using tvm::runtime::Optional; TVM_REGISTER_NODE_TYPE(ShapeExprNode); -ShapeExpr::ShapeExpr(Array values) { +ShapeExpr::ShapeExpr(Array values, Span span) { ObjectPtr n = make_object(); n->values = std::move(values); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeExpr") -.set_body_typed([](Array values) { - return ShapeExpr(values); +TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { + return ShapeExpr(values, span); }); - TVM_REGISTER_NODE_TYPE(VarNode); -Var::Var(Id vid, - Optional shape_annotation, - Optional type_annotation, - Span span) { +Var::Var(Id vid, Optional shape_annotation, Optional type_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); n->shape_ = std::move(shape_annotation); @@ -70,128 +65,116 @@ Var::Var(Id vid, } TVM_REGISTER_GLOBAL("relax.Var") -.set_body_typed([](String name_hint, - Optional shape_annotation, - Optional type_annotation) { - return Var(name_hint, shape_annotation, type_annotation); -}); - + .set_body_typed([](String name_hint, Optional shape_annotation, + Optional type_annotation, Span span) { + return Var(name_hint, shape_annotation, type_annotation, span); + }); TVM_REGISTER_NODE_TYPE(DataflowVarNode); TVM_REGISTER_GLOBAL("relax.DataflowVar") -.set_body_typed([](String name_hint, - Optional shape_annotation, - Optional type_annotation) { - return DataflowVar(name_hint, shape_annotation, type_annotation); -}); - + .set_body_typed([](String name_hint, Optional shape_annotation, + Optional type_annotation, Span span) { + return DataflowVar(name_hint, shape_annotation, type_annotation, span); + }); + +Binding::Binding(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} TVM_REGISTER_NODE_TYPE(BindingNode); -TVM_REGISTER_GLOBAL("relax.Binding") -.set_body_typed([]() { - return Binding(); -}); - +TVM_REGISTER_GLOBAL("relax.Binding").set_body_typed([](Span span) { return Binding(span); }); TVM_REGISTER_NODE_TYPE(MatchShapeNode); -MatchShape::MatchShape(Array pattern, - Expr value) { +MatchShape::MatchShape(Array pattern, Expr value, Span span) { ObjectPtr n = make_object(); n->pattern = std::move(pattern); n->value = std::move(value); + n->span = span; data_ = std::move(n); } TVM_REGISTER_GLOBAL("relax.MatchShape") -.set_body_typed([](Array pattern, Expr value) { - return MatchShape(pattern, value); -}); - + .set_body_typed([](Array pattern, Expr value, Span span) { + return MatchShape(pattern, value, span); + }); TVM_REGISTER_NODE_TYPE(VarBindingNode); -VarBinding::VarBinding(Var var, - Expr value) { +VarBinding::VarBinding(Var var, Expr value, Span span) { ObjectPtr n = make_object(); n->var = std::move(var); n->value = std::move(value); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.VarBinding") -.set_body_typed([](Var var,Expr value) { - return VarBinding(var,value); +TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { + return VarBinding(var, value, span); }); - TVM_REGISTER_NODE_TYPE(BindingBlockNode); -BindingBlock::BindingBlock(Array bindings) { +BindingBlock::BindingBlock(Array bindings, Span span) { ObjectPtr n = make_object(); n->bindings = std::move(bindings); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.BindingBlock") -.set_body_typed([](Array bindings) { - return BindingBlock(bindings); +TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { + return BindingBlock(bindings, span); }); - TVM_REGISTER_NODE_TYPE(DataflowBlockNode); -DataflowBlock::DataflowBlock(Array bindings) { +DataflowBlock::DataflowBlock(Array bindings, Span span) { ObjectPtr n = make_object(); n->bindings = std::move(bindings); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.DataflowBlock") -.set_body_typed([](Array bindings) { - return DataflowBlock(bindings); +TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bindings, Span span) { + return DataflowBlock(bindings, span); }); - TVM_REGISTER_NODE_TYPE(SeqExprNode); -SeqExpr::SeqExpr(Array blocks, - Expr body) { +SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { ObjectPtr n = make_object(); n->blocks = std::move(blocks); n->body = std::move(body); + n->span = span; data_ = std::move(n); } TVM_REGISTER_GLOBAL("relax.SeqExpr") -.set_body_typed([](Array blocks, Expr body) { - return SeqExpr(blocks, body); -}); - + .set_body_typed([](Array blocks, Expr body, Span span) { + return SeqExpr(blocks, body, span); + }); TVM_REGISTER_NODE_TYPE(FunctionNode); -Function::Function(runtime::Optional name, - Array params, - Expr body, - Type ret_type) { +Function::Function(runtime::Optional name, Array params, Expr body, Type ret_type, + Span span) { ObjectPtr n = make_object(); n->name = std::move(name); n->params = std::move(params); n->body = std::move(body); n->ret_type = std::move(ret_type); + n->span = span; data_ = std::move(n); } TVM_REGISTER_GLOBAL("relax.Function") -.set_body_typed([](runtime::Optional name, - Array params, - Expr body, - Type ret_type) { - return Function(name, params, body, ret_type); -}); + .set_body_typed([](runtime::Optional name, Array params, Expr body, + Type ret_type, + Span span) { return Function(name, params, body, ret_type, span); }); TVM_REGISTER_NODE_TYPE(ExternFuncNode); @@ -201,10 +184,9 @@ ExternFunc::ExternFunc(String global_symbol) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ExternFunc") -.set_body_typed([](String global_symbol) { +TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol) { return ExternFunc(global_symbol); }); -} // namespace relax -} // namespace tvm +} // namespace relax +} // namespace tvm diff --git a/src/relax/type.cc b/src/relax/type.cc index 498d6082de..9c398c23ea 100644 --- a/src/relax/type.cc +++ b/src/relax/type.cc @@ -29,30 +29,37 @@ namespace relax { TVM_REGISTER_NODE_TYPE(ShapeTypeNode); -ShapeType::ShapeType() { +ShapeType::ShapeType(Span span) { ObjectPtr n = make_object(); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeType") -.set_body_typed([]() { - return ShapeType(); -}); +TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](Span span) { return ShapeType(span); }); -DynTensorType::DynTensorType(int rank, DataType dtype) { +DynTensorType::DynTensorType(int rank, DataType dtype, Span span) { ObjectPtr n = make_object(); n->rank = std::move(rank); n->dtype = std::move(dtype); + n->span = span; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DynTensorTypeNode); -TVM_REGISTER_GLOBAL("relax.DynTensorType") -.set_body_typed([](int rank, DataType dtype) { - return DynTensorType(rank, dtype); +TVM_REGISTER_GLOBAL("relax.DynTensorType").set_body_typed([](int rank, DataType dtype, Span span) { + return DynTensorType(rank, dtype, span); }); +DimType::DimType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DimTypeNode); + +TVM_REGISTER_GLOBAL("relax.DimType").set_body_typed([](Span span) { return DimType(span); }); } // namespace relax } // namespace tvm diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3b3c8797d7..854d413c53 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -167,8 +167,8 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") - .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { - return If(cond, true_branch, false_branch); + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { + return If(cond, true_branch, false_branch, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) From 06d12106daaae7575fa53beaf593fb6e000cad29 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 24 Aug 2021 15:10:18 -0700 Subject: [PATCH 13/20] hook up AST --- include/tvm/relax/expr.h | 21 ++-- include/tvm/relay/expr.h | 1 + python/tvm/relax/__init__.py | 5 + python/tvm/relax/expr.py | 14 ++- python/tvm/relax/op/base.py | 7 +- python/tvm/relax/parser.py | 236 ++++++++++++++++++----------------- src/relax/expr.cc | 7 +- src/relax/op/op.cc | 56 +++++---- tests/python/relax/parser.py | 165 ++++++++++++------------ 9 files changed, 271 insertions(+), 241 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 73d14b393a..c5349db25c 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -33,8 +33,8 @@ namespace relax { using Expr = RelayExpr; using ExprNode = RelayExprNode; -using relay::Id; using relay::Call; +using relay::Id; using relay::Tuple; using relay::TupleGetItem; @@ -117,12 +117,12 @@ class VarNode : public ExprNode { class Var : public Expr { public: - TVM_DLL explicit Var(String name_hint, runtime::Optional> shape_annotation, - runtime::Optional type_annotation, Span span = Span()) + TVM_DLL explicit Var(String name_hint, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()) : Var(Id(name_hint), shape_annotation, type_annotation, span) {} - TVM_DLL explicit Var(Id vid, runtime::Optional> shape_annotation, - runtime::Optional type_annotation, Span span = Span()); + TVM_DLL explicit Var(Id vid, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); }; @@ -389,11 +389,11 @@ class FunctionNode : public BaseFuncNode { class Function : public Expr { public: - TVM_DLL explicit Function(runtime::Optional name, Array params, Expr body, Type ret_type, Span span = Span()); + TVM_DLL explicit Function(runtime::Optional name, Array params, Expr body, + Type ret_type, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); }; - /*! \brief The extern function, which can represent packed function. */ class ExternFuncNode : public BaseFuncNode { public: @@ -402,15 +402,14 @@ class ExternFuncNode : public BaseFuncNode { void VisitAttrs(AttrVisitor* v) { v->Visit("global_symbol", &global_symbol); + v->Visit("span", &span); } bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { return equal(global_symbol, other->global_symbol); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(global_symbol); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(global_symbol); } static constexpr const char* _type_key = "relax.expr.ExternFunc"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -420,7 +419,7 @@ class ExternFuncNode : public BaseFuncNode { class ExternFunc : public Expr { public: - TVM_DLL ExternFunc(String global_symbol); + TVM_DLL ExternFunc(String global_symbol, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode); }; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index daad8514f9..50492be6d6 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -109,6 +109,7 @@ class TupleNode : public ExprNode { v->Visit("fields", &fields); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("shape_", &shape_); } bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index bcb24b3b52..e3b3c61c5e 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -42,6 +42,8 @@ Tuple = expr.Tuple Function = expr.Function ExternFunc = expr.ExternFunc +Call = expr.Call +If = expr.If # helper functions const = expr.const @@ -63,3 +65,6 @@ # IRBuilder IRBuilder = ir_builder.IRBuilder + +# Parser +from .parser import script diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index bf8158d43e..2aeb8e330f 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -27,6 +27,7 @@ Type = relay.Type GlobalVar = relay.GlobalVar Call = relay.Call +If = relay.If const = relay.const @@ -45,6 +46,7 @@ def __getitem__(self, index): def __len__(self): return len(self.values) + def make_shape(shape: List[PrimExpr]) -> ShapeExpr: if isinstance(shape, (list, tuple)): return ShapeExpr(shape) @@ -54,13 +56,13 @@ def make_shape(shape: List[PrimExpr]) -> ShapeExpr: @tvm._ffi.register_object("relax.expr.Var") class Var(Expr): - id: Id + vid: Id type_annotation: Optional[Type] def __init__( self, name_hint: str, - shape_annotation: Optional[List[Type]] = None, + shape_annotation: Optional[Expr] = None, type_annotation: Optional[Type] = None, span: Span = None, ) -> None: @@ -148,9 +150,9 @@ def __init__( class ExternFunc(BaseFunc): global_symbol: String - def __init__(self, global_symbol: String) -> None: - self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol) + def __init__(self, global_symbol: String, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol, span) -def extern(name): - return ExternFunc(name) +def extern(name, span: Span = None): + return ExternFunc(name, span) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 767dbdea15..2fa88cb954 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -18,9 +18,10 @@ from . import _ffi_api from typing import Union, List -def call_dps(shape: Union[ShapeExpr, List[int]], - func: Expr, - args: Union[Tuple, List[Expr]]) -> Call: + +def call_dps( + shape: Union[ShapeExpr, List[int]], func: Expr, args: Union[Tuple, List[Expr]] +) -> Call: """ Call a destination-passing-style function and return the output. diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index f4b59ed3ee..11ce99cf56 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -26,6 +26,12 @@ def pretty_print(f): print(f) +def is_registered(op_name, op_set=None): + if op_set is None: + op_set = tvm.ir._ffi_api.ListOpNames() + return op_name in op_set + + class RelaxTransformer(Transformer): def __init__(self, definition_scope, diag_ctx): super().__init__() @@ -33,6 +39,7 @@ def __init__(self, definition_scope, diag_ctx): self.diag_ctx = diag_ctx self.module = {} self._scopes = [{}] # str -> Var + self._registered_ops = set(tvm.ir._ffi_api.ListOpNames()) # cached def new_scope(self): class _Scope: @@ -52,7 +59,7 @@ def __exit__(self, *exc): def scope(self): return self._scopes[-1] - def tvm_span(self, span: synr.Span) -> tvm.ir.Span: + def tvm_span(self, span: ast.Span) -> tvm.ir.Span: """Converts the synr span to a TVM span Parameters @@ -77,7 +84,8 @@ def decl_var( type_annotation: Optional[rx.Type], shape: Optional[rx.Expr], span: tvm.ir.Span, - ) -> Var: + is_dataflow: bool = False, + ) -> rx.Var: """Introduces a variable with the given name and annotations to the current scope. Parameters @@ -101,7 +109,10 @@ def decl_var( "error", "variable has already been declared in the current scope", span ) self._diagnostic_context.render() - var = rx.Var(name, type_annotation, shape, span) + if is_dataflow: + var = rx.DataflowVar(name, shape, type_annotation, span) + else: + var = rx.Var(name, shape, type_annotation, span) self.scope[name] = var return var @@ -165,7 +176,10 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E # FIXME: use a special node for unknown shape vs no shape? pass # shape = None elif isinstance(shape_annotation, ast.TypeTuple): - shape = self.parse_shape(shape_annotation, allow_intro) + shape = rx.ShapeExpr( + self.parse_shape(shape_annotation, allow_intro), + span=self.tvm_span(shape_annotation.span), + ) else: self._diagnostic_context.emit( "error", @@ -196,13 +210,15 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E fty, fsh = self.transform_type(field, allow_intro=False) field_types.append(fty) field_shapes.append(fsh) - return relay.TupleType(field_types, self.tvm_span(ty.span)), field_shapes + return (relay.TupleType(field_types, self.tvm_span(ty.span)), None) # TODO: other types with args, e.g. Ref[T], func types self._diagnostic_context.emit("error", "invalid type", span) self._diagnostic_context.render() def parse_shape( - self, shape_annotation: Union[ast.TypeTuple, ast.Tuple], allow_intro: bool + self, + shape_annotation: Union[ast.TypeTuple, ast.Tuple], + allow_intro: bool, ) -> List[tir.PrimExpr]: """Parses the given shape annotation to a list of PrimExprs @@ -316,10 +332,10 @@ def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: def transform_module(self, mod: ast.Module) -> IRModule: for func_name in mod.funcs: func = mod.funcs[func_name] - self.module[func_name] = self.transform_function(func) + self.module[func_name] = self.transform_function(func, is_global=True) return self.module - def transform_function(self, func: ast.Function) -> rx.Function: + def transform_function(self, func: ast.Function, is_global=False) -> rx.Function: with self.new_scope(): params = [] for param in func.params: @@ -328,7 +344,31 @@ def transform_function(self, func: ast.Function) -> rx.Function: params.append(param) new_body = self.transform_block(func.body) ret_type, _ = self.transform_type(func.ret_type, allow_intro=False) - return rx.Function(func.name, params, new_body, ret_type, self.tvm_span(func.span)) + + func_name = rx.GlobalVar(func.name) if is_global else None + return rx.Function( + params, new_body, ret_type, name=func_name, span=self.tvm_span(func.span) + ) + + def parse_binding(self, stmt: ast.Assign, is_dataflow=False): + if not isinstance(stmt.lhs, ast.Var): + self._diagnostic_context.emit( + "error", + "the left hand side of a binding must be a variable", + self.tvm_span(stmt.lhs.span), + ) + self._diagnostic_context.render() + # TODO: figure out proper way of doing this + rhs = self.transform_expr(stmt.rhs) + if isinstance(rhs, relay.Call) and rhs.op == relay.op.get("rx.call_packed"): + allow_intro = True + else: + allow_intro = False + ty, shape = self.transform_type(stmt.ty, allow_intro) + lhs = self.decl_var( + stmt.lhs.id.name, ty, shape, self.tvm_span(stmt.lhs.span), is_dataflow=is_dataflow + ) + return rx.VarBinding(lhs, rhs, self.tvm_span(stmt.span)) # Stmts: # - Assert: probably unsupported for now @@ -338,24 +378,10 @@ def transform_function(self, func: ast.Function) -> rx.Function: # - Return: just the returned expression, must terminate blocks? (special case if-else) # - UnassignedCall: match_shape # - With: rx.dataflow - def transform_stmt(self, stmt: ast.Stmt) -> rx.Expr: + def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.DataflowBlock]: if isinstance(stmt, ast.Assign): - if not isinstance(stmt.lhs, ast.Var): - self._diagnostic_context.emit( - "error", - "the left hand side of a binding must be a variable", - self.tvm_span(stmt.lhs.span), - ) - self._diagnostic_context.render() - # TODO: figure out proper way of doing this - rhs = self.transform_expr(stmt.rhs) - if isinstance(rhs, relay.Call) and rhs.op == tvm.ir.Op.get("rx.call_packed"): - allow_intro = True - else: - allow_intro = False - ty, shape = self.transform_type(stmt.ty, allow_intro) - lhs = self.decl_var(stmt.lhs.id.name, ty, shape, self.tvm_span(stmt.lhs.span)) - return rx.VarBinding(lhs, rhs, self.tvm_span(stmt.span)) + # dataflow bindings are handled separately in parse_dataflow + return self.parse_binding(stmt) elif isinstance(stmt, ast.If): # TODO: proper diagnostics @@ -404,7 +430,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> rx.Expr: op = self.transform_expr(call.func_name) # FIXME: this check is unreachable since transform_expr tries looking up func_name as a # variable and fails - if op != tvm.ir.Op.get("rx.match_shape"): + if op != relay.op.get("rx.match_shape"): self._diagnostic_context.emit( "error", "the results of calls must be bound or used", self.tvm_span(stmt.span) ) @@ -441,7 +467,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> rx.Expr: # TODO: perhaps this ought to be more general - if op != "rx.dataflow": + if op != relay.op.get("rx.dataflow"): self._diagnostic_context.emit( "error", "unsupported with block type", self.tvm_span(call.span) ) @@ -455,7 +481,9 @@ def transform_stmt(self, stmt: ast.Stmt) -> rx.Expr: self._diagnostic_context.render() if len(stmt.lhs) > 0: self._diagnostic_context.emit( - "error", "dataflow blocks don't bind any patterns", self.tvm_span(stmt.lhs[0].span) + "error", + "dataflow blocks don't bind any patterns", + self.tvm_span(stmt.lhs[0].span), ) self._diagnostic_context.render() @@ -463,8 +491,8 @@ def transform_stmt(self, stmt: ast.Stmt) -> rx.Expr: elif isinstance(stmt, ast.Function): func = self.transform_function(stmt) - func_var = self.decl_var(func.name, None, None, self.tvm_span(stmt.span)) - return rxVarBinding(func_var, func, self.tvm_span(stmt.span)) + func_var = self.decl_var(stmt.name, None, None, self.tvm_span(stmt.span)) + return rx.VarBinding(func_var, func, self.tvm_span(stmt.span)) else: self._diagnostic_context.emit( @@ -474,22 +502,13 @@ def transform_stmt(self, stmt: ast.Stmt) -> rx.Expr: ) self._diagnostic_context.render() - def parse_dataflow(self, block: ast.Block): + def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: assert len(block.stmts) > 0, "should never have an empty dataflow block" bindings = [] + output_vars = [] with self.new_scope(): - for binding_stmt in block.stmts[:-1]: - if not isinstance(binding_stmt, ast.Assign): - self._diagnostic_context.emit( - "error", - "only bindings are supported in dataflow blocks", - self.tvm_span(binding_stmt.span), - ) - self._diagnostic_context.render() - binding = self.transform_stmt(binding_stmt) - bindings.append(binding) - + # parse the return statement first to figure out which bindings assign normal Vars output_stmt = block.stmts[-1] if not isinstance(output_stmt, ast.Return): self._diagnostic_context.emit( @@ -503,21 +522,37 @@ def parse_dataflow(self, block: ast.Block): if isinstance(ret_val, ast.Var): ret_val = ast.Tuple(values=[ret_val], span=ret_val.span) - if not isinstance(ret_val, ast.Tuple) or any([not isinstance(f, ast.Var) for f in ret_val.values]): + if not isinstance(ret_val, ast.Tuple) or any( + [not isinstance(f, ast.Var) for f in ret_val.values] + ): self._diagnostic_context.emit( "error", "the returned values must be variables", self.tvm_span(ret_val.span), ) - ret_vars = [self.transform_expr(v) for v in ret_val.values] + # output variables are bound to normal (not data flow) Vars + output_var_names = {var.id.name for var in ret_val.values} - # parent scope - for v in ret_vars: - self.scope[v.id] = v + for binding_stmt in block.stmts[:-1]: + if not isinstance(binding_stmt, ast.Assign): + self._diagnostic_context.emit( + "error", + "only bindings are supported in dataflow blocks", + self.tvm_span(binding_stmt.span), + ) + self._diagnostic_context.render() + is_dataflow = binding_stmt.lhs.id.name not in output_var_names + binding = self.parse_binding(binding_stmt, is_dataflow=is_dataflow) + bindings.append(binding) + if not is_dataflow: + output_vars.append(binding.var) - return rxDataflowBlock(bindings, ret_vars, self.tvm_span(block.span)) + # make output variables visible in parent scope + for v in output_vars: + self.scope[v.name_hint] = v + return rx.DataflowBlock(bindings, self.tvm_span(block.span)) # Exprs: # - ArrayLiteral: unsupported for now? @@ -528,100 +563,70 @@ def parse_dataflow(self, block: ast.Block): # - Slice: unsupported for now, could desugar to slice op # - Tuple # - Var - def transform_expr(self, expr: ast.Expr) -> rxExpr: + def transform_expr(self, expr: ast.Expr) -> rx.Expr: if isinstance(expr, ast.Attr): - obj = self.transform_expr(expr.object) - field_name = expr.field.name - # TODO: use some kind of proper identifier? str bad - if isinstance(obj, str): - return obj + "." + field_name - elif field_name == "shape": - return rxCall("rx.shape_of", [obj], self.tvm_span(expr.span)) + if expr.field.name == "shape": + obj = self.transform_expr(expr.object) + return relay.op.shape_of(obj) else: - self._diagnostic_context.emit( - "error", "unsupported attribute", self.tvm_span(expr.span) - ) - self._diagnostic_context.render() + # assume it's a hierarchical op identifier (e.g. nn.softmax, rx.call_dps) + op_name = [] + attr = expr + while isinstance(attr, ast.Attr): + op_name.append(expr.field.name) + attr = attr.object + if not isinstance(attr, ast.Var): + self._diagnostic_context.emit( + "error", "unsupported field access", self.tvm_span(expr) + ) + self._diagnostic_context.render() + op_name.append(attr.id.name) + op_name = ".".join(reversed(op_name)) + return relay.op.get(op_name) # TODO: maybe diagnostics here in case this fails? + if isinstance(expr, ast.Call): - op = expr.func_name - if isinstance(op, ast.Var): - args = [] - for arg in expr.params: - args.append(self.transform_expr(arg)) - if op.id.name in self.scope: - op = self.transform_expr(op) - else: - # TODO: fix - op = op.id.name - return rxCall(op, args, self.tvm_span(expr.span)) - # if exp.func_name.id.name in self.str_to_var: - # return self.str_to_var[exp.func_name.id.name] - # else: - # name = exp.func_name.id.name - # relax_fn = getattr(self.definition_scope, name, None) - # # builtin operator - # if relax_fn is None: - # return rxCall(rxGetBuiltin(name), params, None) - # else: - # self.module[name] = relax_fn.module[name] - # # todo: globalvar equality? use global str -> id map? - # ident = Id(exp.func_name.id.name) - # return rxCall(rxGlobalVar(ident, None, None), params, None) - elif isinstance(op, ast.Op): - assert False, "TODO: sugar for python built in operators" - # if exp.func_name.name == ast.BuiltinOp.Subscript: - # tensor = self.transform_expr(exp.params[0]) - # indicies = [] - # for index in exp.params[1].values: - # indicies.append(self.transform_expr(index)) - # # TODO: Replace with relax node - # return rxTensorSlice(tensor, indicies, None) - # elif exp.func_name.name == ast.BuiltinOp.Add: - # params = [] - # for arg in exp.params: - # params.append(self.transform_expr(arg)) - # # TODO: Replace with relax node - # return rxCall("add", [params[0], params[1]], None) - else: - self._diagnostic_context.emit( - "error", "unsupported function", self.tvm_span(expr.span) - ) - self._diagnostic_context.render() + op = self.transform_expr(expr.func_name) + args = [self.transform_expr(arg) for arg in expr.params] + return relay.Call(op, args, span=self.tvm_span(expr.span)) + elif isinstance(expr, ast.Tuple): fields = [self.transform_expr(field) for field in expr.values] - return rxTuple(fields, self.tvm_span(expr.span)) + return relay.Tuple(fields, span=self.tvm_span(expr.span)) + elif isinstance(expr, ast.Var): var_name = expr.id.name - if var_name == "rx": - return "rx" + if is_registered(var_name, op_set=self._registered_ops): + return relay.op.get(var_name) if var_name not in self.scope: self._diagnostic_context.emit( "error", "undefined variable", self.tvm_span(expr.span) ) self._diagnostic_context.render() return self.scope[var_name] + else: self._diagnostic_context.emit( "error", "unsupported expression", self.tvm_span(expr.span) ) self._diagnostic_context.render() - def transform_block(self, block: ast.Block) -> rxSeqExpr: + def transform_block(self, block: ast.Block) -> rx.SeqExpr: # a block of statements needs to be converted to a SeqExpr of binding blocks blocks = [] current_block = [] for stmt in block.stmts[:-1]: parsed_stmt = self.transform_stmt(stmt) - if isinstance(parsed_stmt, rxDataflowBlock): + if isinstance(parsed_stmt, rx.DataflowBlock): if current_block: - blocks.append(current_block) + # FIXME: span + blocks.append(rx.BindingBlock(current_block, self.tvm_span(stmt.span))) current_block = [] blocks.append(parsed_stmt) else: - assert isinstance(parsed_stmt, rxBinding) + assert isinstance(parsed_stmt, rx.Binding) current_block.append(parsed_stmt) if len(current_block) > 0: - blocks.append(current_block) + blocks.append(rx.BindingBlock(current_block, self.tvm_span(stmt.span))) ret_stmt = block.stmts[-1] if not isinstance(ret_stmt, ast.Return): @@ -633,10 +638,7 @@ def transform_block(self, block: ast.Block) -> rxSeqExpr: self._diagnostic_context.render() ret_expr = self.transform_stmt(ret_stmt) - return rxSeqExpr(blocks, ret_expr, self.tvm_span(block.span)) - - def transform_parameter(self, expr: ast.Parameter) -> rxExpr: - pass + return rx.SeqExpr(blocks, ret_expr, self.tvm_span(block.span)) class TVMDiagnosticContext(synr.DiagnosticContext): diff --git a/src/relax/expr.cc b/src/relax/expr.cc index 83892a0296..b75c98746c 100644 --- a/src/relax/expr.cc +++ b/src/relax/expr.cc @@ -178,14 +178,15 @@ TVM_REGISTER_GLOBAL("relax.Function") TVM_REGISTER_NODE_TYPE(ExternFuncNode); -ExternFunc::ExternFunc(String global_symbol) { +ExternFunc::ExternFunc(String global_symbol, Span span) { ObjectPtr n = make_object(); n->global_symbol = std::move(global_symbol); + n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol) { - return ExternFunc(global_symbol); +TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) { + return ExternFunc(global_symbol, span); }); } // namespace relax diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 676dfb22a9..0b8448478f 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1,23 +1,23 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ -#include + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ #include +#include namespace tvm { namespace relax { @@ -35,8 +35,7 @@ Expr MakeCallDPS(ShapeExpr shape, Expr func, Tuple args) { return Call(op, {shape, func, args}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.call_dps") -.set_body_typed(MakeCallDPS); +TVM_REGISTER_GLOBAL("relax.op.call_dps").set_body_typed(MakeCallDPS); // shape_of @@ -51,5 +50,18 @@ Expr MakeShapeOf(Expr expr) { TVM_REGISTER_GLOBAL("relax.op.shape_of") .set_body_typed(MakeShapeOf); -} // namespace relax -} // namespace tvm + +RELAY_REGISTER_OP("rx.call_packed") + .set_num_inputs(2) + .add_argument("func", "Expr", "The extern packed function.") + .add_argument("args", "Tuple", "The input arguments."); + +RELAY_REGISTER_OP("rx.match_shape") + .set_num_inputs(2) + .add_argument("pattern", "Array", "The matched shape pattern.") + .add_argument("value", "Expr", "The shape expression to match on."); + +RELAY_REGISTER_OP("rx.dataflow").set_num_inputs(0); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/parser.py b/tests/python/relax/parser.py index f95a2216f3..145ef7bdf4 100644 --- a/tests/python/relax/parser.py +++ b/tests/python/relax/parser.py @@ -2,8 +2,8 @@ import pytest import tvm from tvm import relax as rx -from tvm import tir - +from tvm import tir, relay +from tvm.ir import structural_equal # TODO: replace xfails with proper diagnostics checking. # c.f. tests/python/unittest/test_tvmscript_error_report.py @@ -14,14 +14,13 @@ def rx_func(func): def check_shape(e, s): - if not isinstance(e, (list, tuple)) and e is not None: - e = e._shape + if isinstance(e, rx.Expr): + e = e.shape_ if s is None: assert e is None return - assert isinstance(e, (list, tuple)) assert len(e) == len(s) for edim, sdim in zip(e, s): @@ -34,23 +33,33 @@ def check_shape(e, s): def check_tensor_var(v, s, d): - assert isinstance(v.type_annotation, rx.ty.rxTensor) + assert isinstance(v.type_annotation, rx.ty.DynTensorType) assert v.type_annotation.dtype == d + if isinstance(s, (list, tuple)): + assert v.type_annotation.rank == len(s) check_shape(v, s) +def check_call(call, op, args): + assert isinstance(call, rx.Call) + if isinstance(op, str): + op = relay.op.get(op) + assert call.op == op + assert structural_equal(call.args, args) + + def test_annotations(): @rx.script def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: - z: Tensor[(32, k), "float32"] = matmul(x, y) - w: Tensor[_, _] = mul(z, z) - t = sub(w, z) + z: Tensor[(32, k), "float32"] = nn.matmul(x, y) + w: Tensor[_, _] = multiply(z, z) + t = subtract(w, z) sh: Shape = t.shape return t f = rx_func(foo) x, y = f.params - z_bind, w_bind, t_bind, sh_bind = f.body.blocks[0] + z_bind, w_bind, t_bind, sh_bind = f.body.blocks[0].bindings z, mm = z_bind.var, z_bind.value w, mul = w_bind.var, w_bind.value t, sub = t_bind.var, t_bind.value @@ -59,42 +68,33 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: check_tensor_var(x, (32, "m"), "float32") check_tensor_var(y, ("m", "k"), "float32") check_tensor_var(z, (32, "k"), "float32") - check_tensor_var(w, None, None) + check_tensor_var(w, None, "") assert t.type_annotation is None - assert isinstance(sh.type_annotation, rx.ty.rxShape) - - assert mm.op == "matmul" - assert mm.args == [x, y] + assert isinstance(sh.type_annotation, rx.ty.ShapeType) - assert mul.op == "mul" - assert mul.args == [z, z] - - assert sub.op == "sub" - assert sub.args == [w, z] - - assert shape_of.op == "rx.shape_of" - assert shape_of.args == [t] + check_call(mm, "nn.matmul", [x, y]) + check_call(mul, "multiply", [z, z]) + check_call(sub, "subtract", [w, z]) + check_call(shape_of, "shape_of", [t]) assert f.body.body == t - assert isinstance(f.ret_type, rx.ty.rxTensor) + assert isinstance(f.ret_type, rx.ty.DynTensorType) def test_match_shape(): @rx.script def foo(x: Tensor[_, "float32"]): rx.match_shape((n, m), x.shape) - y: Tensor[(n, m), "float32"] = refine(x) + y: Tensor[(n, m), "float32"] = add(x, x) return x f = rx_func(foo) - match_sh = f.body.blocks[0][0] + match_sh = f.body.blocks[0].bindings[0] pattern, value = match_sh.pattern, match_sh.value check_shape(pattern, ("n", "m")) - assert isinstance(value, rx.expr.rxCall) - assert value.op == "rx.shape_of" - assert value.args == [f.params[0]] + check_call(value, "shape_of", [f.params[0]]) @pytest.mark.xfail @@ -110,42 +110,38 @@ def test_if(): def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): if cond: w = add(x, x) - y = mul(w, w) + y = multiply(w, w) else: - w = mul(x, x) + w = multiply(x, x) y = add(w, w) return y f = rx_func(foo) cond, x = f.params - y_bind = f.body.blocks[0][0] + y_bind = f.body.blocks[0].bindings[0] y, ite = y_bind.var, y_bind.value check_tensor_var(cond, tuple(), "bool") check_tensor_var(x, (1,), "float32") - assert isinstance(y, rx.expr.rxVar) - assert y.id == "y" + assert isinstance(y, rx.Var) + assert y.name_hint == "y" - assert isinstance(ite, rx.expr.rxIfThenElse) - assert isinstance(ite.true_branch, rx.expr.rxSeqExpr) - assert isinstance(ite.false_branch, rx.expr.rxSeqExpr) + assert isinstance(ite, rx.If) + assert isinstance(ite.true_branch, rx.SeqExpr) + assert isinstance(ite.false_branch, rx.SeqExpr) - w_bind = ite.true_branch.blocks[0][0] + w_bind = ite.true_branch.blocks[0].bindings[0] body = ite.true_branch.body - assert w_bind.var.id == "w" - assert isinstance(w_bind.value, rx.expr.rxCall) - assert w_bind.value.op == "add" and w_bind.value.args == [x, x] - assert isinstance(body, rx.expr.rxCall) - assert body.op == "mul" and body.args == [w_bind.var, w_bind.var] + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "add", [x, x]) + check_call(body, "multiply", [w_bind.var, w_bind.var]) - w_bind = ite.false_branch.blocks[0][0] + w_bind = ite.false_branch.blocks[0].bindings[0] body = ite.false_branch.body - assert w_bind.var.id == "w" - assert isinstance(w_bind.value, rx.expr.rxCall) - assert w_bind.value.op == "mul" and w_bind.value.args == [x, x] - assert isinstance(body, rx.expr.rxCall) - assert body.op == "add" and body.args == [w_bind.var, w_bind.var] + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "multiply", [x, x]) + check_call(body, "add", [w_bind.var, w_bind.var]) # TODO: figure out if-else binding type and shape @@ -167,9 +163,9 @@ def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): y = x if cond: w = add(x, x) - y = mul(w, w) + y = multiply(w, w) else: - w = mul(x, x) + w = multiply(x, x) y = add(w, w) return y @@ -180,13 +176,26 @@ def test_var_if_scoping_fail(): def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): if cond: w = add(x, x) - y = mul(w, w) + y = multiply(w, w) else: - w = mul(x, x) + w = multiply(x, x) y = add(w, w) return w +@pytest.mark.xfail +def test_if_mismatch_var_fail(): + @rx.script + def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + z = add(w, w) + return z + + @pytest.mark.xfail def test_unassigned_call_fail(): @rx.script @@ -203,23 +212,21 @@ def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]): f = rx_func(foo) x, y = f.params - t_bind = f.body.blocks[0][0] + t_bind = f.body.blocks[0].bindings[0] t, tup = t_bind.var, t_bind.value - assert isinstance(t.type_annotation, rx.ty.rxTupleType) + assert isinstance(t.type_annotation, relay.TupleType) annot = t.type_annotation - assert isinstance(annot.fields[0], rx.ty.rxTensor) and annot.fields[0].dtype is None - assert isinstance(annot.fields[1], rx.ty.rxTensor) and annot.fields[1].dtype == "float32" + assert isinstance(annot.fields[0], rx.ty.DynTensorType) and annot.fields[0].dtype == "" + assert isinstance(annot.fields[1], rx.ty.DynTensorType) and annot.fields[1].dtype == "float32" - assert isinstance(t._shape, list) and len(t._shape) == 2 - check_shape(t._shape[0], None) - check_shape(t._shape[1], (32,)) + assert t.shape_ is None - assert isinstance(tup, rx.expr.rxTuple) - assert tup.fields == [x, y] - assert isinstance(tup._shape, list) and len(tup._shape) == 2 - check_shape(tup._shape[0], None) - check_shape(tup._shape[1], (32,)) + assert isinstance(tup, rx.Tuple) + assert structural_equal(tup.fields, [x, y]) + assert tup.shape_ is None + check_shape(tup.fields[0], None) + check_shape(tup.fields[1], (32,)) # NOTE: this test requires patching synr to support local function definitions. @@ -234,11 +241,11 @@ def bar(y: Tensor[_, _]): return z f = rx_func(foo) - bar_bind, z_bind = f.body.blocks[0] + bar_bind, z_bind = f.body.blocks[0].bindings bar, bar_fn = bar_bind.var, bar_bind.value bar_x = z_bind.value - assert isinstance(bar_fn, rx.expr.rxFunction) + assert isinstance(bar_fn, rx.Function) assert bar_fn.body.body == bar_fn.params[0] assert bar_x.op == bar @@ -249,10 +256,10 @@ def test_dataflow(): def foo(x: Tensor[_, _]): with rx.dataflow(): y = add(x, x) - z = mul(y, x) - w = sub(z, x) + z = multiply(y, x) + w = subtract(z, x) return y, w - t = div(y, w) + t = divide(y, w) return t f = rx_func(foo) @@ -267,10 +274,10 @@ def test_dataflow_scope_fail(): def foo(x: Tensor[_, _]): with rx.dataflow(): y = add(x, x) - z = mul(y, x) - w = sub(z, x) + z = multiply(y, x) + w = subtract(z, x) return y, w - t = div(y, z) + t = divide(y, z) return t @@ -280,10 +287,10 @@ def test_dataflow_syntax_fail_pattern(): def foo(x: Tensor[_, _]): with rx.dataflow() as df: y = add(x, x) - z = mul(y, x) - w = sub(z, x) + z = multiply(y, x) + w = subtract(z, x) return y, w - t = div(y, z) + t = divide(y, z) return t @@ -293,8 +300,8 @@ def test_dataflow_syntax_fail_params(): def foo(x: Tensor[_, _]): with rx.dataflow(x) as df: y = add(x, x) - z = mul(y, x) - w = sub(z, x) + z = multiply(y, x) + w = subtract(z, x) return y, w - t = div(y, z) + t = divide(y, z) return t From 31cefdd583c2c3af4dfab3f7a3a68f84ed668986 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 27 Aug 2021 10:13:08 -0700 Subject: [PATCH 14/20] add inline TIR parsing --- include/tvm/relax/expr.h | 8 +- python/tvm/relax/expr.py | 11 +- python/tvm/relax/parser.py | 220 ++++++++++++++++------------------- src/relax/expr.cc | 10 ++ src/relax/op/op.cc | 6 +- tests/python/relax/parser.py | 42 ++++++- 6 files changed, 169 insertions(+), 128 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index c5349db25c..714acd0d45 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -159,7 +159,13 @@ class DataflowVarNode : public VarNode { class DataflowVar : public Var { public: - using Var::Var; // inherit constructors from Var + TVM_DLL explicit DataflowVar(String name_hint, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()) + : DataflowVar(Id(name_hint), shape_annotation, type_annotation, span) {} + + TVM_DLL explicit DataflowVar(Id vid, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); }; diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 2aeb8e330f..d4ed3cc797 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -79,7 +79,16 @@ def name_hint(self): @tvm._ffi.register_object("relax.expr.DataflowVar") class DataflowVar(Var): - pass + def __init__( + self, + name_hint: str, + shape_annotation: Optional[Expr] = None, + type_annotation: Optional[Type] = None, + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DataflowVar, name_hint, shape_annotation, type_annotation, span + ) @tvm._ffi.register_object("relax.expr.Binding") diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 11ce99cf56..91f2143127 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -5,6 +5,7 @@ from io import StringIO import tvm +import tvm.script from tvm.ir.module import IRModule from tvm.relay.base import Id from tvm.ir import diagnostics @@ -32,15 +33,22 @@ def is_registered(op_name, op_set=None): return op_name in op_set +def _tir_from_synr(synr_ast: ast.Node, diag_ctx: tvm.script.diagnostics.TVMDiagnosticCtx): + parser = tvm.script.parser.TVMScriptParser(synr_ast.span.start_line) + return parser.do_transform(synr_ast, diag_ctx) + + class RelaxTransformer(Transformer): - def __init__(self, definition_scope, diag_ctx): + def __init__(self, definition_scope): super().__init__() self.definition_scope = definition_scope - self.diag_ctx = diag_ctx self.module = {} self._scopes = [{}] # str -> Var self._registered_ops = set(tvm.ir._ffi_api.ListOpNames()) # cached + def to_tvm_span(self, span: ast.Span) -> tvm.ir.Span: + return self._diagnostic_context.to_tvm_span(self._diagnostic_context.source_name, span) + def new_scope(self): class _Scope: def __init__(self, transformer: "RelaxTransformer"): @@ -59,25 +67,6 @@ def __exit__(self, *exc): def scope(self): return self._scopes[-1] - def tvm_span(self, span: ast.Span) -> tvm.ir.Span: - """Converts the synr span to a TVM span - - Parameters - ---------- - span : synr.Span - The synr span - - Returns - ------- - tvm.ir.Span - The corresponding TVM span - """ - src_name = self.diag_ctx.str_to_source_name[span.filename] - tvm_span = tvm.ir.Span( - src_name, span.start_line, span.end_line, span.start_column, span.end_column - ) - return tvm_span - def decl_var( self, name: str, @@ -134,7 +123,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E if ty is None: return (None, None) - span = self.tvm_span(ty.span) + span = self.to_tvm_span(ty.span) # simple annotation with no type arguments if isinstance(ty, ast.TypeVar): @@ -169,7 +158,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E self._diagnostic_context.emit( "error", "variable Tensor shape annotations not yet supported", - self.tvm_span(shape_annotation.span), + shape_annotation.span, ) self._diagnostic_context.render() else: @@ -178,13 +167,13 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E elif isinstance(shape_annotation, ast.TypeTuple): shape = rx.ShapeExpr( self.parse_shape(shape_annotation, allow_intro), - span=self.tvm_span(shape_annotation.span), + span=self.to_tvm_span(shape_annotation.span), ) else: self._diagnostic_context.emit( "error", "unsupported shape annotation", - self.tvm_span(shape_annotation.span), + shape_annotation.span, ) self._diagnostic_context.render() @@ -197,7 +186,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E self._diagnostic_context.emit( "error", "Tensor dtype annotations must be concrete or erased", - self.tvm_span(dtype_annotation.span), + dtype_annotation.span, ) self._diagnostic_context.render() @@ -210,7 +199,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E fty, fsh = self.transform_type(field, allow_intro=False) field_types.append(fty) field_shapes.append(fsh) - return (relay.TupleType(field_types, self.tvm_span(ty.span)), None) + return (relay.TupleType(field_types, span), None) # TODO: other types with args, e.g. Ref[T], func types self._diagnostic_context.emit("error", "invalid type", span) self._diagnostic_context.render() @@ -259,35 +248,35 @@ def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: self._diagnostic_context.emit( "error", "non-dimension variables cannot appear in dimension expressions", - self.tvm_span(expr.span), + expr.span, ) self._diagnostic_context.render() return var elif allow_intro: # introduce TIR variable to scope, e.g. for func params or rx.call_packed - var = tir.Var(var_name, "int32", self.tvm_span(expr.span)) + var = tir.Var(var_name, "int32", self.to_tvm_span(expr.span)) self.scope[var_name] = var return var else: self._diagnostic_context.emit( "error", "cannot introduce new dimension variables in this expression", - self.tvm_span(expr.span), + expr.span, ) self._diagnostic_context.render() elif isinstance(expr, ast.Constant): if not isinstance(expr.value, int): self._diagnostic_context.emit( - "error", "only integer constants are supported", self.tvm_span(expr.span) + "error", "only integer constants are supported", expr.span ) self._diagnostic_context.render() - return tir.const(expr.value, "int32", self.tvm_span(expr.span)) + return tir.const(expr.value, "int32", self.to_tvm_span(expr.span)) else: # TODO: parse (simple) PrimExprs self._diagnostic_context.emit( "error", "only dimension variable expressions are currently supported", - self.tvm_span(expr.span), + expr.span, ) self._diagnostic_context.render() @@ -325,7 +314,7 @@ def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: # self._diagnostic_context.emit( # "error", # "The shape expression can only contain arithmetic operators, integer constants and variables", - # self.tvm_span(expr.span), + # (expr.span), # ) # self._diagnostic_context.render() @@ -336,18 +325,25 @@ def transform_module(self, mod: ast.Module) -> IRModule: return self.module def transform_function(self, func: ast.Function, is_global=False) -> rx.Function: + if ( + len(func.decorators) == 1 + and isinstance(func.decorators[0], ast.Var) + and func.decorators[0].id.name == "tir" + ): + return _tir_from_synr(func, self._diagnostic_context) + with self.new_scope(): params = [] for param in func.params: ty, shape = self.transform_type(param.ty, allow_intro=True) - param = self.decl_var(param.name, ty, shape, self.tvm_span(param.span)) + param = self.decl_var(param.name, ty, shape, self.to_tvm_span(param.span)) params.append(param) new_body = self.transform_block(func.body) ret_type, _ = self.transform_type(func.ret_type, allow_intro=False) func_name = rx.GlobalVar(func.name) if is_global else None return rx.Function( - params, new_body, ret_type, name=func_name, span=self.tvm_span(func.span) + params, new_body, ret_type, name=func_name, span=self.to_tvm_span(func.span) ) def parse_binding(self, stmt: ast.Assign, is_dataflow=False): @@ -355,20 +351,20 @@ def parse_binding(self, stmt: ast.Assign, is_dataflow=False): self._diagnostic_context.emit( "error", "the left hand side of a binding must be a variable", - self.tvm_span(stmt.lhs.span), + stmt.lhs.span, ) self._diagnostic_context.render() # TODO: figure out proper way of doing this rhs = self.transform_expr(stmt.rhs) - if isinstance(rhs, relay.Call) and rhs.op == relay.op.get("rx.call_packed"): + if isinstance(rhs, relay.Call) and rhs.op == relay.op.get("relax.call_packed"): allow_intro = True else: allow_intro = False ty, shape = self.transform_type(stmt.ty, allow_intro) lhs = self.decl_var( - stmt.lhs.id.name, ty, shape, self.tvm_span(stmt.lhs.span), is_dataflow=is_dataflow + stmt.lhs.id.name, ty, shape, self.to_tvm_span(stmt.lhs.span), is_dataflow=is_dataflow ) - return rx.VarBinding(lhs, rhs, self.tvm_span(stmt.span)) + return rx.VarBinding(lhs, rhs, self.to_tvm_span(stmt.span)) # Stmts: # - Assert: probably unsupported for now @@ -417,9 +413,9 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl with self.new_scope(): false_branch = self.transform_block(false_block) # TODO: the spans here are all sorts of messed up, not sure how to fix - ite_expr = relay.If(cond, true_branch, false_branch, self.tvm_span(stmt.span)) - var = self.decl_var(var_name, None, None, self.tvm_span(false_assign.span)) - return rx.VarBinding(var, ite_expr, self.tvm_span(stmt.span)) + ite_expr = relay.If(cond, true_branch, false_branch, self.to_tvm_span(stmt.span)) + var = self.decl_var(var_name, None, None, self.to_tvm_span(false_assign.span)) + return rx.VarBinding(var, ite_expr, self.to_tvm_span(stmt.span)) elif isinstance(stmt, ast.Return): return self.transform_expr(stmt.value) @@ -428,16 +424,14 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl elif isinstance(stmt, ast.UnassignedCall): call: synr.ast.Call = stmt.call op = self.transform_expr(call.func_name) - # FIXME: this check is unreachable since transform_expr tries looking up func_name as a - # variable and fails - if op != relay.op.get("rx.match_shape"): + if op != relay.op.get("relax.match_shape"): self._diagnostic_context.emit( - "error", "the results of calls must be bound or used", self.tvm_span(stmt.span) + "error", "the results of calls must be bound or used", stmt.span ) self._diagnostic_context.render() if len(stmt.call.params) != 2: self._diagnostic_context.emit( - "error", "rx.match_shape takes exactly two arguments", self.tvm_span(stmt.span) + "error", "relax.match_shape takes exactly two arguments", stmt.span ) self._diagnostic_context.render() @@ -448,18 +442,16 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl if not isinstance(lhs, ast.Tuple): self._diagnostic_context.emit( "error", - "the pattern (lhs) of rx.match_shape must be a tuple", - self.tvm_span(lhs.span), + "the pattern (lhs) of relax.match_shape must be a tuple", + lhs.span, ) self._diagnostic_context.render() lhs_expr = self.parse_shape(lhs, allow_intro=True) - return rx.MatchShape(lhs_expr, rhs_expr, self.tvm_span(stmt.span)) + return rx.MatchShape(lhs_expr, rhs_expr, self.to_tvm_span(stmt.span)) elif isinstance(stmt, ast.With): if not isinstance(stmt.rhs, ast.Call): - self._diagnostic_context.emit( - "error", "unsupported with block", self.tvm_span(stmt.span) - ) + self._diagnostic_context.emit("error", "unsupported with block", stmt.span) self._diagnostic_context.render() call = stmt.rhs @@ -467,23 +459,21 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl # TODO: perhaps this ought to be more general - if op != relay.op.get("rx.dataflow"): - self._diagnostic_context.emit( - "error", "unsupported with block type", self.tvm_span(call.span) - ) + if op != relay.op.get("relax.dataflow"): + self._diagnostic_context.emit("error", "unsupported with block type", call.span) self._diagnostic_context.render() if len(call.params) > 0: self._diagnostic_context.emit( "error", "dataflow block constructor takes no arguments", - self.tvm_span(call.params[0].span), + call.params[0].span, ) self._diagnostic_context.render() if len(stmt.lhs) > 0: self._diagnostic_context.emit( "error", "dataflow blocks don't bind any patterns", - self.tvm_span(stmt.lhs[0].span), + stmt.lhs[0].span, ) self._diagnostic_context.render() @@ -491,14 +481,14 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl elif isinstance(stmt, ast.Function): func = self.transform_function(stmt) - func_var = self.decl_var(stmt.name, None, None, self.tvm_span(stmt.span)) - return rx.VarBinding(func_var, func, self.tvm_span(stmt.span)) + func_var = self.decl_var(stmt.name, None, None, self.to_tvm_span(stmt.span)) + return rx.VarBinding(func_var, func, self.to_tvm_span(stmt.span)) else: self._diagnostic_context.emit( "error", "unsupported statement", - self.tvm_span(stmt.span), + stmt.span, ) self._diagnostic_context.render() @@ -514,7 +504,7 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: self._diagnostic_context.emit( "error", "dataflow blocks must end with returning the output variables", - self.tvm_span(output_stmt.span), + output_stmt.span, ) self._diagnostic_context.render() @@ -528,7 +518,7 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: self._diagnostic_context.emit( "error", "the returned values must be variables", - self.tvm_span(ret_val.span), + ret_val.span, ) # output variables are bound to normal (not data flow) Vars @@ -539,7 +529,7 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: self._diagnostic_context.emit( "error", "only bindings are supported in dataflow blocks", - self.tvm_span(binding_stmt.span), + binding_stmt.span, ) self._diagnostic_context.render() is_dataflow = binding_stmt.lhs.id.name not in output_var_names @@ -552,7 +542,7 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: for v in output_vars: self.scope[v.name_hint] = v - return rx.DataflowBlock(bindings, self.tvm_span(block.span)) + return rx.DataflowBlock(bindings, self.to_tvm_span(block.span)) # Exprs: # - ArrayLiteral: unsupported for now? @@ -576,9 +566,7 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: op_name.append(expr.field.name) attr = attr.object if not isinstance(attr, ast.Var): - self._diagnostic_context.emit( - "error", "unsupported field access", self.tvm_span(expr) - ) + self._diagnostic_context.emit("error", "unsupported field access", expr.span) self._diagnostic_context.render() op_name.append(attr.id.name) op_name = ".".join(reversed(op_name)) @@ -587,27 +575,23 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: if isinstance(expr, ast.Call): op = self.transform_expr(expr.func_name) args = [self.transform_expr(arg) for arg in expr.params] - return relay.Call(op, args, span=self.tvm_span(expr.span)) + return relay.Call(op, args, span=self.to_tvm_span(expr.span)) elif isinstance(expr, ast.Tuple): fields = [self.transform_expr(field) for field in expr.values] - return relay.Tuple(fields, span=self.tvm_span(expr.span)) + return relay.Tuple(fields, span=self.to_tvm_span(expr.span)) elif isinstance(expr, ast.Var): var_name = expr.id.name if is_registered(var_name, op_set=self._registered_ops): return relay.op.get(var_name) if var_name not in self.scope: - self._diagnostic_context.emit( - "error", "undefined variable", self.tvm_span(expr.span) - ) + self._diagnostic_context.emit("error", "undefined variable", expr.span) self._diagnostic_context.render() return self.scope[var_name] else: - self._diagnostic_context.emit( - "error", "unsupported expression", self.tvm_span(expr.span) - ) + self._diagnostic_context.emit("error", "unsupported expression", expr.span) self._diagnostic_context.render() def transform_block(self, block: ast.Block) -> rx.SeqExpr: @@ -619,65 +603,65 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr: if isinstance(parsed_stmt, rx.DataflowBlock): if current_block: # FIXME: span - blocks.append(rx.BindingBlock(current_block, self.tvm_span(stmt.span))) + blocks.append(rx.BindingBlock(current_block, self.to_tvm_span(stmt.span))) current_block = [] blocks.append(parsed_stmt) else: assert isinstance(parsed_stmt, rx.Binding) current_block.append(parsed_stmt) if len(current_block) > 0: - blocks.append(rx.BindingBlock(current_block, self.tvm_span(stmt.span))) + blocks.append(rx.BindingBlock(current_block, self.to_tvm_span(stmt.span))) ret_stmt = block.stmts[-1] if not isinstance(ret_stmt, ast.Return): self._diagnostic_context.emit( "error", "block must end with a returned expression", - self.tvm_span(ret_stmt.span), + ret_stmt.span, ) self._diagnostic_context.render() ret_expr = self.transform_stmt(ret_stmt) - return rx.SeqExpr(blocks, ret_expr, self.tvm_span(block.span)) + return rx.SeqExpr(blocks, ret_expr, self.to_tvm_span(block.span)) -class TVMDiagnosticContext(synr.DiagnosticContext): - def __init__(self, tvm_diag_ctx): - self.tvm_diag_ctx = tvm_diag_ctx - self.str_to_source_name = {} +# class TVMDiagnosticContext(synr.DiagnosticContext): +# def __init__(self, tvm_diag_ctx): +# self.tvm_diag_ctx = tvm_diag_ctx +# self.str_to_source_name = {} - def add_source(self, name: str, source: str) -> None: - """Add a file with source code to the context. This will be called - before any call to :py:func:`emit` that contains a span in this - file. - """ - src_name = self.tvm_diag_ctx.module.source_map.add(name, source) - self.str_to_source_name[name] = src_name - - def emit(self, level: str, message: str, span: tvm.ir.Span) -> None: - """Called when an error has occured.""" - - if level == "error": - level = diagnostics.DiagnosticLevel.ERROR - elif level == "bug": - level = diagnostics.DiagnosticLevel.BUG - elif level == "warning": - level = diagnostics.DiagnosticLevel.WARNING - else: - level = "error" +# def add_source(self, name: str, source: str) -> None: +# """Add a file with source code to the context. This will be called +# before any call to :py:func:`emit` that contains a span in this +# file. +# """ +# src_name = self.tvm_diag_ctx.module.source_map.add(name, source) +# self.str_to_source_name[name] = src_name - assert span, "Span must not be null" - assert isinstance(span, tvm.ir.Span), "Expected tvm.ir.Span, but got " + str(type(span)) +# def emit(self, level: str, message: str, span: tvm.ir.Span) -> None: +# """Called when an error has occured.""" - diag = diagnostics.Diagnostic(level, span, message) +# if level == "error": +# level = diagnostics.DiagnosticLevel.ERROR +# elif level == "bug": +# level = diagnostics.DiagnosticLevel.BUG +# elif level == "warning": +# level = diagnostics.DiagnosticLevel.WARNING +# else: +# level = "error" - self.tvm_diag_ctx.emit(diag) +# assert span, "Span must not be null" +# assert isinstance(span, tvm.ir.Span), "Expected tvm.ir.Span, but got " + str(type(span)) - def render(self) -> Optional[Any]: - """Render out all error messages. Can either return a value or raise - and execption. - """ - self.tvm_diag_ctx.render() +# diag = diagnostics.Diagnostic(level, span, message) + +# self.tvm_diag_ctx.emit(diag) + +# def render(self) -> Optional[Any]: +# """Render out all error messages. Can either return a value or raise +# and execption. +# """ +# self.tvm_diag_ctx.render() class RelaxDecoratedFn: @@ -697,10 +681,10 @@ def __call__(self, *args): def script(f): - ir_module = tvm.IRModule({}) - diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) - diag_ctx = TVMDiagnosticContext(diag_ctx) + # ir_module = tvm.IRModule({}) + # diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) + diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() ast = synr.to_ast(f, diag_ctx) definition_scope = inspect.getmodule(f) - module = RelaxTransformer(definition_scope, diag_ctx).do_transform(ast, diag_ctx) + module = RelaxTransformer(definition_scope).do_transform(ast, diag_ctx) return RelaxDecoratedFn(f.__name__, module, diag_ctx) diff --git a/src/relax/expr.cc b/src/relax/expr.cc index b75c98746c..08e1a01f1d 100644 --- a/src/relax/expr.cc +++ b/src/relax/expr.cc @@ -72,6 +72,16 @@ TVM_REGISTER_GLOBAL("relax.Var") TVM_REGISTER_NODE_TYPE(DataflowVarNode); +DataflowVar::DataflowVar(Id vid, Optional shape_annotation, Optional type_annotation, + Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + n->shape_ = std::move(shape_annotation); + n->type_annotation = std::move(type_annotation); + n->span = std::move(span); + data_ = std::move(n); +} + TVM_REGISTER_GLOBAL("relax.DataflowVar") .set_body_typed([](String name_hint, Optional shape_annotation, Optional type_annotation, Span span) { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 0b8448478f..316fa7d5a2 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -51,17 +51,17 @@ Expr MakeShapeOf(Expr expr) { TVM_REGISTER_GLOBAL("relax.op.shape_of") .set_body_typed(MakeShapeOf); -RELAY_REGISTER_OP("rx.call_packed") +RELAY_REGISTER_OP("relax.call_packed") .set_num_inputs(2) .add_argument("func", "Expr", "The extern packed function.") .add_argument("args", "Tuple", "The input arguments."); -RELAY_REGISTER_OP("rx.match_shape") +RELAY_REGISTER_OP("relax.match_shape") .set_num_inputs(2) .add_argument("pattern", "Array", "The matched shape pattern.") .add_argument("value", "Expr", "The shape expression to match on."); -RELAY_REGISTER_OP("rx.dataflow").set_num_inputs(0); +RELAY_REGISTER_OP("relax.dataflow").set_num_inputs(0); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/parser.py b/tests/python/relax/parser.py index 145ef7bdf4..3469dc1c41 100644 --- a/tests/python/relax/parser.py +++ b/tests/python/relax/parser.py @@ -85,7 +85,7 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: def test_match_shape(): @rx.script def foo(x: Tensor[_, "float32"]): - rx.match_shape((n, m), x.shape) + relax.match_shape((n, m), x.shape) y: Tensor[(n, m), "float32"] = add(x, x) return x @@ -254,7 +254,7 @@ def bar(y: Tensor[_, _]): def test_dataflow(): @rx.script def foo(x: Tensor[_, _]): - with rx.dataflow(): + with relax.dataflow(): y = add(x, x) z = multiply(y, x) w = subtract(z, x) @@ -264,15 +264,22 @@ def foo(x: Tensor[_, _]): f = rx_func(foo) df_block = f.body.blocks[0] + y_bind, z_bind, w_bind = df_block.bindings + + assert isinstance(y_bind.var, rx.Var) + assert isinstance(z_bind.var, rx.DataflowVar) + assert isinstance(w_bind.var, rx.Var) # TODO: check correctness + # import pdb; pdb.set_trace() + @pytest.mark.xfail def test_dataflow_scope_fail(): @rx.script def foo(x: Tensor[_, _]): - with rx.dataflow(): + with relax.dataflow(): y = add(x, x) z = multiply(y, x) w = subtract(z, x) @@ -285,7 +292,7 @@ def foo(x: Tensor[_, _]): def test_dataflow_syntax_fail_pattern(): @rx.script def foo(x: Tensor[_, _]): - with rx.dataflow() as df: + with relax.dataflow() as df: y = add(x, x) z = multiply(y, x) w = subtract(z, x) @@ -298,10 +305,35 @@ def foo(x: Tensor[_, _]): def test_dataflow_syntax_fail_params(): @rx.script def foo(x: Tensor[_, _]): - with rx.dataflow(x) as df: + with relax.dataflow(x) as df: y = add(x, x) z = multiply(y, x) w = subtract(z, x) return y, w t = divide(y, z) return t + + +@pytest.mark.xfail +def test_func_no_return_fail(): + @rx.script + def foo(x: Tensor[_, _]): + y = add(x, x) + + +def test_inline_tir(): + @rx.script + def foo(x: Tensor[(128, 128), "float32"], y: Tensor[(128, 128), "float32"]): + @tir + def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = tir.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + z = relax.call_dps(my_matmul, x, y) + return z From 6d9398d40bbc1062c2a54c0d6ffe4d6675644d4c Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 27 Aug 2021 16:04:24 -0700 Subject: [PATCH 15/20] some cleanup --- python/tvm/relax/parser.py | 257 +++++++++------------- python/tvm/relax/parser_tests/__init__.py | 0 python/tvm/relax/parser_tests/failing.py | 4 - python/tvm/relax/parser_tests/passing.py | 38 ---- tests/python/relax/parser.py | 41 +++- 5 files changed, 132 insertions(+), 208 deletions(-) delete mode 100644 python/tvm/relax/parser_tests/__init__.py delete mode 100644 python/tvm/relax/parser_tests/failing.py delete mode 100644 python/tvm/relax/parser_tests/passing.py diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 91f2143127..5aab17abb9 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -49,6 +49,10 @@ def __init__(self, definition_scope): def to_tvm_span(self, span: ast.Span) -> tvm.ir.Span: return self._diagnostic_context.to_tvm_span(self._diagnostic_context.source_name, span) + def report_error(self, msg: str, span: ast.Span): + self._diagnostic_context.emit("error", msg, span) + self._diagnostic_context.render() + def new_scope(self): class _Scope: def __init__(self, transformer: "RelaxTransformer"): @@ -72,7 +76,7 @@ def decl_var( name: str, type_annotation: Optional[rx.Type], shape: Optional[rx.Expr], - span: tvm.ir.Span, + span: ast.Span, is_dataflow: bool = False, ) -> rx.Var: """Introduces a variable with the given name and annotations to the current scope. @@ -85,7 +89,7 @@ def decl_var( The type annotation shape : Optional[rxExpr] The shape annotation - span : tvm.ir.Span + span : ast.Span The span where the variable is declared Returns @@ -94,14 +98,11 @@ def decl_var( The declared variable """ if name in self.scope: - self._diagnostic_context.emit( - "error", "variable has already been declared in the current scope", span - ) - self._diagnostic_context.render() + self.report_error("variable has already been declared in the current scope", span) if is_dataflow: - var = rx.DataflowVar(name, shape, type_annotation, span) + var = rx.DataflowVar(name, shape, type_annotation, self.to_tvm_span(span)) else: - var = rx.Var(name, shape, type_annotation, span) + var = rx.Var(name, shape, type_annotation, self.to_tvm_span(span)) self.scope[name] = var return var @@ -134,19 +135,16 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E elif ty.id.name == "Dim": return (rx.DimType(span), None) else: - self._diagnostic_context.emit("error", "unknown type in annotation", span) - self._diagnostic_context.render() + self.report_error("unknown type in annotation", span) # annotation with type arguments/shape annotation if isinstance(ty, ast.TypeApply): if ty.id.name == "Tensor": if len(ty.params) != 2: - self._diagnostic_context.emit( - "error", + self.report_error( "Tensor type annotations must have 2 fields (shape and dtype)", span, ) - self._diagnostic_context.render() shape_annotation, dtype_annotation = ty.params shape, dtype = None, None @@ -155,12 +153,10 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E if isinstance(shape_annotation, ast.TypeVar): if shape_annotation.id.name != "_": # TODO: handle variable annotations, e.g. x: Tensor[my_shape, _] - self._diagnostic_context.emit( - "error", + self.report_error( "variable Tensor shape annotations not yet supported", shape_annotation.span, ) - self._diagnostic_context.render() else: # FIXME: use a special node for unknown shape vs no shape? pass # shape = None @@ -170,12 +166,10 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E span=self.to_tvm_span(shape_annotation.span), ) else: - self._diagnostic_context.emit( - "error", + self.report_error( "unsupported shape annotation", shape_annotation.span, ) - self._diagnostic_context.render() # parse the dtype annotation if isinstance(dtype_annotation, ast.TypeVar) and dtype_annotation.id.name == "_": @@ -183,12 +177,10 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E elif isinstance(dtype_annotation, ast.TypeConstant): dtype = dtype_annotation.value # TODO: parse to TVM DType? else: - self._diagnostic_context.emit( - "error", + self.report_error( "Tensor dtype annotations must be concrete or erased", dtype_annotation.span, ) - self._diagnostic_context.render() rank = len(shape) if shape is not None else -1 return (rx.DynTensorType(rank=rank, dtype=dtype, span=span), shape) @@ -201,8 +193,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E field_shapes.append(fsh) return (relay.TupleType(field_types, span), None) # TODO: other types with args, e.g. Ref[T], func types - self._diagnostic_context.emit("error", "invalid type", span) - self._diagnostic_context.render() + self.report_error("invalid type", span) def parse_shape( self, @@ -245,12 +236,10 @@ def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: if var_name in self.scope: var = self.scope[var_name] if not isinstance(var, tir.Var): - self._diagnostic_context.emit( - "error", + self.report_error( "non-dimension variables cannot appear in dimension expressions", expr.span, ) - self._diagnostic_context.render() return var elif allow_intro: # introduce TIR variable to scope, e.g. for func params or rx.call_packed @@ -258,65 +247,20 @@ def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: self.scope[var_name] = var return var else: - self._diagnostic_context.emit( - "error", + self.report_error( "cannot introduce new dimension variables in this expression", expr.span, ) - self._diagnostic_context.render() elif isinstance(expr, ast.Constant): if not isinstance(expr.value, int): - self._diagnostic_context.emit( - "error", "only integer constants are supported", expr.span - ) - self._diagnostic_context.render() + self.report_error("only integer constants are supported", expr.span) return tir.const(expr.value, "int32", self.to_tvm_span(expr.span)) else: # TODO: parse (simple) PrimExprs - self._diagnostic_context.emit( - "error", + self.report_error( "only dimension variable expressions are currently supported", expr.span, ) - self._diagnostic_context.render() - - # Turns a tuple into an array of PrimExprs - # Allow arithmetic indicates whether we are letting there be - # def expr_to_primexpr(self, expr: ast.Expr, allow_arithmetic=False) -> PrimExpr: - # if not allow_arithmetic and not isinstance(expr, ast.Var): - # # TODO: improve error message - # self._diagnostic_context.emit( - # "error", - # "You can only use single variables here, not an expression", - # self.span_to_span(expr.span), - # ) - # self._diagnostic_context.render() - # else: - # if isinstance(expr, ast.Var): - # return tir.Var(expr.id.name, "int32") - - # # TODO: do all the ops here - # elif isinstance(expr, ast.Constant) and isinstance(expr.value, int): - # return tir.IntImm("int32", expr.value) - # elif isinstance(expr, ast.Call): - # if exp.func_name.name == ast.BuiltinOp.Add: - # # TODO: call this fn on args and return primexpr containing result - # assert False - # if exp.func_name.name == ast.BuiltinOp.Sub: - # assert False - # if exp.func_name.name == ast.BuiltinOp.Mul: - # assert False - # if exp.func_name.name == ast.BuiltinOp.Div: - # assert False - # if exp.func_name.name == ast.BuiltinOp.Mod: - # assert False - # else: - # self._diagnostic_context.emit( - # "error", - # "The shape expression can only contain arithmetic operators, integer constants and variables", - # (expr.span), - # ) - # self._diagnostic_context.render() def transform_module(self, mod: ast.Module) -> IRModule: for func_name in mod.funcs: @@ -336,7 +280,7 @@ def transform_function(self, func: ast.Function, is_global=False) -> rx.Function params = [] for param in func.params: ty, shape = self.transform_type(param.ty, allow_intro=True) - param = self.decl_var(param.name, ty, shape, self.to_tvm_span(param.span)) + param = self.decl_var(param.name, ty, shape, param.span) params.append(param) new_body = self.transform_block(func.body) ret_type, _ = self.transform_type(func.ret_type, allow_intro=False) @@ -346,14 +290,39 @@ def transform_function(self, func: ast.Function, is_global=False) -> rx.Function params, new_body, ret_type, name=func_name, span=self.to_tvm_span(func.span) ) - def parse_binding(self, stmt: ast.Assign, is_dataflow=False): + def parse_binding(self, stmt: ast.Stmt, is_dataflow=False): + assert isinstance(stmt, (ast.Assign, ast.UnassignedCall)) + if isinstance(stmt, ast.Assign): + return self.parse_var_binding(stmt, is_dataflow=is_dataflow) + else: + return self.parse_shape_binding(stmt) + + def parse_shape_binding(self, stmt: ast.UnassignedCall): + call: synr.ast.Call = stmt.call + op = self.transform_expr(call.func_name) + if op != relay.op.get("relax.match_shape"): + self.report_error("the results of calls must be bound or used", stmt.span) + if len(stmt.call.params) != 2: + self.report_error("relax.match_shape takes exactly two arguments", stmt.span) + + lhs = stmt.call.params[0] + rhs = stmt.call.params[1] + + rhs_expr = self.transform_expr(rhs) + if not isinstance(lhs, ast.Tuple): + self.report_error( + "the pattern (lhs) of relax.match_shape must be a tuple", + lhs.span, + ) + lhs_expr = self.parse_shape(lhs, allow_intro=True) + return rx.MatchShape(lhs_expr, rhs_expr, self.to_tvm_span(stmt.span)) + + def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False): if not isinstance(stmt.lhs, ast.Var): - self._diagnostic_context.emit( - "error", + self.report_error( "the left hand side of a binding must be a variable", stmt.lhs.span, ) - self._diagnostic_context.render() # TODO: figure out proper way of doing this rhs = self.transform_expr(stmt.rhs) if isinstance(rhs, relay.Call) and rhs.op == relay.op.get("relax.call_packed"): @@ -361,9 +330,7 @@ def parse_binding(self, stmt: ast.Assign, is_dataflow=False): else: allow_intro = False ty, shape = self.transform_type(stmt.ty, allow_intro) - lhs = self.decl_var( - stmt.lhs.id.name, ty, shape, self.to_tvm_span(stmt.lhs.span), is_dataflow=is_dataflow - ) + lhs = self.decl_var(stmt.lhs.id.name, ty, shape, stmt.lhs.span, is_dataflow=is_dataflow) return rx.VarBinding(lhs, rhs, self.to_tvm_span(stmt.span)) # Stmts: @@ -380,18 +347,32 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl return self.parse_binding(stmt) elif isinstance(stmt, ast.If): - # TODO: proper diagnostics - # check branches are non-empty - assert stmt.true.stmts - assert stmt.false.stmts + if len(stmt.true.stmts) == 0 or len(stmt.false.stmts) == 0: + self.report_error("both branches of an if-else block must be non-empty", stmt.span) true_assign = stmt.true.stmts[-1] false_assign = stmt.false.stmts[-1] # check last statement in each branch lines up - assert isinstance(true_assign, ast.Assign) and isinstance(true_assign.lhs, ast.Var) - assert isinstance(false_assign, ast.Assign) and isinstance(false_assign.lhs, ast.Var) - assert true_assign.lhs.id.name == false_assign.lhs.id.name + if not isinstance(true_assign, ast.Assign) or not isinstance(true_assign.lhs, ast.Var): + self.report_error( + "each branch of an if-else statement must end in a variable assignment", + true_assign.span, + ) + if not isinstance(false_assign, ast.Assign) or not isinstance( + false_assign.lhs, ast.Var + ): + self.report_error( + "each branch of an if-else statement must end in a variable assignment", + false_assign.span, + ) + union_span = ast.Span.union([true_assign.span, false_assign.span]) + if true_assign.lhs.id.name != false_assign.lhs.id.name: + self.report_error( + "the final assignment of both branches must have the same variable", + union_span, + ) + var_name = true_assign.lhs.id.name # rewrite branches to have a return statement so the blocks properly parse to SeqExprs @@ -414,45 +395,20 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl false_branch = self.transform_block(false_block) # TODO: the spans here are all sorts of messed up, not sure how to fix ite_expr = relay.If(cond, true_branch, false_branch, self.to_tvm_span(stmt.span)) - var = self.decl_var(var_name, None, None, self.to_tvm_span(false_assign.span)) - return rx.VarBinding(var, ite_expr, self.to_tvm_span(stmt.span)) + # TODO: type and shape of return var + var = self.decl_var(var_name, None, None, union_span) + return rx.VarBinding(var, ite_expr, self.to_tvm_span(union_span)) elif isinstance(stmt, ast.Return): return self.transform_expr(stmt.value) - # match_shape is the ONLY node that doesn't have to be bound to an LHS variable! elif isinstance(stmt, ast.UnassignedCall): - call: synr.ast.Call = stmt.call - op = self.transform_expr(call.func_name) - if op != relay.op.get("relax.match_shape"): - self._diagnostic_context.emit( - "error", "the results of calls must be bound or used", stmt.span - ) - self._diagnostic_context.render() - if len(stmt.call.params) != 2: - self._diagnostic_context.emit( - "error", "relax.match_shape takes exactly two arguments", stmt.span - ) - self._diagnostic_context.render() - - lhs = stmt.call.params[0] - rhs = stmt.call.params[1] - - rhs_expr = self.transform_expr(rhs) - if not isinstance(lhs, ast.Tuple): - self._diagnostic_context.emit( - "error", - "the pattern (lhs) of relax.match_shape must be a tuple", - lhs.span, - ) - self._diagnostic_context.render() - lhs_expr = self.parse_shape(lhs, allow_intro=True) - return rx.MatchShape(lhs_expr, rhs_expr, self.to_tvm_span(stmt.span)) + # FIXME: when we add ref support, ref_write can be unassigned + return self.parse_shape_binding(stmt) elif isinstance(stmt, ast.With): if not isinstance(stmt.rhs, ast.Call): - self._diagnostic_context.emit("error", "unsupported with block", stmt.span) - self._diagnostic_context.render() + self.report_error("unsupported with block", stmt.span) call = stmt.rhs op = self.transform_expr(call.func_name) @@ -460,37 +416,30 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl # TODO: perhaps this ought to be more general if op != relay.op.get("relax.dataflow"): - self._diagnostic_context.emit("error", "unsupported with block type", call.span) - self._diagnostic_context.render() + self.report_error("unsupported with block type", call.span) if len(call.params) > 0: - self._diagnostic_context.emit( - "error", + self.report_error( "dataflow block constructor takes no arguments", call.params[0].span, ) - self._diagnostic_context.render() if len(stmt.lhs) > 0: - self._diagnostic_context.emit( - "error", + self.report_error( "dataflow blocks don't bind any patterns", stmt.lhs[0].span, ) - self._diagnostic_context.render() return self.parse_dataflow(stmt.body) elif isinstance(stmt, ast.Function): func = self.transform_function(stmt) - func_var = self.decl_var(stmt.name, None, None, self.to_tvm_span(stmt.span)) + func_var = self.decl_var(stmt.name, None, None, stmt.span) return rx.VarBinding(func_var, func, self.to_tvm_span(stmt.span)) else: - self._diagnostic_context.emit( - "error", + self.report_error( "unsupported statement", stmt.span, ) - self._diagnostic_context.render() def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: assert len(block.stmts) > 0, "should never have an empty dataflow block" @@ -501,12 +450,10 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: # parse the return statement first to figure out which bindings assign normal Vars output_stmt = block.stmts[-1] if not isinstance(output_stmt, ast.Return): - self._diagnostic_context.emit( - "error", + self.report_error( "dataflow blocks must end with returning the output variables", output_stmt.span, ) - self._diagnostic_context.render() ret_val = output_stmt.value if isinstance(ret_val, ast.Var): @@ -515,8 +462,7 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: if not isinstance(ret_val, ast.Tuple) or any( [not isinstance(f, ast.Var) for f in ret_val.values] ): - self._diagnostic_context.emit( - "error", + self.report_error( "the returned values must be variables", ret_val.span, ) @@ -525,22 +471,30 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: output_var_names = {var.id.name for var in ret_val.values} for binding_stmt in block.stmts[:-1]: - if not isinstance(binding_stmt, ast.Assign): - self._diagnostic_context.emit( - "error", + if not isinstance(binding_stmt, (ast.Assign, ast.UnassignedCall)): + self.report_error( "only bindings are supported in dataflow blocks", binding_stmt.span, ) - self._diagnostic_context.render() - is_dataflow = binding_stmt.lhs.id.name not in output_var_names + is_match_shape = isinstance(binding_stmt, ast.UnassignedCall) + is_dataflow = ( + False if is_match_shape else (binding_stmt.lhs.id.name not in output_var_names) + ) binding = self.parse_binding(binding_stmt, is_dataflow=is_dataflow) bindings.append(binding) if not is_dataflow: - output_vars.append(binding.var) + if is_match_shape: + for var in binding.pattern: + output_vars.append(var) + else: + output_vars.append(binding.var) # make output variables visible in parent scope for v in output_vars: - self.scope[v.name_hint] = v + # v could already be in scope if it was a previously bound dimension variable + v_name = v.name if isinstance(v, tir.Var) else v.name_hint + if v not in self.scope: + self.scope[v_name] = v return rx.DataflowBlock(bindings, self.to_tvm_span(block.span)) @@ -566,8 +520,7 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: op_name.append(expr.field.name) attr = attr.object if not isinstance(attr, ast.Var): - self._diagnostic_context.emit("error", "unsupported field access", expr.span) - self._diagnostic_context.render() + self.report_error("unsupported field access", expr.span) op_name.append(attr.id.name) op_name = ".".join(reversed(op_name)) return relay.op.get(op_name) # TODO: maybe diagnostics here in case this fails? @@ -586,13 +539,11 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: if is_registered(var_name, op_set=self._registered_ops): return relay.op.get(var_name) if var_name not in self.scope: - self._diagnostic_context.emit("error", "undefined variable", expr.span) - self._diagnostic_context.render() + self.report_error("undefined variable", expr.span) return self.scope[var_name] else: - self._diagnostic_context.emit("error", "unsupported expression", expr.span) - self._diagnostic_context.render() + self.report_error("unsupported expression", expr.span) def transform_block(self, block: ast.Block) -> rx.SeqExpr: # a block of statements needs to be converted to a SeqExpr of binding blocks @@ -614,12 +565,10 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr: ret_stmt = block.stmts[-1] if not isinstance(ret_stmt, ast.Return): - self._diagnostic_context.emit( - "error", + self.report_error( "block must end with a returned expression", ret_stmt.span, ) - self._diagnostic_context.render() ret_expr = self.transform_stmt(ret_stmt) return rx.SeqExpr(blocks, ret_expr, self.to_tvm_span(block.span)) diff --git a/python/tvm/relax/parser_tests/__init__.py b/python/tvm/relax/parser_tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/tvm/relax/parser_tests/failing.py b/python/tvm/relax/parser_tests/failing.py deleted file mode 100644 index 68ca0347f6..0000000000 --- a/python/tvm/relax/parser_tests/failing.py +++ /dev/null @@ -1,4 +0,0 @@ -from .. import relax -@relax -def my_test(x : Tensor[_, _]): - return None \ No newline at end of file diff --git a/python/tvm/relax/parser_tests/passing.py b/python/tvm/relax/parser_tests/passing.py deleted file mode 100644 index f5ac0a41e6..0000000000 --- a/python/tvm/relax/parser_tests/passing.py +++ /dev/null @@ -1,38 +0,0 @@ -from tvm.relax.parser import relax - -# Type annotation tests -""" -@relax -def my_test(x : Tensor[_, _]): - return None - -@relax -def my_test(x: Tensor[(a, b, c), "int32"]): - return None - -@relax -def my_test(x: Tensor[(1, 2, 3), _]): - return None - -@relax -def my_test(x: Tensor[_, "int32"]): - return None -""" -# Builtin functions - -@relax -def my_test(x: Tensor[_, "float32"]): - match_shape(x.shape, (1, 2, 3)) - - -# These should pass in the future but don't right now -""" -@relax -def my_test(x: Tensor[_, "float32"]): - match_shape(x.shape, (1, 2, 3)) - - -@relax -def my_test(x : Tensor[(1, 2, 3), "int32"], y: Tensor[_, _]): - return call_packed("my_func", x, y) -""" \ No newline at end of file diff --git a/tests/python/relax/parser.py b/tests/python/relax/parser.py index 3469dc1c41..44a5f3e3f6 100644 --- a/tests/python/relax/parser.py +++ b/tests/python/relax/parser.py @@ -229,21 +229,18 @@ def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]): check_shape(tup.fields[1], (32,)) -# NOTE: this test requires patching synr to support local function definitions. -# it's an easy change (just two lines), but may break other users of synr -# (e.g. tvmscript). should investigate. def test_local_func(): @rx.script def foo(x: Tensor[_, _]): def bar(y: Tensor[_, _]): return y - z = bar(x) - return z + y = bar(x) # tests local function variable scoping + return y f = rx_func(foo) - bar_bind, z_bind = f.body.blocks[0].bindings + bar_bind, y_bind = f.body.blocks[0].bindings bar, bar_fn = bar_bind.var, bar_bind.value - bar_x = z_bind.value + bar_x = y_bind.value assert isinstance(bar_fn, rx.Function) assert bar_fn.body.body == bar_fn.params[0] @@ -263,16 +260,36 @@ def foo(x: Tensor[_, _]): return t f = rx_func(foo) + assert len(f.body.blocks) == 2 df_block = f.body.blocks[0] y_bind, z_bind, w_bind = df_block.bindings + (t_bind,) = f.body.blocks[1].bindings + x = f.params[0] + y, z, w, t = map(lambda b: b.var, [y_bind, z_bind, w_bind, t_bind]) - assert isinstance(y_bind.var, rx.Var) - assert isinstance(z_bind.var, rx.DataflowVar) - assert isinstance(w_bind.var, rx.Var) + assert isinstance(y, rx.Var) + assert isinstance(z, rx.DataflowVar) + assert isinstance(w, rx.Var) - # TODO: check correctness + check_call(y_bind.value, "add", [x, x]) + check_call(z_bind.value, "multiply", [y, x]) + check_call(w_bind.value, "subtract", [z, x]) + check_call(t_bind.value, "divide", [y, w]) - # import pdb; pdb.set_trace() + assert f.body.body == t + + +def test_dataflow_match_shape(): + @rx.script + def foo(x: Tensor[_, _]): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + relax.match_shape((n, m), z.shape) + w: Tensor[(n, m), _] = subtract(z, x) + return y, w + t: Tensor[(n, m), _] = divide(y, w) + return t @pytest.mark.xfail From f9b0078563bfaf1295a29ad668d2747de189597f Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 3 Sep 2021 13:10:22 -0700 Subject: [PATCH 16/20] support call_packed parsing to ExternFunc call --- python/tvm/relax/parser.py | 69 +++++++++++++++++++++++++----------- tests/python/relax/parser.py | 19 ++++++++++ 2 files changed, 68 insertions(+), 20 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 5aab17abb9..bceec4fb69 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -3,6 +3,7 @@ import inspect from typing import TypeVar, Generic, Union, Dict, List, Tuple, Optional from io import StringIO +from enum import Enum import tvm import tvm.script @@ -38,6 +39,13 @@ def _tir_from_synr(synr_ast: ast.Node, diag_ctx: tvm.script.diagnostics.TVMDiagn return parser.do_transform(synr_ast, diag_ctx) +# NOTE: call_dps is an actual registered operator +class SpecialOp(Enum): + MATCH_SHAPE = "relax.match_shape" + CALL_PACKED = "relax.call_packed" + DATAFLOW = "relax.dataflow" + + class RelaxTransformer(Transformer): def __init__(self, definition_scope): super().__init__() @@ -300,10 +308,10 @@ def parse_binding(self, stmt: ast.Stmt, is_dataflow=False): def parse_shape_binding(self, stmt: ast.UnassignedCall): call: synr.ast.Call = stmt.call op = self.transform_expr(call.func_name) - if op != relay.op.get("relax.match_shape"): + if op != SpecialOp.MATCH_SHAPE: self.report_error("the results of calls must be bound or used", stmt.span) if len(stmt.call.params) != 2: - self.report_error("relax.match_shape takes exactly two arguments", stmt.span) + self.report_error(op.value + " takes exactly two arguments", stmt.span) lhs = stmt.call.params[0] rhs = stmt.call.params[1] @@ -311,7 +319,7 @@ def parse_shape_binding(self, stmt: ast.UnassignedCall): rhs_expr = self.transform_expr(rhs) if not isinstance(lhs, ast.Tuple): self.report_error( - "the pattern (lhs) of relax.match_shape must be a tuple", + "the pattern (lhs) of " + op.value + " must be a tuple", lhs.span, ) lhs_expr = self.parse_shape(lhs, allow_intro=True) @@ -323,9 +331,9 @@ def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False): "the left hand side of a binding must be a variable", stmt.lhs.span, ) - # TODO: figure out proper way of doing this rhs = self.transform_expr(stmt.rhs) - if isinstance(rhs, relay.Call) and rhs.op == relay.op.get("relax.call_packed"): + # an ExternFunc call comes from call_packed + if isinstance(rhs, relay.Call) and isinstance(rhs.op, rx.ExternFunc): allow_intro = True else: allow_intro = False @@ -415,20 +423,20 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl # TODO: perhaps this ought to be more general - if op != relay.op.get("relax.dataflow"): + if op == SpecialOp.DATAFLOW: + if len(call.params) > 0: + self.report_error( + "dataflow block constructor takes no arguments", + call.params[0].span, + ) + if len(stmt.lhs) > 0: + self.report_error( + "dataflow blocks don't bind any patterns", + stmt.lhs[0].span, + ) + return self.parse_dataflow(stmt.body) + else: self.report_error("unsupported with block type", call.span) - if len(call.params) > 0: - self.report_error( - "dataflow block constructor takes no arguments", - call.params[0].span, - ) - if len(stmt.lhs) > 0: - self.report_error( - "dataflow blocks don't bind any patterns", - stmt.lhs[0].span, - ) - - return self.parse_dataflow(stmt.body) elif isinstance(stmt, ast.Function): func = self.transform_function(stmt) @@ -523,11 +531,32 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: self.report_error("unsupported field access", expr.span) op_name.append(attr.id.name) op_name = ".".join(reversed(op_name)) - return relay.op.get(op_name) # TODO: maybe diagnostics here in case this fails? + # NOTE: at least for now, all special operators are namespaced + try: + return SpecialOp(op_name) + except ValueError: + return relay.op.get(op_name) # TODO: maybe diagnostics here in case this fails? if isinstance(expr, ast.Call): op = self.transform_expr(expr.func_name) - args = [self.transform_expr(arg) for arg in expr.params] + if op == SpecialOp.CALL_PACKED: + if len(expr.params) != 2: + self.report_error( + op.value + " takes an extern function name and a tuple of arguments", + expr.span, + ) + extern_func = expr.params[0] + if not ( + isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str) + ): + self.report_error( + "the first argument of " + op.value + " must be the extern function name", + extern_func.span, + ) + op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span)) + args = [self.transform_expr(expr.params[1])] + else: + args = [self.transform_expr(arg) for arg in expr.params] return relay.Call(op, args, span=self.to_tvm_span(expr.span)) elif isinstance(expr, ast.Tuple): diff --git a/tests/python/relax/parser.py b/tests/python/relax/parser.py index 44a5f3e3f6..935a19bc59 100644 --- a/tests/python/relax/parser.py +++ b/tests/python/relax/parser.py @@ -252,6 +252,8 @@ def test_dataflow(): @rx.script def foo(x: Tensor[_, _]): with relax.dataflow(): + # TODO: parse this + # nonlocal y, w y = add(x, x) z = multiply(y, x) w = subtract(z, x) @@ -354,3 +356,20 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: z = relax.call_dps(my_matmul, x, y) return z + + +def test_call_packed(): + @rx.script + def foo(x: Tensor[(3, 4), "float32"]): + # test that we can intro dim vars + z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x)) + return z + + f = rx_func(foo) + x = f.params[0] + (z_bind,) = f.body.blocks[0].bindings + check_tensor_var(z_bind.var, ("n", "m"), "float32") + + assert isinstance(z_bind.value.op, rx.ExternFunc) + assert z_bind.value.op.global_symbol == "contrib.my_matmul" + assert structural_equal(z_bind.value.args, [rx.Tuple([x, x])]) From c9832c1e357abb32810dc8c3b312f687233b5107 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 3 Sep 2021 13:25:55 -0700 Subject: [PATCH 17/20] remove stub ops --- src/relax/op/op.cc | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 316fa7d5a2..242bb3249e 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -51,17 +51,5 @@ Expr MakeShapeOf(Expr expr) { TVM_REGISTER_GLOBAL("relax.op.shape_of") .set_body_typed(MakeShapeOf); -RELAY_REGISTER_OP("relax.call_packed") - .set_num_inputs(2) - .add_argument("func", "Expr", "The extern packed function.") - .add_argument("args", "Tuple", "The input arguments."); - -RELAY_REGISTER_OP("relax.match_shape") - .set_num_inputs(2) - .add_argument("pattern", "Array", "The matched shape pattern.") - .add_argument("value", "Expr", "The shape expression to match on."); - -RELAY_REGISTER_OP("relax.dataflow").set_num_inputs(0); - } // namespace relax } // namespace tvm From 7e5dc4fd6121b193924a9457e7ed6309bbfb4747 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 3 Sep 2021 13:54:17 -0700 Subject: [PATCH 18/20] improve docstrings --- python/tvm/relax/parser.py | 218 +++++++++++++++++++++++++++++++++---- 1 file changed, 194 insertions(+), 24 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index bceec4fb69..dbc1b66a95 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -23,18 +23,20 @@ import tvm.relax as rx -# TODO: Replace with a real pretty print method once we have the real AST +# TODO(@altanh): Replace with a real pretty print method once we have the real AST def pretty_print(f): print(f) -def is_registered(op_name, op_set=None): +def is_registered(op_name: str, op_set=None) -> bool: if op_set is None: op_set = tvm.ir._ffi_api.ListOpNames() return op_name in op_set -def _tir_from_synr(synr_ast: ast.Node, diag_ctx: tvm.script.diagnostics.TVMDiagnosticCtx): +def _tir_from_synr( + synr_ast: ast.Node, diag_ctx: tvm.script.diagnostics.TVMDiagnosticCtx +) -> tir.PrimFunc: parser = tvm.script.parser.TVMScriptParser(synr_ast.span.start_line) return parser.do_transform(synr_ast, diag_ctx) @@ -55,13 +57,42 @@ def __init__(self, definition_scope): self._registered_ops = set(tvm.ir._ffi_api.ListOpNames()) # cached def to_tvm_span(self, span: ast.Span) -> tvm.ir.Span: + """Helper method for converting synr span to TVM span. + + Parameters + ---------- + span : ast.Span + The synr span + + Returns + ------- + tvm.ir.Span + The corresponding TVM span + """ return self._diagnostic_context.to_tvm_span(self._diagnostic_context.source_name, span) def report_error(self, msg: str, span: ast.Span): + """Helper method for emitting and immediately rendering an error. + + Parameters + ---------- + msg : str + The error message + span : ast.Span + The span to report the error at + """ self._diagnostic_context.emit("error", msg, span) self._diagnostic_context.render() def new_scope(self): + """Helper method for creating a new scope context object + + Returns + ------- + _Scope + An internal scope context object used in a with block to create a new scope + """ + class _Scope: def __init__(self, transformer: "RelaxTransformer"): self.transformer = transformer @@ -77,6 +108,13 @@ def __exit__(self, *exc): @property def scope(self): + """Returns the current definition scope. + + Returns + ------- + Dict[str, Union[rx.Var, tir.Var]] + The scope of all currently defined variables (Relax and TIR). + """ return self._scopes[-1] def decl_var( @@ -93,19 +131,20 @@ def decl_var( ---------- name : str The name of the variable - type_annotation : Optional[rxType] + type_annotation : Optional[rx.Type] The type annotation - shape : Optional[rxExpr] + shape : Optional[rx.Expr] The shape annotation span : ast.Span The span where the variable is declared Returns ------- - rxVar + rx.Var The declared variable """ if name in self.scope: + # TODO(@altanh): maybe emit an error at the declaration site and report it together self.report_error("variable has already been declared in the current scope", span) if is_dataflow: var = rx.DataflowVar(name, shape, type_annotation, self.to_tvm_span(span)) @@ -126,7 +165,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E Returns ------- - Tuple[rxType, rxExpr]: + Tuple[rx.Type, rx.Expr]: The corresponding Relax type and shape expression """ if ty is None: @@ -148,6 +187,8 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E # annotation with type arguments/shape annotation if isinstance(ty, ast.TypeApply): if ty.id.name == "Tensor": + # TODO(@altanh): forgetting dtype like "Tensor[(n, m)]" ends up getting parsed as + # Tensor[n, m] which makes correct errors difficult here... if len(ty.params) != 2: self.report_error( "Tensor type annotations must have 2 fields (shape and dtype)", @@ -160,13 +201,13 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E # parse the shape annotation if isinstance(shape_annotation, ast.TypeVar): if shape_annotation.id.name != "_": - # TODO: handle variable annotations, e.g. x: Tensor[my_shape, _] + # TODO(@altanh): handle variable annotations, e.g. x: Tensor[my_shape, _] self.report_error( "variable Tensor shape annotations not yet supported", shape_annotation.span, ) else: - # FIXME: use a special node for unknown shape vs no shape? + # FIXME(@altanh): use a special node for unknown shape vs no shape? pass # shape = None elif isinstance(shape_annotation, ast.TypeTuple): shape = rx.ShapeExpr( @@ -183,7 +224,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E if isinstance(dtype_annotation, ast.TypeVar) and dtype_annotation.id.name == "_": pass # dtype = None elif isinstance(dtype_annotation, ast.TypeConstant): - dtype = dtype_annotation.value # TODO: parse to TVM DType? + dtype = dtype_annotation.value else: self.report_error( "Tensor dtype annotations must be concrete or erased", @@ -200,7 +241,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E field_types.append(fty) field_shapes.append(fsh) return (relay.TupleType(field_types, span), None) - # TODO: other types with args, e.g. Ref[T], func types + # TODO(@altanh): other types with args, e.g. Ref[T], func types self.report_error("invalid type", span) def parse_shape( @@ -208,11 +249,11 @@ def parse_shape( shape_annotation: Union[ast.TypeTuple, ast.Tuple], allow_intro: bool, ) -> List[tir.PrimExpr]: - """Parses the given shape annotation to a list of PrimExprs + """Parses the given shape annotation to a list of PrimExprs. Parameters ---------- - shape_annotation : ast.TypeTuple + shape_annotation : Union[ast.TypeTuple, ast.Tuple] The shape annotation in synr allow_intro : bool Whether or not the annotation can bind previously free variables @@ -225,7 +266,7 @@ def parse_shape( return [self.parse_primexpr(field, allow_intro) for field in shape_annotation.values] def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: - """Parses the given expression to a PrimExpr + """Parses the given expression to a PrimExpr. Parameters ---------- @@ -264,19 +305,45 @@ def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: self.report_error("only integer constants are supported", expr.span) return tir.const(expr.value, "int32", self.to_tvm_span(expr.span)) else: - # TODO: parse (simple) PrimExprs + # TODO(@altanh): parse (simple) PrimExprs self.report_error( "only dimension variable expressions are currently supported", expr.span, ) def transform_module(self, mod: ast.Module) -> IRModule: + """Transforms the given synr Module to a Relax IRModule. + + Parameters + ---------- + mod : ast.Module + The input synr Module + + Returns + ------- + IRModule + The parsed Relax IRModule + """ for func_name in mod.funcs: func = mod.funcs[func_name] self.module[func_name] = self.transform_function(func, is_global=True) return self.module - def transform_function(self, func: ast.Function, is_global=False) -> rx.Function: + def transform_function(self, func: ast.Function, is_global: bool = False) -> rx.Function: + """Transforms the given synr Function to a Relax Function. + + Parameters + ---------- + func : ast.Function + The input synr Function + is_global : bool, optional + Whether or not the input function is global/module-level, by default False + + Returns + ------- + rx.Function + The parsed Relax Function + """ if ( len(func.decorators) == 1 and isinstance(func.decorators[0], ast.Var) @@ -298,14 +365,40 @@ def transform_function(self, func: ast.Function, is_global=False) -> rx.Function params, new_body, ret_type, name=func_name, span=self.to_tvm_span(func.span) ) - def parse_binding(self, stmt: ast.Stmt, is_dataflow=False): + def parse_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> rx.Binding: + """Parses the input synr statement to the corresponding Relax binding. + + Parameters + ---------- + stmt : ast.Stmt + The input synr statement (either an assignment or a unassigned call) + is_dataflow : bool, optional + Whether or not the binding is in a dataflow block, by default False + + Returns + ------- + rx.Binding + The parsed Relax binding + """ assert isinstance(stmt, (ast.Assign, ast.UnassignedCall)) if isinstance(stmt, ast.Assign): return self.parse_var_binding(stmt, is_dataflow=is_dataflow) else: return self.parse_shape_binding(stmt) - def parse_shape_binding(self, stmt: ast.UnassignedCall): + def parse_shape_binding(self, stmt: ast.UnassignedCall) -> rx.MatchShape: + """Parses the input synr statement to a Relax shape binding. + + Parameters + ---------- + stmt : ast.UnassignedCall + The input synr statement + + Returns + ------- + rx.MatchShape + The parsed Relax shape binding + """ call: synr.ast.Call = stmt.call op = self.transform_expr(call.func_name) if op != SpecialOp.MATCH_SHAPE: @@ -325,7 +418,21 @@ def parse_shape_binding(self, stmt: ast.UnassignedCall): lhs_expr = self.parse_shape(lhs, allow_intro=True) return rx.MatchShape(lhs_expr, rhs_expr, self.to_tvm_span(stmt.span)) - def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False): + def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> rx.VarBinding: + """Parses the input synr assignment to a Relax variable binding. + + Parameters + ---------- + stmt : ast.Assign + The input synr assignment + is_dataflow : bool, optional + Whether or not the binding is in a dataflow block, by default False + + Returns + ------- + rx.VarBinding + The prased Relax variable binding + """ if not isinstance(stmt.lhs, ast.Var): self.report_error( "the left hand side of a binding must be a variable", @@ -350,6 +457,18 @@ def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False): # - UnassignedCall: match_shape # - With: rx.dataflow def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.DataflowBlock]: + """Transforms the given synr statement to the corresponding Relax node. + + Parameters + ---------- + stmt : ast.Stmt + The input synr statement + + Returns + ------- + Union[rx.Expr, rx.Binding, rx.DataflowBlock] + The parsed Relax node + """ if isinstance(stmt, ast.Assign): # dataflow bindings are handled separately in parse_dataflow return self.parse_binding(stmt) @@ -401,9 +520,9 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl true_branch = self.transform_block(true_block) with self.new_scope(): false_branch = self.transform_block(false_block) - # TODO: the spans here are all sorts of messed up, not sure how to fix + # TODO(@altanh): the spans here are all sorts of messed up, not sure how to fix ite_expr = relay.If(cond, true_branch, false_branch, self.to_tvm_span(stmt.span)) - # TODO: type and shape of return var + # TODO(@altanh): type and shape of return var var = self.decl_var(var_name, None, None, union_span) return rx.VarBinding(var, ite_expr, self.to_tvm_span(union_span)) @@ -421,7 +540,7 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl call = stmt.rhs op = self.transform_expr(call.func_name) - # TODO: perhaps this ought to be more general + # TODO(@altanh): perhaps this ought to be more general if op == SpecialOp.DATAFLOW: if len(call.params) > 0: @@ -450,6 +569,18 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl ) def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: + """Parses the input synr block to a Relax dataflow block. + + Parameters + ---------- + block : ast.Block + The input synr block + + Returns + ------- + rx.DataflowBlock + The parsed Relax dataflow block + """ assert len(block.stmts) > 0, "should never have an empty dataflow block" bindings = [] output_vars = [] @@ -516,6 +647,18 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: # - Tuple # - Var def transform_expr(self, expr: ast.Expr) -> rx.Expr: + """Transforms the input synr expression to a Relax expression. + + Parameters + ---------- + expr : ast.Expr + The input synr + + Returns + ------- + rx.Expr + The corresponding Relax expression + """ if isinstance(expr, ast.Attr): if expr.field.name == "shape": obj = self.transform_expr(expr.object) @@ -535,7 +678,8 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: try: return SpecialOp(op_name) except ValueError: - return relay.op.get(op_name) # TODO: maybe diagnostics here in case this fails? + # TODO(@altanh): maybe diagnostics here in case this fails? + return relay.op.get(op_name) if isinstance(expr, ast.Call): op = self.transform_expr(expr.func_name) @@ -575,6 +719,19 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: self.report_error("unsupported expression", expr.span) def transform_block(self, block: ast.Block) -> rx.SeqExpr: + """Transforms the given synr block to a Relax SeqExpr (sequence of Blocks with a final + expression). + + Parameters + ---------- + block : ast.Block + The input synr block + + Returns + ------- + rx.SeqExpr + The parsed SeqExpr + """ # a block of statements needs to be converted to a SeqExpr of binding blocks blocks = [] current_block = [] @@ -642,6 +799,7 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr: # self.tvm_diag_ctx.render() +# TODO(@altanh, @jroesch): revisit this? class RelaxDecoratedFn: def __init__(self, fn_name, relax_module, diag_ctx): self.fn_name = fn_name @@ -658,7 +816,19 @@ def __call__(self, *args): # return out -def script(f): +def script(f) -> RelaxDecoratedFn: + """Parses the decorated Relax function (in Relax IR) to the Relax AST + + Parameters + ---------- + f : function + The function to be parsed, written in the Relax IR + + Returns + ------- + RelaxDecoratedFn + The parsed Relax function + """ # ir_module = tvm.IRModule({}) # diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() From 1809bcf94c64bddaf123675c49b79a70e1eca91c Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 7 Sep 2021 10:57:20 -0700 Subject: [PATCH 19/20] address nits --- include/tvm/relax/expr.h | 2 +- python/tvm/relax/parser.py | 36 +++++++++++++++++++----------------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 714acd0d45..5c05e7c9aa 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -78,7 +78,7 @@ class ShapeExpr : public Expr { /*! \brief The variable class for all Relax bindings. */ class VarNode : public ExprNode { public: - /*! \brief The identifier of the variable, is used for comparing stable equality across + /*! \brief The identifier of the variable, which is used for comparing stable equality across * transformations. */ Id vid; /*! \brief The type annotation, used by binding sites and parameter declarations. */ diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index dbc1b66a95..11147d8e55 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -140,7 +140,7 @@ def decl_var( Returns ------- - rx.Var + Union[rx.Var, rx.DataflowVar] The declared variable """ if name in self.scope: @@ -153,14 +153,14 @@ def decl_var( self.scope[name] = var return var - def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.Expr]: + def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[rx.Type, rx.Expr]: """Transforms the given synr type annotation to a Relax type and shape expression. Parameters ---------- ty : ast.Type The synr type - allow_intro : bool + bind_free_vars : bool Whether or not the shape annotation can introduce new dimension variables Returns @@ -211,7 +211,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E pass # shape = None elif isinstance(shape_annotation, ast.TypeTuple): shape = rx.ShapeExpr( - self.parse_shape(shape_annotation, allow_intro), + self.parse_shape(shape_annotation, bind_free_vars), span=self.to_tvm_span(shape_annotation.span), ) else: @@ -237,7 +237,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E field_types = [] field_shapes = [] for field in ty.params: - fty, fsh = self.transform_type(field, allow_intro=False) + fty, fsh = self.transform_type(field, bind_free_vars=False) field_types.append(fty) field_shapes.append(fsh) return (relay.TupleType(field_types, span), None) @@ -247,7 +247,7 @@ def transform_type(self, ty: ast.Type, allow_intro: bool) -> Tuple[rx.Type, rx.E def parse_shape( self, shape_annotation: Union[ast.TypeTuple, ast.Tuple], - allow_intro: bool, + bind_free_vars: bool, ) -> List[tir.PrimExpr]: """Parses the given shape annotation to a list of PrimExprs. @@ -255,7 +255,7 @@ def parse_shape( ---------- shape_annotation : Union[ast.TypeTuple, ast.Tuple] The shape annotation in synr - allow_intro : bool + bind_free_vars : bool Whether or not the annotation can bind previously free variables Returns @@ -263,16 +263,16 @@ def parse_shape( List[tir.PrimExpr] The parsed shape as a list of PrimExprs """ - return [self.parse_primexpr(field, allow_intro) for field in shape_annotation.values] + return [self.parse_primexpr(field, bind_free_vars) for field in shape_annotation.values] - def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: + def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr: """Parses the given expression to a PrimExpr. Parameters ---------- expr : ast.Expr The input expression - allow_intro : bool + bind_free_vars : bool Whether or not the expression can bind previously free variables Returns @@ -285,12 +285,14 @@ def parse_primexpr(self, expr: ast.Expr, allow_intro: bool) -> tir.PrimExpr: if var_name in self.scope: var = self.scope[var_name] if not isinstance(var, tir.Var): + # TODO(@altanh): we may wish to relax this in the future to support constructing + # shapes from Dim-typed Relax expressions self.report_error( "non-dimension variables cannot appear in dimension expressions", expr.span, ) return var - elif allow_intro: + elif bind_free_vars: # introduce TIR variable to scope, e.g. for func params or rx.call_packed var = tir.Var(var_name, "int32", self.to_tvm_span(expr.span)) self.scope[var_name] = var @@ -354,11 +356,11 @@ def transform_function(self, func: ast.Function, is_global: bool = False) -> rx. with self.new_scope(): params = [] for param in func.params: - ty, shape = self.transform_type(param.ty, allow_intro=True) + ty, shape = self.transform_type(param.ty, bind_free_vars=True) param = self.decl_var(param.name, ty, shape, param.span) params.append(param) new_body = self.transform_block(func.body) - ret_type, _ = self.transform_type(func.ret_type, allow_intro=False) + ret_type, _ = self.transform_type(func.ret_type, bind_free_vars=False) func_name = rx.GlobalVar(func.name) if is_global else None return rx.Function( @@ -415,7 +417,7 @@ def parse_shape_binding(self, stmt: ast.UnassignedCall) -> rx.MatchShape: "the pattern (lhs) of " + op.value + " must be a tuple", lhs.span, ) - lhs_expr = self.parse_shape(lhs, allow_intro=True) + lhs_expr = self.parse_shape(lhs, bind_free_vars=True) return rx.MatchShape(lhs_expr, rhs_expr, self.to_tvm_span(stmt.span)) def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> rx.VarBinding: @@ -441,10 +443,10 @@ def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> rx.VarBindin rhs = self.transform_expr(stmt.rhs) # an ExternFunc call comes from call_packed if isinstance(rhs, relay.Call) and isinstance(rhs.op, rx.ExternFunc): - allow_intro = True + bind_free_vars = True else: - allow_intro = False - ty, shape = self.transform_type(stmt.ty, allow_intro) + bind_free_vars = False + ty, shape = self.transform_type(stmt.ty, bind_free_vars) lhs = self.decl_var(stmt.lhs.id.name, ty, shape, stmt.lhs.span, is_dataflow=is_dataflow) return rx.VarBinding(lhs, rhs, self.to_tvm_span(stmt.span)) From c7bf50f98df65f46ca5d84dbe39b4aabc30fd578 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 9 Sep 2021 14:53:24 -0700 Subject: [PATCH 20/20] support coercing tuples to ShapeExpr when possible for call_dps --- python/tvm/relax/parser.py | 26 +++++++++++++++++++++++--- tests/python/relax/parser.py | 19 +++++++++++++++++-- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 11147d8e55..62e587b708 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -618,8 +618,8 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: binding_stmt.span, ) is_match_shape = isinstance(binding_stmt, ast.UnassignedCall) - is_dataflow = ( - False if is_match_shape else (binding_stmt.lhs.id.name not in output_var_names) + is_dataflow = not is_match_shape and ( + binding_stmt.lhs.id.name not in output_var_names ) binding = self.parse_binding(binding_stmt, is_dataflow=is_dataflow) bindings.append(binding) @@ -666,7 +666,7 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: obj = self.transform_expr(expr.object) return relay.op.shape_of(obj) else: - # assume it's a hierarchical op identifier (e.g. nn.softmax, rx.call_dps) + # assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps) op_name = [] attr = expr while isinstance(attr, ast.Attr): @@ -703,10 +703,18 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: args = [self.transform_expr(expr.params[1])] else: args = [self.transform_expr(arg) for arg in expr.params] + # TODO(@altanh): should we check for correct arity here eagerly, or defer to a pass? return relay.Call(op, args, span=self.to_tvm_span(expr.span)) elif isinstance(expr, ast.Tuple): fields = [self.transform_expr(field) for field in expr.values] + + # TODO(@altanh): this check might be too weak; we really only accept integral PrimExprs + # (e.g. int constants, dim vars, and integer operations on these) + + # coerce to ShapeExpr when fields are all PrimExprs + if all([isinstance(f, tir.PrimExpr) for f in fields]): + return rx.ShapeExpr(fields, span=self.to_tvm_span(expr.span)) return relay.Tuple(fields, span=self.to_tvm_span(expr.span)) elif isinstance(expr, ast.Var): @@ -717,6 +725,18 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: self.report_error("undefined variable", expr.span) return self.scope[var_name] + elif isinstance(expr, ast.Constant): + # FIXME(@altanh): use internal representation that doesn't have precision limits here + if isinstance(expr.value, int): + return tir.IntImm("int32", expr.value, self.to_tvm_span(expr.span)) + elif isinstance(expr.value, float): + return tir.FloatImm("float32", expr.value, self.to_tvm_span(expr.span)) + else: + self.report_error( + "unsupported constant expression (we currently only support int and float)", + expr.span, + ) + else: self.report_error("unsupported expression", expr.span) diff --git a/tests/python/relax/parser.py b/tests/python/relax/parser.py index 935a19bc59..12ba128e25 100644 --- a/tests/python/relax/parser.py +++ b/tests/python/relax/parser.py @@ -234,6 +234,7 @@ def test_local_func(): def foo(x: Tensor[_, _]): def bar(y: Tensor[_, _]): return y + y = bar(x) # tests local function variable scoping return y @@ -342,7 +343,7 @@ def foo(x: Tensor[_, _]): def test_inline_tir(): @rx.script - def foo(x: Tensor[(128, 128), "float32"], y: Tensor[(128, 128), "float32"]): + def foo(x: Tensor[(B, 128), "float32"], y: Tensor[(128, 128), "float32"]): @tir def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) @@ -354,9 +355,23 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: C[vi, vj] = tir.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - z = relax.call_dps(my_matmul, x, y) + z = relax.call_dps((B, 128), my_matmul, (x, y)) return z + f = rx_func(foo) + x, y = f.params + B = x.shape_[0] + mm_bind, z_bind = f.body.blocks[0].bindings + + assert mm_bind.var.name_hint == "my_matmul" + assert isinstance(mm_bind.value, tir.PrimFunc) + + check_call( + z_bind.value, + "relax.call_dps", + [rx.ShapeExpr([B, tir.IntImm("int32", 128)]), mm_bind.var, rx.Tuple([x, y])], + ) + def test_call_packed(): @rx.script