Skip to content

Commit

Permalink
[mypyc] Allow registering implementations for singledispatch function…
Browse files Browse the repository at this point in the history
…s in different files (#10880)

* Add test for singledispatch across multiple files

Add a test to make sure that we handle registering a singledispatch
function in a different module than the main function correctly.

This test uses a pattern similar to the code in
mypyc/mypyc#802 (comment) to make
sure that we don't accidentally end up dynamically registering functions
in other modules and just relying on the normal singledispatch machinery
to check the dynamically registered functions.

* Move info about a register implementation to type alias

Add a type alias for the type that we use for information about a
register implementation (which is a tuple of the dispatch type's
TypeInfo and the function's FuncDef).

* Run register-finding pass over entire SCC

Instead of trying to look for singledispatch register implementations in
one file at a time and only using the register implementations found in
the same file as the singledispatch main function when generating code,
look for register implementations across the entire SCC once, and then
use that list when compiling every module.

That change necessitates adding an extra argument for the
PreBuildVisitor constructor so that we can remove any register calls
that we found (so we can avoid register trying to access
`__annotations__` on a builtin function) and adding an extra argument to
IRBuilder so that we can access the list of singledispatch functions
when building IR in any module.

* Load functions from other modules correctly

When we're generating calls to registered implementations in other
modules, we need to load those functions by loading the module and
accessing the correct attribute, instead of trying to load that function
from the globals dict of the current module.

That change also necessitates generating imports of any modules with
registered implementations that weren't already imported so that we can
load those modules in the dispatch function.
  • Loading branch information
pranavrajpal authored Jul 27, 2021
1 parent 14e06c2 commit 9b10175
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 96 deletions.
8 changes: 5 additions & 3 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
functions are transformed in mypyc.irbuild.function.
"""

from mypyc.irbuild.prepare import RegisterImplInfo
from typing import Callable, Dict, List, Tuple, Optional, Union, Sequence, Set, Any
from typing_extensions import overload
from mypy.backports import OrderedDict
Expand All @@ -20,7 +21,7 @@
MypyFile, SymbolNode, Statement, OpExpr, IntExpr, NameExpr, LDEF, Var, UnaryExpr,
CallExpr, IndexExpr, Expression, MemberExpr, RefExpr, Lvalue, TupleExpr,
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, ComparisonExpr, GDEF,
ArgKind, ARG_POS, ARG_NAMED,
ArgKind, ARG_POS, ARG_NAMED, FuncDef,
)
from mypy.types import (
Type, Instance, TupleType, UninhabitedType, get_proper_type
Expand Down Expand Up @@ -85,7 +86,8 @@ def __init__(self,
mapper: Mapper,
pbv: PreBuildVisitor,
visitor: IRVisitor,
options: CompilerOptions) -> None:
options: CompilerOptions,
singledispatch_impls: Dict[FuncDef, List[RegisterImplInfo]]) -> None:
self.builder = LowLevelIRBuilder(current_module, mapper, options)
self.builders = [self.builder]
self.symtables: List[OrderedDict[SymbolNode, SymbolTarget]] = [OrderedDict()]
Expand Down Expand Up @@ -116,7 +118,7 @@ def __init__(self,
self.encapsulating_funcs = pbv.encapsulating_funcs
self.nested_fitems = pbv.nested_funcs.keys()
self.fdefs_to_decorators = pbv.funcs_to_decorators
self.singledispatch_impls = pbv.singledispatch_impls
self.singledispatch_impls = singledispatch_impls

self.visitor = visitor

Expand Down
46 changes: 40 additions & 6 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
instance of the callable class.
"""

from mypyc.irbuild.prepare import RegisterImplInfo
from mypy.build import topsort
from typing import (
NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, Iterator,
Expand Down Expand Up @@ -806,6 +807,22 @@ def check_if_isinstance(builder: IRBuilder, obj: Value, typ: TypeInfo, line: int
return builder.call_c(slow_isinstance_op, [obj, class_obj], line)


def load_func(builder: IRBuilder, func_name: str, fullname: Optional[str], line: int) -> Value:
if fullname is not None and not fullname.startswith(builder.current_module):
# we're calling a function in a different module

# We can't use load_module_attr_by_fullname here because we need to load the function using
# func_name, not the name specified by fullname (which can be different for underscore
# function)
module = fullname.rsplit('.')[0]
loaded_module = builder.load_module(module)

func = builder.py_get_attr(loaded_module, func_name, line)
else:
func = builder.load_global_str(func_name, line)
return func


def generate_singledispatch_dispatch_function(
builder: IRBuilder,
main_singledispatch_function_name: str,
Expand All @@ -816,15 +833,27 @@ def generate_singledispatch_dispatch_function(
current_func_decl = builder.mapper.func_to_decl[fitem]
arg_info = get_args(builder, current_func_decl.sig.args, line)

def gen_func_call_and_return(func_name: str) -> None:
func = builder.load_global_str(func_name, line)
# TODO: don't pass optional arguments if they weren't passed to this function
def gen_func_call_and_return(func_name: str, fullname: Optional[str] = None) -> None:
func = load_func(builder, func_name, fullname, line)
ret_val = builder.builder.py_call(
func, arg_info.args, line, arg_info.arg_kinds, arg_info.arg_names
)
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
builder.nonlocal_control[-1].gen_return(builder, coerced, line)

# Add all necessary imports of other modules that have registered functions in other modules
# We're doing this in a separate pass over the implementations because that avoids the
# complexity and code size implications of generating this import before every call to a
# registered implementation that might need this imported
# TODO: avoid adding imports if we use native calls for all of the registered implementations
# in a module (once we add support for using native calls for registered implementations)
for _, impl in impls:
module_name = impl.fullname.rsplit('.')[0]
if module_name not in builder.imports:
# We need to generate an import here because the module needs to be imported before we
# try loading the function from it
builder.gen_import(module_name, line)

# Sort the list of implementations so that we check any subclasses before we check the classes
# they inherit from, to better match singledispatch's behavior of going through the argument's
# MRO, and using the first implementation it finds
Expand All @@ -839,9 +868,14 @@ def gen_func_call_and_return(func_name: str) -> None:
# The shortname of a function is just '{class}.{func_name}', and we don't support
# singledispatchmethod yet, so that is always the same as the function name
name = short_id_from_name(impl.name, impl.name, impl.line)
gen_func_call_and_return(name)
gen_func_call_and_return(name, fullname=impl.fullname)
builder.activate_block(next_impl)

# We don't pass fullname here because getting the fullname of the main generated singledispatch
# function isn't easy, and we don't need it because the fullname is only needed for making sure
# we load the function from another module instead of the globals dict if it's defined in
# another module, which will never be true for the main singledispatch function (it's always
# generated in the same module as the dispatch function)
gen_func_call_and_return(main_singledispatch_function_name)


Expand All @@ -864,8 +898,8 @@ def gen_dispatch_func_ir(


def sort_with_subclasses_first(
impls: List[Tuple[TypeInfo, FuncDef]]
) -> Iterator[Tuple[TypeInfo, FuncDef]]:
impls: List[RegisterImplInfo]
) -> Iterator[RegisterImplInfo]:

# graph with edges pointing from every class to their subclasses
graph = {typ: set(typ.mro[1:]) for typ, _ in impls}
Expand Down
8 changes: 5 additions & 3 deletions mypyc/irbuild/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def f(x: int) -> int:
from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature
from mypyc.irbuild.prebuildvisitor import PreBuildVisitor
from mypyc.irbuild.vtable import compute_vtable
from mypyc.irbuild.prepare import build_type_map
from mypyc.irbuild.prepare import build_type_map, find_singledispatch_register_impls
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.visitor import IRBuilderVisitor
from mypyc.irbuild.mapper import Mapper
Expand All @@ -58,6 +58,7 @@ def build_ir(modules: List[MypyFile],
"""Build IR for a set of modules that have been type-checked by mypy."""

build_type_map(mapper, modules, graph, types, options, errors)
singledispatch_info = find_singledispatch_register_impls(modules, errors)

result: ModuleIRs = OrderedDict()

Expand All @@ -66,13 +67,14 @@ def build_ir(modules: List[MypyFile],

for module in modules:
# First pass to determine free symbols.
pbv = PreBuildVisitor(errors, module)
pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove)
module.accept(pbv)

# Construct and configure builder objects (cyclic runtime dependency).
visitor = IRBuilderVisitor()
builder = IRBuilder(
module.fullname, types, graph, errors, mapper, pbv, visitor, options
module.fullname, types, graph, errors, mapper, pbv, visitor, options,
singledispatch_info.singledispatch_impls,
)
visitor.builder = builder

Expand Down
100 changes: 19 additions & 81 deletions mypyc/irbuild/prebuildvisitor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from mypyc.errors import Errors
from mypy.types import Instance, get_proper_type
from typing import DefaultDict, Dict, List, NamedTuple, Set, Optional, Tuple
from collections import defaultdict
from typing import Dict, List, Set

from mypy.nodes import (
Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr,
CallExpr, RefExpr, TypeInfo, MypyFile
MypyFile
)
from mypy.traverser import TraverserVisitor

Expand All @@ -24,7 +22,12 @@ class PreBuildVisitor(TraverserVisitor):
The main IR build pass uses this information.
"""

def __init__(self, errors: Errors, current_file: MypyFile) -> None:
def __init__(
self,
errors: Errors,
current_file: MypyFile,
decorators_to_remove: Dict[FuncDef, List[int]],
) -> None:
super().__init__()
# Dict from a function to symbols defined directly in the
# function that are used as non-local (free) variables within a
Expand Down Expand Up @@ -54,9 +57,8 @@ def __init__(self, errors: Errors, current_file: MypyFile) -> None:
# Map function to its non-special decorators.
self.funcs_to_decorators: Dict[FuncDef, List[Expression]] = {}

# Map of main singledispatch function to list of registered implementations
self.singledispatch_impls: DefaultDict[
FuncDef, List[Tuple[TypeInfo, FuncDef]]] = defaultdict(list)
# Map function to indices of decorators to remove
self.decorators_to_remove: Dict[FuncDef, List[int]] = decorators_to_remove

self.errors: Errors = errors

Expand All @@ -76,37 +78,15 @@ def visit_decorator(self, dec: Decorator) -> None:
self.prop_setters.add(dec.func)
else:
decorators_to_store = dec.decorators.copy()
removed: List[int] = []
# the index of the last non-register decorator before finding a register decorator
# when going through decorators from top to bottom
last_non_register: Optional[int] = None
for i, d in enumerate(decorators_to_store):
impl = get_singledispatch_register_call_info(d, dec.func)
if impl is not None:
self.singledispatch_impls[impl.singledispatch_func].append(
(impl.dispatch_type, dec.func))
removed.append(i)
if last_non_register is not None:
# found a register decorator after a non-register decorator, which we
# don't support because we'd have to make a copy of the function before
# calling the decorator so that we can call it later, which complicates
# the implementation for something that is probably not commonly used
self.errors.error(
"Calling decorator after registering function not supported",
self.current_file.path,
decorators_to_store[last_non_register].line,
)
else:
last_non_register = i
# calling register on a function that tries to dispatch based on type annotations
# raises a TypeError because compiled functions don't have an __annotations__
# attribute
for i in reversed(removed):
del decorators_to_store[i]
# if the only decorators are register calls, we shouldn't treat this
# as a decorated function because there aren't any decorators to apply
if not decorators_to_store:
return
if dec.func in self.decorators_to_remove:
to_remove = self.decorators_to_remove[dec.func]

for i in reversed(to_remove):
del decorators_to_store[i]
# if all of the decorators are removed, we shouldn't treat this as a decorated
# function because there aren't any decorators to apply
if not decorators_to_store:
return

self.funcs_to_decorators[dec.func] = decorators_to_store
super().visit_decorator(dec)
Expand Down Expand Up @@ -186,45 +166,3 @@ def add_free_variable(self, symbol: SymbolNode) -> None:
# and mark is as a non-local symbol within that function.
func = self.symbols_to_funcs[symbol]
self.free_variables.setdefault(func, set()).add(symbol)


class RegisteredImpl(NamedTuple):
singledispatch_func: FuncDef
dispatch_type: TypeInfo


def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef
) -> Optional[RegisteredImpl]:
# @fun.register(complex)
# def g(arg): ...
if (isinstance(decorator, CallExpr) and len(decorator.args) == 1
and isinstance(decorator.args[0], RefExpr)):
callee = decorator.callee
dispatch_type = decorator.args[0].node
if not isinstance(dispatch_type, TypeInfo):
return None

if isinstance(callee, MemberExpr):
return registered_impl_from_possible_register_call(callee, dispatch_type)
# @fun.register
# def g(arg: int): ...
elif isinstance(decorator, MemberExpr):
# we don't know if this is a register call yet, so we can't be sure that the function
# actually has arguments
if not func.arguments:
return None
arg_type = get_proper_type(func.arguments[0].variable.type)
if not isinstance(arg_type, Instance):
return None
info = arg_type.type
return registered_impl_from_possible_register_call(decorator, info)
return None


def registered_impl_from_possible_register_call(expr: MemberExpr, dispatch_type: TypeInfo
) -> Optional[RegisteredImpl]:
if expr.name == 'register' and isinstance(expr.expr, NameExpr):
node = expr.expr.node
if isinstance(node, Decorator):
return RegisteredImpl(node.func, dispatch_type)
return None
Loading

0 comments on commit 9b10175

Please sign in to comment.