diff --git a/mypy/checker.py b/mypy/checker.py index 1f635c09bc0a..f8461fefc55f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4500,6 +4500,26 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]: # Non-tuple iterable. return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0] + def analyze_iterable_item_type_without_expression( + self, type: Type, context: Context + ) -> tuple[Type, Type]: + """Analyse iterable type and return iterator and iterator item types.""" + echk = self.expr_checker + iterable = get_proper_type(type) + iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0] + + if isinstance(iterable, TupleType): + joined: Type = UninhabitedType() + for item in iterable.items: + joined = join_types(joined, item) + return iterator, joined + else: + # Non-tuple iterable. + return ( + iterator, + echk.check_method_call_by_name("__next__", iterator, [], [], context)[0], + ) + def analyze_range_native_int_type(self, expr: Expression) -> Type | None: """Try to infer native int item type from arguments to range(...). diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d918eb9b5467..2a04aeddb634 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2919,68 +2919,108 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: That is, 'a < b > c == d' is check as 'a < b and b > c and c == d' """ result: Type | None = None - sub_result: Type | None = None + sub_result: Type # Check each consecutive operand pair and their operator for left, right, operator in zip(e.operands, e.operands[1:], e.operators): left_type = self.accept(left) - method_type: mypy.types.Type | None = None - if operator == "in" or operator == "not in": + # This case covers both iterables and containers, which have different meanings. + # For a container, the in operator calls the __contains__ method. + # For an iterable, the in operator iterates over the iterable, and compares each item one-by-one. + # We allow `in` for a union of containers and iterables as long as at least one of them matches the + # type of the left operand, as the operation will simply return False if the union's container/iterator + # type doesn't match the left operand. + # If the right operand has partial type, look it up without triggering # a "Need type annotation ..." message, as it would be noise. right_type = self.find_partial_type_ref_fast_path(right) if right_type is None: right_type = self.accept(right) # Validate the right operand - # Keep track of whether we get type check errors (these won't be reported, they - # are just to verify whether something is valid typing wise). - with self.msg.filter_errors(save_filtered_errors=True) as local_errors: - _, method_type = self.check_method_call_by_name( - method="__contains__", - base_type=right_type, - args=[left], - arg_kinds=[ARG_POS], - context=e, - ) + right_type = get_proper_type(right_type) + item_types: Sequence[Type] = [right_type] + if isinstance(right_type, UnionType): + item_types = list(right_type.items) sub_result = self.bool_type() - # Container item type for strict type overlap checks. Note: we need to only - # check for nominal type, because a usual "Unsupported operands for in" - # will be reported for types incompatible with __contains__(). - # See testCustomContainsCheckStrictEquality for an example. - cont_type = self.chk.analyze_container_item_type(right_type) - if isinstance(right_type, PartialType): - # We don't really know if this is an error or not, so just shut up. - pass - elif ( - local_errors.has_new_errors() - and - # is_valid_var_arg is True for any Iterable - self.is_valid_var_arg(right_type) - ): - _, itertype = self.chk.analyze_iterable_item_type(right) - method_type = CallableType( - [left_type], - [nodes.ARG_POS], - [None], - self.bool_type(), - self.named_type("builtins.function"), - ) - if not is_subtype(left_type, itertype): - self.msg.unsupported_operand_types("in", left_type, right_type, e) - # Only show dangerous overlap if there are no other errors. - elif ( - not local_errors.has_new_errors() - and cont_type - and self.dangerous_comparison( - left_type, cont_type, original_container=right_type, prefer_literal=False - ) - ): - self.msg.dangerous_comparison(left_type, cont_type, "container", e) - else: - self.msg.add_errors(local_errors.filtered_errors()) + + container_types: list[Type] = [] + iterable_types: list[Type] = [] + failed_out = False + encountered_partial_type = False + + for item_type in item_types: + # Keep track of whether we get type check errors (these won't be reported, they + # are just to verify whether something is valid typing wise). + with self.msg.filter_errors(save_filtered_errors=True) as container_errors: + _, method_type = self.check_method_call_by_name( + method="__contains__", + base_type=item_type, + args=[left], + arg_kinds=[ARG_POS], + context=e, + original_type=right_type, + ) + # Container item type for strict type overlap checks. Note: we need to only + # check for nominal type, because a usual "Unsupported operands for in" + # will be reported for types incompatible with __contains__(). + # See testCustomContainsCheckStrictEquality for an example. + cont_type = self.chk.analyze_container_item_type(item_type) + + if isinstance(item_type, PartialType): + # We don't really know if this is an error or not, so just shut up. + encountered_partial_type = True + pass + elif ( + container_errors.has_new_errors() + and + # is_valid_var_arg is True for any Iterable + self.is_valid_var_arg(item_type) + ): + # it's not a container, but it is an iterable + with self.msg.filter_errors(save_filtered_errors=True) as iterable_errors: + _, itertype = self.chk.analyze_iterable_item_type_without_expression( + item_type, e + ) + if iterable_errors.has_new_errors(): + self.msg.add_errors(iterable_errors.filtered_errors()) + failed_out = True + else: + method_type = CallableType( + [left_type], + [nodes.ARG_POS], + [None], + self.bool_type(), + self.named_type("builtins.function"), + ) + e.method_types.append(method_type) + iterable_types.append(itertype) + elif not container_errors.has_new_errors() and cont_type: + container_types.append(cont_type) + e.method_types.append(method_type) + else: + self.msg.add_errors(container_errors.filtered_errors()) + failed_out = True + + if not encountered_partial_type and not failed_out: + iterable_type = UnionType.make_union(iterable_types) + if not is_subtype(left_type, iterable_type): + if len(container_types) == 0: + self.msg.unsupported_operand_types("in", left_type, right_type, e) + else: + container_type = UnionType.make_union(container_types) + if self.dangerous_comparison( + left_type, + container_type, + original_container=right_type, + prefer_literal=False, + ): + self.msg.dangerous_comparison( + left_type, container_type, "container", e + ) + elif operator in operators.op_methods: method = operators.op_methods[operator] @@ -2988,6 +3028,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: sub_result, method_type = self.check_op( method, left_type, right, e, allow_reverse=True ) + e.method_types.append(method_type) # Only show dangerous overlap if there are no other errors. See # testCustomEqCheckStrictEquality for an example. @@ -3007,12 +3048,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: left_type = try_getting_literal(left_type) right_type = try_getting_literal(right_type) self.msg.dangerous_comparison(left_type, right_type, "identity", e) - method_type = None + e.method_types.append(None) else: raise RuntimeError(f"Unknown comparison operator {operator}") - e.method_types.append(method_type) - # Determine type of boolean-and of result and sub_result if result is None: result = sub_result diff --git a/test-data/unit/check-unions.test b/test-data/unit/check-unions.test index cabc28e786b2..65d5c1abc7e8 100644 --- a/test-data/unit/check-unions.test +++ b/test-data/unit/check-unions.test @@ -1202,3 +1202,20 @@ def foo( yield i foo([1]) [builtins fixtures/list.pyi] + +[case testUnionIterableContainer] +from typing import Iterable, Container, Union + +i: Iterable[str] +c: Container[str] +u: Union[Iterable[str], Container[str]] +ni: Union[Iterable[str], int] +nc: Union[Container[str], int] + +'x' in i +'x' in c +'x' in u +'x' in ni # E: Unsupported right operand type for in ("Union[Iterable[str], int]") +'x' in nc # E: Unsupported right operand type for in ("Union[Container[str], int]") +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi]