Skip to content

Commit

Permalink
[mypyc] Use a cache to speed up singledispatch calls (#10972)
Browse files Browse the repository at this point in the history
This adds a cache to the dispatching code for singledispatch functions,
which makes calling singledispatch functions much faster.
  • Loading branch information
pranavrajpal authored Aug 14, 2021
1 parent 6444968 commit 4c90aa4
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 62 deletions.
2 changes: 1 addition & 1 deletion mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def generate_class(cl: ClassIR, module: str, emitter: Emitter) -> None:
fields['tp_name'] = '"{}"'.format(name)

generate_full = not cl.is_trait and not cl.builtin_base
needs_getseters = not cl.is_generated
needs_getseters = cl.needs_getseters or not cl.is_generated

if not cl.builtin_base:
fields['tp_new'] = new_name
Expand Down
5 changes: 5 additions & 0 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def __init__(self, name: str, module_name: str, is_trait: bool = False,
self.has_dict = False
# Do we allow interpreted subclasses? Derived from a mypyc_attr.
self.allow_interpreted_subclasses = False
# Does this class need getseters to be generated for its attributes? (getseters are also
# added if is_generated is False)
self.needs_getseters = False
# If this a subclass of some built-in python class, the name
# of the object for that class. We currently only support this
# in a few ad-hoc cases.
Expand Down Expand Up @@ -279,6 +282,7 @@ def serialize(self) -> JsonDict:
'inherits_python': self.inherits_python,
'has_dict': self.has_dict,
'allow_interpreted_subclasses': self.allow_interpreted_subclasses,
'needs_getseters': self.needs_getseters,
'builtin_base': self.builtin_base,
'ctor': self.ctor.serialize(),
# We serialize dicts as lists to ensure order is preserved
Expand Down Expand Up @@ -329,6 +333,7 @@ def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'ClassIR':
ir.inherits_python = data['inherits_python']
ir.has_dict = data['has_dict']
ir.allow_interpreted_subclasses = data['allow_interpreted_subclasses']
ir.needs_getseters = data['needs_getseters']
ir.builtin_base = data['builtin_base']
ir.ctor = FuncDecl.deserialize(data['ctor'], ctx)
ir.attributes = OrderedDict(
Expand Down
69 changes: 54 additions & 15 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from mypyc.primitives.misc_ops import (
check_stop_op, yield_from_except_op, coro_op, send_op, register_function
)
from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op
from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op, dict_get_method_with_none
from mypyc.common import SELF_NAME, LAMBDA_NAME, decorator_helper_name
from mypyc.sametype import is_same_method_signature
from mypyc.irbuild.util import is_constant
Expand Down Expand Up @@ -841,6 +841,43 @@ def generate_singledispatch_dispatch_function(
current_func_decl = builder.mapper.func_to_decl[fitem]
arg_info = get_args(builder, current_func_decl.sig.args, line)

dispatch_func_obj = builder.self()

arg_type = builder.builder.get_type_of_obj(arg_info.args[0], line)
dispatch_cache = builder.builder.get_attr(
dispatch_func_obj, 'dispatch_cache', dict_rprimitive, line
)
call_find_impl, use_cache, call_func = BasicBlock(), BasicBlock(), BasicBlock()
get_result = builder.call_c(dict_get_method_with_none, [dispatch_cache, arg_type], line)
is_not_none = builder.translate_is_op(get_result, builder.none_object(), 'is not', line)
impl_to_use = Register(object_rprimitive)
builder.add_bool_branch(is_not_none, use_cache, call_find_impl)

builder.activate_block(use_cache)
builder.assign(impl_to_use, get_result, line)
builder.goto(call_func)

builder.activate_block(call_find_impl)
find_impl = builder.load_module_attr_by_fullname('functools._find_impl', line)
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
uncached_impl = builder.py_call(find_impl, [arg_type, registry], line)
builder.call_c(dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line)
builder.assign(impl_to_use, uncached_impl, line)
builder.goto(call_func)

builder.activate_block(call_func)
gen_calls_to_correct_impl(builder, impl_to_use, arg_info, fitem, line)


def gen_calls_to_correct_impl(
builder: IRBuilder,
impl_to_use: Value,
arg_info: ArgInfo,
fitem: FuncDef,
line: int,
) -> None:
current_func_decl = builder.mapper.func_to_decl[fitem]

def gen_native_func_call_and_return(fdef: FuncDef) -> None:
func_decl = builder.mapper.func_to_decl[fdef]
ret_val = builder.builder.call(
Expand All @@ -849,16 +886,6 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None:
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
builder.add(Return(coerced))

registry = load_singledispatch_registry(builder, fitem, line)

# TODO: cache the output of _find_impl - without adding that caching, this implementation is
# probably slower than the standard library functools implementation because functools caches
# the output of _find_impl and _find_impl looks like it is very slow

find_impl = builder.load_module_attr_by_fullname('functools._find_impl', line)
arg_type = builder.builder.get_type_of_obj(arg_info.args[0], line)
impl_to_use = builder.py_call(find_impl, [arg_type, registry], line)

typ, src = builtin_names['builtins.int']
int_type_obj = builder.add(LoadAddress(typ, src, line))
is_int = builder.builder.type_is_op(impl_to_use, int_type_obj, line)
Expand Down Expand Up @@ -913,7 +940,9 @@ def gen_dispatch_func_ir(
builder.enter(FuncInfo(fitem, dispatch_name))
setup_callable_class(builder)
builder.fn_info.callable_class.ir.attributes['registry'] = dict_rprimitive
builder.fn_info.callable_class.ir.attributes['dispatch_cache'] = dict_rprimitive
builder.fn_info.callable_class.ir.has_dict = True
builder.fn_info.callable_class.ir.needs_getseters = True
generate_singledispatch_callable_class_ctor(builder)

generate_singledispatch_dispatch_function(builder, main_func_name, fitem)
Expand All @@ -927,12 +956,16 @@ def gen_dispatch_func_ir(


def generate_singledispatch_callable_class_ctor(builder: IRBuilder) -> None:
"""Create an __init__ that sets registry to an empty dict"""
"""Create an __init__ that sets registry and dispatch_cache to empty dicts"""
line = -1
class_ir = builder.fn_info.callable_class.ir
builder.enter_method(class_ir, '__init__', bool_rprimitive)
empty_dict = builder.call_c(dict_new_op, [], line)
builder.add(SetAttr(builder.self(), 'registry', empty_dict, line))
cache_dict = builder.call_c(dict_new_op, [], line)
dispatch_cache_str = builder.load_str('dispatch_cache')
# use the py_setattr_op instead of SetAttr so that it also gets added to our __dict__
builder.call_c(py_setattr_op, [builder.self(), dispatch_cache_str, cache_dict], line)
# the generated C code seems to expect that __init__ returns a char, so just return 1
builder.add(Return(Integer(1, bool_rprimitive, line), line))
builder.leave_method()
Expand All @@ -948,8 +981,7 @@ def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo)
builder.leave_method()


def load_singledispatch_registry(builder: IRBuilder, fitem: FuncDef, line: int) -> Value:
dispatch_func_obj = load_func(builder, fitem.name, fitem.fullname, line)
def load_singledispatch_registry(builder: IRBuilder, dispatch_func_obj: Value, line: int) -> Value:
return builder.builder.get_attr(dispatch_func_obj, 'registry', dict_rprimitive, line)


Expand Down Expand Up @@ -997,10 +1029,17 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None:
load_literal = LoadLiteral(current_id, object_rprimitive)
to_insert = builder.add(load_literal)
# TODO: avoid reloading the registry here if we just created it
registry = load_singledispatch_registry(builder, singledispatch_func, line)
dispatch_func_obj = load_func(
builder, singledispatch_func.name, singledispatch_func.fullname, line
)
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
for typ in types:
loaded_type = load_type(builder, typ, line)
builder.call_c(dict_set_item_op, [registry, loaded_type, to_insert], line)
dispatch_cache = builder.builder.get_attr(
dispatch_func_obj, 'dispatch_cache', dict_rprimitive, line
)
builder.gen_method_call(dispatch_cache, 'clear', [], None, line)


def get_native_impl_ids(builder: IRBuilder, singledispatch_func: FuncDef) -> Dict[FuncDef, int]:
Expand Down
5 changes: 5 additions & 0 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,11 @@ PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func,
goto fail;
}

// clear the cache so we consider the newly added function when dispatching
PyObject *dispatch_cache = PyObject_GetAttrString(singledispatch_func, "dispatch_cache");
if (dispatch_cache == NULL) goto fail;
PyDict_Clear(dispatch_cache);

Py_INCREF(func);
return func;

Expand Down
2 changes: 1 addition & 1 deletion mypyc/primitives/dict_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
error_kind=ERR_MAGIC)

# dict.get(key)
method_op(
dict_get_method_with_none = method_op(
name='get',
arg_types=[dict_rprimitive, object_rprimitive],
return_type=object_rprimitive,
Expand Down
111 changes: 66 additions & 45 deletions mypyc/test-data/irbuild-singledispatch.test
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@ def f_obj.__init__(__mypyc_self__):
__mypyc_self__ :: __main__.f_obj
r0 :: dict
r1 :: bool
r2 :: dict
r3 :: str
r4 :: int32
r5 :: bit
L0:
r0 = PyDict_New()
__mypyc_self__.registry = r0; r1 = is_error
r2 = PyDict_New()
r3 = 'dispatch_cache'
r4 = PyObject_SetAttr(__mypyc_self__, r3, r2)
r5 = r4 >= 0 :: signed
return 1
def f_obj.__get__(__mypyc_self__, instance, owner):
__mypyc_self__, instance, owner, r0 :: object
Expand All @@ -42,57 +50,70 @@ L0:
def f_obj.__call__(__mypyc_self__, arg):
__mypyc_self__ :: __main__.f_obj
arg :: object
r0 :: dict
r1 :: str
r2 :: object
r3 :: str
r4, r5 :: object
r6 :: str
r7 :: object
r8 :: ptr
r9, r10, r11 :: object
r12 :: ptr
r13 :: object
r14 :: bit
r15 :: int
r16 :: bit
r17 :: int
r18 :: bool
r19 :: object
r20 :: bool
r0 :: ptr
r1 :: object
r2 :: dict
r3, r4 :: object
r5 :: bit
r6, r7 :: object
r8 :: str
r9 :: object
r10 :: dict
r11 :: object
r12 :: int32
r13 :: bit
r14 :: object
r15 :: ptr
r16 :: object
r17 :: bit
r18 :: int
r19 :: bit
r20 :: int
r21 :: bool
r22 :: object
r23 :: bool
L0:
r0 = __main__.globals :: static
r1 = 'f'
r2 = CPyDict_GetItem(r0, r1)
r3 = 'registry'
r4 = CPyObject_GetAttr(r2, r3)
r5 = functools :: module
r6 = '_find_impl'
r7 = CPyObject_GetAttr(r5, r6)
r8 = get_element_ptr arg ob_type :: PyObject
r9 = load_mem r8 :: builtins.object*
r0 = get_element_ptr arg ob_type :: PyObject
r1 = load_mem r0 :: builtins.object*
keep_alive arg
r10 = PyObject_CallFunctionObjArgs(r7, r9, r4, 0)
r11 = load_address PyLong_Type
r12 = get_element_ptr r10 ob_type :: PyObject
r13 = load_mem r12 :: builtins.object*
keep_alive r10
r14 = r13 == r11
if r14 goto L1 else goto L4 :: bool
r2 = __mypyc_self__.dispatch_cache
r3 = CPyDict_GetWithNone(r2, r1)
r4 = load_address _Py_NoneStruct
r5 = r3 != r4
if r5 goto L1 else goto L2 :: bool
L1:
r15 = unbox(int, r10)
r16 = r15 == 0
if r16 goto L2 else goto L3 :: bool
r6 = r3
goto L3
L2:
r17 = unbox(int, arg)
r18 = g(r17)
return r18
r7 = functools :: module
r8 = '_find_impl'
r9 = CPyObject_GetAttr(r7, r8)
r10 = __mypyc_self__.registry
r11 = PyObject_CallFunctionObjArgs(r9, r1, r10, 0)
r12 = CPyDict_SetItem(r2, r1, r11)
r13 = r12 >= 0 :: signed
r6 = r11
L3:
unreachable
r14 = load_address PyLong_Type
r15 = get_element_ptr r6 ob_type :: PyObject
r16 = load_mem r15 :: builtins.object*
keep_alive r6
r17 = r16 == r14
if r17 goto L4 else goto L7 :: bool
L4:
r19 = PyObject_CallFunctionObjArgs(r10, arg, 0)
r20 = unbox(bool, r19)
return r20
r18 = unbox(int, r6)
r19 = r18 == 0
if r19 goto L5 else goto L6 :: bool
L5:
r20 = unbox(int, arg)
r21 = g(r20)
return r21
L6:
unreachable
L7:
r22 = PyObject_CallFunctionObjArgs(r6, arg, 0)
r23 = unbox(bool, r22)
return r23
def g(arg):
arg :: int
L0:
Expand Down
41 changes: 41 additions & 0 deletions mypyc/test-data/run-singledispatch.test
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,44 @@ with assertRaises(TypeError, 'Invalid first argument to `register()`'):
@f.register
def _():
pass

[file driver.py]
import register

[case testCacheClearedWhenNewFunctionRegistered]
from functools import singledispatch

@singledispatch
def f(arg) -> str:
return 'default'

[file register.py]
from native import f
class A: pass
class B: pass
class C: pass

# annotated function
assert f(A()) == 'default'
@f.register
def _(arg: A) -> str:
return 'a'
assert f(A()) == 'a'

# type passed as argument
assert f(B()) == 'default'
@f.register(B)
def _(arg: B) -> str:
return 'b'
assert f(B()) == 'b'

# 2 argument form
assert f(C()) == 'default'
def c(arg) -> str:
return 'c'
f.register(C, c)
assert f(C()) == 'c'


[file driver.py]
import register

0 comments on commit 4c90aa4

Please sign in to comment.