diff --git a/mypy/typeops.py b/mypy/typeops.py index 718800967b44..952d02be04e9 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -314,6 +314,17 @@ def callable_corresponding_argument(typ: CallableType, return by_name if by_name is not None else by_pos +def is_simple_literal(t: ProperType) -> bool: + """ + Whether a type is a simple enough literal to allow for fast Union simplification + + For now this means enum or string + """ + return isinstance(t, LiteralType) and ( + t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str' + ) + + def make_simplified_union(items: Sequence[Type], line: int = -1, column: int = -1, *, keep_erased: bool = False, @@ -348,36 +359,54 @@ def make_simplified_union(items: Sequence[Type], from mypy.subtypes import is_proper_subtype removed: Set[int] = set() - - # Avoid slow nested for loop for Union of Literal of strings (issue #9169) - if all((isinstance(item, LiteralType) and - item.fallback.type.fullname == 'builtins.str') - for item in items): - seen: Set[str] = set() - for index, item in enumerate(items): + seen: Set[Tuple[str, str]] = set() + + # NB: having a separate fast path for Union of Literal and slow path for other things + # would arguably be cleaner, however it breaks down when simplifying the Union of two + # different enum types as try_expanding_enum_to_union works recursively and will + # trigger intermediate simplifications that would render the fast path useless + for i, item in enumerate(items): + if i in removed: + continue + # Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169) + if is_simple_literal(item): assert isinstance(item, LiteralType) assert isinstance(item.value, str) - if item.value in seen: - removed.add(index) - seen.add(item.value) + k = (item.value, item.fallback.type.fullname) + if k in seen: + removed.add(i) + continue - else: - for i, ti in enumerate(items): - if i in removed: continue - # Keep track of the truishness info for deleted subtypes which can be relevant - cbt = cbf = False - for j, tj in enumerate(items): - if i != j and is_proper_subtype(tj, ti, keep_erased_types=keep_erased) and \ - is_redundant_literal_instance(ti, tj): - # We found a redundant item in the union. - removed.add(j) - cbt = cbt or tj.can_be_true - cbf = cbf or tj.can_be_false - # if deleted subtypes had more general truthiness, use that - if not ti.can_be_true and cbt: - items[i] = true_or_false(ti) - elif not ti.can_be_false and cbf: - items[i] = true_or_false(ti) + # NB: one would naively expect that it would be safe to skip the slow path + # always for literals. One would be sorely mistaken. Indeed, some simplifications + # such as that of None/Optional when strict optional is false, do require that we + # proceed with the slow path. Thankfully, all literals will have the same subtype + # relationship to non-literal types, so we only need to do that walk for the first + # literal, which keeps the fast path fast even in the presence of a mixture of + # literals and other types. + safe_skip = len(seen) > 0 + seen.add(k) + if safe_skip: + continue + # Keep track of the truishness info for deleted subtypes which can be relevant + cbt = cbf = False + for j, tj in enumerate(items): + # NB: we don't need to check literals as the fast path above takes care of that + if ( + i != j + and not is_simple_literal(tj) + and is_proper_subtype(tj, item, keep_erased_types=keep_erased) + and is_redundant_literal_instance(item, tj) # XXX? + ): + # We found a redundant item in the union. + removed.add(j) + cbt = cbt or tj.can_be_true + cbf = cbf or tj.can_be_false + # if deleted subtypes had more general truthiness, use that + if not item.can_be_true and cbt: + items[i] = true_or_false(item) + elif not item.can_be_false and cbf: + items[i] = true_or_false(item) simplified_set = [items[i] for i in range(len(items)) if i not in removed]