Skip to content

Commit

Permalink
Add suggested param/return type (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Dec 21, 2021
1 parent 4240d99 commit b1ee354
Show file tree
Hide file tree
Showing 13 changed files with 327 additions and 4 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Add check that suggests parameter and return types for untyped
functions, using the new `suggested_parameter_type` and
`suggested_return_type` codes (#358)
- Extract constraints from multi-comparisons (`a < b < c`) (#354)
- Support positional-only arguments with the `__` prefix
outside of stubs (#353)
Expand Down
2 changes: 2 additions & 0 deletions pyanalyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import safe
from . import signature
from . import stacked_scopes
from . import suggested_type
from . import test_config
from . import type_object
from . import typeshed
Expand All @@ -44,3 +45,4 @@
used(value.UNRESOLVED_VALUE) # keeping it around for now just in case
used(reexport)
used(checker)
used(suggested_type)
7 changes: 7 additions & 0 deletions pyanalyze/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@
import sys
from typing import Iterable, Iterator, List, Set, Tuple, Union, Dict

from .node_visitor import Failure
from .value import TypedValue
from .arg_spec import ArgSpecCache
from .config import Config
from .reexport import ImplicitReexportTracker
from .safe import is_instance_of_typing_name, is_typing_name, safe_getattr
from .type_object import TypeObject, get_mro
from .suggested_type import CallableTracker


@dataclass
class Checker:
config: Config
arg_spec_cache: ArgSpecCache = field(init=False)
reexport_tracker: ImplicitReexportTracker = field(init=False)
callable_tracker: CallableTracker = field(init=False)
type_object_cache: Dict[Union[type, super, str], TypeObject] = field(
default_factory=dict, init=False, repr=False
)
Expand All @@ -32,6 +35,10 @@ class Checker:
def __post_init__(self) -> None:
self.arg_spec_cache = ArgSpecCache(self.config)
self.reexport_tracker = ImplicitReexportTracker(self.config)
self.callable_tracker = CallableTracker()

def perform_final_checks(self) -> List[Failure]:
return self.callable_tracker.check()

def get_additional_bases(self, typ: Union[type, super]) -> Set[type]:
return self.config.get_additional_bases(typ)
Expand Down
6 changes: 6 additions & 0 deletions pyanalyze/error_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,16 @@ class ErrorCode(enum.Enum):
no_return_may_return = 68
implicit_reexport = 69
invalid_context_manager = 70
suggested_return_type = 71
suggested_parameter_type = 72


# Allow testing unannotated functions without too much fuss
DISABLED_IN_TESTS = {
ErrorCode.missing_return_annotation,
ErrorCode.missing_parameter_annotation,
ErrorCode.suggested_return_type,
ErrorCode.suggested_parameter_type,
}


Expand Down Expand Up @@ -193,6 +197,8 @@ class ErrorCode(enum.Enum):
ErrorCode.no_return_may_return: "Function is annotated as NoReturn but may return",
ErrorCode.implicit_reexport: "Use of implicitly re-exported name",
ErrorCode.invalid_context_manager: "Use of invalid object in with or async with",
ErrorCode.suggested_return_type: "Suggested return type",
ErrorCode.suggested_parameter_type: "Suggested parameter type",
}


Expand Down
37 changes: 37 additions & 0 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
ARGS,
KWARGS,
)
from .suggested_type import CallArgs, display_suggested_type
from .asynq_checker import AsyncFunctionKind, AsynqChecker, FunctionInfo
from .yield_checker import YieldChecker
from .type_object import TypeObject, get_mro
Expand Down Expand Up @@ -1306,6 +1307,17 @@ def visit_FunctionDef(
else:
potential_function = None

if (
potential_function is not None
and self.settings
and self.settings[ErrorCode.suggested_parameter_type]
):
sig = self.signature_from_value(KnownValue(potential_function))
if isinstance(sig, Signature):
self.checker.callable_tracker.record_callable(
node, potential_function, sig, self
)

self.yield_checker.reset_yield_checks()

# This code handles nested functions
Expand Down Expand Up @@ -1373,6 +1385,21 @@ def visit_FunctionDef(
else:
self._show_error_if_checking(node, error_code=ErrorCode.missing_return)

if (
has_return
and expected_return_value is None
and not info.is_overload
and not any(
decorator == KnownValue(abstractmethod)
for _, decorator in info.decorators
)
):
self._show_error_if_checking(
node,
error_code=ErrorCode.suggested_return_type,
detail=display_suggested_type(return_value),
)

if evaled_function:
return evaled_function

Expand Down Expand Up @@ -1427,6 +1454,10 @@ def visit_FunctionDef(
self.log(logging.DEBUG, "No argspec", (potential_function, node))
return KnownValue(potential_function)

def record_call(self, callable: object, arguments: CallArgs) -> None:
if self.settings and self.settings[ErrorCode.suggested_parameter_type]:
self.checker.callable_tracker.record_call(callable, arguments)

def _visit_defaults(
self, node: FunctionNode
) -> Tuple[List[Value], List[Optional[Value]]]:
Expand Down Expand Up @@ -4439,6 +4470,12 @@ def prepare_constructor_kwargs(cls, kwargs: Mapping[str, Any]) -> Mapping[str, A
kwargs.setdefault("checker", Checker(cls.config))
return kwargs

@classmethod
def perform_final_checks(
cls, kwargs: Mapping[str, Any]
) -> List[node_visitor.Failure]:
return kwargs["checker"].perform_final_checks()

@classmethod
def _run_on_files(
cls,
Expand Down
9 changes: 8 additions & 1 deletion pyanalyze/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def get_files_to_check(cls, include_tests: bool) -> List[str]:
def prepare_constructor_kwargs(cls, kwargs: Mapping[str, Any]) -> Mapping[str, Any]:
return kwargs

@classmethod
def perform_final_checks(cls, kwargs: Mapping[str, Any]) -> List[Failure]:
return []

@classmethod
def main(cls) -> int:
"""Can be used as a main function. Calls the checker on files given on the command line."""
Expand Down Expand Up @@ -520,6 +524,7 @@ def show_error(
obey_ignore: bool = True,
ignore_comment: str = IGNORE_COMMENT,
detail: Optional[str] = None,
save: bool = True,
) -> Optional[Failure]:
"""Shows an error associated with this node.
Expand Down Expand Up @@ -647,7 +652,8 @@ def show_error(
self._changes_for_fixer[self.filename].append(replacement)

error["message"] = message
self.all_failures.append(error)
if save:
self.all_failures.append(error)
sys.stderr.write(message)
sys.stderr.flush()
if self.fail_after_first:
Expand Down Expand Up @@ -710,6 +716,7 @@ def _run_on_files(cls, files: Iterable[str], **kwargs: Any) -> List[Failure]:
else:
for failures, _ in map(cls._check_file_single_arg, args):
all_failures += failures
all_failures += cls.perform_final_checks(kwargs)
return all_failures

@classmethod
Expand Down
8 changes: 7 additions & 1 deletion pyanalyze/reexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ class ErrorContext:
all_failures: List[Failure]

def show_error(
self, node: AST, message: str, error_code: Enum
self,
node: AST,
message: str,
error_code: Enum,
*,
detail: Optional[str] = None,
save: bool = True,
) -> Optional[Failure]:
raise NotImplementedError

Expand Down
4 changes: 4 additions & 0 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ def check_call_preprocessed(
if bound_args is None:
return self.get_default_return()
variables = {key: composite.value for key, (_, composite) in bound_args.items()}

if self.callable is not None:
visitor.record_call(self.callable, variables)

return_value = self.return_value
typevar_values: Dict[TypeVar, Value] = {}
if self.all_typevars:
Expand Down
178 changes: 178 additions & 0 deletions pyanalyze/suggested_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""
Suggest types for untyped code.
"""
import ast
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Mapping, Sequence, Union

from pyanalyze.safe import safe_isinstance

from .error_code import ErrorCode
from .node_visitor import Failure
from .value import (
AnnotatedValue,
AnySource,
AnyValue,
CallableValue,
CanAssignError,
GenericValue,
KnownValue,
SequenceIncompleteValue,
SubclassValue,
TypedDictValue,
TypedValue,
Value,
MultiValuedValue,
VariableNameValue,
replace_known_sequence_value,
unite_values,
)
from .reexport import ErrorContext
from .signature import Signature

CallArgs = Mapping[str, Value]
FunctionNode = Union[ast.FunctionDef, ast.AsyncFunctionDef]


@dataclass
class CallableData:
node: FunctionNode
ctx: ErrorContext
sig: Signature
calls: List[CallArgs] = field(default_factory=list)

def check(self) -> Iterator[Failure]:
if not self.calls:
return
for param in _extract_params(self.node):
if param.annotation is not None:
continue
sig_param = self.sig.parameters.get(param.arg)
if sig_param is None or not isinstance(sig_param.annotation, AnyValue):
continue # e.g. inferred type for self
all_values = [call[param.arg] for call in self.calls]
all_values = [prepare_type(v) for v in all_values]
all_values = [v for v in all_values if not isinstance(v, AnyValue)]
if not all_values:
continue
suggested = display_suggested_type(unite_values(*all_values))
failure = self.ctx.show_error(
param,
f"Suggested type for parameter {param.arg}",
ErrorCode.suggested_parameter_type,
detail=suggested,
# Otherwise we record it twice in tests. We should ultimately
# refactor error tracking to make it less hacky for things that
# show errors outside of files.
save=False,
)
if failure is not None:
yield failure


@dataclass
class CallableTracker:
callable_to_data: Dict[object, CallableData] = field(default_factory=dict)
callable_to_calls: Dict[object, List[CallArgs]] = field(
default_factory=lambda: defaultdict(list)
)

def record_callable(
self, node: FunctionNode, callable: object, sig: Signature, ctx: ErrorContext
) -> None:
"""Record when we encounter a callable."""
self.callable_to_data[callable] = CallableData(node, ctx, sig)

def record_call(self, callable: object, arguments: Mapping[str, Value]) -> None:
"""Record the actual arguments passed in in a call."""
self.callable_to_calls[callable].append(arguments)

def check(self) -> List[Failure]:
failures = []
for callable, calls in self.callable_to_calls.items():
if callable in self.callable_to_data:
data = self.callable_to_data[callable]
data.calls += calls
failures += data.check()
return failures


def display_suggested_type(value: Value) -> str:
value = prepare_type(value)
if isinstance(value, MultiValuedValue) and value.vals:
cae = CanAssignError("Union", [CanAssignError(str(val)) for val in value.vals])
else:
cae = CanAssignError(str(value))
return str(cae)


def prepare_type(value: Value) -> Value:
"""Simplify a type to turn it into a suggestion."""
if isinstance(value, AnnotatedValue):
return prepare_type(value.value)
elif isinstance(value, SequenceIncompleteValue):
if value.typ is tuple:
return SequenceIncompleteValue(
tuple, [prepare_type(elt) for elt in value.members]
)
else:
return GenericValue(value.typ, [prepare_type(arg) for arg in value.args])
elif isinstance(value, (TypedDictValue, CallableValue)):
return value
elif isinstance(value, GenericValue):
# TODO maybe turn DictIncompleteValue into TypedDictValue?
return GenericValue(value.typ, [prepare_type(arg) for arg in value.args])
elif isinstance(value, VariableNameValue):
return AnyValue(AnySource.unannotated)
elif isinstance(value, KnownValue):
if value.val is None or safe_isinstance(value.val, type):
return value
elif callable(value.val):
return value # TODO get the signature instead and return a CallableValue?
value = replace_known_sequence_value(value)
if isinstance(value, KnownValue):
return TypedValue(type(value.val))
else:
return prepare_type(value)
elif isinstance(value, MultiValuedValue):
vals = [prepare_type(subval) for subval in value.vals]
type_literals = [
v
for v in vals
if isinstance(v, KnownValue) and safe_isinstance(v.val, type)
]
if len(type_literals) > 1:
types = [v.val for v in type_literals if isinstance(v.val, type)]
shared_type = get_shared_type(types)
type_val = SubclassValue(TypedValue(shared_type))
others = [
v
for v in vals
if not isinstance(v, KnownValue) or not safe_isinstance(v.val, type)
]
return unite_values(type_val, *others)
return unite_values(*vals)
else:
return value


def get_shared_type(types: Sequence[type]) -> type:
mros = [t.mro() for t in types]
first, *rest = mros
rest_sets = [set(mro) for mro in rest]
for candidate in first:
if all(candidate in mro for mro in rest_sets):
return candidate
assert False, "should at least have found object"


def _extract_params(node: FunctionNode) -> Iterator[ast.arg]:
yield from node.args.args
if node.args.vararg is not None:
yield node.args.vararg
yield from node.args.kwonlyargs
if node.args.kwarg is not None:
yield node.args.kwarg
Loading

0 comments on commit b1ee354

Please sign in to comment.