Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Simplify union types #14363

Merged
merged 2 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,18 @@ def type_to_rtype(self, typ: Type | None) -> RType:
elif isinstance(typ, NoneTyp):
return none_rprimitive
elif isinstance(typ, UnionType):
return RUnion([self.type_to_rtype(item) for item in typ.items])
# Remove redundant items using set + list to preserve item order
seen = set()
items = []
for item in typ.items:
rtype = self.type_to_rtype(item)
if rtype not in seen:
items.append(rtype)
seen.add(rtype)
if len(items) > 1:
return RUnion(items)
else:
return items[0]
elif isinstance(typ, AnyType):
return object_rprimitive
elif isinstance(typ, TypeType):
Expand Down
37 changes: 37 additions & 0 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,40 @@ L4:
L5:
res = r8
return 1

[case testSimplifyListUnion]
from typing import List, Union

def f(a: Union[List[str], List[bytes], int]) -> int:
if isinstance(a, list):
return len(a)
return a
[out]
def f(a):
a :: union[list, int]
r0 :: object
r1 :: int32
r2 :: bit
r3 :: bool
r4 :: list
r5 :: ptr
r6 :: native_int
r7 :: short_int
r8 :: int
L0:
r0 = load_address PyList_Type
r1 = PyObject_IsInstance(a, r0)
r2 = r1 >= 0 :: signed
r3 = truncate r1: int32 to builtins.bool
if r3 goto L1 else goto L2 :: bool
L1:
r4 = borrow cast(list, a)
r5 = get_element_ptr r4 ob_size :: PyVarObject
r6 = load_mem r5 :: native_int*
keep_alive r4
r7 = r6 << 1
keep_alive a
return r7
L2:
r8 = unbox(int, a)
return r8
14 changes: 5 additions & 9 deletions mypyc/test-data/irbuild-optional.test
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,10 @@ class B:

[out]
def f(o):
o :: union[object, object]
r0 :: object
r1 :: str
r2, r3 :: object
o :: object
r0 :: str
r1 :: object
L0:
r0 = o
r1 = 'x'
r2 = CPyObject_GetAttr(r0, r1)
r3 = r2
L1:
r0 = 'x'
r1 = CPyObject_GetAttr(o, r0)
return 1