From 4c90aa42e6a74f127956124aec94f55dbc46a281 Mon Sep 17 00:00:00 2001 From: pranavrajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Sat, 14 Aug 2021 12:22:44 -0700 Subject: [PATCH] [mypyc] Use a cache to speed up singledispatch calls (#10972) This adds a cache to the dispatching code for singledispatch functions, which makes calling singledispatch functions much faster. --- mypyc/codegen/emitclass.py | 2 +- mypyc/ir/class_ir.py | 5 + mypyc/irbuild/function.py | 69 +++++++++--- mypyc/lib-rt/misc_ops.c | 5 + mypyc/primitives/dict_ops.py | 2 +- mypyc/test-data/irbuild-singledispatch.test | 111 ++++++++++++-------- mypyc/test-data/run-singledispatch.test | 41 ++++++++ 7 files changed, 173 insertions(+), 62 deletions(-) diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 755f5c0b3e8e..9c960cf80bdd 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -203,7 +203,7 @@ def generate_class(cl: ClassIR, module: str, emitter: Emitter) -> None: fields['tp_name'] = '"{}"'.format(name) generate_full = not cl.is_trait and not cl.builtin_base - needs_getseters = not cl.is_generated + needs_getseters = cl.needs_getseters or not cl.is_generated if not cl.builtin_base: fields['tp_new'] = new_name diff --git a/mypyc/ir/class_ir.py b/mypyc/ir/class_ir.py index d3454b7371f3..742ed9bf7631 100644 --- a/mypyc/ir/class_ir.py +++ b/mypyc/ir/class_ir.py @@ -102,6 +102,9 @@ def __init__(self, name: str, module_name: str, is_trait: bool = False, self.has_dict = False # Do we allow interpreted subclasses? Derived from a mypyc_attr. self.allow_interpreted_subclasses = False + # Does this class need getseters to be generated for its attributes? (getseters are also + # added if is_generated is False) + self.needs_getseters = False # If this a subclass of some built-in python class, the name # of the object for that class. We currently only support this # in a few ad-hoc cases. @@ -279,6 +282,7 @@ def serialize(self) -> JsonDict: 'inherits_python': self.inherits_python, 'has_dict': self.has_dict, 'allow_interpreted_subclasses': self.allow_interpreted_subclasses, + 'needs_getseters': self.needs_getseters, 'builtin_base': self.builtin_base, 'ctor': self.ctor.serialize(), # We serialize dicts as lists to ensure order is preserved @@ -329,6 +333,7 @@ def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'ClassIR': ir.inherits_python = data['inherits_python'] ir.has_dict = data['has_dict'] ir.allow_interpreted_subclasses = data['allow_interpreted_subclasses'] + ir.needs_getseters = data['needs_getseters'] ir.builtin_base = data['builtin_base'] ir.ctor = FuncDecl.deserialize(data['ctor'], ctx) ir.attributes = OrderedDict( diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 0a040ba10712..66028722a453 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -36,7 +36,7 @@ from mypyc.primitives.misc_ops import ( check_stop_op, yield_from_except_op, coro_op, send_op, register_function ) -from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op +from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op, dict_get_method_with_none from mypyc.common import SELF_NAME, LAMBDA_NAME, decorator_helper_name from mypyc.sametype import is_same_method_signature from mypyc.irbuild.util import is_constant @@ -841,6 +841,43 @@ def generate_singledispatch_dispatch_function( current_func_decl = builder.mapper.func_to_decl[fitem] arg_info = get_args(builder, current_func_decl.sig.args, line) + dispatch_func_obj = builder.self() + + arg_type = builder.builder.get_type_of_obj(arg_info.args[0], line) + dispatch_cache = builder.builder.get_attr( + dispatch_func_obj, 'dispatch_cache', dict_rprimitive, line + ) + call_find_impl, use_cache, call_func = BasicBlock(), BasicBlock(), BasicBlock() + get_result = builder.call_c(dict_get_method_with_none, [dispatch_cache, arg_type], line) + is_not_none = builder.translate_is_op(get_result, builder.none_object(), 'is not', line) + impl_to_use = Register(object_rprimitive) + builder.add_bool_branch(is_not_none, use_cache, call_find_impl) + + builder.activate_block(use_cache) + builder.assign(impl_to_use, get_result, line) + builder.goto(call_func) + + builder.activate_block(call_find_impl) + find_impl = builder.load_module_attr_by_fullname('functools._find_impl', line) + registry = load_singledispatch_registry(builder, dispatch_func_obj, line) + uncached_impl = builder.py_call(find_impl, [arg_type, registry], line) + builder.call_c(dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line) + builder.assign(impl_to_use, uncached_impl, line) + builder.goto(call_func) + + builder.activate_block(call_func) + gen_calls_to_correct_impl(builder, impl_to_use, arg_info, fitem, line) + + +def gen_calls_to_correct_impl( + builder: IRBuilder, + impl_to_use: Value, + arg_info: ArgInfo, + fitem: FuncDef, + line: int, +) -> None: + current_func_decl = builder.mapper.func_to_decl[fitem] + def gen_native_func_call_and_return(fdef: FuncDef) -> None: func_decl = builder.mapper.func_to_decl[fdef] ret_val = builder.builder.call( @@ -849,16 +886,6 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None: coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line) builder.add(Return(coerced)) - registry = load_singledispatch_registry(builder, fitem, line) - - # TODO: cache the output of _find_impl - without adding that caching, this implementation is - # probably slower than the standard library functools implementation because functools caches - # the output of _find_impl and _find_impl looks like it is very slow - - find_impl = builder.load_module_attr_by_fullname('functools._find_impl', line) - arg_type = builder.builder.get_type_of_obj(arg_info.args[0], line) - impl_to_use = builder.py_call(find_impl, [arg_type, registry], line) - typ, src = builtin_names['builtins.int'] int_type_obj = builder.add(LoadAddress(typ, src, line)) is_int = builder.builder.type_is_op(impl_to_use, int_type_obj, line) @@ -913,7 +940,9 @@ def gen_dispatch_func_ir( builder.enter(FuncInfo(fitem, dispatch_name)) setup_callable_class(builder) builder.fn_info.callable_class.ir.attributes['registry'] = dict_rprimitive + builder.fn_info.callable_class.ir.attributes['dispatch_cache'] = dict_rprimitive builder.fn_info.callable_class.ir.has_dict = True + builder.fn_info.callable_class.ir.needs_getseters = True generate_singledispatch_callable_class_ctor(builder) generate_singledispatch_dispatch_function(builder, main_func_name, fitem) @@ -927,12 +956,16 @@ def gen_dispatch_func_ir( def generate_singledispatch_callable_class_ctor(builder: IRBuilder) -> None: - """Create an __init__ that sets registry to an empty dict""" + """Create an __init__ that sets registry and dispatch_cache to empty dicts""" line = -1 class_ir = builder.fn_info.callable_class.ir builder.enter_method(class_ir, '__init__', bool_rprimitive) empty_dict = builder.call_c(dict_new_op, [], line) builder.add(SetAttr(builder.self(), 'registry', empty_dict, line)) + cache_dict = builder.call_c(dict_new_op, [], line) + dispatch_cache_str = builder.load_str('dispatch_cache') + # use the py_setattr_op instead of SetAttr so that it also gets added to our __dict__ + builder.call_c(py_setattr_op, [builder.self(), dispatch_cache_str, cache_dict], line) # the generated C code seems to expect that __init__ returns a char, so just return 1 builder.add(Return(Integer(1, bool_rprimitive, line), line)) builder.leave_method() @@ -948,8 +981,7 @@ def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) builder.leave_method() -def load_singledispatch_registry(builder: IRBuilder, fitem: FuncDef, line: int) -> Value: - dispatch_func_obj = load_func(builder, fitem.name, fitem.fullname, line) +def load_singledispatch_registry(builder: IRBuilder, dispatch_func_obj: Value, line: int) -> Value: return builder.builder.get_attr(dispatch_func_obj, 'registry', dict_rprimitive, line) @@ -997,10 +1029,17 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None: load_literal = LoadLiteral(current_id, object_rprimitive) to_insert = builder.add(load_literal) # TODO: avoid reloading the registry here if we just created it - registry = load_singledispatch_registry(builder, singledispatch_func, line) + dispatch_func_obj = load_func( + builder, singledispatch_func.name, singledispatch_func.fullname, line + ) + registry = load_singledispatch_registry(builder, dispatch_func_obj, line) for typ in types: loaded_type = load_type(builder, typ, line) builder.call_c(dict_set_item_op, [registry, loaded_type, to_insert], line) + dispatch_cache = builder.builder.get_attr( + dispatch_func_obj, 'dispatch_cache', dict_rprimitive, line + ) + builder.gen_method_call(dispatch_cache, 'clear', [], None, line) def get_native_impl_ids(builder: IRBuilder, singledispatch_func: FuncDef) -> Dict[FuncDef, int]: diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index f301fa874211..0701ca9d71a8 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -766,6 +766,11 @@ PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, goto fail; } + // clear the cache so we consider the newly added function when dispatching + PyObject *dispatch_cache = PyObject_GetAttrString(singledispatch_func, "dispatch_cache"); + if (dispatch_cache == NULL) goto fail; + PyDict_Clear(dispatch_cache); + Py_INCREF(func); return func; diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index 51dcde12506b..4fe8693c66c5 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -119,7 +119,7 @@ error_kind=ERR_MAGIC) # dict.get(key) -method_op( +dict_get_method_with_none = method_op( name='get', arg_types=[dict_rprimitive, object_rprimitive], return_type=object_rprimitive, diff --git a/mypyc/test-data/irbuild-singledispatch.test b/mypyc/test-data/irbuild-singledispatch.test index cfbef229cc46..a00f624a9d61 100644 --- a/mypyc/test-data/irbuild-singledispatch.test +++ b/mypyc/test-data/irbuild-singledispatch.test @@ -16,9 +16,17 @@ def f_obj.__init__(__mypyc_self__): __mypyc_self__ :: __main__.f_obj r0 :: dict r1 :: bool + r2 :: dict + r3 :: str + r4 :: int32 + r5 :: bit L0: r0 = PyDict_New() __mypyc_self__.registry = r0; r1 = is_error + r2 = PyDict_New() + r3 = 'dispatch_cache' + r4 = PyObject_SetAttr(__mypyc_self__, r3, r2) + r5 = r4 >= 0 :: signed return 1 def f_obj.__get__(__mypyc_self__, instance, owner): __mypyc_self__, instance, owner, r0 :: object @@ -42,57 +50,70 @@ L0: def f_obj.__call__(__mypyc_self__, arg): __mypyc_self__ :: __main__.f_obj arg :: object - r0 :: dict - r1 :: str - r2 :: object - r3 :: str - r4, r5 :: object - r6 :: str - r7 :: object - r8 :: ptr - r9, r10, r11 :: object - r12 :: ptr - r13 :: object - r14 :: bit - r15 :: int - r16 :: bit - r17 :: int - r18 :: bool - r19 :: object - r20 :: bool + r0 :: ptr + r1 :: object + r2 :: dict + r3, r4 :: object + r5 :: bit + r6, r7 :: object + r8 :: str + r9 :: object + r10 :: dict + r11 :: object + r12 :: int32 + r13 :: bit + r14 :: object + r15 :: ptr + r16 :: object + r17 :: bit + r18 :: int + r19 :: bit + r20 :: int + r21 :: bool + r22 :: object + r23 :: bool L0: - r0 = __main__.globals :: static - r1 = 'f' - r2 = CPyDict_GetItem(r0, r1) - r3 = 'registry' - r4 = CPyObject_GetAttr(r2, r3) - r5 = functools :: module - r6 = '_find_impl' - r7 = CPyObject_GetAttr(r5, r6) - r8 = get_element_ptr arg ob_type :: PyObject - r9 = load_mem r8 :: builtins.object* + r0 = get_element_ptr arg ob_type :: PyObject + r1 = load_mem r0 :: builtins.object* keep_alive arg - r10 = PyObject_CallFunctionObjArgs(r7, r9, r4, 0) - r11 = load_address PyLong_Type - r12 = get_element_ptr r10 ob_type :: PyObject - r13 = load_mem r12 :: builtins.object* - keep_alive r10 - r14 = r13 == r11 - if r14 goto L1 else goto L4 :: bool + r2 = __mypyc_self__.dispatch_cache + r3 = CPyDict_GetWithNone(r2, r1) + r4 = load_address _Py_NoneStruct + r5 = r3 != r4 + if r5 goto L1 else goto L2 :: bool L1: - r15 = unbox(int, r10) - r16 = r15 == 0 - if r16 goto L2 else goto L3 :: bool + r6 = r3 + goto L3 L2: - r17 = unbox(int, arg) - r18 = g(r17) - return r18 + r7 = functools :: module + r8 = '_find_impl' + r9 = CPyObject_GetAttr(r7, r8) + r10 = __mypyc_self__.registry + r11 = PyObject_CallFunctionObjArgs(r9, r1, r10, 0) + r12 = CPyDict_SetItem(r2, r1, r11) + r13 = r12 >= 0 :: signed + r6 = r11 L3: - unreachable + r14 = load_address PyLong_Type + r15 = get_element_ptr r6 ob_type :: PyObject + r16 = load_mem r15 :: builtins.object* + keep_alive r6 + r17 = r16 == r14 + if r17 goto L4 else goto L7 :: bool L4: - r19 = PyObject_CallFunctionObjArgs(r10, arg, 0) - r20 = unbox(bool, r19) - return r20 + r18 = unbox(int, r6) + r19 = r18 == 0 + if r19 goto L5 else goto L6 :: bool +L5: + r20 = unbox(int, arg) + r21 = g(r20) + return r21 +L6: + unreachable +L7: + r22 = PyObject_CallFunctionObjArgs(r6, arg, 0) + r23 = unbox(bool, r22) + return r23 def g(arg): arg :: int L0: diff --git a/mypyc/test-data/run-singledispatch.test b/mypyc/test-data/run-singledispatch.test index de4bc2f649f5..61e4897c96d6 100644 --- a/mypyc/test-data/run-singledispatch.test +++ b/mypyc/test-data/run-singledispatch.test @@ -655,3 +655,44 @@ with assertRaises(TypeError, 'Invalid first argument to `register()`'): @f.register def _(): pass + +[file driver.py] +import register + +[case testCacheClearedWhenNewFunctionRegistered] +from functools import singledispatch + +@singledispatch +def f(arg) -> str: + return 'default' + +[file register.py] +from native import f +class A: pass +class B: pass +class C: pass + +# annotated function +assert f(A()) == 'default' +@f.register +def _(arg: A) -> str: + return 'a' +assert f(A()) == 'a' + +# type passed as argument +assert f(B()) == 'default' +@f.register(B) +def _(arg: B) -> str: + return 'b' +assert f(B()) == 'b' + +# 2 argument form +assert f(C()) == 'default' +def c(arg) -> str: + return 'c' +f.register(C, c) +assert f(C()) == 'c' + + +[file driver.py] +import register