Skip to content

Commit

Permalink
Refactor: DRY adhoc_tool code_quality_tool (#20255)
Browse files Browse the repository at this point in the history
Before moving to step 2 of the plan described in
#17729 (comment)
, cleaning up a gross duplication of rule code that I introduced in
#20135 between `adhoc_tool` and
the new `code_quality_tool`.

This PR extracts the shared logic into the concept of a ToolRunner and a
rule to hydrate it in `adhoc_process_support`.

Both `adhoc_tool` and `code_quality_tool` have the latent idea of a tool
runner and a considerable machinery to build it. Starting from something
like
```python
@DataClass(frozen=True)
class ToolRunnerRequest:
    runnable_address_str: str
    args: tuple[str, ...]
    execution_dependencies: tuple[str, ...]
    runnable_dependencies: tuple[str, ...]
    target: Target
```
they need to assemble things like locate the actual runnable by str and
figure out what should be its base digest, args, env, etc. and also
co-locate the execution and runnable dependencies. We now capture that
information as a "runner":
```python
@DataClass(frozen=True)
class ToolRunner:
    digest: Digest
    args: tuple[str, ...]
    extra_env: Mapping[str, str]
    append_only_caches: Mapping[str, str]
    immutable_input_digests: Mapping[str, Digest]
```

After this, `adhoc_tool` and `code_quality_tool` diverge in what they do
with it. `adhoc_tool` uses this runner to generate code and
code_quality_tool uses it to run batches of lint/fmt/fix on source
files.

## Food for thought ...

It should not escape our attention that this `ToolRunner` could also be
surfaced as a Target, to be used by `adhoc_tool` and `code_quality_tool`
rather than each specifying all these fields together. It would also
help to reduce confusion when handling all the kinds of 'dependencies'
arguments that `adhoc_tool` takes.
  • Loading branch information
gauthamnair authored Dec 30, 2023
1 parent 6d8078b commit 3ffbba3
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 208 deletions.
115 changes: 19 additions & 96 deletions src/python/pants/backend/adhoc/adhoc_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,19 @@
AdhocToolStdoutFilenameField,
AdhocToolWorkdirField,
)
from pants.build_graph.address import Address, AddressInput
from pants.core.goals.run import RunFieldSet, RunInSandboxRequest
from pants.core.target_types import FileSourceField
from pants.core.util_rules.adhoc_process_support import (
AdhocProcessRequest,
AdhocProcessResult,
ExtraSandboxContents,
MergeExtraSandboxContents,
ResolvedExecutionDependencies,
ResolveExecutionDependenciesRequest,
ToolRunner,
ToolRunnerRequest,
)
from pants.core.util_rules.adhoc_process_support import rules as adhoc_process_support_rules
from pants.core.util_rules.environments import EnvironmentNameRequest
from pants.engine.addresses import Addresses
from pants.engine.environment import EnvironmentName
from pants.engine.fs import Digest, MergeDigests, Snapshot
from pants.engine.internals.native_engine import EMPTY_DIGEST
from pants.engine.fs import Digest, Snapshot
from pants.engine.rules import Get, collect_rules, rule
from pants.engine.target import (
FieldSetsPerTarget,
FieldSetsPerTargetRequest,
GeneratedSources,
GenerateSourcesRequest,
Targets,
)
from pants.engine.target import GeneratedSources, GenerateSourcesRequest
from pants.engine.unions import UnionRule
from pants.util.frozendict import FrozenDict
from pants.util.logging import LogLevel
Expand All @@ -63,6 +51,7 @@ async def run_in_sandbox_request(
) -> GeneratedSources:
target = request.protocol_target
description = f"the `{target.alias}` at {target.address}"

environment_name = await Get(
EnvironmentName, EnvironmentNameRequest, EnvironmentNameRequest.from_target(target)
)
Expand All @@ -71,104 +60,38 @@ async def run_in_sandbox_request(
if not runnable_address_str:
raise Exception(f"Must supply a value for `runnable` for {description}.")

runnable_address = await Get(
Address,
AddressInput,
AddressInput.parse(
runnable_address_str,
relative_to=target.address.spec_path,
description_of_origin=f"The `{AdhocToolRunnableField.alias}` field of {description}",
tool_runner = await Get(
ToolRunner,
ToolRunnerRequest(
runnable_address_str=runnable_address_str,
args=target.get(AdhocToolArgumentsField).value or (),
execution_dependencies=target.get(AdhocToolExecutionDependenciesField).value or (),
runnable_dependencies=target.get(AdhocToolRunnableDependenciesField).value or (),
target=request.protocol_target,
named_caches=FrozenDict(target.get(AdhocToolNamedCachesField).value or {}),
),
)

addresses = Addresses((runnable_address,))
addresses.expect_single()

runnable_targets = await Get(Targets, Addresses, addresses)
field_sets = await Get(
FieldSetsPerTarget, FieldSetsPerTargetRequest(RunFieldSet, runnable_targets)
)
run_field_set: RunFieldSet = field_sets.field_sets[0]

working_directory = target[AdhocToolWorkdirField].value or ""
root_output_directory = target[AdhocToolOutputRootDirField].value or ""

# Must be run in target environment so that the binaries/envvars match the execution
# environment when we actually run the process.
run_request = await Get(
RunInSandboxRequest, {environment_name: EnvironmentName, run_field_set: RunFieldSet}
)

execution_environment = await Get(
ResolvedExecutionDependencies,
ResolveExecutionDependenciesRequest(
target.address,
target.get(AdhocToolExecutionDependenciesField).value,
target.get(AdhocToolRunnableDependenciesField).value,
),
)
dependencies_digest = execution_environment.digest
runnable_dependencies = execution_environment.runnable_dependencies

extra_env: dict[str, str] = dict(run_request.extra_env or {})
extra_path = extra_env.pop("PATH", None)

extra_sandbox_contents = []

extra_sandbox_contents.append(
ExtraSandboxContents(
EMPTY_DIGEST,
extra_path,
run_request.immutable_input_digests or FrozenDict(),
run_request.append_only_caches or FrozenDict(),
run_request.extra_env or FrozenDict(),
)
)

if runnable_dependencies:
extra_sandbox_contents.append(
ExtraSandboxContents(
EMPTY_DIGEST,
f"{{chroot}}/{runnable_dependencies.path_component}",
runnable_dependencies.immutable_input_digests,
runnable_dependencies.append_only_caches,
runnable_dependencies.extra_env,
)
)

merged_extras = await Get(
ExtraSandboxContents, MergeExtraSandboxContents(tuple(extra_sandbox_contents))
)
extra_env = dict(merged_extras.extra_env)
if merged_extras.path:
extra_env["PATH"] = merged_extras.path

input_digest = await Get(Digest, MergeDigests((dependencies_digest, run_request.digest)))

output_files = target.get(AdhocToolOutputFilesField).value or ()
output_directories = target.get(AdhocToolOutputDirectoriesField).value or ()

extra_args = target.get(AdhocToolArgumentsField).value or ()

append_only_caches = {
**merged_extras.append_only_caches,
**(target.get(AdhocToolNamedCachesField).value or {}),
}

process_request = AdhocProcessRequest(
description=description,
address=target.address,
working_directory=working_directory,
root_output_directory=root_output_directory,
argv=tuple(run_request.args + extra_args),
argv=tool_runner.args,
timeout=None,
input_digest=input_digest,
immutable_input_digests=FrozenDict.frozen(merged_extras.immutable_input_digests),
append_only_caches=FrozenDict(append_only_caches),
input_digest=tool_runner.digest,
immutable_input_digests=FrozenDict.frozen(tool_runner.immutable_input_digests),
append_only_caches=FrozenDict(tool_runner.append_only_caches),
output_files=output_files,
output_directories=output_directories,
fetch_env_vars=target.get(AdhocToolExtraEnvVarsField).value or (),
supplied_env_var_values=FrozenDict(extra_env),
supplied_env_var_values=FrozenDict(tool_runner.extra_env),
log_on_process_errors=None,
log_output=target[AdhocToolLogOutputField].value,
capture_stderr_file=target[AdhocToolStderrFilenameField].value,
Expand Down
126 changes: 15 additions & 111 deletions src/python/pants/backend/adhoc/code_quality_tool.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
# Copyright 2023 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from dataclasses import dataclass
from typing import ClassVar, Iterable, Mapping
from typing import ClassVar, Iterable

from pants.core.goals.fix import Fix, FixFilesRequest, FixResult
from pants.core.goals.fmt import Fmt, FmtFilesRequest, FmtResult
from pants.core.goals.lint import Lint, LintFilesRequest, LintResult
from pants.core.goals.run import RunFieldSet, RunInSandboxRequest
from pants.core.util_rules.adhoc_process_support import (
ExtraSandboxContents,
MergeExtraSandboxContents,
ResolvedExecutionDependencies,
ResolveExecutionDependenciesRequest,
)
from pants.core.util_rules.adhoc_process_support import ToolRunner, ToolRunnerRequest
from pants.core.util_rules.adhoc_process_support import rules as adhoc_process_support_rules
from pants.core.util_rules.environments import EnvironmentNameRequest
from pants.core.util_rules.partitions import Partitions
from pants.engine.addresses import Addresses
from pants.engine.environment import EnvironmentName
from pants.engine.fs import PathGlobs
from pants.engine.goal import Goal
from pants.engine.internals.native_engine import (
EMPTY_DIGEST,
Address,
AddressInput,
Digest,
Expand All @@ -34,8 +25,6 @@
from pants.engine.rules import Rule, collect_rules, rule
from pants.engine.target import (
COMMON_TARGET_FIELDS,
FieldSetsPerTarget,
FieldSetsPerTargetRequest,
SpecialCasedDependencies,
StringField,
StringSequenceField,
Expand Down Expand Up @@ -199,18 +188,9 @@ async def find_code_quality_tool(request: CodeQualityToolAddressString) -> CodeQ
)


@dataclass(frozen=True)
class CodeQualityToolBatchRunner:
digest: Digest
args: tuple[str, ...]
extra_env: Mapping[str, str]
append_only_caches: Mapping[str, str]
immutable_input_digests: Mapping[str, Digest]


@dataclass(frozen=True)
class CodeQualityToolBatch:
runner: CodeQualityToolBatchRunner
runner: ToolRunner
sources_snapshot: Snapshot
output_files: tuple[str, ...]

Expand All @@ -237,91 +217,15 @@ async def process_files(batch: CodeQualityToolBatch) -> FallibleProcessResult:


@rule
async def hydrate_code_quality_tool(
request: CodeQualityToolAddressString,
) -> CodeQualityToolBatchRunner:
cqt = await Get(CodeQualityTool, CodeQualityToolAddressString, request)

runnable_address = await Get(
Address,
AddressInput,
AddressInput.parse(
cqt.runnable_address_str,
relative_to=cqt.target.address.spec_path,
description_of_origin=f"Runnable target for code quality tool {cqt.target.address.spec_path}",
),
)

addresses = Addresses((runnable_address,))
addresses.expect_single()

runnable_targets = await Get(Targets, Addresses, addresses)

target = runnable_targets[0]

run_field_sets, environment_name, execution_environment = await MultiGet(
Get(FieldSetsPerTarget, FieldSetsPerTargetRequest(RunFieldSet, runnable_targets)),
Get(EnvironmentName, EnvironmentNameRequest, EnvironmentNameRequest.from_target(target)),
Get(
ResolvedExecutionDependencies,
ResolveExecutionDependenciesRequest(
address=runnable_address,
execution_dependencies=cqt.execution_dependencies,
runnable_dependencies=cqt.runnable_dependencies,
),
),
)

run_field_set: RunFieldSet = run_field_sets.field_sets[0]

run_request = await Get(
RunInSandboxRequest, {environment_name: EnvironmentName, run_field_set: RunFieldSet}
)

dependencies_digest = execution_environment.digest
runnable_dependencies = execution_environment.runnable_dependencies

extra_env: dict[str, str] = dict(run_request.extra_env or {})
extra_path = extra_env.pop("PATH", None)

extra_sandbox_contents = []

extra_sandbox_contents.append(
ExtraSandboxContents(
EMPTY_DIGEST,
extra_path,
run_request.immutable_input_digests or FrozenDict(),
run_request.append_only_caches or FrozenDict(),
run_request.extra_env or FrozenDict(),
)
)

if runnable_dependencies:
extra_sandbox_contents.append(
ExtraSandboxContents(
EMPTY_DIGEST,
f"{{chroot}}/{runnable_dependencies.path_component}",
runnable_dependencies.immutable_input_digests,
runnable_dependencies.append_only_caches,
runnable_dependencies.extra_env,
)
)

merged_extras, main_digest = await MultiGet(
Get(ExtraSandboxContents, MergeExtraSandboxContents(tuple(extra_sandbox_contents))),
Get(Digest, MergeDigests((dependencies_digest, run_request.digest))),
)

extra_env = dict(merged_extras.extra_env)
if merged_extras.path:
extra_env["PATH"] = merged_extras.path

return CodeQualityToolBatchRunner(
digest=main_digest,
args=run_request.args + tuple(cqt.args),
extra_env=FrozenDict(extra_env),
append_only_caches=merged_extras.append_only_caches,
immutable_input_digests=merged_extras.immutable_input_digests,
async def runner_request_for_code_quality_tool(
cqt: CodeQualityTool,
) -> ToolRunnerRequest:
return ToolRunnerRequest(
runnable_address_str=cqt.runnable_address_str,
args=cqt.args,
execution_dependencies=cqt.execution_dependencies,
runnable_dependencies=cqt.runnable_dependencies,
target=cqt.target,
)


Expand Down Expand Up @@ -391,7 +295,7 @@ async def partition_inputs(
async def run_code_quality(request: CodeQualityProcessingRequest.Batch) -> LintResult:
sources_snapshot, code_quality_tool_runner = await MultiGet(
Get(Snapshot, PathGlobs(request.elements)),
Get(CodeQualityToolBatchRunner, CodeQualityToolAddressString(address=self.target)),
Get(ToolRunner, CodeQualityToolAddressString(address=self.target)),
)

proc_result = await Get(
Expand Down Expand Up @@ -445,7 +349,7 @@ async def run_code_quality(request: CodeQualityProcessingRequest.Batch) -> FmtRe
sources_snapshot = request.snapshot

code_quality_tool_runner = await Get(
CodeQualityToolBatchRunner, CodeQualityToolAddressString(address=self.target)
ToolRunner, CodeQualityToolAddressString(address=self.target)
)

proc_result = await Get(
Expand Down Expand Up @@ -507,7 +411,7 @@ async def run_code_quality(request: CodeQualityProcessingRequest.Batch) -> FixRe
sources_snapshot = request.snapshot

code_quality_tool_runner = await Get(
CodeQualityToolBatchRunner, CodeQualityToolAddressString(address=self.target)
ToolRunner, CodeQualityToolAddressString(address=self.target)
)

proc_result = await Get(
Expand Down
Loading

0 comments on commit 3ffbba3

Please sign in to comment.