Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up freshening type variables #14323

Merged
merged 8 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload
from typing_extensions import Final

from mypy.nodes import ARG_POS, ARG_STAR, Var
from mypy.type_visitor import TypeTranslator
from mypy.types import (
ANY_STRATEGY,
AnyType,
BoolTypeQuery,
CallableType,
DeletedType,
ErasedType,
Expand Down Expand Up @@ -138,13 +141,30 @@ def freshen_function_type_vars(callee: F) -> F:
return cast(F, fresh_overload)


class HasGenericCallable(BoolTypeQuery):
def __init__(self) -> None:
super().__init__(ANY_STRATEGY)

def visit_callable_type(self, t: CallableType) -> bool:
return t.is_generic() or super().visit_callable_type(t)


# Share a singleton since this is performance sensitive
has_generic_callable: Final = HasGenericCallable()


T = TypeVar("T", bound=Type)


def freshen_all_functions_type_vars(t: T) -> T:
result = t.accept(FreshenCallableVisitor())
assert isinstance(result, type(t))
return result
result: Type
has_generic_callable.reset()
if not t.accept(has_generic_callable):
return t # Fast path to avoid expensive freshening
else:
result = t.accept(FreshenCallableVisitor())
assert isinstance(result, type(t))
return result


class FreshenCallableVisitor(TypeTranslator):
Expand Down
156 changes: 156 additions & 0 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from abc import abstractmethod
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
from typing_extensions import Final

from mypy_extensions import mypyc_attr, trait

Expand Down Expand Up @@ -417,3 +418,158 @@ def visit_type_alias_type(self, t: TypeAliasType) -> T:
def query_types(self, types: Iterable[Type]) -> T:
"""Perform a query for a list of types using the strategy to combine the results."""
return self.strategy([t.accept(self) for t in types])


# Return True if at least one type component returns True
ANY_STRATEGY: Final = 0
# Return True if no type component returns False
ALL_STRATEGY: Final = 1


class BoolTypeQuery(SyntheticTypeVisitor[bool]):
"""Visitor for performing recursive queries of types with a bool result.

Use TypeQuery if you need non-bool results.

'strategy' is used to combine results for a series of types. It must
be ANY_STRATEGY or ALL_STRATEGY.

Note: This visitor keeps an internal state (tracks type aliases to avoid
recursion), so it should *never* be re-used for querying different types
unless you call reset() first.
"""

def __init__(self, strategy: int) -> None:
self.strategy = strategy
if strategy == ANY_STRATEGY:
self.default = False
else:
assert strategy == ALL_STRATEGY
self.default = True
# Keep track of the type aliases already visited. This is needed to avoid
# infinite recursion on types like A = Union[int, List[A]]. An empty set is
# represented as None as a micro-optimization.
self.seen_aliases: set[TypeAliasType] | None = None
# By default, we eagerly expand type aliases, and query also types in the
# alias target. In most cases this is a desired behavior, but we may want
# to skip targets in some cases (e.g. when collecting type variables).
self.skip_alias_target = False

def reset(self) -> None:
"""Clear mutable state (but preserve strategy).

This *must* be called if you want to reuse the visitor.
"""
self.seen_aliases = None

def visit_unbound_type(self, t: UnboundType) -> bool:
return self.query_types(t.args)

def visit_type_list(self, t: TypeList) -> bool:
return self.query_types(t.items)

def visit_callable_argument(self, t: CallableArgument) -> bool:
return t.typ.accept(self)

def visit_any(self, t: AnyType) -> bool:
return self.default

def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
return self.default

def visit_none_type(self, t: NoneType) -> bool:
return self.default

def visit_erased_type(self, t: ErasedType) -> bool:
return self.default

def visit_deleted_type(self, t: DeletedType) -> bool:
return self.default

def visit_type_var(self, t: TypeVarType) -> bool:
return self.query_types([t.upper_bound] + t.values)

def visit_param_spec(self, t: ParamSpecType) -> bool:
return self.default

def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
return self.default

def visit_unpack_type(self, t: UnpackType) -> bool:
return self.query_types([t.type])

def visit_parameters(self, t: Parameters) -> bool:
return self.query_types(t.arg_types)

def visit_partial_type(self, t: PartialType) -> bool:
return self.default

def visit_instance(self, t: Instance) -> bool:
return self.query_types(t.args)

def visit_callable_type(self, t: CallableType) -> bool:
# FIX generics
# Avoid allocating any objects here as an optimization.
args = self.query_types(t.arg_types)
ret = t.ret_type.accept(self)
if self.strategy == ANY_STRATEGY:
return args or ret
else:
return args and ret

def visit_tuple_type(self, t: TupleType) -> bool:
return self.query_types(t.items)

def visit_typeddict_type(self, t: TypedDictType) -> bool:
return self.query_types(list(t.items.values()))

def visit_raw_expression_type(self, t: RawExpressionType) -> bool:
return self.default

def visit_literal_type(self, t: LiteralType) -> bool:
return self.default

def visit_star_type(self, t: StarType) -> bool:
return t.type.accept(self)

def visit_union_type(self, t: UnionType) -> bool:
return self.query_types(t.items)

def visit_overloaded(self, t: Overloaded) -> bool:
return self.query_types(t.items) # type: ignore[arg-type]

def visit_type_type(self, t: TypeType) -> bool:
return t.item.accept(self)

def visit_ellipsis_type(self, t: EllipsisType) -> bool:
return self.default

def visit_placeholder_type(self, t: PlaceholderType) -> bool:
return self.query_types(t.args)

def visit_type_alias_type(self, t: TypeAliasType) -> bool:
# Skip type aliases already visited types to avoid infinite recursion.
# TODO: Ideally we should fire subvisitors here (or use caching) if we care
# about duplicates.
if self.seen_aliases is None:
self.seen_aliases = set()
Comment on lines +554 to +555
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just start with an empty set rather than a None? That should simplify this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We start with None since it saves the creation (and deletion) of an empty set in the common code path where we don't use the set at all. It's a micro-optimization for sure, but allocating objects is kind of expensive in CPython and this is a hot code path, so it should make things a little bit faster.

elif t in self.seen_aliases:
return self.default
self.seen_aliases.add(t)
if self.skip_alias_target:
return self.query_types(t.args)
return get_proper_type(t).accept(self)

def query_types(self, types: list[Type] | tuple[Type, ...]) -> bool:
"""Perform a query for a sequence of types using the strategy to combine the results."""
# Special-case for lists and tuples to allow mypyc to produce better code.
if isinstance(types, list):
if self.strategy == ANY_STRATEGY:
return any(t.accept(self) for t in types)
else:
return all(t.accept(self) for t in types)
else:
if self.strategy == ANY_STRATEGY:
return any(t.accept(self) for t in types)
else:
return all(t.accept(self) for t in types)
5 changes: 4 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,7 +2874,10 @@ def get_proper_types(it: Iterable[Type | None]) -> list[ProperType] | list[Prope
# to make it easier to gradually get modules working with mypyc.
# Import them here, after the types are defined.
# This is intended as a re-export also.
from mypy.type_visitor import ( # noqa: F811
from mypy.type_visitor import ( # noqa: F811,F401
ALL_STRATEGY as ALL_STRATEGY,
ANY_STRATEGY as ANY_STRATEGY,
BoolTypeQuery as BoolTypeQuery,
SyntheticTypeVisitor as SyntheticTypeVisitor,
TypeQuery as TypeQuery,
TypeTranslator as TypeTranslator,
Expand Down