Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support context managers that suppress exceptions #496

Merged
merged 9 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyanalyze/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _get_attribute_from_typed(
return KnownValue(typ)
elif ctx.attr == "__dict__":
return TypedValue(dict)

result, provider, should_unwrap = _get_attribute_from_mro(typ, ctx, on_class=False)
result = _substitute_typevars(typ, generic_args, result, provider, ctx)
if should_unwrap:
Expand Down
54 changes: 47 additions & 7 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
Type,
TypeVar,
Container,
Protocol,
)
from typing_extensions import Annotated

Expand Down Expand Up @@ -205,6 +206,7 @@
Match = Any

T = TypeVar("T")
U = TypeVar("U")
AwaitableValue = GenericValue(collections.abc.Awaitable, [TypeVarValue(T)])
KnownNone = KnownValue(None)
ExceptionValue = TypedValue(BaseException) | SubclassValue(TypedValue(BaseException))
Expand Down Expand Up @@ -276,6 +278,19 @@ def _not_in(a: object, b: Container[object]) -> bool:
}


class CustomContextManager(Protocol[T, U]):
def __enter__(self) -> T:
raise NotImplementedError

def __exit__(
self,
_exc_type: Optional[Type[BaseException]],
_exc_value: Optional[BaseException],
nbdaaron marked this conversation as resolved.
Show resolved Hide resolved
__traceback: Optional[types.TracebackType],
) -> U:
raise NotImplementedError


@dataclass
class _StarredValue(Value):
"""Helper Value to represent the result of "*x".
Expand Down Expand Up @@ -2213,6 +2228,7 @@ def _simulate_import(
private=not force_public,
)

pseudo_module = None
nbdaaron marked this conversation as resolved.
Show resolved Hide resolved
with tempfile.NamedTemporaryFile(suffix=".py") as f:
f.write(source_code.encode("utf-8"))
f.flush()
Expand All @@ -2233,6 +2249,7 @@ def _simulate_import(
if pseudo_module_name in sys.modules:
del sys.modules[pseudo_module_name]

assert pseudo_module
for name, value in pseudo_module.__dict__.items():
if name.startswith("__") or (
hasattr(builtins, name) and value == getattr(builtins, name)
Expand Down Expand Up @@ -3410,14 +3427,27 @@ def _member_value_of_iterator(
return result

def visit_With(self, node: ast.With) -> None:
contexts = [self.visit_withitem(item) for item in node.items]
if len(contexts) == 1:
context = contexts[0]
results = [self.visit_withitem(item) for item in node.items]
if len(results) == 1:
context, _ = results[0]
if isinstance(context, AnnotatedValue) and context.has_metadata_of_type(
AssertErrorExtension
):
self._visit_assert_errors_block(node)
return

can_suppress = any(can_suppress for _, can_suppress in results)

if can_suppress:
with self.scopes.subscope() as body_scope:
self._generic_visit_list(node.body)

# assume nothing ran if an exception was suppressed
with self.scopes.subscope() as dummy_subscope:
pass

self.scopes.combine_subscopes([body_scope, dummy_subscope])
return
self._generic_visit_list(node.body)

def _visit_assert_errors_block(self, node: ast.With) -> None:
Expand All @@ -3435,13 +3465,15 @@ def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
self.visit_withitem(item, is_async=True)
self._generic_visit_list(node.body)

def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> Value:
def visit_withitem(
self, node: ast.withitem, is_async: bool = False
) -> Tuple[Value, bool]:
context = self.visit(node.context_expr)
if is_async:
protocol = "typing.AsyncContextManager"
nbdaaron marked this conversation as resolved.
Show resolved Hide resolved
else:
protocol = "typing.ContextManager"
val = GenericValue(protocol, [TypeVarValue(T)])
protocol = CustomContextManager
val = GenericValue(protocol, [TypeVarValue(T), TypeVarValue(U)])
can_assign = get_tv_map(val, context, self)
if isinstance(can_assign, CanAssignError):
self._show_error_if_checking(
Expand All @@ -3451,12 +3483,20 @@ def visit_withitem(self, node: ast.withitem, is_async: bool = False) -> Value:
error_code=ErrorCode.invalid_context_manager,
)
assigned = AnyValue(AnySource.error)
can_suppress = False
else:
assigned = can_assign.get(T, AnyValue(AnySource.generic_argument))
exit_assigned = can_assign.get(U, AnyValue(AnySource.generic_argument))
exit_boolability = get_boolability(exit_assigned)
can_suppress = not exit_boolability.is_safely_false()
if isinstance(exit_assigned, AnyValue):
# cannot easily infer what the context manager will do,
# assume it does not suppress exceptions.
can_suppress = False
if node.optional_vars is not None:
with qcore.override(self, "being_assigned", assigned):
self.visit(node.optional_vars)
return context
return (context, can_suppress)

def visit_try_except(self, node: ast.Try) -> List[SubScope]:
# reset yield checks between branches to avoid incorrect errors when we yield both in the
Expand Down
69 changes: 69 additions & 0 deletions pyanalyze/test_name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,3 +2035,72 @@ def capybara(x: Optional[str], y: Optional[str]) -> Optional[str]:
assert_is_value(x, TypedValue(str) | KnownValue(None))
assert_is_value(y, TypedValue(str) | KnownValue(None))
return x or y


class TestContextManagerWithSuppression(TestNameCheckVisitorBase):
@assert_passes()
def test(self):
from typing import Optional, Type, ContextManager, Iterator
from types import TracebackType
import contextlib

class SuppressException:
def __enter__(self):
pass

def __exit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
return isinstance(exn, Exception)

class EmptyContext(object):
def __enter__(self):
pass

def __exit__(
self,
typ: Optional[Type[BaseException]],
exn: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass

def empty_context_manager() -> ContextManager[None]:
return EmptyContext()

@contextlib.contextmanager
def contextlib_manager() -> Iterator[None]:
yield

def capybara():
a = 2
with SuppressException():
a = 3
assert_is_value(a, KnownValue(2) | KnownValue(3))

def pacarana():
a = 2 # static analysis: ignore[unused_variable]
with EmptyContext():
a = 3
assert_is_value(a, KnownValue(3))

def use_context_manager():
a = 2 # static analysis: ignore[unused_variable]
with empty_context_manager():
a = 3
assert_is_value(a, KnownValue(3))

def use_builtin_function():
a = 2 # static analysis: ignore[unused_variable]
with open("test_file.txt"):
a = 3
assert_is_value(a, KnownValue(3))

def use_contextlib_manager():
a = 2 # static analysis: ignore[unused_variable]
with contextlib_manager():
a = 3
assert_is_value(a, KnownValue(3))
nbdaaron marked this conversation as resolved.
Show resolved Hide resolved