Skip to content

Commit

Permalink
add assert_error() (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jan 23, 2022
1 parent b11f15a commit 11ca89a
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 41 deletions.
5 changes: 3 additions & 2 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

## Unreleased

- Add support for `assert_type()` (#433)
- Add support for `assert_error()` (#435)
- Add support for `assert_type()` (#434)
- `reveal_type()` and `dump_value()` now return their argument,
the anticipated behavior for `typing.reveal_type()` in Python
3.11 (#432)
3.11 (#433)
- Fix return type of async generator functions (#431)
- Type check function decorators (#428)
- Handle `NoReturn` in `async def` functions (#427)
Expand Down
18 changes: 18 additions & 0 deletions pyanalyze/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
import typing_extensions
import pyanalyze
Expand All @@ -18,6 +19,7 @@
Container,
Dict,
Iterable,
Iterator,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -414,6 +416,22 @@ def f(x: int) -> None:
return val


@contextmanager
def assert_error() -> Iterator[None]:
"""Context manager that asserts that code produces a type checker error.
Example::
with assert_error(): # ok
1 + "x"
with assert_error(): # error: no error found in this block
1 + 1
"""
yield


_overloads: Dict[str, List[Callable[..., Any]]] = defaultdict(list)
_type_evaluations: Dict[str, List[Callable[..., Any]]] = defaultdict(list)

Expand Down
32 changes: 28 additions & 4 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from .boolability import Boolability, get_boolability
from .checker import Checker, CheckerAttrContext
from .error_code import ErrorCode, ERROR_DESCRIPTION
from .extensions import ParameterTypeGuard, patch_typing_overload
from .extensions import ParameterTypeGuard, assert_error, patch_typing_overload
from .find_unused import UnusedObjectFinder, used
from .options import (
ConfigOption,
Expand Down Expand Up @@ -147,6 +147,7 @@
AnnotatedValue,
AnySource,
AnyValue,
AssertErrorExtension,
CanAssign,
CanAssignError,
ConstraintExtension,
Expand All @@ -155,6 +156,7 @@
UNINITIALIZED_VALUE,
NO_RETURN_VALUE,
NoReturnConstraintExtension,
annotate_value,
flatten_values,
is_union,
kv_pairs_from_mapping,
Expand Down Expand Up @@ -3286,16 +3288,32 @@ def _member_value_of_iterator(
return result

def visit_With(self, node: ast.With) -> None:
for item in node.items:
self.visit_withitem(item)
contexts = [self.visit_withitem(item) for item in node.items]
if len(contexts) == 1:
context = contexts[0]
if isinstance(context, AnnotatedValue) and context.has_metadata_of_type(
AssertErrorExtension
):
self._visit_assert_errors_block(node)
return
self._generic_visit_list(node.body)

def _visit_assert_errors_block(self, node: ast.With) -> None:
with self.catch_errors() as caught:
self._generic_visit_list(node.body)
if not caught:
self._show_error_if_checking(
node,
"No errors found in assert_error() block",
error_code=ErrorCode.inference_failure,
)

def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
for item in node.items:
self.visit_withitem(item, is_async=True)
self._generic_visit_list(node.body)

def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> None:
def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> Value:
context = self.visit(node.context_expr)
if is_async:
protocol = "typing.AsyncContextManager"
Expand All @@ -3316,6 +3334,7 @@ def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> None:
if node.optional_vars is not None:
with qcore.override(self, "being_assigned", assigned):
self.visit(node.optional_vars)
return context

def visit_try_except(self, node: ast.Try) -> List[SubScope]:
# reset yield checks between branches to avoid incorrect errors when we yield both in the
Expand Down Expand Up @@ -4178,6 +4197,11 @@ def visit_Call(self, node: ast.Call) -> Value:
if caller is not None:
self.collector.record_call(caller, callee_val)

if (
isinstance(callee_wrapped, KnownValue)
and callee_wrapped.val is assert_error
):
return annotate_value(return_value, [AssertErrorExtension()])
return return_value

def _can_perform_call(
Expand Down
35 changes: 0 additions & 35 deletions pyanalyze/test_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,41 +1161,6 @@ def capybara(x: Union[Type[str], Type[int]]) -> None:
assert_is_value(x, SubclassValue(TypedValue(int)))


class TestInferenceHelpers(TestNameCheckVisitorBase):
@assert_passes()
def test(self) -> None:
from pyanalyze import dump_value, assert_is_value
from pyanalyze.value import Value

def capybara(val: Value) -> None:
reveal_type(dump_value) # E: inference_failure
dump_value(reveal_type) # E: inference_failure
assert_is_value(1, KnownValue(1))
assert_is_value(1, KnownValue(2)) # E: inference_failure
assert_is_value(1, val) # E: inference_failure

@assert_passes()
def test_return_value(self) -> None:
from pyanalyze import dump_value, assert_is_value

def capybara():
x = dump_value(1) # E: inference_failure
y = reveal_type(1) # E: inference_failure
assert_is_value(x, KnownValue(1))
assert_is_value(y, KnownValue(1))

@assert_passes()
def test_assert_type(self) -> None:
from pyanalyze.extensions import assert_type
from typing import Any

def capybara(x: int) -> None:
assert_type(x, int)
assert_type(x, "int")
assert_type(x, Any) # E: inference_failure
assert_type(x, str) # E: inference_failure


class TestCallableGuards(TestNameCheckVisitorBase):
@assert_passes()
def test_callable(self):
Expand Down
55 changes: 55 additions & 0 deletions pyanalyze/test_inference_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# static analysis: ignore
from .test_name_check_visitor import TestNameCheckVisitorBase
from .test_node_visitor import assert_passes
from .value import KnownValue


class TestInferenceHelpers(TestNameCheckVisitorBase):
@assert_passes()
def test(self) -> None:
from pyanalyze import dump_value, assert_is_value
from pyanalyze.value import Value

def capybara(val: Value) -> None:
reveal_type(dump_value) # E: inference_failure
dump_value(reveal_type) # E: inference_failure
assert_is_value(1, KnownValue(1))
assert_is_value(1, KnownValue(2)) # E: inference_failure
assert_is_value(1, val) # E: inference_failure

@assert_passes()
def test_return_value(self) -> None:
from pyanalyze import dump_value, assert_is_value

def capybara():
x = dump_value(1) # E: inference_failure
y = reveal_type(1) # E: inference_failure
assert_is_value(x, KnownValue(1))
assert_is_value(y, KnownValue(1))

@assert_passes()
def test_assert_type(self) -> None:
from pyanalyze.extensions import assert_type
from typing import Any

def capybara(x: int) -> None:
assert_type(x, int)
assert_type(x, "int")
assert_type(x, Any) # E: inference_failure
assert_type(x, str) # E: inference_failure


class TestAssertError(TestNameCheckVisitorBase):
@assert_passes()
def test(self) -> None:
from pyanalyze.extensions import assert_error

def f(x: int) -> None:
pass

def capybara() -> None:
with assert_error():
f("x")

with assert_error(): # E: inference_failure
f(1)
5 changes: 5 additions & 0 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,11 @@ class AlwaysPresentExtension(Extension):
"""


@dataclass(frozen=True)
class AssertErrorExtension(Extension):
"""Used for the implementation of :func:`pyanalyze.extensions.assert_error`."""


@dataclass(frozen=True)
class AnnotatedValue(Value):
"""Value representing a `PEP 593 <https://www.python.org/dev/peps/pep-0593/>`_ Annotated object.
Expand Down

0 comments on commit 11ca89a

Please sign in to comment.