From 9b101751cc3214b07799f9ef83258d38aefa4960 Mon Sep 17 00:00:00 2001 From: pranavrajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Tue, 27 Jul 2021 14:18:23 -0700 Subject: [PATCH] [mypyc] Allow registering implementations for singledispatch functions in different files (#10880) * Add test for singledispatch across multiple files Add a test to make sure that we handle registering a singledispatch function in a different module than the main function correctly. This test uses a pattern similar to the code in https://github.com/mypyc/mypyc/issues/802#issuecomment-883780584 to make sure that we don't accidentally end up dynamically registering functions in other modules and just relying on the normal singledispatch machinery to check the dynamically registered functions. * Move info about a register implementation to type alias Add a type alias for the type that we use for information about a register implementation (which is a tuple of the dispatch type's TypeInfo and the function's FuncDef). * Run register-finding pass over entire SCC Instead of trying to look for singledispatch register implementations in one file at a time and only using the register implementations found in the same file as the singledispatch main function when generating code, look for register implementations across the entire SCC once, and then use that list when compiling every module. That change necessitates adding an extra argument for the PreBuildVisitor constructor so that we can remove any register calls that we found (so we can avoid register trying to access `__annotations__` on a builtin function) and adding an extra argument to IRBuilder so that we can access the list of singledispatch functions when building IR in any module. * Load functions from other modules correctly When we're generating calls to registered implementations in other modules, we need to load those functions by loading the module and accessing the correct attribute, instead of trying to load that function from the globals dict of the current module. That change also necessitates generating imports of any modules with registered implementations that weren't already imported so that we can load those modules in the dispatch function. --- mypyc/irbuild/builder.py | 8 +- mypyc/irbuild/function.py | 46 +++++++-- mypyc/irbuild/main.py | 8 +- mypyc/irbuild/prebuildvisitor.py | 100 ++++---------------- mypyc/irbuild/prepare.py | 119 +++++++++++++++++++++++- mypyc/test-data/run-singledispatch.test | 31 ++++++ 6 files changed, 216 insertions(+), 96 deletions(-) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 33682fe83240..8ad64f5559da 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -11,6 +11,7 @@ functions are transformed in mypyc.irbuild.function. """ +from mypyc.irbuild.prepare import RegisterImplInfo from typing import Callable, Dict, List, Tuple, Optional, Union, Sequence, Set, Any from typing_extensions import overload from mypy.backports import OrderedDict @@ -20,7 +21,7 @@ MypyFile, SymbolNode, Statement, OpExpr, IntExpr, NameExpr, LDEF, Var, UnaryExpr, CallExpr, IndexExpr, Expression, MemberExpr, RefExpr, Lvalue, TupleExpr, TypeInfo, Decorator, OverloadedFuncDef, StarExpr, ComparisonExpr, GDEF, - ArgKind, ARG_POS, ARG_NAMED, + ArgKind, ARG_POS, ARG_NAMED, FuncDef, ) from mypy.types import ( Type, Instance, TupleType, UninhabitedType, get_proper_type @@ -85,7 +86,8 @@ def __init__(self, mapper: Mapper, pbv: PreBuildVisitor, visitor: IRVisitor, - options: CompilerOptions) -> None: + options: CompilerOptions, + singledispatch_impls: Dict[FuncDef, List[RegisterImplInfo]]) -> None: self.builder = LowLevelIRBuilder(current_module, mapper, options) self.builders = [self.builder] self.symtables: List[OrderedDict[SymbolNode, SymbolTarget]] = [OrderedDict()] @@ -116,7 +118,7 @@ def __init__(self, self.encapsulating_funcs = pbv.encapsulating_funcs self.nested_fitems = pbv.nested_funcs.keys() self.fdefs_to_decorators = pbv.funcs_to_decorators - self.singledispatch_impls = pbv.singledispatch_impls + self.singledispatch_impls = singledispatch_impls self.visitor = visitor diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index c4e99646eb03..5a0ac89540bb 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -10,6 +10,7 @@ instance of the callable class. """ +from mypyc.irbuild.prepare import RegisterImplInfo from mypy.build import topsort from typing import ( NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, Iterator, @@ -806,6 +807,22 @@ def check_if_isinstance(builder: IRBuilder, obj: Value, typ: TypeInfo, line: int return builder.call_c(slow_isinstance_op, [obj, class_obj], line) +def load_func(builder: IRBuilder, func_name: str, fullname: Optional[str], line: int) -> Value: + if fullname is not None and not fullname.startswith(builder.current_module): + # we're calling a function in a different module + + # We can't use load_module_attr_by_fullname here because we need to load the function using + # func_name, not the name specified by fullname (which can be different for underscore + # function) + module = fullname.rsplit('.')[0] + loaded_module = builder.load_module(module) + + func = builder.py_get_attr(loaded_module, func_name, line) + else: + func = builder.load_global_str(func_name, line) + return func + + def generate_singledispatch_dispatch_function( builder: IRBuilder, main_singledispatch_function_name: str, @@ -816,15 +833,27 @@ def generate_singledispatch_dispatch_function( current_func_decl = builder.mapper.func_to_decl[fitem] arg_info = get_args(builder, current_func_decl.sig.args, line) - def gen_func_call_and_return(func_name: str) -> None: - func = builder.load_global_str(func_name, line) - # TODO: don't pass optional arguments if they weren't passed to this function + def gen_func_call_and_return(func_name: str, fullname: Optional[str] = None) -> None: + func = load_func(builder, func_name, fullname, line) ret_val = builder.builder.py_call( func, arg_info.args, line, arg_info.arg_kinds, arg_info.arg_names ) coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line) builder.nonlocal_control[-1].gen_return(builder, coerced, line) + # Add all necessary imports of other modules that have registered functions in other modules + # We're doing this in a separate pass over the implementations because that avoids the + # complexity and code size implications of generating this import before every call to a + # registered implementation that might need this imported + # TODO: avoid adding imports if we use native calls for all of the registered implementations + # in a module (once we add support for using native calls for registered implementations) + for _, impl in impls: + module_name = impl.fullname.rsplit('.')[0] + if module_name not in builder.imports: + # We need to generate an import here because the module needs to be imported before we + # try loading the function from it + builder.gen_import(module_name, line) + # Sort the list of implementations so that we check any subclasses before we check the classes # they inherit from, to better match singledispatch's behavior of going through the argument's # MRO, and using the first implementation it finds @@ -839,9 +868,14 @@ def gen_func_call_and_return(func_name: str) -> None: # The shortname of a function is just '{class}.{func_name}', and we don't support # singledispatchmethod yet, so that is always the same as the function name name = short_id_from_name(impl.name, impl.name, impl.line) - gen_func_call_and_return(name) + gen_func_call_and_return(name, fullname=impl.fullname) builder.activate_block(next_impl) + # We don't pass fullname here because getting the fullname of the main generated singledispatch + # function isn't easy, and we don't need it because the fullname is only needed for making sure + # we load the function from another module instead of the globals dict if it's defined in + # another module, which will never be true for the main singledispatch function (it's always + # generated in the same module as the dispatch function) gen_func_call_and_return(main_singledispatch_function_name) @@ -864,8 +898,8 @@ def gen_dispatch_func_ir( def sort_with_subclasses_first( - impls: List[Tuple[TypeInfo, FuncDef]] -) -> Iterator[Tuple[TypeInfo, FuncDef]]: + impls: List[RegisterImplInfo] +) -> Iterator[RegisterImplInfo]: # graph with edges pointing from every class to their subclasses graph = {typ: set(typ.mro[1:]) for typ, _ in impls} diff --git a/mypyc/irbuild/main.py b/mypyc/irbuild/main.py index e0dee7e67035..5d666384b0a9 100644 --- a/mypyc/irbuild/main.py +++ b/mypyc/irbuild/main.py @@ -36,7 +36,7 @@ def f(x: int) -> int: from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature from mypyc.irbuild.prebuildvisitor import PreBuildVisitor from mypyc.irbuild.vtable import compute_vtable -from mypyc.irbuild.prepare import build_type_map +from mypyc.irbuild.prepare import build_type_map, find_singledispatch_register_impls from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.visitor import IRBuilderVisitor from mypyc.irbuild.mapper import Mapper @@ -58,6 +58,7 @@ def build_ir(modules: List[MypyFile], """Build IR for a set of modules that have been type-checked by mypy.""" build_type_map(mapper, modules, graph, types, options, errors) + singledispatch_info = find_singledispatch_register_impls(modules, errors) result: ModuleIRs = OrderedDict() @@ -66,13 +67,14 @@ def build_ir(modules: List[MypyFile], for module in modules: # First pass to determine free symbols. - pbv = PreBuildVisitor(errors, module) + pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove) module.accept(pbv) # Construct and configure builder objects (cyclic runtime dependency). visitor = IRBuilderVisitor() builder = IRBuilder( - module.fullname, types, graph, errors, mapper, pbv, visitor, options + module.fullname, types, graph, errors, mapper, pbv, visitor, options, + singledispatch_info.singledispatch_impls, ) visitor.builder = builder diff --git a/mypyc/irbuild/prebuildvisitor.py b/mypyc/irbuild/prebuildvisitor.py index 8b0b0dd8073b..55928a57b839 100644 --- a/mypyc/irbuild/prebuildvisitor.py +++ b/mypyc/irbuild/prebuildvisitor.py @@ -1,11 +1,9 @@ from mypyc.errors import Errors -from mypy.types import Instance, get_proper_type -from typing import DefaultDict, Dict, List, NamedTuple, Set, Optional, Tuple -from collections import defaultdict +from typing import Dict, List, Set from mypy.nodes import ( Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr, - CallExpr, RefExpr, TypeInfo, MypyFile + MypyFile ) from mypy.traverser import TraverserVisitor @@ -24,7 +22,12 @@ class PreBuildVisitor(TraverserVisitor): The main IR build pass uses this information. """ - def __init__(self, errors: Errors, current_file: MypyFile) -> None: + def __init__( + self, + errors: Errors, + current_file: MypyFile, + decorators_to_remove: Dict[FuncDef, List[int]], + ) -> None: super().__init__() # Dict from a function to symbols defined directly in the # function that are used as non-local (free) variables within a @@ -54,9 +57,8 @@ def __init__(self, errors: Errors, current_file: MypyFile) -> None: # Map function to its non-special decorators. self.funcs_to_decorators: Dict[FuncDef, List[Expression]] = {} - # Map of main singledispatch function to list of registered implementations - self.singledispatch_impls: DefaultDict[ - FuncDef, List[Tuple[TypeInfo, FuncDef]]] = defaultdict(list) + # Map function to indices of decorators to remove + self.decorators_to_remove: Dict[FuncDef, List[int]] = decorators_to_remove self.errors: Errors = errors @@ -76,37 +78,15 @@ def visit_decorator(self, dec: Decorator) -> None: self.prop_setters.add(dec.func) else: decorators_to_store = dec.decorators.copy() - removed: List[int] = [] - # the index of the last non-register decorator before finding a register decorator - # when going through decorators from top to bottom - last_non_register: Optional[int] = None - for i, d in enumerate(decorators_to_store): - impl = get_singledispatch_register_call_info(d, dec.func) - if impl is not None: - self.singledispatch_impls[impl.singledispatch_func].append( - (impl.dispatch_type, dec.func)) - removed.append(i) - if last_non_register is not None: - # found a register decorator after a non-register decorator, which we - # don't support because we'd have to make a copy of the function before - # calling the decorator so that we can call it later, which complicates - # the implementation for something that is probably not commonly used - self.errors.error( - "Calling decorator after registering function not supported", - self.current_file.path, - decorators_to_store[last_non_register].line, - ) - else: - last_non_register = i - # calling register on a function that tries to dispatch based on type annotations - # raises a TypeError because compiled functions don't have an __annotations__ - # attribute - for i in reversed(removed): - del decorators_to_store[i] - # if the only decorators are register calls, we shouldn't treat this - # as a decorated function because there aren't any decorators to apply - if not decorators_to_store: - return + if dec.func in self.decorators_to_remove: + to_remove = self.decorators_to_remove[dec.func] + + for i in reversed(to_remove): + del decorators_to_store[i] + # if all of the decorators are removed, we shouldn't treat this as a decorated + # function because there aren't any decorators to apply + if not decorators_to_store: + return self.funcs_to_decorators[dec.func] = decorators_to_store super().visit_decorator(dec) @@ -186,45 +166,3 @@ def add_free_variable(self, symbol: SymbolNode) -> None: # and mark is as a non-local symbol within that function. func = self.symbols_to_funcs[symbol] self.free_variables.setdefault(func, set()).add(symbol) - - -class RegisteredImpl(NamedTuple): - singledispatch_func: FuncDef - dispatch_type: TypeInfo - - -def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef - ) -> Optional[RegisteredImpl]: - # @fun.register(complex) - # def g(arg): ... - if (isinstance(decorator, CallExpr) and len(decorator.args) == 1 - and isinstance(decorator.args[0], RefExpr)): - callee = decorator.callee - dispatch_type = decorator.args[0].node - if not isinstance(dispatch_type, TypeInfo): - return None - - if isinstance(callee, MemberExpr): - return registered_impl_from_possible_register_call(callee, dispatch_type) - # @fun.register - # def g(arg: int): ... - elif isinstance(decorator, MemberExpr): - # we don't know if this is a register call yet, so we can't be sure that the function - # actually has arguments - if not func.arguments: - return None - arg_type = get_proper_type(func.arguments[0].variable.type) - if not isinstance(arg_type, Instance): - return None - info = arg_type.type - return registered_impl_from_possible_register_call(decorator, info) - return None - - -def registered_impl_from_possible_register_call(expr: MemberExpr, dispatch_type: TypeInfo - ) -> Optional[RegisteredImpl]: - if expr.name == 'register' and isinstance(expr.expr, NameExpr): - node = expr.expr.node - if isinstance(node, Decorator): - return RegisteredImpl(node.func, dispatch_type) - return None diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index a4f2fc996893..338d0940901b 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -11,13 +11,14 @@ Also build a mapping from mypy TypeInfos to ClassIR objects. """ -from typing import List, Dict, Iterable, Optional, Union +from typing import List, Dict, Iterable, Optional, Union, DefaultDict, NamedTuple, Tuple from mypy.nodes import ( MypyFile, TypeInfo, FuncDef, ClassDef, Decorator, OverloadedFuncDef, MemberExpr, Var, - Expression, SymbolNode, ARG_STAR, ARG_STAR2 + Expression, SymbolNode, ARG_STAR, ARG_STAR2, CallExpr, Decorator, Expression, FuncDef, + MemberExpr, MypyFile, NameExpr, RefExpr, TypeInfo, ) -from mypy.types import Type +from mypy.types import Type, Instance, get_proper_type from mypy.build import Graph from mypyc.ir.ops import DeserMaps @@ -34,6 +35,8 @@ from mypyc.errors import Errors from mypyc.options import CompilerOptions from mypyc.crash import catch_errors +from collections import defaultdict +from mypy.traverser import TraverserVisitor def build_type_map(mapper: Mapper, @@ -303,3 +306,113 @@ def prepare_non_ext_class_def(path: str, module_name: str, cdef: ClassDef, ): errors.error( "Non-extension classes may not inherit from extension classes", path, cdef.line) + + +RegisterImplInfo = Tuple[TypeInfo, FuncDef] + + +class SingledispatchInfo(NamedTuple): + singledispatch_impls: Dict[FuncDef, List[RegisterImplInfo]] + decorators_to_remove: Dict[FuncDef, List[int]] + + +def find_singledispatch_register_impls( + modules: List[MypyFile], + errors: Errors, +) -> SingledispatchInfo: + visitor = SingledispatchVisitor(errors) + for module in modules: + visitor.current_path = module.path + module.accept(visitor) + return SingledispatchInfo(visitor.singledispatch_impls, visitor.decorators_to_remove) + + +class SingledispatchVisitor(TraverserVisitor): + current_path: str + + def __init__(self, errors: Errors) -> None: + super().__init__() + + # Map of main singledispatch function to list of registered implementations + self.singledispatch_impls: DefaultDict[FuncDef, List[RegisterImplInfo]] = defaultdict(list) + + # Map of decorated function to the indices of any register decorators + self.decorators_to_remove: Dict[FuncDef, List[int]] = {} + + self.errors: Errors = errors + + def visit_decorator(self, dec: Decorator) -> None: + if dec.decorators: + decorators_to_store = dec.decorators.copy() + register_indices: List[int] = [] + # the index of the last non-register decorator before finding a register decorator + # when going through decorators from top to bottom + last_non_register: Optional[int] = None + for i, d in enumerate(decorators_to_store): + impl = get_singledispatch_register_call_info(d, dec.func) + if impl is not None: + self.singledispatch_impls[impl.singledispatch_func].append( + (impl.dispatch_type, dec.func)) + register_indices.append(i) + if last_non_register is not None: + # found a register decorator after a non-register decorator, which we + # don't support because we'd have to make a copy of the function before + # calling the decorator so that we can call it later, which complicates + # the implementation for something that is probably not commonly used + self.errors.error( + "Calling decorator after registering function not supported", + self.current_path, + decorators_to_store[last_non_register].line, + ) + else: + last_non_register = i + + if register_indices: + # calling register on a function that tries to dispatch based on type annotations + # raises a TypeError because compiled functions don't have an __annotations__ + # attribute + self.decorators_to_remove[dec.func] = register_indices + + super().visit_decorator(dec) + + +class RegisteredImpl(NamedTuple): + singledispatch_func: FuncDef + dispatch_type: TypeInfo + + +def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef + ) -> Optional[RegisteredImpl]: + # @fun.register(complex) + # def g(arg): ... + if (isinstance(decorator, CallExpr) and len(decorator.args) == 1 + and isinstance(decorator.args[0], RefExpr)): + callee = decorator.callee + dispatch_type = decorator.args[0].node + if not isinstance(dispatch_type, TypeInfo): + return None + + if isinstance(callee, MemberExpr): + return registered_impl_from_possible_register_call(callee, dispatch_type) + # @fun.register + # def g(arg: int): ... + elif isinstance(decorator, MemberExpr): + # we don't know if this is a register call yet, so we can't be sure that the function + # actually has arguments + if not func.arguments: + return None + arg_type = get_proper_type(func.arguments[0].variable.type) + if not isinstance(arg_type, Instance): + return None + info = arg_type.type + return registered_impl_from_possible_register_call(decorator, info) + return None + + +def registered_impl_from_possible_register_call(expr: MemberExpr, dispatch_type: TypeInfo + ) -> Optional[RegisteredImpl]: + if expr.name == 'register' and isinstance(expr.expr, NameExpr): + node = expr.expr.node + if isinstance(node, Decorator): + return RegisteredImpl(node.func, dispatch_type) + return None diff --git a/mypyc/test-data/run-singledispatch.test b/mypyc/test-data/run-singledispatch.test index 5d38e27272af..11035a121bf2 100644 --- a/mypyc/test-data/run-singledispatch.test +++ b/mypyc/test-data/run-singledispatch.test @@ -512,3 +512,34 @@ def test_singledispatch(): assert f(B()) == 'b' assert f(C()) == 'c' assert f(1) == 'default' + +[case testRegisteredImplementationsInDifferentFiles] +from other_a import f, A, B, C +@f.register +def a(arg: A) -> int: + return 2 + +@f.register +def _(arg: C) -> int: + return 3 + +def test_singledispatch(): + assert f(B()) == 1 + assert f(A()) == 2 + assert f(C()) == 3 + assert f(1) == 0 + +[file other_a.py] +from functools import singledispatch + +class A: pass +class B(A): pass +class C(B): pass + +@singledispatch +def f(arg) -> int: + return 0 + +@f.register +def g(arg: B) -> int: + return 1