Skip to content

Commit

Permalink
[mypyc] Simplify union types (#14363)
Browse files Browse the repository at this point in the history
We can sometimes simplify a mypyc RType union, even if the mypy union
couldn't be simplified. A typical example is `list[x] | list[y]` which
can be simplified to just `list`. Previously this would generate a
redundant union `union[list, list]`.
  • Loading branch information
JukkaL authored Dec 28, 2022
1 parent 8e7e220 commit 86dad8a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 10 deletions.
13 changes: 12 additions & 1 deletion mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,18 @@ def type_to_rtype(self, typ: Type | None) -> RType:
elif isinstance(typ, NoneTyp):
return none_rprimitive
elif isinstance(typ, UnionType):
return RUnion([self.type_to_rtype(item) for item in typ.items])
# Remove redundant items using set + list to preserve item order
seen = set()
items = []
for item in typ.items:
rtype = self.type_to_rtype(item)
if rtype not in seen:
items.append(rtype)
seen.add(rtype)
if len(items) > 1:
return RUnion(items)
else:
return items[0]
elif isinstance(typ, AnyType):
return object_rprimitive
elif isinstance(typ, TypeType):
Expand Down
37 changes: 37 additions & 0 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,40 @@ L4:
L5:
res = r8
return 1

[case testSimplifyListUnion]
from typing import List, Union

def f(a: Union[List[str], List[bytes], int]) -> int:
if isinstance(a, list):
return len(a)
return a
[out]
def f(a):
a :: union[list, int]
r0 :: object
r1 :: int32
r2 :: bit
r3 :: bool
r4 :: list
r5 :: ptr
r6 :: native_int
r7 :: short_int
r8 :: int
L0:
r0 = load_address PyList_Type
r1 = PyObject_IsInstance(a, r0)
r2 = r1 >= 0 :: signed
r3 = truncate r1: int32 to builtins.bool
if r3 goto L1 else goto L2 :: bool
L1:
r4 = borrow cast(list, a)
r5 = get_element_ptr r4 ob_size :: PyVarObject
r6 = load_mem r5 :: native_int*
keep_alive r4
r7 = r6 << 1
keep_alive a
return r7
L2:
r8 = unbox(int, a)
return r8
14 changes: 5 additions & 9 deletions mypyc/test-data/irbuild-optional.test
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,10 @@ class B:

[out]
def f(o):
o :: union[object, object]
r0 :: object
r1 :: str
r2, r3 :: object
o :: object
r0 :: str
r1 :: object
L0:
r0 = o
r1 = 'x'
r2 = CPyObject_GetAttr(r0, r1)
r3 = r2
L1:
r0 = 'x'
r1 = CPyObject_GetAttr(o, r0)
return 1

0 comments on commit 86dad8a

Please sign in to comment.