diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 33682fe83240..8ad64f5559da 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -11,6 +11,7 @@ functions are transformed in mypyc.irbuild.function. """ +from mypyc.irbuild.prepare import RegisterImplInfo from typing import Callable, Dict, List, Tuple, Optional, Union, Sequence, Set, Any from typing_extensions import overload from mypy.backports import OrderedDict @@ -20,7 +21,7 @@ MypyFile, SymbolNode, Statement, OpExpr, IntExpr, NameExpr, LDEF, Var, UnaryExpr, CallExpr, IndexExpr, Expression, MemberExpr, RefExpr, Lvalue, TupleExpr, TypeInfo, Decorator, OverloadedFuncDef, StarExpr, ComparisonExpr, GDEF, - ArgKind, ARG_POS, ARG_NAMED, + ArgKind, ARG_POS, ARG_NAMED, FuncDef, ) from mypy.types import ( Type, Instance, TupleType, UninhabitedType, get_proper_type @@ -85,7 +86,8 @@ def __init__(self, mapper: Mapper, pbv: PreBuildVisitor, visitor: IRVisitor, - options: CompilerOptions) -> None: + options: CompilerOptions, + singledispatch_impls: Dict[FuncDef, List[RegisterImplInfo]]) -> None: self.builder = LowLevelIRBuilder(current_module, mapper, options) self.builders = [self.builder] self.symtables: List[OrderedDict[SymbolNode, SymbolTarget]] = [OrderedDict()] @@ -116,7 +118,7 @@ def __init__(self, self.encapsulating_funcs = pbv.encapsulating_funcs self.nested_fitems = pbv.nested_funcs.keys() self.fdefs_to_decorators = pbv.funcs_to_decorators - self.singledispatch_impls = pbv.singledispatch_impls + self.singledispatch_impls = singledispatch_impls self.visitor = visitor diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index c4e99646eb03..5a0ac89540bb 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -10,6 +10,7 @@ instance of the callable class. """ +from mypyc.irbuild.prepare import RegisterImplInfo from mypy.build import topsort from typing import ( NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, Iterator, @@ -806,6 +807,22 @@ def check_if_isinstance(builder: IRBuilder, obj: Value, typ: TypeInfo, line: int return builder.call_c(slow_isinstance_op, [obj, class_obj], line) +def load_func(builder: IRBuilder, func_name: str, fullname: Optional[str], line: int) -> Value: + if fullname is not None and not fullname.startswith(builder.current_module): + # we're calling a function in a different module + + # We can't use load_module_attr_by_fullname here because we need to load the function using + # func_name, not the name specified by fullname (which can be different for underscore + # function) + module = fullname.rsplit('.')[0] + loaded_module = builder.load_module(module) + + func = builder.py_get_attr(loaded_module, func_name, line) + else: + func = builder.load_global_str(func_name, line) + return func + + def generate_singledispatch_dispatch_function( builder: IRBuilder, main_singledispatch_function_name: str, @@ -816,15 +833,27 @@ def generate_singledispatch_dispatch_function( current_func_decl = builder.mapper.func_to_decl[fitem] arg_info = get_args(builder, current_func_decl.sig.args, line) - def gen_func_call_and_return(func_name: str) -> None: - func = builder.load_global_str(func_name, line) - # TODO: don't pass optional arguments if they weren't passed to this function + def gen_func_call_and_return(func_name: str, fullname: Optional[str] = None) -> None: + func = load_func(builder, func_name, fullname, line) ret_val = builder.builder.py_call( func, arg_info.args, line, arg_info.arg_kinds, arg_info.arg_names ) coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line) builder.nonlocal_control[-1].gen_return(builder, coerced, line) + # Add all necessary imports of other modules that have registered functions in other modules + # We're doing this in a separate pass over the implementations because that avoids the + # complexity and code size implications of generating this import before every call to a + # registered implementation that might need this imported + # TODO: avoid adding imports if we use native calls for all of the registered implementations + # in a module (once we add support for using native calls for registered implementations) + for _, impl in impls: + module_name = impl.fullname.rsplit('.')[0] + if module_name not in builder.imports: + # We need to generate an import here because the module needs to be imported before we + # try loading the function from it + builder.gen_import(module_name, line) + # Sort the list of implementations so that we check any subclasses before we check the classes # they inherit from, to better match singledispatch's behavior of going through the argument's # MRO, and using the first implementation it finds @@ -839,9 +868,14 @@ def gen_func_call_and_return(func_name: str) -> None: # The shortname of a function is just '{class}.{func_name}', and we don't support # singledispatchmethod yet, so that is always the same as the function name name = short_id_from_name(impl.name, impl.name, impl.line) - gen_func_call_and_return(name) + gen_func_call_and_return(name, fullname=impl.fullname) builder.activate_block(next_impl) + # We don't pass fullname here because getting the fullname of the main generated singledispatch + # function isn't easy, and we don't need it because the fullname is only needed for making sure + # we load the function from another module instead of the globals dict if it's defined in + # another module, which will never be true for the main singledispatch function (it's always + # generated in the same module as the dispatch function) gen_func_call_and_return(main_singledispatch_function_name) @@ -864,8 +898,8 @@ def gen_dispatch_func_ir( def sort_with_subclasses_first( - impls: List[Tuple[TypeInfo, FuncDef]] -) -> Iterator[Tuple[TypeInfo, FuncDef]]: + impls: List[RegisterImplInfo] +) -> Iterator[RegisterImplInfo]: # graph with edges pointing from every class to their subclasses graph = {typ: set(typ.mro[1:]) for typ, _ in impls} diff --git a/mypyc/irbuild/main.py b/mypyc/irbuild/main.py index e0dee7e67035..5d666384b0a9 100644 --- a/mypyc/irbuild/main.py +++ b/mypyc/irbuild/main.py @@ -36,7 +36,7 @@ def f(x: int) -> int: from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature from mypyc.irbuild.prebuildvisitor import PreBuildVisitor from mypyc.irbuild.vtable import compute_vtable -from mypyc.irbuild.prepare import build_type_map +from mypyc.irbuild.prepare import build_type_map, find_singledispatch_register_impls from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.visitor import IRBuilderVisitor from mypyc.irbuild.mapper import Mapper @@ -58,6 +58,7 @@ def build_ir(modules: List[MypyFile], """Build IR for a set of modules that have been type-checked by mypy.""" build_type_map(mapper, modules, graph, types, options, errors) + singledispatch_info = find_singledispatch_register_impls(modules, errors) result: ModuleIRs = OrderedDict() @@ -66,13 +67,14 @@ def build_ir(modules: List[MypyFile], for module in modules: # First pass to determine free symbols. - pbv = PreBuildVisitor(errors, module) + pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove) module.accept(pbv) # Construct and configure builder objects (cyclic runtime dependency). visitor = IRBuilderVisitor() builder = IRBuilder( - module.fullname, types, graph, errors, mapper, pbv, visitor, options + module.fullname, types, graph, errors, mapper, pbv, visitor, options, + singledispatch_info.singledispatch_impls, ) visitor.builder = builder diff --git a/mypyc/irbuild/prebuildvisitor.py b/mypyc/irbuild/prebuildvisitor.py index 8b0b0dd8073b..55928a57b839 100644 --- a/mypyc/irbuild/prebuildvisitor.py +++ b/mypyc/irbuild/prebuildvisitor.py @@ -1,11 +1,9 @@ from mypyc.errors import Errors -from mypy.types import Instance, get_proper_type -from typing import DefaultDict, Dict, List, NamedTuple, Set, Optional, Tuple -from collections import defaultdict +from typing import Dict, List, Set from mypy.nodes import ( Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr, - CallExpr, RefExpr, TypeInfo, MypyFile + MypyFile ) from mypy.traverser import TraverserVisitor @@ -24,7 +22,12 @@ class PreBuildVisitor(TraverserVisitor): The main IR build pass uses this information. """ - def __init__(self, errors: Errors, current_file: MypyFile) -> None: + def __init__( + self, + errors: Errors, + current_file: MypyFile, + decorators_to_remove: Dict[FuncDef, List[int]], + ) -> None: super().__init__() # Dict from a function to symbols defined directly in the # function that are used as non-local (free) variables within a @@ -54,9 +57,8 @@ def __init__(self, errors: Errors, current_file: MypyFile) -> None: # Map function to its non-special decorators. self.funcs_to_decorators: Dict[FuncDef, List[Expression]] = {} - # Map of main singledispatch function to list of registered implementations - self.singledispatch_impls: DefaultDict[ - FuncDef, List[Tuple[TypeInfo, FuncDef]]] = defaultdict(list) + # Map function to indices of decorators to remove + self.decorators_to_remove: Dict[FuncDef, List[int]] = decorators_to_remove self.errors: Errors = errors @@ -76,37 +78,15 @@ def visit_decorator(self, dec: Decorator) -> None: self.prop_setters.add(dec.func) else: decorators_to_store = dec.decorators.copy() - removed: List[int] = [] - # the index of the last non-register decorator before finding a register decorator - # when going through decorators from top to bottom - last_non_register: Optional[int] = None - for i, d in enumerate(decorators_to_store): - impl = get_singledispatch_register_call_info(d, dec.func) - if impl is not None: - self.singledispatch_impls[impl.singledispatch_func].append( - (impl.dispatch_type, dec.func)) - removed.append(i) - if last_non_register is not None: - # found a register decorator after a non-register decorator, which we - # don't support because we'd have to make a copy of the function before - # calling the decorator so that we can call it later, which complicates - # the implementation for something that is probably not commonly used - self.errors.error( - "Calling decorator after registering function not supported", - self.current_file.path, - decorators_to_store[last_non_register].line, - ) - else: - last_non_register = i - # calling register on a function that tries to dispatch based on type annotations - # raises a TypeError because compiled functions don't have an __annotations__ - # attribute - for i in reversed(removed): - del decorators_to_store[i] - # if the only decorators are register calls, we shouldn't treat this - # as a decorated function because there aren't any decorators to apply - if not decorators_to_store: - return + if dec.func in self.decorators_to_remove: + to_remove = self.decorators_to_remove[dec.func] + + for i in reversed(to_remove): + del decorators_to_store[i] + # if all of the decorators are removed, we shouldn't treat this as a decorated + # function because there aren't any decorators to apply + if not decorators_to_store: + return self.funcs_to_decorators[dec.func] = decorators_to_store super().visit_decorator(dec) @@ -186,45 +166,3 @@ def add_free_variable(self, symbol: SymbolNode) -> None: # and mark is as a non-local symbol within that function. func = self.symbols_to_funcs[symbol] self.free_variables.setdefault(func, set()).add(symbol) - - -class RegisteredImpl(NamedTuple): - singledispatch_func: FuncDef - dispatch_type: TypeInfo - - -def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef - ) -> Optional[RegisteredImpl]: - # @fun.register(complex) - # def g(arg): ... - if (isinstance(decorator, CallExpr) and len(decorator.args) == 1 - and isinstance(decorator.args[0], RefExpr)): - callee = decorator.callee - dispatch_type = decorator.args[0].node - if not isinstance(dispatch_type, TypeInfo): - return None - - if isinstance(callee, MemberExpr): - return registered_impl_from_possible_register_call(callee, dispatch_type) - # @fun.register - # def g(arg: int): ... - elif isinstance(decorator, MemberExpr): - # we don't know if this is a register call yet, so we can't be sure that the function - # actually has arguments - if not func.arguments: - return None - arg_type = get_proper_type(func.arguments[0].variable.type) - if not isinstance(arg_type, Instance): - return None - info = arg_type.type - return registered_impl_from_possible_register_call(decorator, info) - return None - - -def registered_impl_from_possible_register_call(expr: MemberExpr, dispatch_type: TypeInfo - ) -> Optional[RegisteredImpl]: - if expr.name == 'register' and isinstance(expr.expr, NameExpr): - node = expr.expr.node - if isinstance(node, Decorator): - return RegisteredImpl(node.func, dispatch_type) - return None diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index a4f2fc996893..338d0940901b 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -11,13 +11,14 @@ Also build a mapping from mypy TypeInfos to ClassIR objects. """ -from typing import List, Dict, Iterable, Optional, Union +from typing import List, Dict, Iterable, Optional, Union, DefaultDict, NamedTuple, Tuple from mypy.nodes import ( MypyFile, TypeInfo, FuncDef, ClassDef, Decorator, OverloadedFuncDef, MemberExpr, Var, - Expression, SymbolNode, ARG_STAR, ARG_STAR2 + Expression, SymbolNode, ARG_STAR, ARG_STAR2, CallExpr, Decorator, Expression, FuncDef, + MemberExpr, MypyFile, NameExpr, RefExpr, TypeInfo, ) -from mypy.types import Type +from mypy.types import Type, Instance, get_proper_type from mypy.build import Graph from mypyc.ir.ops import DeserMaps @@ -34,6 +35,8 @@ from mypyc.errors import Errors from mypyc.options import CompilerOptions from mypyc.crash import catch_errors +from collections import defaultdict +from mypy.traverser import TraverserVisitor def build_type_map(mapper: Mapper, @@ -303,3 +306,113 @@ def prepare_non_ext_class_def(path: str, module_name: str, cdef: ClassDef, ): errors.error( "Non-extension classes may not inherit from extension classes", path, cdef.line) + + +RegisterImplInfo = Tuple[TypeInfo, FuncDef] + + +class SingledispatchInfo(NamedTuple): + singledispatch_impls: Dict[FuncDef, List[RegisterImplInfo]] + decorators_to_remove: Dict[FuncDef, List[int]] + + +def find_singledispatch_register_impls( + modules: List[MypyFile], + errors: Errors, +) -> SingledispatchInfo: + visitor = SingledispatchVisitor(errors) + for module in modules: + visitor.current_path = module.path + module.accept(visitor) + return SingledispatchInfo(visitor.singledispatch_impls, visitor.decorators_to_remove) + + +class SingledispatchVisitor(TraverserVisitor): + current_path: str + + def __init__(self, errors: Errors) -> None: + super().__init__() + + # Map of main singledispatch function to list of registered implementations + self.singledispatch_impls: DefaultDict[FuncDef, List[RegisterImplInfo]] = defaultdict(list) + + # Map of decorated function to the indices of any register decorators + self.decorators_to_remove: Dict[FuncDef, List[int]] = {} + + self.errors: Errors = errors + + def visit_decorator(self, dec: Decorator) -> None: + if dec.decorators: + decorators_to_store = dec.decorators.copy() + register_indices: List[int] = [] + # the index of the last non-register decorator before finding a register decorator + # when going through decorators from top to bottom + last_non_register: Optional[int] = None + for i, d in enumerate(decorators_to_store): + impl = get_singledispatch_register_call_info(d, dec.func) + if impl is not None: + self.singledispatch_impls[impl.singledispatch_func].append( + (impl.dispatch_type, dec.func)) + register_indices.append(i) + if last_non_register is not None: + # found a register decorator after a non-register decorator, which we + # don't support because we'd have to make a copy of the function before + # calling the decorator so that we can call it later, which complicates + # the implementation for something that is probably not commonly used + self.errors.error( + "Calling decorator after registering function not supported", + self.current_path, + decorators_to_store[last_non_register].line, + ) + else: + last_non_register = i + + if register_indices: + # calling register on a function that tries to dispatch based on type annotations + # raises a TypeError because compiled functions don't have an __annotations__ + # attribute + self.decorators_to_remove[dec.func] = register_indices + + super().visit_decorator(dec) + + +class RegisteredImpl(NamedTuple): + singledispatch_func: FuncDef + dispatch_type: TypeInfo + + +def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef + ) -> Optional[RegisteredImpl]: + # @fun.register(complex) + # def g(arg): ... + if (isinstance(decorator, CallExpr) and len(decorator.args) == 1 + and isinstance(decorator.args[0], RefExpr)): + callee = decorator.callee + dispatch_type = decorator.args[0].node + if not isinstance(dispatch_type, TypeInfo): + return None + + if isinstance(callee, MemberExpr): + return registered_impl_from_possible_register_call(callee, dispatch_type) + # @fun.register + # def g(arg: int): ... + elif isinstance(decorator, MemberExpr): + # we don't know if this is a register call yet, so we can't be sure that the function + # actually has arguments + if not func.arguments: + return None + arg_type = get_proper_type(func.arguments[0].variable.type) + if not isinstance(arg_type, Instance): + return None + info = arg_type.type + return registered_impl_from_possible_register_call(decorator, info) + return None + + +def registered_impl_from_possible_register_call(expr: MemberExpr, dispatch_type: TypeInfo + ) -> Optional[RegisteredImpl]: + if expr.name == 'register' and isinstance(expr.expr, NameExpr): + node = expr.expr.node + if isinstance(node, Decorator): + return RegisteredImpl(node.func, dispatch_type) + return None diff --git a/mypyc/test-data/run-singledispatch.test b/mypyc/test-data/run-singledispatch.test index 5d38e27272af..11035a121bf2 100644 --- a/mypyc/test-data/run-singledispatch.test +++ b/mypyc/test-data/run-singledispatch.test @@ -512,3 +512,34 @@ def test_singledispatch(): assert f(B()) == 'b' assert f(C()) == 'c' assert f(1) == 'default' + +[case testRegisteredImplementationsInDifferentFiles] +from other_a import f, A, B, C +@f.register +def a(arg: A) -> int: + return 2 + +@f.register +def _(arg: C) -> int: + return 3 + +def test_singledispatch(): + assert f(B()) == 1 + assert f(A()) == 2 + assert f(C()) == 3 + assert f(1) == 0 + +[file other_a.py] +from functools import singledispatch + +class A: pass +class B(A): pass +class C(B): pass + +@singledispatch +def f(arg) -> int: + return 0 + +@f.register +def g(arg: B) -> int: + return 1