Skip to content

Commit

Permalink
it works!
Browse files Browse the repository at this point in the history
  • Loading branch information
thejcannon committed Jan 5, 2022
1 parent c08565f commit 83e7f2a
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from collections import defaultdict
from json.decoder import JSONDecodeError
import pathlib
import json
from dataclasses import dataclass
from typing import Tuple

from pants.backend.python.target_types import PythonSourceField
from pants.backend.python.util_rules.interpreter_constraints import InterpreterConstraints
Expand Down Expand Up @@ -33,8 +37,11 @@
# This regex is used to infer imports from strings, e.g.
# `importlib.import_module("example.subdir.Foo")`.
STRING_IMPORT_REGEX = re.compile(r"^([a-z_][a-z_\\d]*\\.){{{import_min_dots},}}[a-zA-Z_]\\w*$", re.UNICODE)
# @TODO: Need to handle more complex filenames
STRING_RESOURCE_REGEX = re.compile(r"^([\\w]*\\/){{{resource_min_slashes},}}\\w*\\.[^\\/]+$", re.UNICODE)
# This regex is used to infer resource names from strings, e.g.
# `load_resource("data/db1.json")
# Since Unix allows basically anything for filenames, we require some "sane" subset of possibilities
# namely, word-character filenames and a mandatory extension.
STRING_RESOURCE_REGEX = re.compile(r"^([\\w]*\\/){{{resource_min_slashes},}}\\w*(\\.[^\\/\\.]+)+$", re.UNICODE)
class AstVisitor(ast.NodeVisitor):
def __init__(self, package_parts):
Expand All @@ -48,7 +55,7 @@ def maybe_add_dependency(self, s):
if self._string_imports and STRING_IMPORT_REGEX.match(s):
self.imports.add(s)
if self._string_resources and STRING_RESOURCE_REGEX.match(s):
self.resources.add(s)
self.resources.add((None, s))
def visit_Import(self, node):
for alias in node.names:
Expand Down Expand Up @@ -155,16 +162,19 @@ class ParsedPythonImports(DeduplicatedCollection[str]):
# N.B Don't set `sort_input`, as the input is already sorted


class ParsedPythonResources(DeduplicatedCollection[str]):
"""All the discovered possible resources from a Python source file."""
class ParsedPythonResources(DeduplicatedCollection[Tuple[str, str]]):
"""All the discovered possible resources from a Python source file.
The tuple is of (containing module, relative filename), similar to
the arguments of `pkgutil.get_data`.
"""

# N.B. Don't set `sort_input`, as the input is already sorted


@dataclass(frozen=True)
class ParsedPythonDependencies:
imports: ParsedPythonImports
# N.B. resources are given as file paths relative to the resuest's source's parent dir
resources: ParsedPythonResources


Expand All @@ -177,6 +187,8 @@ class ParsePythonDependenciesRequest:
string_resources: bool
string_resources_min_slashes: int

def _filepath_to_modname(filepath: str):
return str(pathlib.Path(filepath).with_suffix("")).replace("/", ".")

@rule
async def parse_python_dependencies(
Expand Down Expand Up @@ -218,11 +230,20 @@ async def parse_python_dependencies(
)
# See above for where we explicitly encoded as utf8. Even though utf8 is the
# default for decode(), we make that explicit here for emphasis.
output = json.loads(process_result.stdout.decode("utf8").strip())
try:
output = json.loads(process_result.stdout.decode("utf8").strip())
except JSONDecodeError:
output = defaultdict(str)

return ParsedPythonDependencies(
imports=output["imports"],
resources=output["resources"],
resources=[
(
_filepath_to_modname(request.source.file_path) if pkgname is None else pkgname,
filepath
)
for pkgname, filepath in output["resources"]
],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ def rule_runner() -> RuleRunner:
)


def assert_imports_parsed(
def assert_deps_parsed(
rule_runner: RuleRunner,
content: str,
*,
expected: list[str],
expected_imports: list[str] = [],
expected_resources: list[str] = [],
filename: str = "project/foo.py",
constraints: str = ">=3.6",
string_imports: bool = True,
Expand All @@ -58,7 +59,7 @@ def assert_imports_parsed(
}
)
tgt = rule_runner.get_target(Address("", target_name="t"))
imports = rule_runner.request(
result = rule_runner.request(
ParsedPythonDependencies,
[
ParsePythonDependenciesRequest(
Expand All @@ -70,8 +71,9 @@ def assert_imports_parsed(
string_resources_min_slashes=string_resources_min_slashes,
)
],
).imports
assert list(imports) == sorted(expected)
)
assert list(result.imports) == sorted(expected_imports)
assert list(result.resources) == sorted(("project.foo", resource) for resource in expected_resources)


def test_normal_imports(rule_runner: RuleRunner) -> None:
Expand Down Expand Up @@ -100,10 +102,10 @@ def test_normal_imports(rule_runner: RuleRunner) -> None:
__import__("pkg_resources")
"""
)
assert_imports_parsed(
assert_deps_parsed(
rule_runner,
content,
expected=[
expected_imports=[
"__future__.print_function",
"os",
"os.path",
Expand All @@ -119,7 +121,6 @@ def test_normal_imports(rule_runner: RuleRunner) -> None:
],
)


@pytest.mark.parametrize("basename", ["foo.py", "__init__.py"])
def test_relative_imports(rule_runner: RuleRunner, basename: str) -> None:
content = dedent(
Expand All @@ -130,11 +131,11 @@ def test_relative_imports(rule_runner: RuleRunner, basename: str) -> None:
from ..parent import Parent
"""
)
assert_imports_parsed(
assert_deps_parsed(
rule_runner,
content,
filename=f"project/util/{basename}",
expected=[
expected_imports=[
"project.util.sibling",
"project.util.sibling.Nibling",
"project.util.subdir.child.Child",
Expand Down Expand Up @@ -192,16 +193,55 @@ def test_imports_from_strings(rule_runner: RuleRunner, min_dots: int) -> None:
]
expected = [sym for sym in potentially_valid if sym.count(".") >= min_dots]

assert_imports_parsed(rule_runner, content, expected=expected, string_imports_min_dots=min_dots)
assert_imports_parsed(rule_runner, content, string_imports=False, expected=[])
assert_deps_parsed(rule_runner, content, expected_imports=expected, string_imports_min_dots=min_dots)
assert_deps_parsed(rule_runner, content, string_imports=False, expected_imports=[])

@pytest.mark.parametrize("min_slashes", [1, 2, 3, 4])
def test_resources_from_strings(rule_runner: RuleRunner, min_slashes: int) -> None:
content = dedent(
"""\
resources = [
# Potentially valid strings (depending on min_slashes).
'a/b.txt',
'a/Foo.txt',
'a/b/d.json',
'a/b2/d.data',
'a/b/c/d.gz.tar',
'a/b_c/d.7zip',
'a/b/c_狗.txt',
# Invalid strings (according to our regex)
'a/b..',
'.gitignore',
'"A B"/foo.txt',
'a/foo.txt/b.txt',
'extensionless',
'windows\\sep.txt',
]
"""
)

potentially_valid = [
'a/b.txt',
'a/Foo.txt',
'a/b/d.json',
'a/b2/d.data',
'a/b/c/d.gz.tar',
'a/b_c/d.7zip',
'a/b/c_狗.txt',
]
expected = [sym for sym in potentially_valid if sym.count("/") >= min_slashes]

assert_deps_parsed(rule_runner, content, expected_resources=expected, string_resources_min_slashes=min_slashes)
assert_deps_parsed(rule_runner, content, string_resources=False, expected_resources=[])


def test_gracefully_handle_syntax_errors(rule_runner: RuleRunner) -> None:
assert_imports_parsed(rule_runner, "x =", expected=[])
assert_deps_parsed(rule_runner, "x =", expected_imports=[])


def test_handle_unicode(rule_runner: RuleRunner) -> None:
assert_imports_parsed(rule_runner, "x = 'äbç'", expected=[])
assert_deps_parsed(rule_runner, "x = 'äbç'", expected_imports=[])


@skip_unless_python27_present
Expand All @@ -224,11 +264,11 @@ def test_works_with_python2(rule_runner: RuleRunner) -> None:
b"\\xa0 a non-utf8 string, make sure we ignore it"
"""
)
assert_imports_parsed(
assert_deps_parsed(
rule_runner,
content,
constraints="==2.7.*",
expected=[
expected_imports=[
"demo",
"dep.from.bytes",
"dep.from.str",
Expand Down Expand Up @@ -257,11 +297,11 @@ def test_works_with_python38(rule_runner: RuleRunner) -> None:
importlib.import_module("dep.from.str")
"""
)
assert_imports_parsed(
assert_deps_parsed(
rule_runner,
content,
constraints=">=3.8",
expected=[
expected_imports=[
"demo",
"dep.from.str",
"project.demo.Demo",
Expand Down Expand Up @@ -290,11 +330,11 @@ def test_works_with_python39(rule_runner: RuleRunner) -> None:
importlib.import_module("dep.from.str")
"""
)
assert_imports_parsed(
assert_deps_parsed(
rule_runner,
content,
constraints=">=3.9",
expected=[
expected_imports=[
"demo",
"dep.from.str",
"project.demo.Demo",
Expand Down
38 changes: 17 additions & 21 deletions src/python/pants/backend/python/dependency_inference/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

logger = logging.getLogger(__name__)

# @TODO: Find a home for these

from pants.engine.target import AllTargets, Target
from dataclasses import dataclass
Expand Down Expand Up @@ -235,37 +236,32 @@ async def infer_python_dependencies_via_imports(
),
)

inferred_deps: set[Address] = set()

detected_imports = detected_dependencies.imports
detected_resources = detected_dependencies.resources

owners_per_import = await MultiGet(
Get(PythonModuleOwners, PythonModule(imported_module))
for imported_module in detected_imports
)

request_parent_path = pathlib.Path(request.sources_field.file_path).parent
resources_by_path: Dict[pathlib.Path, Target] = {}
for file_tgt in all_resource_targets.files:
resources_by_path[pathlib.Path(file_tgt[FileSourceField].file_path)] = file_tgt
for resource_tgt in all_resource_targets.resources:
path = pathlib.Path(resource_tgt[ResourceSourceField].file_path)
try:
resources_by_path[path.relative_to(request_parent_path)] = resource_tgt
except ValueError:
# The resource path is not relative to this source's parent
continue

for resource in detected_resources:
resource_path = pathlib.Path(resource)
resources_by_path[path] = resource_tgt

for pkgname, filepath in detected_resources:
resource_path = pathlib.Path(*pkgname.split('.')).parent / filepath
inferred_resource_tgt = resources_by_path.get(resource_path)
if inferred_resource_tgt:
logger.error(f"HUZZAH! {inferred_resource_tgt.address}")
inferred_deps.add(inferred_resource_tgt.address)

owners_per_import = await MultiGet(
Get(PythonModuleOwners, PythonModule(imported_module))
for imported_module in detected_imports
)

merged_result: set[Address] = set()
unowned_imports: set[str] = set()
address = wrapped_tgt.target.address
for owners, imp in zip(owners_per_import, detected_imports):
merged_result.update(owners.unambiguous)
inferred_deps.update(owners.unambiguous)
explicitly_provided_deps.maybe_warn_of_ambiguous_dependency_inference(
owners.ambiguous,
address,
Expand All @@ -274,7 +270,7 @@ async def infer_python_dependencies_via_imports(
)
maybe_disambiguated = explicitly_provided_deps.disambiguated(owners.ambiguous)
if maybe_disambiguated:
merged_result.add(maybe_disambiguated)
inferred_deps.add(maybe_disambiguated)

if not owners.unambiguous and imp.split(".")[0] not in DEFAULT_UNOWNED_DEPENDENCIES:
unowned_imports.add(imp)
Expand All @@ -296,7 +292,7 @@ async def infer_python_dependencies_via_imports(
"One or more unowned dependencies detected. Check logs for more details."
)

return InferredDependencies(sorted(merged_result))
return InferredDependencies(sorted(inferred_deps))


class InferInitDependencies(InferDependenciesRequest):
Expand Down Expand Up @@ -351,6 +347,7 @@ async def infer_python_conftest_dependencies(
def import_rules():
return [
infer_python_dependencies_via_imports,
find_all_resources,
*pex.rules(),
*dependency_parser.rules(),
*module_mapper.rules(),
Expand All @@ -363,7 +360,6 @@ def import_rules():
def rules():
return [
*import_rules(),
find_all_resources,
infer_python_init_dependencies,
infer_python_conftest_dependencies,
*ancestor_files.rules(),
Expand Down
Loading

0 comments on commit 83e7f2a

Please sign in to comment.