From f0c7129b0b1a1b01849b265e80356f333cfd2237 Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Tue, 29 Oct 2019 21:51:20 -0700 Subject: [PATCH] [Relay] Add Python type functor and tests (#4209) * Add Python type functor and tests * Lint roller --- python/tvm/relay/__init__.py | 6 + python/tvm/relay/type_functor.py | 194 ++++++++++++++++++++++++ tests/python/relay/test_type_functor.py | 107 +++++++++++++ 3 files changed, 307 insertions(+) create mode 100644 python/tvm/relay/type_functor.py create mode 100644 tests/python/relay/test_type_functor.py diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index f05098bd0c8e..bd3f5bd1fb8d 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -23,6 +23,7 @@ from . import base from . import ty from . import expr +from . import type_functor from . import expr_functor from . import module from . import adt @@ -118,6 +119,11 @@ function_pass = transform.function_pass alpha_equal = analysis.alpha_equal +# TypeFunctor +TypeFunctor = type_functor.TypeFunctor +TypeVisitor = type_functor.TypeVisitor +TypeMutator = type_functor.TypeMutator + # ExprFunctor ExprFunctor = expr_functor.ExprFunctor ExprVisitor = expr_functor.ExprVisitor diff --git a/python/tvm/relay/type_functor.py b/python/tvm/relay/type_functor.py new file mode 100644 index 000000000000..1331058b37ca --- /dev/null +++ b/python/tvm/relay/type_functor.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The type functor of Relay.""" +from .ty import (TypeVar, IncompleteType, TensorType, FuncType, + TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall) +from .adt import TypeData + +class TypeFunctor: + """ + An abstract visitor defined over Type. + + Defines the default dispatch over types. + """ + def __init__(self): + # TODO(weberlo): make type vars hashable, so we can memoize + pass + + # pylint: disable=no-else-return + def visit(self, typ): + """Apply the visitor to a type.""" + if isinstance(typ, TypeVar): + return self.visit_type_var(typ) + elif isinstance(typ, IncompleteType): + return self.visit_incomplete_type(typ) + elif isinstance(typ, TensorType): + return self.visit_tensor_type(typ) + elif isinstance(typ, FuncType): + return self.visit_func_type(typ) + elif isinstance(typ, TupleType): + return self.visit_tuple_type(typ) + elif isinstance(typ, TypeRelation): + return self.visit_type_relation(typ) + elif isinstance(typ, RefType): + return self.visit_ref_type(typ) + elif isinstance(typ, GlobalTypeVar): + return self.visit_global_type_var(typ) + elif isinstance(typ, TypeCall): + return self.visit_type_call(typ) + elif isinstance(typ, TypeData): + return self.visit_type_data(typ) + else: + raise Exception('unhandled case: {0}'.format(type(typ))) + + def visit_type_var(self, _): + raise NotImplementedError() + + def visit_incomplete_type(self, _): + raise NotImplementedError() + + def visit_tensor_type(self, _): + raise NotImplementedError() + + def visit_func_type(self, _): + raise NotImplementedError() + + def visit_tuple_type(self, _): + raise NotImplementedError() + + def visit_type_relation(self, _): + raise NotImplementedError() + + def visit_ref_type(self, _): + raise NotImplementedError() + + def visit_global_type_var(self, _): + raise NotImplementedError() + + def visit_type_call(self, _): + raise NotImplementedError() + + def visit_type_data(self, _): + raise NotImplementedError() + + +class TypeVisitor(TypeFunctor): + """ + A visitor over Type. + + The default behavior recursively traverses the AST. + """ + def visit_type_var(self, tv): + pass + + def visit_incomplete_type(self, it): + pass + + def visit_tensor_type(self, tt): + pass + + def visit_func_type(self, ft): + for arg_type in ft.arg_types: + self.visit(arg_type) + self.visit(ft.ret_type) + for type_param in getattr(ft, 'type_params', []): + self.visit(type_param) + for type_constraint in getattr(ft, 'type_constraints', []): + self.visit(type_constraint) + + def visit_tuple_type(self, tt): + for field in tt.fields: + self.visit(field) + + def visit_type_relation(self, tr): + for arg in tr.args: + self.visit(arg) + + def visit_ref_type(self, rt): + self.visit(rt.value) + + def visit_global_type_var(self, gtv): + pass + + def visit_type_call(self, tc): + self.visit(tc.func) + for arg in tc.args: + self.visit(arg) + + def visit_type_data(self, td): + self.visit(td.header) + for type_var in td.type_vars: + self.visit(type_var) + + +class TypeMutator(TypeFunctor): + """ + A functional visitor over Type. + + The default behavior recursively traverses the AST + and reconstructs the AST. + """ + def visit_type_var(self, tv): + return TypeVar(tv.var.name, tv.kind) + + def visit_incomplete_type(self, it): + return IncompleteType(it.kind) + + def visit_tensor_type(self, tt): + return TensorType(tt.shape, tt.dtype) + + def visit_func_type(self, ft): + new_arg_types = [self.visit(arg_type) for arg_type in ft.arg_types] + new_ret_type = self.visit(ft.ret_type) + new_type_params = [ + self.visit(type_param) + for type_param in getattr(ft, 'type_params', [])] + new_type_constraints = [ + self.visit(type_constraint) + for type_constraint in getattr(ft, 'type_constraints', [])] + return FuncType( + new_arg_types, + new_ret_type, + new_type_params, + new_type_constraints) + + def visit_tuple_type(self, tt): + return TupleType([self.visit(field) for field in tt.fields]) + + def visit_type_relation(self, tr): + return TypeRelation( + tr.func, + [self.visit(arg) for arg in tr.args], + tr.num_inputs, + tr.attrs) + + def visit_ref_type(self, rt): + return RefType(self.visit(rt.value)) + + def visit_global_type_var(self, gtv): + return GlobalTypeVar(gtv.var.name, gtv.kind) + + def visit_type_call(self, tc): + return TypeCall( + self.visit(tc.func), + [self.visit(arg) for arg in tc.args]) + + def visit_type_data(self, td): + return TypeData( + self.visit(td.header), + [self.visit(type_var) for type_var in td.type_vars], + td.constructors) diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py new file mode 100644 index 000000000000..d09a8938bb54 --- /dev/null +++ b/tests/python/relay/test_type_functor.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor +from tvm.relay.analysis import assert_graph_equal +from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType, + TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall) +from tvm.relay.adt import TypeData + +def check_visit(typ): + try: + ef = TypeFunctor() + ef.visit(typ) + assert False + except NotImplementedError: + pass + + ev = TypeVisitor() + ev.visit(typ) + + assert_graph_equal(TypeMutator().visit(typ), typ) + + +def test_type_var(): + tv = TypeVar('a') + check_visit(tv) + + +def test_incomplete_type(): + it = IncompleteType() + check_visit(it) + + +def test_tensor_type(): + tt = TensorType([]) + check_visit(tt) + + +def test_func_type(): + tv = TypeVar('tv') + tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') + ft = FuncType([tt], tt, type_params=[tv]) + check_visit(ft) + + +def test_tuple_type(): + tt = TupleType([TupleType([])]) + check_visit(tt) + + +def test_type_relation(): + func = tvm.get_env_func('tvm.relay.type_relation.Broadcast') + attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4)) + tp = TypeVar('tp') + tf = FuncType([], TupleType([]), [], []) + tt = TensorType([1, 2, 3], 'float32') + tr = TypeRelation(func, [tp, tf, tt], 2, attrs) + + check_visit(tr) + + +def test_ref_type(): + rt = RefType(TupleType([])) + check_visit(rt) + + +def test_global_type_var(): + gtv = GlobalTypeVar('gtv') + check_visit(gtv) + + +def test_type_call(): + tc = TypeCall(GlobalTypeVar('tf'), [TupleType([])]) + check_visit(tc) + + +def test_type_data(): + td = TypeData(GlobalTypeVar('td'), [TypeVar('tv')], []) + check_visit(td) + + +if __name__ == "__main__": + test_type_var() + test_incomplete_type() + test_tensor_type() + test_func_type() + test_tuple_type() + test_type_relation() + test_ref_type() + test_global_type_var() + test_type_call() + test_type_data()