Skip to content

Commit

Permalink
implement return annotation behavior for type eval (#408)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jan 12, 2022
1 parent d5258c7 commit 15538e6
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 44 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Implement return annotation behavior for type evaluation
functions (#408)
- Support `extend_config` option in `pyproject.toml` (#407)
- Remove the old method return type check. Use the new
`incompatible_override` check instead (#404)
Expand Down
3 changes: 0 additions & 3 deletions docs/type_evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -690,13 +690,10 @@ in pyanalyze:

Currently unsupported features include:

- Usage in stubs
- pyanalyze should provide a way to register
an evaluation function for a runtime function,
to replace some impls.
- Type compatibility for evaluated functions.
- Implementation of the desired behavior for
return annotations

Areas that need more thought include:

Expand Down
32 changes: 8 additions & 24 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
import qcore
import ast
import builtins
import inspect
from collections.abc import Callable, Iterable
import textwrap
from typing import (
Any,
Container,
Expand All @@ -59,7 +57,6 @@
NoReturnGuard,
ParameterTypeGuard,
TypeGuard,
get_type_evaluation,
)
from .find_unused import used
from .functions import FunctionDefNode
Expand Down Expand Up @@ -170,25 +167,6 @@ def get_name(self, node: ast.Name) -> Value:
"""Return the :class:`Value <pyanalyze.value.Value>` corresponding to a name."""
return self.get_name_from_globals(node.id, self.globals)

@classmethod
def get_for(cls, func: typing.Callable[..., Any]) -> Optional["RuntimeEvaluator"]:
try:
key = f"{func.__module__}.{func.__qualname__}"
except AttributeError:
return None
evaluation_func = get_type_evaluation(key)
if evaluation_func is None or not hasattr(evaluation_func, "__globals__"):
return None
lines, _ = inspect.getsourcelines(evaluation_func)
code = textwrap.dedent("".join(lines))
body = ast.parse(code)
if not body.body:
return None
evaluator = body.body[0]
if not isinstance(evaluator, ast.FunctionDef):
return None
return RuntimeEvaluator(evaluator, evaluation_func.__globals__, evaluation_func)


@dataclass
class SyntheticEvaluator(type_evaluation.Evaluator):
Expand Down Expand Up @@ -217,10 +195,16 @@ def get_name(self, node: ast.Name) -> Value:

@classmethod
def from_visitor(
cls, node: FunctionDefNode, visitor: "NameCheckVisitor"
cls,
node: FunctionDefNode,
visitor: "NameCheckVisitor",
return_annotation: Value,
) -> "SyntheticEvaluator":
return cls(
node, visitor, _DefaultContext(visitor, node, use_name_node_for_error=True)
node,
return_annotation,
visitor,
_DefaultContext(visitor, node, use_name_node_for_error=True),
)


Expand Down
44 changes: 36 additions & 8 deletions pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .options import Options, PyObjectSequenceOption
from .analysis_lib import is_positional_only_arg_name
from .extensions import CustomCheck, get_overloads
from .extensions import CustomCheck, get_overloads, get_type_evaluation
from .annotations import Context, RuntimeEvaluator, type_from_runtime
from .config import Config
from .find_unused import used
Expand Down Expand Up @@ -60,6 +60,7 @@
import qcore
import inspect
import sys
import textwrap
from types import FunctionType, ModuleType
from typing import (
Any,
Expand Down Expand Up @@ -509,6 +510,37 @@ def _cached_get_argspec(
self.known_argspecs[obj] = extended
return extended

def _maybe_make_evaluator_sig(
self, func: Callable[..., Any], impl: Optional[Impl], is_asynq: bool
) -> MaybeSignature:
try:
key = f"{func.__module__}.{func.__qualname__}"
except AttributeError:
return None
evaluation_func = get_type_evaluation(key)
if evaluation_func is None or not hasattr(evaluation_func, "__globals__"):
return None
sig = self._cached_get_argspec(
evaluation_func, impl, is_asynq, in_overload_resolution=True
)
if sig is None:
return None
lines, _ = inspect.getsourcelines(evaluation_func)
code = textwrap.dedent("".join(lines))
body = ast.parse(code)
if not body.body:
return None
evaluator_node = body.body[0]
if not isinstance(evaluator_node, ast.FunctionDef):
return None
evaluator = RuntimeEvaluator(
evaluator_node,
sig.return_value,
evaluation_func.__globals__,
evaluation_func,
)
return replace(sig, evaluator=evaluator)

def _uncached_get_argspec(
self,
obj: Any,
Expand All @@ -529,13 +561,9 @@ def _uncached_get_argspec(
]
if all_of_type(sigs, Signature):
return OverloadedSignature(sigs)
evaluator = RuntimeEvaluator.get_for(obj)
if evaluator is not None:
sig = self._cached_get_argspec(
evaluator.func, impl, is_asynq, in_overload_resolution=True
)
if isinstance(sig, Signature):
return replace(sig, evaluator=evaluator)
evaluator_sig = self._maybe_make_evaluator_sig(obj, impl, is_asynq)
if evaluator_sig is not None:
return evaluator_sig

if isinstance(obj, tuple) or hasattr(obj, "__getattr__"):
return None # lost cause
Expand Down
5 changes: 4 additions & 1 deletion pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,10 @@ def _visit_function_body(self, function_info: FunctionInfo) -> FunctionResult:
if self._is_collecting() or isinstance(node, ast.Lambda):
return FunctionResult(parameters=params)
with self.scopes.allow_only_module_scope():
evaluator = SyntheticEvaluator.from_visitor(node, self)
# The return annotation doesn't actually matter for validation.
evaluator = SyntheticEvaluator.from_visitor(
node, self, AnyValue(AnySource.marker)
)
ctx = type_evaluation.EvalContext(
variables={param.name: param.annotation for param in params},
positions={param.name: type_evaluation.DEFAULT for param in params},
Expand Down
7 changes: 7 additions & 0 deletions pyanalyze/stubs/_pyanalyze_tests-stubs/evaluated.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,10 @@ def open(mode: str):
return BinaryIO
else:
return IO[Any]

@evaluated
def open2(mode: str) -> IO[Any]:
if mode == "r":
return TextIO
elif mode == "rb":
return BinaryIO
13 changes: 13 additions & 0 deletions pyanalyze/test_type_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@ def capybara(i: int):
has_default() # E: incompatible_call
has_default(i) # E: incompatible_call

@assert_passes()
def test_return(self):
from pyanalyze.extensions import evaluated

@evaluated
def maybe_use_header(x: bool) -> int:
if x is True:
return str

def capybara(x: bool):
assert_is_value(maybe_use_header(True), TypedValue(str))
assert_is_value(maybe_use_header(x), TypedValue(int))


class TestBoolOp(TestNameCheckVisitorBase):
@assert_passes()
Expand Down
11 changes: 10 additions & 1 deletion pyanalyze/test_typeshed.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_evaluated(self):
@assert_passes()
def test_evaluated_import(self):
def capybara(unannotated):
from _pyanalyze_tests.evaluated import open
from _pyanalyze_tests.evaluated import open, open2
from typing import TextIO, BinaryIO, IO

assert_is_value(open("r"), TypedValue(TextIO))
Expand All @@ -235,6 +235,15 @@ def capybara(unannotated):
open("r" if unannotated else "rb"),
TypedValue(TextIO) | TypedValue(BinaryIO),
)
assert_is_value(open2("r"), TypedValue(TextIO))
assert_is_value(open2("rb"), TypedValue(BinaryIO))
assert_is_value(
open2(unannotated), GenericValue(IO, [AnyValue(AnySource.explicit)])
)
assert_is_value(
open2("r" if unannotated else "rb"),
TypedValue(TextIO) | TypedValue(BinaryIO),
)

@assert_passes()
def test_recursive_base(self):
Expand Down
8 changes: 2 additions & 6 deletions pyanalyze/type_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
from .safe import all_of_type
from .value import (
NO_RETURN_VALUE,
AnySource,
AnyValue,
CanAssign,
CanAssignContext,
CanAssignError,
Expand Down Expand Up @@ -281,6 +279,7 @@ def narrow_variables(self, varmap: Optional[VarMap]) -> Iterator[None]:
@dataclass
class Evaluator:
node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
return_annotation: Value

def evaluate(self, ctx: EvalContext) -> Tuple[Value, Sequence[UserRaisedError]]:
visitor = EvaluateVisitor(self, ctx)
Expand Down Expand Up @@ -629,10 +628,7 @@ def run(self) -> Value:

def _evaluate_ret(self, ret: EvalReturn, node: ast.AST) -> Value:
if ret is None:
# TODO return the func's return annotation instead
if not self.validation_mode:
self.add_invalid("Evaluator failed to return", node)
return AnyValue(AnySource.error)
return self.evaluator.return_annotation
elif isinstance(ret, CombinedReturn):
children = [self._evaluate_ret(child, node) for child in ret.children]
return unite_values(*children)
Expand Down
4 changes: 3 additions & 1 deletion pyanalyze/typeshed.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,9 @@ def _get_signature_from_func_def(
cleaned_arguments.append(arg)
if is_evaluated:
ctx = _AnnotationContext(self, mod)
evaluator = SyntheticEvaluator(node, _DummyErrorContext(), ctx)
evaluator = SyntheticEvaluator(
node, return_value, _DummyErrorContext(), ctx
)
else:
evaluator = None
return Signature.make(
Expand Down

0 comments on commit 15538e6

Please sign in to comment.