Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support context managers that suppress exceptions #496

Merged
merged 9 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Support context managers that may suppress exceptions (#496)
- Fix type inference for `with` assignment targets on
Python 3.7 and higher (#495)
- Fix bug where code after a `while` loop is considered
Expand Down
1 change: 1 addition & 0 deletions pyanalyze/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _get_attribute_from_typed(
return KnownValue(typ)
elif ctx.attr == "__dict__":
return TypedValue(dict)

result, provider, should_unwrap = _get_attribute_from_mro(typ, ctx, on_class=False)
result = _substitute_typevars(typ, generic_args, result, provider, ctx)
if should_unwrap:
Expand Down
86 changes: 73 additions & 13 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
TypeVar,
Container,
)
from typing_extensions import Annotated
from typing_extensions import Annotated, Protocol

import asynq
import qcore
Expand Down Expand Up @@ -205,6 +205,7 @@
Match = Any

T = TypeVar("T")
U = TypeVar("U")
AwaitableValue = GenericValue(collections.abc.Awaitable, [TypeVarValue(T)])
KnownNone = KnownValue(None)
ExceptionValue = TypedValue(BaseException) | SubclassValue(TypedValue(BaseException))
Expand Down Expand Up @@ -276,6 +277,32 @@ def _not_in(a: object, b: Container[object]) -> bool:
}


class CustomContextManager(Protocol[T, U]):
def __enter__(self) -> T:
raise NotImplementedError

def __exit__(
self,
__exc_type: Optional[Type[BaseException]],
__exc_value: Optional[BaseException],
__traceback: Optional[types.TracebackType],
) -> U:
raise NotImplementedError


class AsyncCustomContextManager(Protocol[T, U]):
async def __aenter__(self) -> T:
raise NotImplementedError

async def __aexit__(
self,
__exc_type: Optional[Type[BaseException]],
__exc_value: Optional[BaseException],
__traceback: Optional[types.TracebackType],
) -> U:
raise NotImplementedError


@dataclass
class _StarredValue(Value):
"""Helper Value to represent the result of "*x".
Expand Down Expand Up @@ -2213,6 +2240,7 @@ def _simulate_import(
private=not force_public,
)

pseudo_module = None
nbdaaron marked this conversation as resolved.
Show resolved Hide resolved
with tempfile.NamedTemporaryFile(suffix=".py") as f:
f.write(source_code.encode("utf-8"))
f.flush()
Expand All @@ -2233,6 +2261,7 @@ def _simulate_import(
if pseudo_module_name in sys.modules:
del sys.modules[pseudo_module_name]

assert pseudo_module
for name, value in pseudo_module.__dict__.items():
if name.startswith("__") or (
hasattr(builtins, name) and value == getattr(builtins, name)
Expand Down Expand Up @@ -3410,15 +3439,16 @@ def _member_value_of_iterator(
return result

def visit_With(self, node: ast.With) -> None:
contexts = [self.visit_withitem(item) for item in node.items]
if len(contexts) == 1:
context = contexts[0]
if len(node.items) == 1:
with self.scopes.subscope():
context = self.visit(node.items[0].context_expr)
if isinstance(context, AnnotatedValue) and context.has_metadata_of_type(
AssertErrorExtension
):
self._visit_assert_errors_block(node)
return
self._generic_visit_list(node.body)

self.visit_single_cm(node.items, node.body, False)
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved

def _visit_assert_errors_block(self, node: ast.With) -> None:
with self.catch_errors() as caught:
Expand All @@ -3431,17 +3461,35 @@ def _visit_assert_errors_block(self, node: ast.With) -> None:
)

def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
for item in node.items:
self.visit_withitem(item, is_async=True)
self._generic_visit_list(node.body)
self.visit_single_cm(node.items, node.body, True)
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved

def visit_single_cm(
self, items: List[ast.withitem], body: Iterable[ast.AST], is_async: bool = False
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
if len(items) == 0:
self._generic_visit_list(body)
return
first_item = items[0]
can_suppress = self.visit_withitem(first_item, is_async)
with self.scopes.subscope() as rest_scope:
# get scope for visiting remaining CMs + Body
self.visit_single_cm(items[1:], body, is_async)
if can_suppress:
# If an exception was suppressed, assume no other CMs
# or any code in the body was executed.
with self.scopes.subscope() as dummy_subscope:
pass
self.scopes.combine_subscopes([rest_scope, dummy_subscope])
else:
self.scopes.combine_subscopes([rest_scope])

def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> Value:
def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> bool:
context = self.visit(node.context_expr)
if is_async:
protocol = "typing.AsyncContextManager"
protocol = AsyncCustomContextManager
else:
protocol = "typing.ContextManager"
val = GenericValue(protocol, [TypeVarValue(T)])
protocol = CustomContextManager
val = GenericValue(protocol, [TypeVarValue(T), TypeVarValue(U)])
can_assign = get_tv_map(val, context, self)
if isinstance(can_assign, CanAssignError):
self._show_error_if_checking(
Expand All @@ -3451,12 +3499,24 @@ def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> Value:
error_code=ErrorCode.invalid_context_manager,
)
assigned = AnyValue(AnySource.error)
can_suppress = False
else:
assigned = can_assign.get(T, AnyValue(AnySource.generic_argument))
exit_assigned = can_assign.get(U, AnyValue(AnySource.generic_argument))
exit_boolability = get_boolability(exit_assigned)
can_suppress = not exit_boolability.is_safely_false()
if isinstance(exit_assigned, AnyValue) or (
isinstance(context, TypedValue)
and context.typ
in ["typing.ContextManager", "typing.AsyncContextManager"]
):
# cannot easily infer what the context manager will do,
# assume it does not suppress exceptions.
can_suppress = False
if node.optional_vars is not None:
with qcore.override(self, "being_assigned", assigned):
self.visit(node.optional_vars)
return context
return can_suppress

def visit_try_except(self, node: ast.Try) -> List[SubScope]:
# reset yield checks between branches to avoid incorrect errors when we yield both in the
Expand Down
156 changes: 156 additions & 0 deletions pyanalyze/test_name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,3 +2035,159 @@ def capybara(x: Optional[str], y: Optional[str]) -> Optional[str]:
assert_is_value(x, TypedValue(str) | KnownValue(None))
assert_is_value(y, TypedValue(str) | KnownValue(None))
return x or y


class TestContextManagerWithSuppression(TestNameCheckVisitorBase):
@assert_passes()
def test_sync(self):
from typing import Optional, Type, ContextManager, Iterator
from types import TracebackType
import contextlib

class SuppressException:
def __enter__(self) -> None:
pass

def __exit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
return isinstance(exn, Exception)

class EmptyContext(object):
def __enter__(self) -> None:
pass

def __exit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass

def empty_context_manager() -> ContextManager[None]:
return EmptyContext()

@contextlib.contextmanager
def empty_contextlib_manager() -> Iterator[None]:
yield

def use_suppress_exception():
a = 2
with SuppressException():
a = 3
assert_is_value(a, KnownValue(2) | KnownValue(3))

def use_empty_context():
a = 2 # static analysis: ignore[unused_variable]
with EmptyContext():
a = 3
assert_is_value(a, KnownValue(3))

def use_context_manager():
a = 2 # static analysis: ignore[unused_variable]
with empty_context_manager():
a = 3
assert_is_value(a, KnownValue(3))

def use_builtin_function():
a = 2 # static analysis: ignore[unused_variable]
with open("test_file.txt"):
a = 3
assert_is_value(a, KnownValue(3))

def use_contextlib_manager():
a = 2 # static analysis: ignore[unused_variable]
with empty_contextlib_manager():
a = 3
assert_is_value(a, KnownValue(3))

def use_nested_contexts():
b = 2
with SuppressException(), EmptyContext() as b:
assert_is_value(b, KnownValue(None))
assert_is_value(b, KnownValue(2) | KnownValue(None))

c = 2 # static analysis: ignore[unused_variable]
with EmptyContext() as c, SuppressException():
assert_is_value(c, KnownValue(None))
assert_is_value(c, KnownValue(None))

@assert_passes()
def test_async(self):
from typing import Optional, Type, AsyncContextManager
from types import TracebackType

class AsyncSuppressException(object):
async def __aenter__(self) -> None:
pass

async def __aexit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
return isinstance(exn, Exception)

class AsyncEmptyContext(object):
async def __aenter__(self) -> None:
pass

async def __aexit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass

def async_empty_context_manager() -> AsyncContextManager[None]:
return AsyncEmptyContext()

async def use_async_suppress_exception():
a = 2
async with AsyncSuppressException():
a = 3
assert_is_value(a, KnownValue(2) | KnownValue(3))

async def use_async_empty_context():
a = 2 # static analysis: ignore[unused_variable]
async with AsyncEmptyContext():
a = 3
assert_is_value(a, KnownValue(3))

async def use_async_context_manager():
a = 2 # static analysis: ignore[unused_variable]
async with async_empty_context_manager():
a = 3
assert_is_value(a, KnownValue(3))

async def use_async_nested_contexts():
b = 2
async with AsyncSuppressException(), AsyncEmptyContext() as b:
assert_is_value(b, KnownValue(None))
assert_is_value(b, KnownValue(2) | KnownValue(None))

c = 2 # static analysis: ignore[unused_variable]
async with AsyncEmptyContext() as c, AsyncSuppressException():
assert_is_value(c, KnownValue(None))
assert_is_value(c, KnownValue(None))

@skip_before((3, 7))
def test_async_contextlib_manager(self):
import contextlib
from typing import AsyncIterator

@contextlib.asynccontextmanager
async def async_empty_contextlib_manager() -> AsyncIterator[None]:
yield

async def use_async_contextlib_manager():
a = 2 # static analysis: ignore[unused_variable]
async with async_empty_contextlib_manager():
a = 3
assert_is_value(a, KnownValue(3))
nbdaaron marked this conversation as resolved.
Show resolved Hide resolved