Skip to content

Commit

Permalink
Support type narrowing on enums and bools in match statements (#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jan 9, 2022
1 parent 5b54c35 commit f3e96ee
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 44 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

- Do not error on boolean operations on values typed
as `object` (#388)
- Support type narrowing on enum types and `bool`
in `match` statements (#387)
- Support some imports from stub-only modules (#386)
- Support type evaluation functions in stubs (#386)
- Support `TypedDict` in stubs (#386)
Expand Down
2 changes: 2 additions & 0 deletions pyanalyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import node_visitor
from . import options
from . import patma
from . import predicates
from . import reexport
from . import safe
from . import shared_options
Expand Down Expand Up @@ -60,3 +61,4 @@
used(options)
used(shared_options)
used(functions)
used(predicates)
6 changes: 6 additions & 0 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4479,6 +4479,12 @@ def visit_Match(self, node: Match) -> None:
with self.scopes.subscope() as case_scope:
for constraint in constraints_to_apply:
self.add_constraint(case, constraint)
self.match_subject = self.match_subject._replace(
value=constrain_value(
self.match_subject.value,
AndConstraint.make(constraints_to_apply),
)
)

pattern_constraint = patma_visitor.visit(case.pattern)
constraints = [pattern_constraint]
Expand Down
30 changes: 8 additions & 22 deletions pyanalyze/patma.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .annotations import type_from_value
from .extensions import CustomCheck
from .error_code import ErrorCode
from .predicates import EqualsPredicate, IsAssignablePredicate
from .stacked_scopes import (
NULL_CONSTRAINT,
AbstractConstraint,
Expand All @@ -38,7 +39,6 @@
ConstraintType,
OrConstraint,
constrain_value,
IsAssignablePredicate,
)
from .value import (
AnnotatedValue,
Expand Down Expand Up @@ -178,7 +178,10 @@ class PatmaVisitor(ast.NodeVisitor):

def visit_MatchSingleton(self, node: MatchSingleton) -> AbstractConstraint:
self.check_impossible_pattern(node, KnownValue(node.value))
return self.make_constraint(ConstraintType.is_value, node.value)
return self.make_constraint(
ConstraintType.predicate,
EqualsPredicate(node.value, self.visitor, use_is=True),
)

def visit_MatchValue(self, node: MatchValue) -> AbstractConstraint:
pattern_val = self.visitor.visit(node.value)
Expand All @@ -191,26 +194,9 @@ def visit_MatchValue(self, node: MatchValue) -> AbstractConstraint:
)
return NULL_CONSTRAINT

def predicate_func(value: Value, positive: bool) -> Optional[Value]:
if isinstance(value, KnownValue):
try:
if positive:
result = value.val == pattern_val.val
else:
result = value.val != pattern_val.val
except Exception:
pass
else:
if not result:
return None
elif positive:
if value.is_assignable(pattern_val, self.visitor):
return pattern_val
else:
return None
return value

return self.make_constraint(ConstraintType.predicate, predicate_func)
return self.make_constraint(
ConstraintType.predicate, EqualsPredicate(pattern_val.val, self.visitor)
)

def visit_MatchSequence(self, node: MatchSequence) -> AbstractConstraint:
self.check_impossible_pattern(node, MatchableSequence)
Expand Down
95 changes: 95 additions & 0 deletions pyanalyze/predicates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Reusable predicates.
"""
from dataclasses import dataclass
import enum
import operator
from typing import Optional

from .safe import safe_issubclass
from .value import (
KnownValue,
TypedValue,
Value,
unite_values,
CanAssignContext,
is_overlapping,
unannotate,
)


@dataclass
class IsAssignablePredicate:
"""Predicate that filters out values that are not assignable to pattern_value."""

pattern_value: Value
ctx: CanAssignContext
positive_only: bool

def __call__(self, value: Value, positive: bool) -> Optional[Value]:
compatible = is_overlapping(self.pattern_value, value, self.ctx)
if positive:
if not compatible:
return None
if value.is_assignable(self.pattern_value, self.ctx):
return self.pattern_value
elif not self.positive_only:
if compatible:
return None
return value


_OPERATOR = {
(True, True): operator.is_,
(False, True): operator.is_not,
(True, False): operator.eq,
(False, False): operator.ne,
}


@dataclass
class EqualsPredicate:
"""Predicate that filters out values that are not equal to pattern_val."""

pattern_val: object
ctx: CanAssignContext
use_is: bool = False

def __call__(self, value: Value, positive: bool) -> Optional[Value]:
inner_value = unannotate(value)
if isinstance(inner_value, KnownValue):
op = _OPERATOR[(positive, self.use_is)]
try:
result = op(inner_value.val, self.pattern_val)
except Exception:
pass
else:
if not result:
return None
elif positive:
known_self = KnownValue(self.pattern_val)
if value.is_assignable(known_self, self.ctx):
return known_self
else:
return None
else:
pattern_type = type(self.pattern_val)
if pattern_type is bool:
simplified = unannotate(value)
if isinstance(simplified, TypedValue) and simplified.typ is bool:
return KnownValue(not self.pattern_val)
elif safe_issubclass(pattern_type, enum.Enum):
simplified = unannotate(value)
if isinstance(simplified, TypedValue) and simplified.typ is type(
self.pattern_val
):
return unite_values(
*[
KnownValue(val)
for val in pattern_type
if val is not self.pattern_val
]
)
return value
21 changes: 0 additions & 21 deletions pyanalyze/stacked_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@
unite_and_simplify,
unite_values,
flatten_values,
CanAssignContext,
is_overlapping,
)

T = TypeVar("T")
Expand Down Expand Up @@ -516,25 +514,6 @@ def invert(self) -> AbstractConstraint:
return NULL_CONSTRAINT


@dataclass
class IsAssignablePredicate:
pattern_value: Value
ctx: CanAssignContext
positive_only: bool

def __call__(self, value: Value, positive: bool) -> Optional[Value]:
compatible = is_overlapping(self.pattern_value, value, self.ctx)
if positive:
if not compatible:
return None
if value.is_assignable(self.pattern_value, self.ctx):
return self.pattern_value
elif not self.positive_only:
if compatible:
return None
return value


@dataclass(frozen=True)
class AndConstraint(AbstractConstraint):
"""Represents the AND of two constraints."""
Expand Down
1 change: 1 addition & 0 deletions pyanalyze/test_boolability.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_get_boolability() -> None:
assert Boolability.erroring_bool == get_boolability(future)
assert Boolability.boolable == get_boolability(TypedValue(object))
assert Boolability.boolable == get_boolability(TypedValue(int))
assert Boolability.boolable == get_boolability(TypedValue(object))

# MultiValuedValue and AnnotatedValue
assert Boolability.erroring_bool == get_boolability(
Expand Down
52 changes: 52 additions & 0 deletions pyanalyze/test_patma.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,55 @@ def capybara(obj: object):
pass
"""
)

@skip_before((3, 10))
def test_bool_narrowing(self):
self.assert_passes(
"""
class X:
true = True
def capybara(b: bool):
match b:
# Make sure we hit the MatchValue case, not MatchSingleton
case X.true:
assert_is_value(b, KnownValue(True))
case _ as b2:
assert_is_value(b, KnownValue(False))
assert_is_value(b2, KnownValue(False))
"""
)
self.assert_passes(
"""
def capybara(b: bool):
match b:
case True:
assert_is_value(b, KnownValue(True))
case _ as b2:
assert_is_value(b, KnownValue(False))
assert_is_value(b2, KnownValue(False))
"""
)

@skip_before((3, 10))
def test_enum_narrowing(self):
self.assert_passes(
"""
from enum import Enum
class Planet(Enum):
mercury = 1
venus = 2
earth = 3
def capybara(p: Planet):
match p:
case Planet.mercury:
assert_is_value(p, KnownValue(Planet.mercury))
case Planet.venus:
assert_is_value(p, KnownValue(Planet.venus))
case _ as p2:
assert_is_value(p2, KnownValue(Planet.earth))
assert_is_value(p, KnownValue(Planet.earth))
"""
)
2 changes: 1 addition & 1 deletion pyanalyze/type_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
Union,
)

from .predicates import IsAssignablePredicate
from .stacked_scopes import (
Constraint,
ConstraintType,
VarnameWithOrigin,
constrain_value,
IsAssignablePredicate,
)
from .safe import all_of_type
from .value import (
Expand Down

0 comments on commit f3e96ee

Please sign in to comment.