diff --git a/mypy/typeops.py b/mypy/typeops.py index 09f418b129ed5..41c59c88f1c4d 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,7 +5,7 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar +from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Tuple from typing_extensions import Type as TypingType import sys @@ -311,6 +311,15 @@ def callable_corresponding_argument(typ: CallableType, return by_name if by_name is not None else by_pos +def is_simple_literal(t: Type) -> bool: + """ + Whether a type is a simple enough literal to allow for fast set-based Union simplification + + For now this means enuum or string + """ + return isinstance(t, LiteralType) and (t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str') + + def make_simplified_union(items: Sequence[Type], line: int = -1, column: int = -1, *, keep_erased: bool = False) -> ProperType: @@ -344,35 +353,49 @@ def make_simplified_union(items: Sequence[Type], from mypy.subtypes import is_proper_subtype removed = set() # type: Set[int] - - # Avoid slow nested for loop for Union of Literal of strings (issue #9169) - if all((isinstance(item, LiteralType) and - item.fallback.type.fullname == 'builtins.str') - for item in items): - seen = set() # type: Set[str] - for index, item in enumerate(items): + seen = set() # type: Set[Tuple[str, str]] + + # NB: having a separate fast path for Union of Literal and slow path for other things + # would arguably be cleaner, however it breaks down when simplifying the Union of two + # different enum types as try_expanding_enum_to_union works recursively and will + # trigger intermediate simplifications that would render the fast path useless + for i, item in enumerate(items): + if i in removed: + continue + # Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169) + if is_simple_literal(item): assert isinstance(item, LiteralType) assert isinstance(item.value, str) - if item.value in seen: - removed.add(index) - seen.add(item.value) + k = (item.value, item.fallback.type.fullname) + if k in seen: + removed.add(i) + continue - else: - for i, ti in enumerate(items): - if i in removed: continue - # Keep track of the truishness info for deleted subtypes which can be relevant - cbt = cbf = False - for j, tj in enumerate(items): - if i != j and is_proper_subtype(tj, ti, keep_erased_types=keep_erased): - # We found a redundant item in the union. - removed.add(j) - cbt = cbt or tj.can_be_true - cbf = cbf or tj.can_be_false - # if deleted subtypes had more general truthiness, use that - if not ti.can_be_true and cbt: - items[i] = true_or_false(ti) - elif not ti.can_be_false and cbf: - items[i] = true_or_false(ti) + # NB: one would naively expect that it would be safe to skip the slow path + # always for literals. One would be sorely mistaken. Indeed, some simplifications + # such as that of None/Optional when strict optional is false, do require that we + # proceed with the slow path. Thankfully, all literals will have the same subtype + # relationship to non-literal types, so we only need to do that walk for the first + # literal, which keeps the fast path fast even in the presence of a mixture of + # literals and other types. + safe_skip = len(seen) > 0 + seen.add(k) + if safe_skip: + continue + # Keep track of the truishness info for deleted subtypes which can be relevant + cbt = cbf = False + for j, tj in enumerate(items): + # NB: we don't need to check literals as the fast path above takes care of that + if i != j and not is_simple_literal(tj) and is_proper_subtype(tj, item, keep_erased_types=keep_erased): + # We found a redundant item in the union. + removed.add(j) + cbt = cbt or tj.can_be_true + cbf = cbf or tj.can_be_false + # if deleted subtypes had more general truthiness, use that + if not item.can_be_true and cbt: + items[i] = true_or_false(item) + elif not item.can_be_false and cbf: + items[i] = true_or_false(item) simplified_set = [items[i] for i in range(len(items)) if i not in removed] return UnionType.make_union(simplified_set, line, column)