Skip to content

Commit

Permalink
support context managers that suppress exceptions (#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbdaaron authored Mar 10, 2022
1 parent 68a3cb9 commit 5aeab3a
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Support context managers that may suppress exceptions (#496)
- Fix type inference for `with` assignment targets on
Python 3.7 and higher (#495)
- Fix bug where code after a `while` loop is considered
Expand Down
1 change: 1 addition & 0 deletions pyanalyze/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _get_attribute_from_typed(
return KnownValue(typ)
elif ctx.attr == "__dict__":
return TypedValue(dict)

result, provider, should_unwrap = _get_attribute_from_mro(typ, ctx, on_class=False)
result = _substitute_typevars(typ, generic_args, result, provider, ctx)
if should_unwrap:
Expand Down
88 changes: 75 additions & 13 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
TypeVar,
Container,
)
from typing_extensions import Annotated
from typing_extensions import Annotated, Protocol

import asynq
import qcore
Expand Down Expand Up @@ -205,6 +205,7 @@
Match = Any

T = TypeVar("T")
U = TypeVar("U")
AwaitableValue = GenericValue(collections.abc.Awaitable, [TypeVarValue(T)])
KnownNone = KnownValue(None)
ExceptionValue = TypedValue(BaseException) | SubclassValue(TypedValue(BaseException))
Expand Down Expand Up @@ -276,6 +277,32 @@ def _not_in(a: object, b: Container[object]) -> bool:
}


class CustomContextManager(Protocol[T, U]):
def __enter__(self) -> T:
raise NotImplementedError

def __exit__(
self,
__exc_type: Optional[Type[BaseException]],
__exc_value: Optional[BaseException],
__traceback: Optional[types.TracebackType],
) -> U:
raise NotImplementedError


class AsyncCustomContextManager(Protocol[T, U]):
async def __aenter__(self) -> T:
raise NotImplementedError

async def __aexit__(
self,
__exc_type: Optional[Type[BaseException]],
__exc_value: Optional[BaseException],
__traceback: Optional[types.TracebackType],
) -> U:
raise NotImplementedError


@dataclass
class _StarredValue(Value):
"""Helper Value to represent the result of "*x".
Expand Down Expand Up @@ -3410,15 +3437,16 @@ def _member_value_of_iterator(
return result

def visit_With(self, node: ast.With) -> None:
contexts = [self.visit_withitem(item) for item in node.items]
if len(contexts) == 1:
context = contexts[0]
if len(node.items) == 1:
with self.scopes.subscope():
context = self.visit(node.items[0].context_expr)
if isinstance(context, AnnotatedValue) and context.has_metadata_of_type(
AssertErrorExtension
):
self._visit_assert_errors_block(node)
return
self._generic_visit_list(node.body)

self.visit_single_cm(node.items, node.body, is_async=False)

def _visit_assert_errors_block(self, node: ast.With) -> None:
with self.catch_errors() as caught:
Expand All @@ -3431,17 +3459,39 @@ def _visit_assert_errors_block(self, node: ast.With) -> None:
)

def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
for item in node.items:
self.visit_withitem(item, is_async=True)
self._generic_visit_list(node.body)
self.visit_single_cm(node.items, node.body, is_async=True)

def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> Value:
def visit_single_cm(
self,
items: List[ast.withitem],
body: Iterable[ast.AST],
*,
is_async: bool = False,
) -> None:
if len(items) == 0:
self._generic_visit_list(body)
return
first_item = items[0]
can_suppress = self.visit_withitem(first_item, is_async)
with self.scopes.subscope() as rest_scope:
# get scope for visiting remaining CMs + Body
self.visit_single_cm(items[1:], body, is_async=is_async)
if can_suppress:
# If an exception was suppressed, assume no other CMs
# or any code in the body was executed.
with self.scopes.subscope() as dummy_subscope:
pass
self.scopes.combine_subscopes([rest_scope, dummy_subscope])
else:
self.scopes.combine_subscopes([rest_scope])

def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> bool:
context = self.visit(node.context_expr)
if is_async:
protocol = "typing.AsyncContextManager"
protocol = AsyncCustomContextManager
else:
protocol = "typing.ContextManager"
val = GenericValue(protocol, [TypeVarValue(T)])
protocol = CustomContextManager
val = GenericValue(protocol, [TypeVarValue(T), TypeVarValue(U)])
can_assign = get_tv_map(val, context, self)
if isinstance(can_assign, CanAssignError):
self._show_error_if_checking(
Expand All @@ -3451,12 +3501,24 @@ def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> Value:
error_code=ErrorCode.invalid_context_manager,
)
assigned = AnyValue(AnySource.error)
can_suppress = False
else:
assigned = can_assign.get(T, AnyValue(AnySource.generic_argument))
exit_assigned = can_assign.get(U, AnyValue(AnySource.generic_argument))
exit_boolability = get_boolability(exit_assigned)
can_suppress = not exit_boolability.is_safely_false()
if isinstance(exit_assigned, AnyValue) or (
isinstance(context, TypedValue)
and context.typ
in ["typing.ContextManager", "typing.AsyncContextManager"]
):
# cannot easily infer what the context manager will do,
# assume it does not suppress exceptions.
can_suppress = False
if node.optional_vars is not None:
with qcore.override(self, "being_assigned", assigned):
self.visit(node.optional_vars)
return context
return can_suppress

def visit_try_except(self, node: ast.Try) -> List[SubScope]:
# reset yield checks between branches to avoid incorrect errors when we yield both in the
Expand Down
4 changes: 2 additions & 2 deletions pyanalyze/patma.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ def visit_MatchMapping(self, node: MatchMapping) -> AbstractConstraint:
)
]
kv_pairs = list(reversed(kv_pairs))
optional_pairs = set()
removed_pairs = set()
optional_pairs: Set[KVPair] = set()
removed_pairs: Set[KVPair] = set()
for key, pattern in zip(node.keys, node.patterns):
key_val = self.visitor.visit(key)
value, new_optional_pairs, new_removed_pairs = get_value_from_kv_pairs(
Expand Down
156 changes: 156 additions & 0 deletions pyanalyze/test_name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,3 +2035,159 @@ def capybara(x: Optional[str], y: Optional[str]) -> Optional[str]:
assert_is_value(x, TypedValue(str) | KnownValue(None))
assert_is_value(y, TypedValue(str) | KnownValue(None))
return x or y


class TestContextManagerWithSuppression(TestNameCheckVisitorBase):
@assert_passes()
def test_sync(self):
from typing import Optional, Type, ContextManager, Iterator
from types import TracebackType
import contextlib

class SuppressException:
def __enter__(self) -> None:
pass

def __exit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
return isinstance(exn, Exception)

class EmptyContext(object):
def __enter__(self) -> None:
pass

def __exit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass

def empty_context_manager() -> ContextManager[None]:
return EmptyContext()

@contextlib.contextmanager
def empty_contextlib_manager() -> Iterator[None]:
yield

def use_suppress_exception():
a = 2
with SuppressException():
a = 3
assert_is_value(a, KnownValue(2) | KnownValue(3))

def use_empty_context():
a = 2 # static analysis: ignore[unused_variable]
with EmptyContext():
a = 3
assert_is_value(a, KnownValue(3))

def use_context_manager():
a = 2 # static analysis: ignore[unused_variable]
with empty_context_manager():
a = 3
assert_is_value(a, KnownValue(3))

def use_builtin_function():
a = 2 # static analysis: ignore[unused_variable]
with open("test_file.txt"):
a = 3
assert_is_value(a, KnownValue(3))

def use_contextlib_manager():
a = 2 # static analysis: ignore[unused_variable]
with empty_contextlib_manager():
a = 3
assert_is_value(a, KnownValue(3))

def use_nested_contexts():
b = 2
with SuppressException(), EmptyContext() as b:
assert_is_value(b, KnownValue(None))
assert_is_value(b, KnownValue(2) | KnownValue(None))

c = 2 # static analysis: ignore[unused_variable]
with EmptyContext() as c, SuppressException():
assert_is_value(c, KnownValue(None))
assert_is_value(c, KnownValue(None))

@assert_passes()
def test_async(self):
from typing import Optional, Type, AsyncContextManager
from types import TracebackType

class AsyncSuppressException(object):
async def __aenter__(self) -> None:
pass

async def __aexit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
return isinstance(exn, Exception)

class AsyncEmptyContext(object):
async def __aenter__(self) -> None:
pass

async def __aexit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass

def async_empty_context_manager() -> AsyncContextManager[None]:
return AsyncEmptyContext()

async def use_async_suppress_exception():
a = 2
async with AsyncSuppressException():
a = 3
assert_is_value(a, KnownValue(2) | KnownValue(3))

async def use_async_empty_context():
a = 2 # static analysis: ignore[unused_variable]
async with AsyncEmptyContext():
a = 3
assert_is_value(a, KnownValue(3))

async def use_async_context_manager():
a = 2 # static analysis: ignore[unused_variable]
async with async_empty_context_manager():
a = 3
assert_is_value(a, KnownValue(3))

async def use_async_nested_contexts():
b = 2
async with AsyncSuppressException(), AsyncEmptyContext() as b:
assert_is_value(b, KnownValue(None))
assert_is_value(b, KnownValue(2) | KnownValue(None))

c = 2 # static analysis: ignore[unused_variable]
async with AsyncEmptyContext() as c, AsyncSuppressException():
assert_is_value(c, KnownValue(None))
assert_is_value(c, KnownValue(None))

@skip_before((3, 7))
def test_async_contextlib_manager(self):
import contextlib
from typing import AsyncIterator

@contextlib.asynccontextmanager
async def async_empty_contextlib_manager() -> AsyncIterator[None]:
yield

async def use_async_contextlib_manager():
a = 2 # static analysis: ignore[unused_variable]
async with async_empty_contextlib_manager():
a = 3
assert_is_value(a, KnownValue(3))

0 comments on commit 5aeab3a

Please sign in to comment.