Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Upgrade rule parsing to follow calls #14238

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions src/python/pants/engine/rule_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

import ast
import inspect
import itertools
import logging
import sys
import types
from functools import partial
from typing import List

from pants.engine.internals.selectors import AwaitableConstraints, GetParseError

logger = logging.getLogger(__name__)

def _is_awaitable_constraint(call_node) -> bool:
return isinstance(call_node.func, ast.Name) and call_node.func.id in ("Get", "Effect")


def _get_starting_indent(source):
"""Used to remove leading indentation from `source` so ast.parse() doesn't raise an
exception."""
if source.startswith(" "):
return sum(1 for _ in itertools.takewhile(lambda c: c in {" ", b" "}, source))
return 0


def _get_lookup_names(attr: ast.Attribute | ast.Name):
names = []
while isinstance(attr, ast.Attribute):
names.append(attr.attr)
attr = attr.value
# NB: attr could be a constant, like `",".join()`
names.append(getattr(attr, "id", None))
return names


class _AwaitableCollector(ast.NodeVisitor):
def __init__(self, func):
self.func = func
self.source = inspect.getsource(func) or "<string>"
beginning_indent = _get_starting_indent(self.source)
if beginning_indent:
self.source = "\n".join(line[beginning_indent:] for line in self.source.split("\n"))
self.source_file = inspect.getsourcefile(func)
self.owning_module = sys.modules[func.__module__]
self.awaitables = []
self.visit(ast.parse(self.source))

def _resolve_constrain_arg_type(self, name, lineno):
lineno += self.func.__code__.co_firstlineno - 1
resolved = (
getattr(self.owning_module, name, None)
or self.owning_module.__builtins__.get(name, None)
) # fmt: skip
if resolved is None:
raise ValueError(
f"Could not resolve type `{name}` in top level of module "
f"{self.owning_module.__name__} defined in {self.source_file}:{lineno}"
)
elif not isinstance(resolved, type):
raise ValueError(
f"Expected a `type` constructor for `{name}`, but got: {resolved} (type "
f"`{type(resolved).__name__}`) in {self.source_file}:{lineno}"
)
return resolved

def _get_awaitable(self, call_node: ast.Call) -> AwaitableConstraints:
is_effect = call_node.func.id == "Effect"
get_args = call_node.args
parse_error = partial(GetParseError, get_args=get_args, source_file_name=self.source_file)

if len(get_args) not in (2, 3):
raise parse_error(
f"Expected either two or three arguments, but got {len(get_args)} arguments."
)

output_expr = get_args[0]
if not isinstance(output_expr, ast.Name):
raise parse_error(
"The first argument should be the output type, like `Digest` or `ProcessResult`."
)
output_type = output_expr

input_args = get_args[1:]
if len(input_args) == 1:
input_constructor = input_args[0]
if not isinstance(input_constructor, ast.Call):
raise parse_error(
f"Because you are using the shorthand form {call_node.func.id}(OutputType, "
"InputType(constructor args), the second argument should be a constructor "
"call, like `MergeDigest(...)` or `Process(...)`."
)
if not hasattr(input_constructor.func, "id"):
raise parse_error(
f"Because you are using the shorthand form {call_node.func.id}(OutputType, "
"InputType(constructor args), the second argument should be a top-level "
"constructor function call, like `MergeDigest(...)` or `Process(...)`, rather "
"than a method call."
)
input_type = input_constructor.func # type: ignore[attr-defined]
else:
input_type, _ = input_args
if not isinstance(input_type, ast.Name):
raise parse_error(
f"Because you are using the longhand form {call_node.func.id}(OutputType, "
"InputType, input), the second argument should be a type, like `MergeDigests` or "
"`Process`."
)

return AwaitableConstraints(
self._resolve_constrain_arg_type(output_type.id, output_type.lineno),
self._resolve_constrain_arg_type(input_type.id, input_type.lineno),
is_effect,
)

def visit_Call(self, call_node: ast.Call) -> None:
if _is_awaitable_constraint(call_node):
self.awaitables.append(self._get_awaitable(call_node))
else:
func_node = call_node.func
lookup_names = _get_lookup_names(func_node)
attr = self.func.__globals__.get(lookup_names.pop(), None)
while attr is not None and lookup_names:
attr = getattr(attr, lookup_names.pop(), None)

if (
attr is not None
and attr is not self.func
# NB: This only is True for free functions and staticmethods
and isinstance(attr, types.FunctionType)
thejcannon marked this conversation as resolved.
Show resolved Hide resolved
# @TODO: This wouldn't be true for plugins
and attr.__module__.startswith("pants")
):
self.awaitables.extend(collect_awaitables(attr))

self.generic_visit(call_node)


# @TODO: memoization?
def collect_awaitables(func) -> List[AwaitableConstraints]:
return _AwaitableCollector(func).awaitables
43 changes: 3 additions & 40 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ast
import inspect
import itertools
import logging
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -38,6 +39,7 @@
from pants.util.memo import memoized
from pants.util.meta import frozen_after_init
from pants.util.ordered_set import FrozenOrderedSet, OrderedSet
from pants.engine.rule_visitor import collect_awaitables


class _RuleVisitor(ast.NodeVisitor):
Expand Down Expand Up @@ -122,46 +124,7 @@ def wrapper(func):
if not inspect.isfunction(func):
raise ValueError("The @rule decorator must be applied innermost of all decorators.")

owning_module = sys.modules[func.__module__]
source = inspect.getsource(func) or "<string>"
source_file = inspect.getsourcefile(func)
beginning_indent = _get_starting_indent(source)
if beginning_indent:
source = "\n".join(line[beginning_indent:] for line in source.split("\n"))
module_ast = ast.parse(source)

def resolve_type(name):
resolved = getattr(owning_module, name, None) or owning_module.__builtins__.get(
name, None
)
if resolved is None:
raise ValueError(
f"Could not resolve type `{name}` in top level of module "
f"{owning_module.__name__} defined in {source_file}"
)
elif not isinstance(resolved, type):
raise ValueError(
f"Expected a `type` constructor for `{name}`, but got: {resolved} (type "
f"`{type(resolved).__name__}`) in {source_file}"
)
return resolved

rule_func_node = assert_single_element(
node
for node in ast.iter_child_nodes(module_ast)
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and node.name == func.__name__
)

parents_table = {}
for parent in ast.walk(rule_func_node):
for child in ast.iter_child_nodes(parent):
parents_table[child] = parent

rule_visitor = _RuleVisitor(source_file_name=source_file, resolve_type=resolve_type)
rule_visitor.visit(rule_func_node)

awaitables = FrozenOrderedSet(rule_visitor.awaitables)
awaitables = FrozenOrderedSet(collect_awaitables(func))

validate_requirements(func_id, parameter_types, awaitables, cacheable)

Expand Down