Skip to content

Commit

Permalink
Fix type narrowing with in on enum types in the negative case (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Apr 11, 2023
1 parent 2462052 commit e5066bf
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 39 deletions.
8 changes: 3 additions & 5 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# Unreleased

- Better typechecking support for async generators (#594)

# Changelog

## Unreleased

- Fix type narrowing with `in` on enum types in the negative case (#606)
- Fix crash when `getattr()` on a module object throws an error (#603)
- Fix handling of positional-only arguments using `/` syntax in stubs (#601)
- Fix bug where objects with a `__call__` method that takes `*args` instead
of `self` was not considered callable (#600)
of `self` were not considered callable (#600)
- Better typechecking support for async generators (#594)

## Version 0.9.0 (January 16, 2023)

Expand Down
45 changes: 12 additions & 33 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
StringSequenceOption,
)
from .patma import PatmaVisitor
from .predicates import EqualsPredicate
from .predicates import EqualsPredicate, InPredicate
from .reexport import ImplicitReexportTracker
from .safe import (
is_dataclass_type,
Expand Down Expand Up @@ -3149,39 +3149,18 @@ def _constraint_from_compare_op(
positive = isinstance(op, ast.Eq)
return Constraint(varname, ConstraintType.predicate, positive, predicate)
elif isinstance(op, (ast.In, ast.NotIn)) and is_right:

def predicate_func(value: Value, positive: bool) -> Optional[Value]:
op = _in if positive else _not_in
if isinstance(value, KnownValue):
try:
result = op(value.val, other_val)
except Exception:
pass
else:
if not result:
return None
elif positive:
known_other = KnownValue(other_val)
member_values = concrete_values_from_iterable(known_other, self)
if isinstance(member_values, CanAssignError):
return value
elif isinstance(member_values, Value):
if value.is_assignable(member_values, self):
return member_values
return None
else:
possible_values = [
val
for val in member_values
if value.is_assignable(val, self)
]
return unite_values(*possible_values)
return value

try:
predicate_vals = list(other_val)
predicate_types = {type(val) for val in predicate_vals}
if len(predicate_types) == 1:
pattern_type = next(iter(predicate_types))
else:
pattern_type = object
except Exception:
return NULL_CONSTRAINT
predicate = InPredicate(other_val, pattern_type, self)
positive = isinstance(op, ast.In)
return Constraint(
varname, ConstraintType.predicate, positive, predicate_func
)
return Constraint(varname, ConstraintType.predicate, positive, predicate)
else:
positive_operator, negative_operator = COMPARATOR_TO_OPERATOR[type(op)]

Expand Down
50 changes: 49 additions & 1 deletion pyanalyze/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import enum
import operator
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Sequence

from .safe import safe_issubclass
from .value import (
Expand Down Expand Up @@ -125,3 +125,51 @@ def __call__(self, value: Value, positive: bool) -> Optional[Value]:
]
)
return value


@dataclass
class InPredicate:
"""Predicate that filters out values that are not in pattern_vals."""

pattern_vals: Sequence[object]
pattern_type: type
ctx: CanAssignContext

def __call__(self, value: Value, positive: bool) -> Optional[Value]:
inner_value = unannotate(value)
if isinstance(inner_value, KnownValue):
try:
if positive:
result = inner_value.val in self.pattern_vals
else:
result = inner_value.val not in self.pattern_vals
except Exception:
pass
else:
if not result:
return None
elif positive:
acceptable_values = [
KnownValue(pattern_val)
for pattern_val in self.pattern_vals
if value.is_assignable(KnownValue(pattern_val), self.ctx)
]
if acceptable_values:
return unite_values(*acceptable_values)
else:
return None
else:
if safe_issubclass(self.pattern_type, enum.Enum):
simplified = unannotate(value)
if (
isinstance(simplified, TypedValue)
and simplified.typ is self.pattern_type
):
return unite_values(
*[
KnownValue(val)
for val in self.pattern_type
if val not in self.pattern_vals
]
)
return value
19 changes: 19 additions & 0 deletions pyanalyze/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class TestNarrowing(TestNameCheckVisitorBase):
@assert_passes()
def test_exhaustive(self):
from enum import Enum
from typing_extensions import assert_never

class X(Enum):
a = 1
Expand All @@ -88,6 +89,24 @@ def capybara_is(x: X):
else:
assert_is_value(x, KnownValue(X.b))

def capybara_in_list(x: X):
if x in [X.a]:
assert_is_value(x, KnownValue(X.a))
else:
assert_is_value(x, KnownValue(X.b))

def capybara_in_tuple(x: X):
if x in (X.a,):
assert_is_value(x, KnownValue(X.a))
else:
assert_is_value(x, KnownValue(X.b))

def test_multi_in(x: X):
if x in (X.a, X.b):
assert_is_value(x, KnownValue(X.a) | KnownValue(X.b))
else:
assert_never(x)

def whatever(x):
if x == X.a:
assert_is_value(x, KnownValue(X.a))
Expand Down
27 changes: 27 additions & 0 deletions pyanalyze/test_never.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,33 @@ def capybara(x: Union[int, str]) -> None:
else:
assert_never(x)

@assert_passes()
def test_enum(self):
from typing_extensions import assert_never
import enum

class Capybara(enum.Enum):
hydrochaeris = 1
isthmius = 2

def capybara(x: Capybara) -> None:
if x in (Capybara.hydrochaeris, Capybara.isthmius):
pass
else:
assert_never(x)

@assert_passes()
def test_literal(self):
from typing_extensions import Literal, assert_never, assert_type

def capybara(x: Literal["a", "b", "c"]) -> None:
if x in ("a", "b"):
assert_type(x, Literal["a", "b"])
elif x == "c":
assert_type(x, Literal["c"])
else:
assert_never(x)


class TestNeverCall(TestNameCheckVisitorBase):
@assert_passes()
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ skip_missing_interpreters = True
deps =
pytest
mypy_extensions
attrs

commands =
pytest pyanalyze/
Expand Down

0 comments on commit e5066bf

Please sign in to comment.