Skip to content

Commit

Permalink
[mypyc] Fixes to union simplification
Browse files Browse the repository at this point in the history
Fix crash related to unions in loops. The crash was introduced in #14363.

Flatten nested unions before simplifying unions.
  • Loading branch information
JukkaL committed Dec 28, 2022
1 parent 86dad8a commit 2a810c8
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 17 deletions.
37 changes: 37 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,30 @@ def __init__(self, items: list[RType]) -> None:
self.items_set = frozenset(items)
self._ctype = "PyObject *"

@staticmethod
def make_simplified_union(items: list[RType]) -> RType:
"""Return a normalized union that covers the given items.
Flatten nested unions and remove duplicate items.
Overlapping items are *not* simplified. For example,
[object, str] will not be simplified.
"""
items = flatten_nested_unions(items)
assert items

# Remove duplicate items using set + list to preserve item order
seen = set()
new_items = []
for item in items:
if item not in seen:
new_items.append(item)
seen.add(item)
if len(new_items) > 1:
return RUnion(new_items)
else:
return new_items[0]

def accept(self, visitor: RTypeVisitor[T]) -> T:
return visitor.visit_runion(self)

Expand All @@ -823,6 +847,19 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RUnion:
return RUnion(types)


def flatten_nested_unions(types: list[RType]) -> list[RType]:
if not any(isinstance(t, RUnion) for t in types):
return types # Fast path

flat_items: list[RType] = []
for t in types:
if isinstance(t, RUnion):
flat_items.extend(flatten_nested_unions(t.items))
else:
flat_items.append(t)
return flat_items


def optional_value_type(rtype: RType) -> RType | None:
"""If rtype is the union of none_rprimitive and another type X, return X.
Expand Down
13 changes: 11 additions & 2 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Type,
TypeOfAny,
UninhabitedType,
UnionType,
get_proper_type,
)
from mypy.util import split_target
Expand Down Expand Up @@ -85,6 +86,7 @@
RInstance,
RTuple,
RType,
RUnion,
bitmap_rprimitive,
c_int_rprimitive,
c_pyssize_t_rprimitive,
Expand Down Expand Up @@ -864,8 +866,15 @@ def extract_int(self, e: Expression) -> int | None:
return None

def get_sequence_type(self, expr: Expression) -> RType:
target_type = get_proper_type(self.types[expr])
assert isinstance(target_type, Instance)
return self.get_sequence_type_from_type(self.types[expr])

def get_sequence_type_from_type(self, target_type: Type) -> RType:
target_type = get_proper_type(target_type)
if isinstance(target_type, UnionType):
return RUnion.make_simplified_union(
[self.get_sequence_type_from_type(item) for item in target_type.items]
)
assert isinstance(target_type, Instance), target_type
if target_type.type.fullname == "builtins.str":
return str_rprimitive
else:
Expand Down
13 changes: 1 addition & 12 deletions mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,7 @@ def type_to_rtype(self, typ: Type | None) -> RType:
elif isinstance(typ, NoneTyp):
return none_rprimitive
elif isinstance(typ, UnionType):
# Remove redundant items using set + list to preserve item order
seen = set()
items = []
for item in typ.items:
rtype = self.type_to_rtype(item)
if rtype not in seen:
items.append(rtype)
seen.add(rtype)
if len(items) > 1:
return RUnion(items)
else:
return items[0]
return RUnion.make_simplified_union([self.type_to_rtype(item) for item in typ.items])
elif isinstance(typ, AnyType):
return object_rprimitive
elif isinstance(typ, TypeType):
Expand Down
70 changes: 67 additions & 3 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,20 @@ L5:
return 1

[case testSimplifyListUnion]
from typing import List, Union
from typing import List, Union, Optional

def f(a: Union[List[str], List[bytes], int]) -> int:
def narrow(a: Union[List[str], List[bytes], int]) -> int:
if isinstance(a, list):
return len(a)
return a
def loop(a: Union[List[str], List[bytes]]) -> None:
for x in a:
pass
def nested_union(a: Union[List[str], List[Optional[str]]]) -> None:
for x in a:
pass
[out]
def f(a):
def narrow(a):
a :: union[list, int]
r0 :: object
r1 :: int32
Expand Down Expand Up @@ -465,3 +471,61 @@ L1:
L2:
r8 = unbox(int, a)
return r8
def loop(a):
a :: list
r0 :: short_int
r1 :: ptr
r2 :: native_int
r3 :: short_int
r4 :: bit
r5 :: object
r6, x :: union[str, bytes]
r7 :: short_int
L0:
r0 = 0
L1:
r1 = get_element_ptr a ob_size :: PyVarObject
r2 = load_mem r1 :: native_int*
keep_alive a
r3 = r2 << 1
r4 = r0 < r3 :: signed
if r4 goto L2 else goto L4 :: bool
L2:
r5 = CPyList_GetItemUnsafe(a, r0)
r6 = cast(union[str, bytes], r5)
x = r6
L3:
r7 = r0 + 2
r0 = r7
goto L1
L4:
return 1
def nested_union(a):
a :: list
r0 :: short_int
r1 :: ptr
r2 :: native_int
r3 :: short_int
r4 :: bit
r5 :: object
r6, x :: union[str, None]
r7 :: short_int
L0:
r0 = 0
L1:
r1 = get_element_ptr a ob_size :: PyVarObject
r2 = load_mem r1 :: native_int*
keep_alive a
r3 = r2 << 1
r4 = r0 < r3 :: signed
if r4 goto L2 else goto L4 :: bool
L2:
r5 = CPyList_GetItemUnsafe(a, r0)
r6 = cast(union[str, None], r5)
x = r6
L3:
r7 = r0 + 2
r0 = r7
goto L1
L4:
return 1

0 comments on commit 2a810c8

Please sign in to comment.