Skip to content

Commit

Permalink
[mypyc] Optimize two-argument super() for the simple case (#8903)
Browse files Browse the repository at this point in the history
  • Loading branch information
msullivan authored May 28, 2020
1 parent 4dce036 commit 07c9f6f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
22 changes: 21 additions & 1 deletion mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,28 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr


def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: SuperExpr) -> Value:
if callee.info is None or callee.call.args:
if callee.info is None or (len(callee.call.args) != 0 and len(callee.call.args) != 2):
return translate_call(builder, expr, callee)

# We support two-argument super but only when it is super(CurrentClass, self)
# TODO: We could support it when it is a parent class in many cases?
if len(callee.call.args) == 2:
self_arg = callee.call.args[1]
if (
not isinstance(self_arg, NameExpr)
or not isinstance(self_arg.node, Var)
or not self_arg.node.is_self
):
return translate_call(builder, expr, callee)

typ_arg = callee.call.args[0]
if (
not isinstance(typ_arg, NameExpr)
or not isinstance(typ_arg.node, TypeInfo)
or callee.info is not typ_arg.node
):
return translate_call(builder, expr, callee)

ir = builder.mapper.type_to_ir[callee.info]
# Search for the method in the mro, skipping ourselves.
for base in ir.mro[1:]:
Expand Down
14 changes: 14 additions & 0 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -709,15 +709,27 @@ from typing import List
class A:
def __init__(self, x: int) -> None:
self.x = x

def foo(self, x: int) -> int:
return x

class B(A):
def __init__(self, x: int, y: int) -> None:
super().__init__(x)
self.y = y

def foo(self, x: int) -> int:
return super().foo(x+1)

class C(B):
def __init__(self, x: int, y: int) -> None:
init = super(C, self).__init__
init(x, y+1)

def foo(self, x: int) -> int:
# should go to A, not B
return super(B, self).foo(x+1)

class X:
def __init__(self, x: int) -> None:
self.x = x
Expand Down Expand Up @@ -753,6 +765,8 @@ assert c.x == 10 and c.y == 21
z = Z(10, 20)
assert z.x == 10 and z.y == 20

assert c.foo(10) == 11

PrintList().v_list([1,2,3])
[out]
yo!
Expand Down

0 comments on commit 07c9f6f

Please sign in to comment.