Skip to content

Commit

Permalink
Fix strict equality check if operand item type has custom __eq__ (#14513
Browse files Browse the repository at this point in the history
)

Don't complain about comparing lists, variable-length tuples or sets if
one of the operands has an item type with a custom `__eq__` method.

Fix #14511.
  • Loading branch information
JukkaL authored Jan 23, 2023
1 parent 4de3f5d commit 9ca3035
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 18 deletions.
38 changes: 21 additions & 17 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2988,20 +2988,14 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
# testCustomEqCheckStrictEquality for an example.
if not w.has_new_errors() and operator in ("==", "!="):
right_type = self.accept(right)
# We suppress the error if there is a custom __eq__() method on either
# side. User defined (or even standard library) classes can define this
# to return True for comparisons between non-overlapping types.
if not custom_special_method(
left_type, "__eq__"
) and not custom_special_method(right_type, "__eq__"):
# Also flag non-overlapping literals in situations like:
# x: Literal['a', 'b']
# if x == 'c':
# ...
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, "equality", e)
# Also flag non-overlapping literals in situations like:
# x: Literal['a', 'b']
# if x == 'c':
# ...
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, "equality", e)

elif operator == "is" or operator == "is not":
right_type = self.accept(right) # validate the right operand
Expand Down Expand Up @@ -3064,6 +3058,12 @@ def dangerous_comparison(

left, right = get_proper_types((left, right))

# We suppress the error if there is a custom __eq__() method on either
# side. User defined (or even standard library) classes can define this
# to return True for comparisons between non-overlapping types.
if custom_special_method(left, "__eq__") or custom_special_method(right, "__eq__"):
return False

if self.chk.binder.is_unreachable_warning_suppressed():
# We are inside a function that contains type variables with value restrictions in
# its signature. In this case we just suppress all strict-equality checks to avoid
Expand Down Expand Up @@ -3094,14 +3094,18 @@ def dangerous_comparison(
return False
if isinstance(left, Instance) and isinstance(right, Instance):
# Special case some builtin implementations of AbstractSet.
left_name = left.type.fullname
right_name = right.type.fullname
if (
left.type.fullname in OVERLAPPING_TYPES_ALLOWLIST
and right.type.fullname in OVERLAPPING_TYPES_ALLOWLIST
left_name in OVERLAPPING_TYPES_ALLOWLIST
and right_name in OVERLAPPING_TYPES_ALLOWLIST
):
abstract_set = self.chk.lookup_typeinfo("typing.AbstractSet")
left = map_instance_to_supertype(left, abstract_set)
right = map_instance_to_supertype(right, abstract_set)
return not is_overlapping_types(left.args[0], right.args[0])
return self.dangerous_comparison(left.args[0], right.args[0])
elif left_name in ("builtins.list", "builtins.tuple") and right_name == left_name:
return self.dangerous_comparison(left.args[0], right.args[0])
if isinstance(left, LiteralType) and isinstance(right, LiteralType):
if isinstance(left.value, bool) and isinstance(right.value, bool):
# Comparing different booleans is not dangerous.
Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,24 @@ class B:
A() == B() # E: Unsupported operand types for == ("A" and "B")
[builtins fixtures/bool.pyi]

[case testStrictEqualitySequenceAndCustomEq]
# flags: --strict-equality
from typing import Tuple

class C: pass
class D:
def __eq__(self, other): return True

a = [C()]
b = [D()]
a == b
b == a
t1: Tuple[C, ...]
t2: Tuple[D, ...]
t1 == t2
t2 == t1
[builtins fixtures/bool.pyi]

[case testCustomEqCheckStrictEqualityOKInstance]
# flags: --strict-equality
class A:
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/fixtures/bool.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ class float: pass
class str: pass
class unicode: pass
class ellipsis: pass
class list: pass
class list(Generic[T]): pass
class property: pass
1 change: 1 addition & 0 deletions test-data/unit/fixtures/set.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ T = TypeVar('T')

class object:
def __init__(self) -> None: pass
def __eq__(self, other): pass

class type: pass
class tuple(Generic[T]): pass
Expand Down

0 comments on commit 9ca3035

Please sign in to comment.