Skip to content

Commit

Permalink
more enum-related speedups
Browse files Browse the repository at this point in the history
As a followup to python#9394 address a few more O(n**2) behaviors
caused by decomposing enums into unions of literals.
  • Loading branch information
hugues-aff committed Jan 21, 2022
1 parent 0cec4f7 commit a690223
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 18 deletions.
32 changes: 32 additions & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
if isinstance(declared, UnionType):
return make_simplified_union([narrow_declared_type(x, narrowed)
for x in declared.relevant_items()])
if is_enum_overlapping_union(declared, narrowed):
return narrowed
elif not is_overlapping_types(declared, narrowed,
prohibit_none_typevar_overlap=True):
if state.strict_optional:
Expand Down Expand Up @@ -137,6 +139,24 @@ def get_possible_variants(typ: Type) -> List[Type]:
return [typ]


def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool:
return (
isinstance(x, Instance) and x.type.is_enum and
isinstance(y, UnionType) and
all(isinstance(z, LiteralType) and z.fallback.type == x.type # type: ignore[misc]
for z in y.items)
)


def is_literal_in_union(x: ProperType, y: ProperType) -> bool:
return (
isinstance(x, LiteralType) and isinstance(y, UnionType) and any(
isinstance(z, LiteralType) and z == x # type: ignore[misc]
for z in y.items
)
)


def is_overlapping_types(left: Type,
right: Type,
ignore_promotions: bool = False,
Expand Down Expand Up @@ -198,6 +218,18 @@ def _is_overlapping_types(left: Type, right: Type) -> bool:
#
# These checks will also handle the NoneType and UninhabitedType cases for us.

# enums are sometimes expanded into an Union of Literals
# when that happens we want to make sure we treat the two as overlapping
# and crucially, we want to do that *fast* in case the enum is large
# so we do it before expanding variants below to avoid O(n**2) behavior
if (
is_enum_overlapping_union(left, right) or
is_enum_overlapping_union(right, left) or
is_literal_in_union(left, right) or
is_literal_in_union(right, left)
):
return True

if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions)
or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)):
return True
Expand Down
30 changes: 24 additions & 6 deletions mypy/sametypes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Sequence
from typing import Sequence, Tuple, Set, List

from mypy.types import (
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType,
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType,
ProperType, get_proper_type, TypeAliasType, ParamSpecType
)
from mypy.typeops import tuple_fallback, make_simplified_union
from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal


def is_same_type(left: Type, right: Type) -> bool:
Expand Down Expand Up @@ -143,14 +143,32 @@ def visit_literal_type(self, left: LiteralType) -> bool:

def visit_union_type(self, left: UnionType) -> bool:
if isinstance(self.right, UnionType):
# fast path for simple literals
def _extract_literals(u: UnionType) -> Tuple[Set[LiteralType], List[Type]]:
lit = set() # type: Set[LiteralType]
rem = [] # type: List[Type]
for i in u.items:
if is_simple_literal(i):
assert isinstance(i, LiteralType) # type: ignore[misc]
lit.add(i)
else:
rem.append(i)
return lit, rem

left_lit, left_rem = _extract_literals(left)
right_lit, right_rem = _extract_literals(self.right)

if left_lit != right_lit:
return False

# Check that everything in left is in right
for left_item in left.items:
if not any(is_same_type(left_item, right_item) for right_item in self.right.items):
for left_item in left_rem:
if not any(is_same_type(left_item, right_item) for right_item in right_rem):
return False

# Check that everything in right is in left
for right_item in self.right.items:
if not any(is_same_type(right_item, left_item) for left_item in left.items):
for right_item in right_rem:
if not any(is_same_type(right_item, left_item) for left_item in left_rem):
return False

return True
Expand Down
47 changes: 37 additions & 10 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,19 @@ def visit_overloaded(self, left: Overloaded) -> bool:
return False

def visit_union_type(self, left: UnionType) -> bool:
if isinstance(self.right, Instance):
literal_types = set() # type: Set[Instance]
# avoid redundant check for union of literals
for item in left.items:
if mypy.typeops.is_simple_literal(item):
assert isinstance(item, LiteralType) # type: ignore[misc]
if item.fallback in literal_types:
continue
literal_types.add(item.fallback)
item = item.fallback
if not self._is_subtype(item, self.orig_right):
return False
return True
return all(self._is_subtype(item, self.orig_right) for item in left.items)

def visit_partial_type(self, left: PartialType) -> bool:
Expand Down Expand Up @@ -1137,6 +1150,17 @@ def report(*args: Any) -> None:
return applied


def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]:
new_items = [] # type: List[Type]
for i in t.relevant_items():
it = get_proper_type(i)
if not mypy.typeops.is_simple_literal(it):
return None
if it != s:
new_items.append(i)
return new_items


def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
"""Return t minus s for runtime type assertions.
Expand All @@ -1150,10 +1174,13 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
s = get_proper_type(s)

if isinstance(t, UnionType):
new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
for item in t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or
not covers_at_runtime(item, s, ignore_promotions))]
new_items = try_restrict_literal_union(t, s) if isinstance(s, LiteralType) else []
new_items = new_items or [
restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
for item in t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or
not covers_at_runtime(item, s, ignore_promotions))
]
return UnionType.make_union(new_items)
elif covers_at_runtime(t, s, ignore_promotions):
return UninhabitedType()
Expand Down Expand Up @@ -1223,11 +1250,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
right = get_proper_type(right)

if isinstance(right, UnionType) and not isinstance(left, UnionType):
return any([is_proper_subtype(orig_left, item,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types)
for item in right.items])
return any(is_proper_subtype(orig_left, item,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types)
for item in right.items)
return left.accept(ProperSubtypeVisitor(orig_right,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
Expand Down Expand Up @@ -1418,7 +1445,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
return False

def visit_union_type(self, left: UnionType) -> bool:
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items])
return all(self._is_proper_subtype(item, self.orig_right) for item in left.items)

def visit_partial_type(self, left: PartialType) -> bool:
# TODO: What's the right thing to do here?
Expand Down
4 changes: 2 additions & 2 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,13 @@ def callable_corresponding_argument(typ: CallableType,
return by_name if by_name is not None else by_pos


def is_simple_literal(t: ProperType) -> bool:
def is_simple_literal(t: Type) -> bool:
"""
Whether a type is a simple enough literal to allow for fast Union simplification
For now this means enum or string
"""
return isinstance(t, LiteralType) and (
return isinstance(t, LiteralType) and ( # type: ignore[misc]
t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str'
)

Expand Down

0 comments on commit a690223

Please sign in to comment.