diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 4364b2b6c511..a108766644ce 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -116,7 +116,18 @@ def type_to_rtype(self, typ: Type | None) -> RType: elif isinstance(typ, NoneTyp): return none_rprimitive elif isinstance(typ, UnionType): - return RUnion([self.type_to_rtype(item) for item in typ.items]) + # Remove redundant items using set + list to preserve item order + seen = set() + items = [] + for item in typ.items: + rtype = self.type_to_rtype(item) + if rtype not in seen: + items.append(rtype) + seen.add(rtype) + if len(items) > 1: + return RUnion(items) + else: + return items[0] elif isinstance(typ, AnyType): return object_rprimitive elif isinstance(typ, TypeType): diff --git a/mypyc/test-data/irbuild-lists.test b/mypyc/test-data/irbuild-lists.test index 47f7ada709e3..b82217465fef 100644 --- a/mypyc/test-data/irbuild-lists.test +++ b/mypyc/test-data/irbuild-lists.test @@ -428,3 +428,40 @@ L4: L5: res = r8 return 1 + +[case testSimplifyListUnion] +from typing import List, Union + +def f(a: Union[List[str], List[bytes], int]) -> int: + if isinstance(a, list): + return len(a) + return a +[out] +def f(a): + a :: union[list, int] + r0 :: object + r1 :: int32 + r2 :: bit + r3 :: bool + r4 :: list + r5 :: ptr + r6 :: native_int + r7 :: short_int + r8 :: int +L0: + r0 = load_address PyList_Type + r1 = PyObject_IsInstance(a, r0) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + if r3 goto L1 else goto L2 :: bool +L1: + r4 = borrow cast(list, a) + r5 = get_element_ptr r4 ob_size :: PyVarObject + r6 = load_mem r5 :: native_int* + keep_alive r4 + r7 = r6 << 1 + keep_alive a + return r7 +L2: + r8 = unbox(int, a) + return r8 diff --git a/mypyc/test-data/irbuild-optional.test b/mypyc/test-data/irbuild-optional.test index 4b1d3d1ffec2..e98cf1b19e2e 100644 --- a/mypyc/test-data/irbuild-optional.test +++ b/mypyc/test-data/irbuild-optional.test @@ -527,14 +527,10 @@ class B: [out] def f(o): - o :: union[object, object] - r0 :: object - r1 :: str - r2, r3 :: object + o :: object + r0 :: str + r1 :: object L0: - r0 = o - r1 = 'x' - r2 = CPyObject_GetAttr(r0, r1) - r3 = r2 -L1: + r0 = 'x' + r1 = CPyObject_GetAttr(o, r0) return 1