From d27bff62f71a8b914b1df239467148e81d2e88a2 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 3 Aug 2022 14:46:34 +0100 Subject: [PATCH] Merge subtype visitors (#13303) Fixes #3297 This removes a significant chunk of code duplication. This is not a pure refactor, there were some cases when one of the visitors (mostly non-proper one) was more correct and/or complete. In few corner cases, where it was hard to decide, I merged behavior with `if` checks. --- mypy/meet.py | 23 +- mypy/subtypes.py | 717 +++++++++----------------- mypy/test/testtypes.py | 2 +- mypy/typestate.py | 6 + test-data/unit/check-expressions.test | 13 + test-data/unit/check-overloading.test | 4 +- 6 files changed, 271 insertions(+), 494 deletions(-) diff --git a/mypy/meet.py b/mypy/meet.py index deb95f11283a..8bc820ba8d09 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -361,8 +361,8 @@ def _type_object_overlap(left: Type, right: Type) -> bool: """Special cases for type object types overlaps.""" # TODO: these checks are a bit in gray area, adjust if they cause problems. left, right = get_proper_types((left, right)) - # 1. Type[C] vs Callable[..., C], where the latter is class object. - if isinstance(left, TypeType) and isinstance(right, CallableType) and right.is_type_obj(): + # 1. Type[C] vs Callable[..., C] overlap even if the latter is not class object. + if isinstance(left, TypeType) and isinstance(right, CallableType): return _is_overlapping_types(left.item, right.ret_type) # 2. Type[C] vs Meta, where Meta is a metaclass for C. if isinstance(left, TypeType) and isinstance(right, Instance): @@ -381,13 +381,18 @@ def _type_object_overlap(left: Type, right: Type) -> bool: return _type_object_overlap(left, right) or _type_object_overlap(right, left) if isinstance(left, CallableType) and isinstance(right, CallableType): - return is_callable_compatible( - left, - right, - is_compat=_is_overlapping_types, - ignore_pos_arg_names=True, - allow_partial_overlap=True, - ) + + def _callable_overlap(left: CallableType, right: CallableType) -> bool: + return is_callable_compatible( + left, + right, + is_compat=_is_overlapping_types, + ignore_pos_arg_names=True, + allow_partial_overlap=True, + ) + + # Compare both directions to handle type objects. + return _callable_overlap(left, right) or _callable_overlap(right, left) elif isinstance(left, CallableType): left = left.fallback elif isinstance(right, CallableType): diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7ef702e8493d..1c639172ffa4 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -65,26 +65,50 @@ IS_CLASSVAR: Final = 2 IS_CLASS_OR_STATIC: Final = 3 -TypeParameterChecker: _TypeAlias = Callable[[Type, Type, int], bool] +TypeParameterChecker: _TypeAlias = Callable[[Type, Type, int, bool], bool] -def check_type_parameter(lefta: Type, righta: Type, variance: int) -> bool: - if variance == COVARIANT: - return is_subtype(lefta, righta) - elif variance == CONTRAVARIANT: - return is_subtype(righta, lefta) - else: - return is_equivalent(lefta, righta) - +class SubtypeContext: + def __init__( + self, + *, + # Non-proper subtype flags + ignore_type_params: bool = False, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + # Supported for both proper and non-proper + ignore_promotions: bool = False, + # Proper subtype flags + erase_instances: bool = False, + keep_erased_types: bool = False, + options: Optional[Options] = None, + ) -> None: + self.ignore_type_params = ignore_type_params + self.ignore_pos_arg_names = ignore_pos_arg_names + self.ignore_declared_variance = ignore_declared_variance + self.ignore_promotions = ignore_promotions + self.erase_instances = erase_instances + self.keep_erased_types = keep_erased_types + self.options = options -def ignore_type_parameter(s: Type, t: Type, v: int) -> bool: - return True + def check_context(self, proper_subtype: bool) -> None: + # Historically proper and non-proper subtypes were defined using different helpers + # and different visitors. Check if flag values are such that we definitely support. + if proper_subtype: + assert ( + not self.ignore_type_params + and not self.ignore_pos_arg_names + and not self.ignore_declared_variance + ) + else: + assert not self.erase_instances and not self.keep_erased_types def is_subtype( left: Type, right: Type, *, + subtype_context: Optional[SubtypeContext] = None, ignore_type_params: bool = False, ignore_pos_arg_names: bool = False, ignore_declared_variance: bool = False, @@ -102,6 +126,24 @@ def is_subtype( between the type arguments (e.g., A and B), taking the variance of the type var into account. """ + if subtype_context is None: + subtype_context = SubtypeContext( + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions, + options=options, + ) + else: + assert not any( + { + ignore_type_params, + ignore_pos_arg_names, + ignore_declared_variance, + ignore_promotions, + options, + } + ), "Don't pass both context and individual flags" if TypeState.is_assumed_subtype(left, right): return True if ( @@ -129,63 +171,107 @@ def is_subtype( # B = Union[int, Tuple[B, ...]] # When checking if A <: B we push pair (A, B) onto 'assuming' stack, then when after few # steps we come back to initial call is_subtype(A, B) and immediately return True. - with pop_on_exit(TypeState._assuming, left, right): - return _is_subtype( - left, - right, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options, - ) - return _is_subtype( - left, - right, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options, - ) + with pop_on_exit(TypeState.get_assumptions(is_proper=False), left, right): + return _is_subtype(left, right, subtype_context, proper_subtype=False) + return _is_subtype(left, right, subtype_context, proper_subtype=False) -def _is_subtype( +def is_proper_subtype( left: Type, right: Type, *, + subtype_context: Optional[SubtypeContext] = None, + ignore_promotions: bool = False, + erase_instances: bool = False, + keep_erased_types: bool = False, +) -> bool: + """Is left a proper subtype of right? + + For proper subtypes, there's no need to rely on compatibility due to + Any types. Every usable type is a proper subtype of itself. + + If erase_instances is True, erase left instance *after* mapping it to supertype + (this is useful for runtime isinstance() checks). If keep_erased_types is True, + do not consider ErasedType a subtype of all types (used by type inference against unions). + """ + if subtype_context is None: + subtype_context = SubtypeContext( + ignore_promotions=ignore_promotions, + erase_instances=erase_instances, + keep_erased_types=keep_erased_types, + ) + else: + assert not any( + {ignore_promotions, erase_instances, keep_erased_types} + ), "Don't pass both context and individual flags" + if TypeState.is_assumed_proper_subtype(left, right): + return True + if ( + isinstance(left, TypeAliasType) + and isinstance(right, TypeAliasType) + and left.is_recursive + and right.is_recursive + ): + # Same as for non-proper subtype, see detailed comment there for explanation. + with pop_on_exit(TypeState.get_assumptions(is_proper=True), left, right): + return _is_subtype(left, right, subtype_context, proper_subtype=True) + return _is_subtype(left, right, subtype_context, proper_subtype=True) + + +def is_equivalent( + a: Type, + b: Type, + *, ignore_type_params: bool = False, ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False, - ignore_promotions: bool = False, options: Optional[Options] = None, ) -> bool: + return is_subtype( + a, + b, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + options=options, + ) and is_subtype( + b, + a, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + options=options, + ) + + +# This is a common entry point for subtyping checks (both proper and non-proper). +# Never call this private function directly, use the public versions. +def _is_subtype( + left: Type, right: Type, subtype_context: SubtypeContext, proper_subtype: bool +) -> bool: + subtype_context.check_context(proper_subtype) orig_right = right orig_left = left left = get_proper_type(left) right = get_proper_type(right) - if ( + if not proper_subtype and ( isinstance(right, AnyType) or isinstance(right, UnboundType) or isinstance(right, ErasedType) ): + # TODO: should we consider all types proper subtypes of UnboundType and/or + # ErasedType as we do for non-proper subtyping. return True - elif isinstance(right, UnionType) and not isinstance(left, UnionType): + + def check_item(left: Type, right: Type, subtype_context: SubtypeContext) -> bool: + if proper_subtype: + return is_proper_subtype(left, right, subtype_context=subtype_context) + return is_subtype(left, right, subtype_context=subtype_context) + + if isinstance(right, UnionType) and not isinstance(left, UnionType): # Normally, when 'left' is not itself a union, the only way # 'left' can be a subtype of the union 'right' is if it is a # subtype of one of the items making up the union. is_subtype_of_item = any( - is_subtype( - orig_left, - item, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options, - ) - for item in right.items + check_item(orig_left, item, subtype_context) for item in right.items ) # Recombine rhs literal types, to make an enum type a subtype # of a union of all enum items as literal types. Only do it if @@ -199,16 +285,7 @@ def _is_subtype( ): right = UnionType(mypy.typeops.try_contracting_literals_in_union(right.items)) is_subtype_of_item = any( - is_subtype( - orig_left, - item, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options, - ) - for item in right.items + check_item(orig_left, item, subtype_context) for item in right.items ) # However, if 'left' is a type variable T, T might also have # an upper bound which is itself a union. This case will be @@ -221,105 +298,68 @@ def _is_subtype( elif is_subtype_of_item: return True # otherwise, fall through - return left.accept( - SubtypeVisitor( - orig_right, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options, - ) - ) + return left.accept(SubtypeVisitor(orig_right, subtype_context, proper_subtype)) -def is_equivalent( - a: Type, - b: Type, - *, - ignore_type_params: bool = False, - ignore_pos_arg_names: bool = False, - options: Optional[Options] = None, -) -> bool: - return is_subtype( - a, - b, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - options=options, - ) and is_subtype( - b, - a, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - options=options, - ) +def check_type_parameter(lefta: Type, righta: Type, variance: int, proper_subtype: bool) -> bool: + def check(left: Type, right: Type) -> bool: + return is_proper_subtype(left, right) if proper_subtype else is_subtype(left, right) + + if variance == COVARIANT: + return check(lefta, righta) + elif variance == CONTRAVARIANT: + return check(righta, lefta) + else: + if proper_subtype: + return mypy.sametypes.is_same_type(lefta, righta) + return is_equivalent(lefta, righta) + + +def ignore_type_parameter(lefta: Type, righta: Type, variance: int, proper_subtype: bool) -> bool: + return True class SubtypeVisitor(TypeVisitor[bool]): - def __init__( - self, - right: Type, - *, - ignore_type_params: bool, - ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False, - ignore_promotions: bool = False, - options: Optional[Options] = None, - ) -> None: + def __init__(self, right: Type, subtype_context: SubtypeContext, proper_subtype: bool) -> None: self.right = get_proper_type(right) self.orig_right = right - self.ignore_type_params = ignore_type_params - self.ignore_pos_arg_names = ignore_pos_arg_names - self.ignore_declared_variance = ignore_declared_variance - self.ignore_promotions = ignore_promotions + self.proper_subtype = proper_subtype + self.subtype_context = subtype_context self.check_type_parameter = ( - ignore_type_parameter if ignore_type_params else check_type_parameter - ) - self.options = options - self._subtype_kind = SubtypeVisitor.build_subtype_kind( - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, + ignore_type_parameter if subtype_context.ignore_type_params else check_type_parameter ) + self.options = subtype_context.options + self._subtype_kind = SubtypeVisitor.build_subtype_kind(subtype_context, proper_subtype) @staticmethod - def build_subtype_kind( - *, - ignore_type_params: bool = False, - ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False, - ignore_promotions: bool = False, - ) -> SubtypeKind: + def build_subtype_kind(subtype_context: SubtypeContext, proper_subtype: bool) -> SubtypeKind: return ( state.strict_optional, - False, # is proper subtype? - ignore_type_params, - ignore_pos_arg_names, - ignore_declared_variance, - ignore_promotions, + proper_subtype, + subtype_context.ignore_type_params, + subtype_context.ignore_pos_arg_names, + subtype_context.ignore_declared_variance, + subtype_context.ignore_promotions, + subtype_context.erase_instances, + subtype_context.keep_erased_types, ) def _is_subtype(self, left: Type, right: Type) -> bool: - return is_subtype( - left, - right, - ignore_type_params=self.ignore_type_params, - ignore_pos_arg_names=self.ignore_pos_arg_names, - ignore_declared_variance=self.ignore_declared_variance, - ignore_promotions=self.ignore_promotions, - options=self.options, - ) + if self.proper_subtype: + return is_proper_subtype(left, right, subtype_context=self.subtype_context) + return is_subtype(left, right, subtype_context=self.subtype_context) # visit_x(left) means: is left (which is an instance of X) a subtype of # right? def visit_unbound_type(self, left: UnboundType) -> bool: + # This can be called if there is a bad type annotation. The result probably + # doesn't matter much but by returning True we simplify these bad types away + # from unions, which could filter out some bogus messages. return True def visit_any(self, left: AnyType) -> bool: - return True + return isinstance(self.right, AnyType) if self.proper_subtype else True def visit_none_type(self, left: NoneType) -> bool: if state.strict_optional: @@ -341,13 +381,18 @@ def visit_uninhabited_type(self, left: UninhabitedType) -> bool: return True def visit_erased_type(self, left: ErasedType) -> bool: + # This may be encountered during type inference. The result probably doesn't + # matter much. + # TODO: it actually does matter, figure out more principled logic about this. + if self.subtype_context.keep_erased_types: + return False return True def visit_deleted_type(self, left: DeletedType) -> bool: return True def visit_instance(self, left: Instance) -> bool: - if left.type.fallback_to_any: + if left.type.fallback_to_any and not self.proper_subtype: if isinstance(self.right, NoneType): # NOTE: `None` is a *non-subclassable* singleton, therefore no class # can by a subtype of it, even with an `Any` fallback. @@ -361,7 +406,7 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(right, Instance): if TypeState.is_cached_subtype_check(self._subtype_kind, left, right): return True - if not self.ignore_promotions: + if not self.subtype_context.ignore_promotions: for base in left.type.mro: if base._promote and any( self._is_subtype(p, self.right) for p in base._promote @@ -386,9 +431,13 @@ def visit_instance(self, left: Instance) -> bool: rname in TYPED_NAMEDTUPLE_NAMES and any(l.is_named_tuple for l in left.type.mro) ) - ) and not self.ignore_declared_variance: + ) and not self.subtype_context.ignore_declared_variance: # Map left type to corresponding right instances. t = map_instance_to_supertype(left, right.type) + if self.subtype_context.erase_instances: + erased = erase_type(t) + assert isinstance(erased, Instance) + t = erased nominal = True if right.type.has_type_var_tuple_type: left_prefix, left_middle, left_suffix = split_with_instance(left) @@ -464,28 +513,36 @@ def check_mixed( type_params = zip(t.args, right.args, right.type.defn.type_vars) for lefta, righta, tvar in type_params: if isinstance(tvar, TypeVarType): - if not self.check_type_parameter(lefta, righta, tvar.variance): + if not self.check_type_parameter( + lefta, righta, tvar.variance, self.proper_subtype + ): nominal = False else: - if not self.check_type_parameter(lefta, righta, COVARIANT): + if not self.check_type_parameter( + lefta, righta, COVARIANT, self.proper_subtype + ): nominal = False if nominal: TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return nominal - if right.type.is_protocol and is_protocol_implementation(left, right): + if right.type.is_protocol and is_protocol_implementation( + left, right, proper_subtype=self.proper_subtype + ): return True return False if isinstance(right, TypeType): item = right.item if isinstance(item, TupleType): item = mypy.typeops.tuple_fallback(item) - if is_named_instance(left, "builtins.type"): - return self._is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) - if left.type.is_metaclass(): - if isinstance(item, AnyType): - return True - if isinstance(item, Instance): - return is_named_instance(item, "builtins.object") + # TODO: this is a bit arbitrary, we should only skip Any-related cases. + if not self.proper_subtype: + if is_named_instance(left, "builtins.type"): + return self._is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) + if left.type.is_metaclass(): + if isinstance(item, AnyType): + return True + if isinstance(item, Instance): + return is_named_instance(item, "builtins.object") if isinstance(right, LiteralType) and left.last_known_value is not None: return self._is_subtype(left.last_known_value, right) if isinstance(right, CallableType): @@ -535,7 +592,7 @@ def visit_parameters(self, left: Parameters) -> bool: left, right, is_compat=self._is_subtype, - ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, ) else: return False @@ -554,7 +611,7 @@ def visit_callable_type(self, left: CallableType) -> bool: left, right, is_compat=self._is_subtype, - ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, strict_concatenate=self.options.strict_concatenate if self.options else True, ) elif isinstance(right, Overloaded): @@ -577,7 +634,7 @@ def visit_callable_type(self, left: CallableType) -> bool: left, right, is_compat=self._is_subtype, - ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, ) else: return False @@ -591,7 +648,15 @@ def visit_tuple_type(self, left: TupleType) -> bool: if right.args: iter_type = right.args[0] else: + if self.proper_subtype: + return False iter_type = AnyType(TypeOfAny.special_form) + if is_named_instance(right, "builtins.tuple") and isinstance( + get_proper_type(iter_type), AnyType + ): + # TODO: We shouldn't need this special case. This is currently needed + # for isinstance(x, tuple), though it's unclear why. + return True return all(self._is_subtype(li, iter_type) for li in left.items) elif self._is_subtype(mypy.typeops.tuple_fallback(left), right): return True @@ -624,9 +689,16 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: if not left.names_are_wider_than(right): return False for name, l, r in left.zip(right): - if not is_equivalent( - l, r, ignore_type_params=self.ignore_type_params, options=self.options - ): + if self.proper_subtype: + check = mypy.sametypes.is_same_type(l, r) + else: + check = is_equivalent( + l, + r, + ignore_type_params=self.subtype_context.ignore_type_params, + options=self.options, + ) + if not check: return False # Non-required key is not compatible with a required key since # indexing may fail unexpectedly if a required key is missing. @@ -699,14 +771,14 @@ def visit_overloaded(self, left: Overloaded) -> bool: right_item, is_compat=self._is_subtype, ignore_return=True, - ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, strict_concatenate=strict_concat, ) or is_callable_compatible( right_item, left_item, is_compat=self._is_subtype, ignore_return=True, - ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, strict_concatenate=strict_concat, ): # If this is an overload that's already been matched, there's no @@ -751,6 +823,9 @@ def visit_union_type(self, left: UnionType) -> bool: def visit_partial_type(self, left: PartialType) -> bool: # This is indeterminate as we don't really know the complete type yet. + if self.proper_subtype: + # TODO: What's the right thing to do here? + return False if left.type is None: # Special case, partial `None`. This might happen when defining # class-level attributes with explicit `None`. @@ -768,6 +843,10 @@ def visit_type_type(self, left: TypeType) -> bool: return self._is_subtype(left.item, right.ret_type) if isinstance(right, Instance): if right.type.fullname in ["builtins.object", "builtins.type"]: + # TODO: Strictly speaking, the type builtins.type is considered equivalent to + # Type[Any]. However, this would break the is_proper_subtype check in + # conditional_types for cases like isinstance(x, type) when the type + # of x is Type[int]. It's unclear what's the right way to address this. return True item = left.item if isinstance(item, TypeVarType): @@ -879,9 +958,12 @@ def f(self) -> A: ... if not proper_subtype: # Nominal check currently ignores arg names, but __call__ is special for protocols ignore_names = right.type.protocol_members != ["__call__"] - subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=ignore_names) else: - subtype_kind = ProperSubtypeVisitor.build_subtype_kind() + ignore_names = False + subtype_kind = SubtypeVisitor.build_subtype_kind( + subtype_context=SubtypeContext(ignore_pos_arg_names=ignore_names), + proper_subtype=proper_subtype, + ) TypeState.record_subtype_cache_entry(subtype_kind, left, right) return True @@ -1520,335 +1602,6 @@ def covers_at_runtime(item: Type, supertype: Type, ignore_promotions: bool) -> b return False -def is_proper_subtype( - left: Type, - right: Type, - *, - ignore_promotions: bool = False, - erase_instances: bool = False, - keep_erased_types: bool = False, -) -> bool: - """Is left a proper subtype of right? - - For proper subtypes, there's no need to rely on compatibility due to - Any types. Every usable type is a proper subtype of itself. - - If erase_instances is True, erase left instance *after* mapping it to supertype - (this is useful for runtime isinstance() checks). If keep_erased_types is True, - do not consider ErasedType a subtype of all types (used by type inference against unions). - """ - if TypeState.is_assumed_proper_subtype(left, right): - return True - if ( - isinstance(left, TypeAliasType) - and isinstance(right, TypeAliasType) - and left.is_recursive - and right.is_recursive - ): - # This case requires special care because it may cause infinite recursion. - # See is_subtype() for more info. - with pop_on_exit(TypeState._assuming_proper, left, right): - return _is_proper_subtype( - left, - right, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types, - ) - return _is_proper_subtype( - left, - right, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types, - ) - - -def _is_proper_subtype( - left: Type, - right: Type, - *, - ignore_promotions: bool = False, - erase_instances: bool = False, - keep_erased_types: bool = False, -) -> bool: - orig_left = left - orig_right = right - left = get_proper_type(left) - right = get_proper_type(right) - - if isinstance(right, UnionType) and not isinstance(left, UnionType): - return any( - is_proper_subtype( - orig_left, - item, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types, - ) - for item in right.items - ) - return left.accept( - ProperSubtypeVisitor( - orig_right, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types, - ) - ) - - -class ProperSubtypeVisitor(TypeVisitor[bool]): - def __init__( - self, - right: Type, - *, - ignore_promotions: bool = False, - erase_instances: bool = False, - keep_erased_types: bool = False, - ) -> None: - self.right = get_proper_type(right) - self.orig_right = right - self.ignore_promotions = ignore_promotions - self.erase_instances = erase_instances - self.keep_erased_types = keep_erased_types - self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind( - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types, - ) - - @staticmethod - def build_subtype_kind( - *, - ignore_promotions: bool = False, - erase_instances: bool = False, - keep_erased_types: bool = False, - ) -> SubtypeKind: - return (state.strict_optional, True, ignore_promotions, erase_instances, keep_erased_types) - - def _is_proper_subtype(self, left: Type, right: Type) -> bool: - return is_proper_subtype( - left, - right, - ignore_promotions=self.ignore_promotions, - erase_instances=self.erase_instances, - keep_erased_types=self.keep_erased_types, - ) - - def visit_unbound_type(self, left: UnboundType) -> bool: - # This can be called if there is a bad type annotation. The result probably - # doesn't matter much but by returning True we simplify these bad types away - # from unions, which could filter out some bogus messages. - return True - - def visit_any(self, left: AnyType) -> bool: - return isinstance(self.right, AnyType) - - def visit_none_type(self, left: NoneType) -> bool: - if state.strict_optional: - return isinstance(self.right, NoneType) or is_named_instance( - self.right, "builtins.object" - ) - return True - - def visit_uninhabited_type(self, left: UninhabitedType) -> bool: - return True - - def visit_erased_type(self, left: ErasedType) -> bool: - # This may be encountered during type inference. The result probably doesn't - # matter much. - # TODO: it actually does matter, figure out more principled logic about this. - if self.keep_erased_types: - return False - return True - - def visit_deleted_type(self, left: DeletedType) -> bool: - return True - - def visit_instance(self, left: Instance) -> bool: - right = self.right - if isinstance(right, Instance): - if TypeState.is_cached_subtype_check(self._subtype_kind, left, right): - return True - if not self.ignore_promotions: - for base in left.type.mro: - if base._promote and any( - self._is_proper_subtype(p, right) for p in base._promote - ): - TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) - return True - - if left.type.has_base(right.type.fullname): - # Map left type to corresponding right instances. - left = map_instance_to_supertype(left, right.type) - if self.erase_instances: - erased = erase_type(left) - assert isinstance(erased, Instance) - left = erased - - nominal = True - for ta, ra, tvar in zip(left.args, right.args, right.type.defn.type_vars): - if isinstance(tvar, TypeVarType): - variance = tvar.variance - if variance == COVARIANT: - nominal = self._is_proper_subtype(ta, ra) - elif variance == CONTRAVARIANT: - nominal = self._is_proper_subtype(ra, ta) - else: - nominal = mypy.sametypes.is_same_type(ta, ra) - else: - nominal = mypy.sametypes.is_same_type(ta, ra) - if not nominal: - break - - if nominal: - TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) - return nominal - if right.type.is_protocol and is_protocol_implementation( - left, right, proper_subtype=True - ): - return True - return False - if isinstance(right, CallableType): - call = find_member("__call__", left, left, is_operator=True) - if call: - return self._is_proper_subtype(call, right) - return False - return False - - def visit_type_var(self, left: TypeVarType) -> bool: - if isinstance(self.right, TypeVarType) and left.id == self.right.id: - return True - if left.values and self._is_proper_subtype( - mypy.typeops.make_simplified_union(left.values), self.right - ): - return True - return self._is_proper_subtype(left.upper_bound, self.right) - - def visit_param_spec(self, left: ParamSpecType) -> bool: - right = self.right - if ( - isinstance(right, ParamSpecType) - and right.id == left.id - and right.flavor == left.flavor - ): - return True - return self._is_proper_subtype(left.upper_bound, self.right) - - def visit_type_var_tuple(self, left: TypeVarTupleType) -> bool: - right = self.right - if isinstance(right, TypeVarTupleType) and right.id == left.id: - return True - return self._is_proper_subtype(left.upper_bound, self.right) - - def visit_unpack_type(self, left: UnpackType) -> bool: - if isinstance(self.right, UnpackType): - return self._is_proper_subtype(left.type, self.right.type) - return False - - def visit_parameters(self, left: Parameters) -> bool: - right = self.right - if isinstance(right, Parameters) or isinstance(right, CallableType): - return are_parameters_compatible(left, right, is_compat=self._is_proper_subtype) - else: - return False - - def visit_callable_type(self, left: CallableType) -> bool: - right = self.right - if isinstance(right, CallableType): - return is_callable_compatible(left, right, is_compat=self._is_proper_subtype) - elif isinstance(right, Overloaded): - return all(self._is_proper_subtype(left, item) for item in right.items) - elif isinstance(right, Instance): - return self._is_proper_subtype(left.fallback, right) - elif isinstance(right, TypeType): - # This is unsound, we don't check the __init__ signature. - return left.is_type_obj() and self._is_proper_subtype(left.ret_type, right.item) - return False - - def visit_tuple_type(self, left: TupleType) -> bool: - right = self.right - if isinstance(right, Instance): - if is_named_instance(right, TUPLE_LIKE_INSTANCE_NAMES): - if not right.args: - return False - iter_type = get_proper_type(right.args[0]) - if is_named_instance(right, "builtins.tuple") and isinstance(iter_type, AnyType): - # TODO: We shouldn't need this special case. This is currently needed - # for isinstance(x, tuple), though it's unclear why. - return True - return all(self._is_proper_subtype(li, iter_type) for li in left.items) - return self._is_proper_subtype(mypy.typeops.tuple_fallback(left), right) - elif isinstance(right, TupleType): - if len(left.items) != len(right.items): - return False - for l, r in zip(left.items, right.items): - if not self._is_proper_subtype(l, r): - return False - return self._is_proper_subtype( - mypy.typeops.tuple_fallback(left), mypy.typeops.tuple_fallback(right) - ) - return False - - def visit_typeddict_type(self, left: TypedDictType) -> bool: - right = self.right - if isinstance(right, TypedDictType): - for name, typ in left.items.items(): - if name in right.items and not mypy.sametypes.is_same_type(typ, right.items[name]): - return False - for name, typ in right.items.items(): - if name not in left.items: - return False - return True - return self._is_proper_subtype(left.fallback, right) - - def visit_literal_type(self, left: LiteralType) -> bool: - if isinstance(self.right, LiteralType): - return left == self.right - else: - return self._is_proper_subtype(left.fallback, self.right) - - def visit_overloaded(self, left: Overloaded) -> bool: - # TODO: What's the right thing to do here? - return False - - def visit_union_type(self, left: UnionType) -> bool: - return all(self._is_proper_subtype(item, self.orig_right) for item in left.items) - - def visit_partial_type(self, left: PartialType) -> bool: - # TODO: What's the right thing to do here? - return False - - def visit_type_type(self, left: TypeType) -> bool: - right = self.right - if isinstance(right, TypeType): - # This is unsound, we don't check the __init__ signature. - return self._is_proper_subtype(left.item, right.item) - if isinstance(right, CallableType): - # This is also unsound because of __init__. - return right.is_type_obj() and self._is_proper_subtype(left.item, right.ret_type) - if isinstance(right, Instance): - if right.type.fullname == "builtins.type": - # TODO: Strictly speaking, the type builtins.type is considered equivalent to - # Type[Any]. However, this would break the is_proper_subtype check in - # conditional_types for cases like isinstance(x, type) when the type - # of x is Type[int]. It's unclear what's the right way to address this. - return True - if right.type.fullname == "builtins.object": - return True - item = left.item - if isinstance(item, TypeVarType): - item = get_proper_type(item.upper_bound) - if isinstance(item, Instance): - metaclass = item.type.metaclass_type - return metaclass is not None and self._is_proper_subtype(metaclass, right) - return False - - def visit_type_alias_type(self, left: TypeAliasType) -> bool: - assert False, f"This should be never called, got {left}" - - def is_more_precise(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: """Check if left is a more precise type than right. diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index fb9e3e80b854..173d80b85426 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -535,7 +535,7 @@ def test_simplified_union_with_literals(self) -> None: [fx.lit1_inst, fx.lit3_inst], UnionType([fx.lit1_inst, fx.lit3_inst]) ) self.assert_simplified_union([fx.lit1_inst, fx.uninhabited], fx.lit1_inst) - self.assert_simplified_union([fx.lit1, fx.lit1_inst], UnionType([fx.lit1, fx.lit1_inst])) + self.assert_simplified_union([fx.lit1, fx.lit1_inst], fx.lit1) self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst])) self.assert_simplified_union([fx.lit1, fx.lit3_inst], UnionType([fx.lit1, fx.lit3_inst])) diff --git a/mypy/typestate.py b/mypy/typestate.py index 91cfb9562139..389dc9c2a358 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -108,6 +108,12 @@ def is_assumed_proper_subtype(left: Type, right: Type) -> bool: return True return False + @staticmethod + def get_assumptions(is_proper: bool) -> List[Tuple[TypeAliasType, TypeAliasType]]: + if is_proper: + return TypeState._assuming_proper + return TypeState._assuming + @staticmethod def reset_all_subtype_caches() -> None: """Completely reset all known subtype caches.""" diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index ab4f0d4e1b06..577e71d78482 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2166,6 +2166,19 @@ if x in (1, 2): [builtins fixtures/tuple.pyi] [typing fixtures/typing-full.pyi] + +[case testOverlappingClassCallables] +# flags: --strict-equality +from typing import Any, Callable, Type + +x: Type[int] +y: Callable[[], Any] +x == y +y == x +int == y +y == int +[builtins fixtures/bool.pyi] + [case testUnimportedHintAny] def f(x: Any) -> None: # E: Name "Any" is not defined \ # N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any") diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 66ac67af1126..3454e2cce948 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5274,14 +5274,14 @@ def f2(g: G[A, B], x: int = ...) -> B: ... def f2(g: Any, x: int = ...) -> Any: ... [case testOverloadTypeVsCallable] -from typing import TypeVar, Type, Callable, Any, overload +from typing import TypeVar, Type, Callable, Any, overload, Optional class Foo: def __init__(self, **kwargs: Any): pass _T = TypeVar('_T') @overload def register(cls: Type[_T]) -> int: ... @overload -def register(cls: Callable[..., _T]) -> str: ... +def register(cls: Callable[..., _T]) -> Optional[int]: ... def register(cls: Any) -> Any: return None