Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type inference for @rules #17947

Merged
merged 13 commits into from
Feb 3, 2023
14 changes: 11 additions & 3 deletions src/python/pants/base/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from pants.engine.internals.native_engine import PyFailure


class TargetDefinitionException(Exception):
class PantsException(Exception):
"""Base exception type for Pants."""


class TargetDefinitionException(PantsException):
"""Indicates an invalid target definition.

:API: public
Expand All @@ -29,18 +33,22 @@ def __init__(self, target, msg):
super().__init__(f"Invalid target {target}: {msg}")


class BuildConfigurationError(Exception):
class BuildConfigurationError(PantsException):
"""Indicates an error in a pants installation's configuration."""


class BackendConfigurationError(BuildConfigurationError):
"""Indicates a plugin backend with a missing or malformed register module."""


class MappingError(Exception):
class MappingError(PantsException):
"""Indicates an error mapping addressable objects."""


class RuleTypeError(PantsException):
"""Invalid @rule implementation."""


class NativeEngineFailure(Exception):
"""A wrapper around a `Failure` instance.

Expand Down
260 changes: 228 additions & 32 deletions src/python/pants/engine/internals/rule_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@
import itertools
import logging
import sys
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, List
from typing import Any, Callable, Iterator, List

from pants.engine.internals.selectors import AwaitableConstraints, GetParseError
import typing_extensions

from pants.base.exceptions import RuleTypeError
from pants.engine.internals.selectors import (
Awaitable,
AwaitableConstraints,
Effect,
GetParseError,
MultiGet,
)
from pants.util.backport import get_annotations
from pants.util.memo import memoized
from pants.util.strutil import softwrap

logger = logging.getLogger(__name__)


def _is_awaitable_constraint(call_node: ast.Call) -> bool:
return isinstance(call_node.func, ast.Name) and call_node.func.id in ("Get", "Effect")


def _get_starting_indent(source: str) -> int:
"""Used to remove leading indentation from `source` so ast.parse() doesn't raise an
exception."""
Expand All @@ -28,6 +36,97 @@ def _get_starting_indent(source: str) -> int:
return 0


def _node_str(node: Any) -> str:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return ".".join([_node_str(node.value), node.attr])
if isinstance(node, ast.Call):
return _node_str(node.func)
if sys.version_info[0:2] < (3, 8):
if isinstance(node, ast.Str):
return node.s
else:
if isinstance(node, ast.Constant):
return str(node.value)
return str(node)


class _TypeStack:
def __init__(self, func: Callable) -> None:
self._stack: list[dict[str, Any]] = []
self.root = sys.modules[func.__module__]
self.push(self.root)
self._push_function_closures(func)

def __getitem__(self, name: str) -> Any:
for ns in reversed(self._stack):
if name in ns:
return ns[name]
return self.root.__builtins__.get(name, None)

def __setitem__(self, name: str, value: Any) -> None:
self._stack[-1][name] = value

def _push_function_closures(self, func: Callable) -> None:
try:
closurevars = [c for c in inspect.getclosurevars(func) if isinstance(c, dict)]
except ValueError:
return

for closures in closurevars:
self.push(closures)

def push(self, frame: object) -> None:
ns = dict(frame if isinstance(frame, dict) else frame.__dict__)
self._stack.append(ns)

def pop(self) -> None:
assert len(self._stack) > 1
self._stack.pop()


def _lookup_annotation(obj: Any, attr: str) -> Any:
"""Get type assocated with a particular attribute on object. This can get hairy, especially on
Python <3.10.

https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
For this reason, we've copied the `inspect.get_annotations` method from CPython `main` branch.
"""
if hasattr(obj, attr):
return getattr(obj, attr)
else:
try:
return get_annotations(obj, eval_str=True).get(attr)
except (NameError, TypeError):
return None
kaos marked this conversation as resolved.
Show resolved Hide resolved


def _lookup_return_type(func: Callable, check: bool = False) -> Any:
ret = _lookup_annotation(func, "return")
typ = typing_extensions.get_origin(ret)
if isinstance(typ, type):
args = typing_extensions.get_args(ret)
if issubclass(typ, (list, set, tuple)):
return tuple(args)
if check and ret is None:
func_file = inspect.getsourcefile(func)
func_line = func.__code__.co_firstlineno
raise TypeError(
f"Return type annotation required for `{func.__name__}` in {func_file}:{func_line}"
)
return ret


def _returns_awaitable(func: Any) -> bool:
if not callable(func):
return False
ret = _lookup_return_type(func)
if not isinstance(ret, tuple):
ret = (ret,)
return any(issubclass(r, Awaitable) for r in ret if isinstance(r, type))


class _AwaitableCollector(ast.NodeVisitor):
def __init__(self, func: Callable):
self.func = func
Expand All @@ -38,10 +137,14 @@ def __init__(self, func: Callable):

self.source_file = inspect.getsourcefile(func)

self.owning_module = sys.modules[func.__module__]
self.types = _TypeStack(func)
self.awaitables: List[AwaitableConstraints] = []
self.visit(ast.parse(source))

def _format(self, node: ast.AST, msg: str) -> str:
lineno = node.lineno + self.func.__code__.co_firstlineno - 1
return f"{self.source_file}:{lineno}: {msg}"

def _lookup(self, attr: ast.expr) -> Any:
names = []
while isinstance(attr, ast.Attribute):
Expand All @@ -56,39 +159,51 @@ def _lookup(self, attr: ast.expr) -> Any:
return attr

name = names.pop()
result = (
getattr(self.owning_module, name)
if hasattr(self.owning_module, name)
else self.owning_module.__builtins__.get(name, None)
)
result = self.types[name]
while result is not None and names:
result = getattr(result, names.pop(), None)

result = _lookup_annotation(result, names.pop())
return result

def _missing_type_error(self, node: ast.AST, context: str) -> str:
mod = self.types.root.__name__
return self._format(
node,
softwrap(
f"""
Could not resolve type for `{_node_str(node)}` in module {mod}.

{context}
"""
),
)

def _check_constraint_arg_type(self, resolved: Any, node: ast.AST) -> type:
lineno = node.lineno + self.func.__code__.co_firstlineno - 1
if resolved is None:
raise ValueError(
f"Could not resolve type `{node}` in top level of module "
f"{self.owning_module.__name__} defined in {self.source_file}:{lineno}"
raise RuleTypeError(
self._missing_type_error(
node, context="This may be a limitation of the Pants rule type inference."
)
)
elif not isinstance(resolved, type):
raise ValueError(
f"Expected a `type`, but got: {resolved}"
+ f" (type `{type(resolved).__name__}`) in {self.source_file}:{lineno}"
raise RuleTypeError(
self._format(
node,
f"Expected a type, but got: {type(resolved).__name__} {_node_str(resolved)!r}",
)
)
return resolved

def _get_awaitable(self, call_node: ast.Call) -> AwaitableConstraints:
func = self._lookup(call_node.func)
is_effect = func.__name__ == "Effect"
def _get_awaitable(self, call_node: ast.Call, is_effect: bool) -> AwaitableConstraints:
get_args = call_node.args
parse_error = partial(GetParseError, get_args=get_args, source_file_name=self.source_file)

if len(get_args) not in (2, 3):
# TODO: fix parse error message formatting... (TODO: create ticket)
raise parse_error(
f"Expected either two or three arguments, but got {len(get_args)} arguments."
self._format(
call_node,
f"Expected either two or three arguments, but got {len(get_args)} arguments.",
)
)

output_node = get_args[0]
Expand All @@ -99,8 +214,17 @@ def _get_awaitable(self, call_node: ast.Call) -> AwaitableConstraints:
if len(input_nodes) == 1:
input_constructor = input_nodes[0]
if isinstance(input_constructor, ast.Call):
cls_or_func = self._lookup(input_constructor.func)
try:
type_ = (
_lookup_return_type(cls_or_func, check=True)
if not isinstance(cls_or_func, type)
else cls_or_func
)
except TypeError as e:
raise RuleTypeError(self._missing_type_error(input_constructor, str(e))) from e
input_nodes = [input_constructor.func]
input_types = [self._lookup(input_constructor.func)]
input_types = [type_]
elif isinstance(input_constructor, ast.Dict):
input_nodes = input_constructor.values
input_types = [self._lookup(v) for v in input_constructor.values]
Expand All @@ -119,15 +243,87 @@ def _get_awaitable(self, call_node: ast.Call) -> AwaitableConstraints:
)

def visit_Call(self, call_node: ast.Call) -> None:
if _is_awaitable_constraint(call_node):
self.awaitables.append(self._get_awaitable(call_node))
else:
attr = self._lookup(call_node.func)
if hasattr(attr, "rule_helper"):
self.awaitables.extend(collect_awaitables(attr))
func = self._lookup(call_node.func)
if func is not None:
if isinstance(func, type) and issubclass(func, Awaitable):
self.awaitables.append(
self._get_awaitable(call_node, is_effect=issubclass(func, Effect))
)
elif inspect.iscoroutinefunction(func) or _returns_awaitable(func):
self.awaitables.extend(collect_awaitables(func))

self.generic_visit(call_node)

def visit_AsyncFunctionDef(self, rule: ast.AsyncFunctionDef) -> None:
with self._visit_rule_args(rule.args):
self.generic_visit(rule)

def visit_FunctionDef(self, rule: ast.FunctionDef) -> None:
with self._visit_rule_args(rule.args):
self.generic_visit(rule)

@contextmanager
def _visit_rule_args(self, node: ast.arguments) -> Iterator[None]:
self.types.push(
{
a.arg: self.types[a.annotation.id]
for a in node.args
if isinstance(a.annotation, ast.Name)
}
)
try:
yield
finally:
self.types.pop()

def visit_Assign(self, assign_node: ast.Assign) -> None:
awaitables_idx = len(self.awaitables)
self.generic_visit(assign_node)
collected_awaitables = self.awaitables[awaitables_idx:]
value = None
node: ast.AST = assign_node
while True:
if isinstance(node, (ast.Assign, ast.Await)):
node = node.value
continue
if isinstance(node, ast.Call):
f = self._lookup(node.func)
if f is MultiGet:
value = tuple(get.output_type for get in collected_awaitables)
elif f is not None:
value = _lookup_return_type(f)
elif isinstance(node, (ast.Name, ast.Attribute)):
value = self._lookup(node)
break

for tgt in assign_node.targets:
if isinstance(tgt, ast.Name):
names = [tgt.id]
values = [value]
elif isinstance(tgt, ast.Tuple):
names = [el.id for el in tgt.elts if isinstance(el, ast.Name)]
values = value or itertools.cycle([None]) # type: ignore[assignment]
else:
# subscript, etc..
continue
try:
for name, value in zip(names, values):
self.types[name] = value
except TypeError as e:
logger.debug(
self._format(
node,
softwrap(
f"""
Rule visitor failed to inspect assignment expression for
{names} - {values}:

{e}
"""
),
)
)


@memoized
def collect_awaitables(func: Callable) -> List[AwaitableConstraints]:
Expand Down
Loading