Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Unify" fmt and lint rules for formatters #14903

Merged
merged 11 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 9 additions & 28 deletions src/python/pants/backend/codegen/protobuf/lint/buf/format_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
ProtobufSourceField,
)
from pants.core.goals.fmt import FmtRequest, FmtResult
from pants.core.goals.lint import LintResult, LintResults, LintTargetsRequest
from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.core.util_rules.system_binaries import (
Expand All @@ -21,7 +20,7 @@
from pants.engine.fs import Digest, MergeDigests
from pants.engine.internals.native_engine import Snapshot
from pants.engine.platform import Platform
from pants.engine.process import FallibleProcessResult, Process, ProcessResult
from pants.engine.process import Process, ProcessResult
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import FieldSet, Target
from pants.engine.unions import UnionRule
Expand All @@ -41,25 +40,19 @@ def opt_out(cls, tgt: Target) -> bool:
return tgt.get(SkipBufFormatField).value


class BufFormatRequest(LintTargetsRequest, FmtRequest):
class BufFormatRequest(FmtRequest):
field_set_type = BufFieldSet
name = "buf-format"


@dataclass(frozen=True)
class SetupRequest:
request: BufFormatRequest
check_only: bool


@dataclass(frozen=True)
class Setup:
process: Process
original_snapshot: Snapshot


@rule(level=LogLevel.DEBUG)
async def setup_buf_format(setup_request: SetupRequest, buf: BufSubsystem) -> Setup:
async def setup_buf_format(request: BufFormatRequest, buf: BufSubsystem) -> Setup:
diff_binary = await Get(DiffBinary, DiffBinaryRequest())
download_buf_get = Get(
DownloadedExternalTool, ExternalToolRequest, buf.get_request(Platform.current)
Expand All @@ -75,16 +68,16 @@ async def setup_buf_format(setup_request: SetupRequest, buf: BufSubsystem) -> Se
)
source_files_get = Get(
SourceFiles,
SourceFilesRequest(field_set.sources for field_set in setup_request.request.field_sets),
SourceFilesRequest(field_set.sources for field_set in request.field_sets),
)
downloaded_buf, binary_shims, source_files = await MultiGet(
download_buf_get, binary_shims_get, source_files_get
)

source_files_snapshot = (
source_files.snapshot
if setup_request.request.prior_formatter_result is None
else setup_request.request.prior_formatter_result
if request.prior_formatter_result is None
else request.prior_formatter_result
)

input_digest = await Get(
Expand All @@ -95,9 +88,7 @@ async def setup_buf_format(setup_request: SetupRequest, buf: BufSubsystem) -> Se
argv = [
downloaded_buf.exe,
"format",
# If linting, use `-d` to error with a diff and `--exit-code` to exit with a non-zero exit code if
# the file is not already formatted. Else, write the change with `-w`.
*(["-d", "--exit-code"] if setup_request.check_only else ["-w"]),
"-w",
*buf.format_args,
"--path",
",".join(source_files_snapshot.files),
Expand All @@ -106,7 +97,7 @@ async def setup_buf_format(setup_request: SetupRequest, buf: BufSubsystem) -> Se
argv=argv,
input_digest=input_digest,
output_files=source_files_snapshot.files,
description=f"Run buf format on {pluralize(len(setup_request.request.field_sets), 'file')}.",
description=f"Run buf format on {pluralize(len(request.field_sets), 'file')}.",
level=LogLevel.DEBUG,
env={
"PATH": binary_shims.bin_directory,
Expand All @@ -119,7 +110,7 @@ async def setup_buf_format(setup_request: SetupRequest, buf: BufSubsystem) -> Se
async def run_buf_format(request: BufFormatRequest, buf: BufSubsystem) -> FmtResult:
if buf.skip_format:
return FmtResult.skip(formatter_name=request.name)
setup = await Get(Setup, SetupRequest(request, check_only=False))
setup = await Get(Setup, BufFormatRequest, request)
result = await Get(ProcessResult, Process, setup.process)
output_snapshot = await Get(Snapshot, Digest, result.output_digest)
return FmtResult(
Expand All @@ -131,18 +122,8 @@ async def run_buf_format(request: BufFormatRequest, buf: BufSubsystem) -> FmtRes
)


@rule(desc="Lint with buf format", level=LogLevel.DEBUG)
async def run_buf_lint(request: BufFormatRequest, buf: BufSubsystem) -> LintResults:
if buf.skip_format:
return LintResults([], linter_name=request.name)
setup = await Get(Setup, SetupRequest(request, check_only=True))
result = await Get(FallibleProcessResult, Process, setup.process)
return LintResults([LintResult.from_fallible_process_result(result)], linter_name=request.name)


def rules():
return [
*collect_rules(),
UnionRule(FmtRequest, BufFormatRequest),
UnionRule(LintTargetsRequest, BufFormatRequest),
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pants.backend.codegen.protobuf.target_types import ProtobufSourcesGeneratorTarget
from pants.backend.codegen.protobuf.target_types import rules as target_types_rules
from pants.core.goals.fmt import FmtResult
from pants.core.goals.lint import LintResult, LintResults
from pants.core.util_rules import config_files, external_tool, source_files
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.addresses import Address
Expand All @@ -29,7 +28,6 @@ def rule_runner() -> RuleRunner:
*external_tool.rules(),
*source_files.rules(),
*target_types_rules(),
QueryRule(LintResults, [BufFormatRequest]),
QueryRule(FmtResult, [BufFormatRequest]),
QueryRule(SourceFiles, [SourceFilesRequest]),
],
Expand All @@ -46,7 +44,7 @@ def run_buf(
targets: list[Target],
*,
extra_args: list[str] | None = None,
) -> tuple[tuple[LintResult, ...], FmtResult]:
) -> FmtResult:
rule_runner.set_options(
[
"--backend-packages=pants.backend.codegen.protobuf.lint.buf",
Expand All @@ -55,7 +53,6 @@ def run_buf(
env_inherit={"PATH"},
)
field_sets = [BufFieldSet.create(tgt) for tgt in targets]
results = rule_runner.request(LintResults, [BufFormatRequest(field_sets)])
input_sources = rule_runner.request(
SourceFiles,
[
Expand All @@ -69,7 +66,7 @@ def run_buf(
],
)

return results.results, fmt_result
return fmt_result


def get_snapshot(rule_runner: RuleRunner, source_files: dict[str, str]) -> Snapshot:
Expand All @@ -81,11 +78,7 @@ def get_snapshot(rule_runner: RuleRunner, source_files: dict[str, str]) -> Snaps
def test_passing(rule_runner: RuleRunner) -> None:
rule_runner.write_files({"f.proto": GOOD_FILE, "BUILD": "protobuf_sources(name='t')"})
tgt = rule_runner.get_target(Address("", target_name="t", relative_file_path="f.proto"))
lint_results, fmt_result = run_buf(rule_runner, [tgt])
assert len(lint_results) == 1
assert lint_results[0].exit_code == 0
assert lint_results[0].stdout == ""
assert lint_results[0].stderr == ""
fmt_result = run_buf(rule_runner, [tgt])
assert fmt_result.stdout == ""
assert fmt_result.output == get_snapshot(rule_runner, {"f.proto": GOOD_FILE})
assert fmt_result.did_change is False
Expand All @@ -94,10 +87,7 @@ def test_passing(rule_runner: RuleRunner) -> None:
def test_failing(rule_runner: RuleRunner) -> None:
rule_runner.write_files({"f.proto": BAD_FILE, "BUILD": "protobuf_sources(name='t')"})
tgt = rule_runner.get_target(Address("", target_name="t", relative_file_path="f.proto"))
lint_results, fmt_result = run_buf(rule_runner, [tgt])
assert len(lint_results) == 1
assert lint_results[0].exit_code == 100
assert "f.proto.orig" in lint_results[0].stdout
fmt_result = run_buf(rule_runner, [tgt])
assert fmt_result.output == get_snapshot(rule_runner, {"f.proto": GOOD_FILE})
assert fmt_result.did_change is True

Expand All @@ -110,11 +100,7 @@ def test_multiple_targets(rule_runner: RuleRunner) -> None:
rule_runner.get_target(Address("", target_name="t", relative_file_path="good.proto")),
rule_runner.get_target(Address("", target_name="t", relative_file_path="bad.proto")),
]
lint_results, fmt_result = run_buf(rule_runner, tgts)
assert len(lint_results) == 1
assert lint_results[0].exit_code == 100
assert "bad.proto.orig" in lint_results[0].stdout
assert "good.proto" not in lint_results[0].stdout
fmt_result = run_buf(rule_runner, tgts)
assert fmt_result.output == get_snapshot(
rule_runner, {"good.proto": GOOD_FILE, "bad.proto": GOOD_FILE}
)
Expand All @@ -124,11 +110,7 @@ def test_multiple_targets(rule_runner: RuleRunner) -> None:
def test_passthrough_args(rule_runner: RuleRunner) -> None:
rule_runner.write_files({"f.proto": GOOD_FILE, "BUILD": "protobuf_sources(name='t')"})
tgt = rule_runner.get_target(Address("", target_name="t", relative_file_path="f.proto"))
lint_results, fmt_result = run_buf(rule_runner, [tgt], extra_args=["--buf-format-args=--debug"])
assert len(lint_results) == 1
assert lint_results[0].exit_code == 0
assert lint_results[0].stdout == ""
assert "DEBUG" in lint_results[0].stderr
fmt_result = run_buf(rule_runner, [tgt], extra_args=["--buf-format-args=--debug"])
assert fmt_result.stdout == ""
assert fmt_result.output == get_snapshot(rule_runner, {"f.proto": GOOD_FILE})
assert fmt_result.did_change is False
Expand All @@ -137,7 +119,6 @@ def test_passthrough_args(rule_runner: RuleRunner) -> None:
def test_skip(rule_runner: RuleRunner) -> None:
rule_runner.write_files({"f.proto": BAD_FILE, "BUILD": "protobuf_sources(name='t')"})
tgt = rule_runner.get_target(Address("", target_name="t", relative_file_path="f.proto"))
lint_results, fmt_result = run_buf(rule_runner, [tgt], extra_args=["--buf-format-skip"])
assert not lint_results
fmt_result = run_buf(rule_runner, [tgt], extra_args=["--buf-format-skip"])
assert fmt_result.skipped is True
assert fmt_result.did_change is False
41 changes: 7 additions & 34 deletions src/python/pants/backend/go/lint/gofmt/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from __future__ import annotations

import dataclasses
import os.path
from dataclasses import dataclass

Expand All @@ -13,12 +12,11 @@
from pants.backend.go.subsystems.golang import GoRoot
from pants.backend.go.target_types import GoPackageSourcesField
from pants.core.goals.fmt import FmtRequest, FmtResult
from pants.core.goals.lint import LintResult, LintResults, LintTargetsRequest
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.fs import Digest
from pants.engine.internals.native_engine import Snapshot
from pants.engine.internals.selectors import Get
from pants.engine.process import FallibleProcessResult, Process, ProcessResult
from pants.engine.process import Process, ProcessResult
from pants.engine.rules import collect_rules, rule
from pants.engine.target import FieldSet, Target
from pants.engine.unions import UnionRule
Expand All @@ -42,33 +40,27 @@ class GofmtRequest(FmtRequest):
name = GofmtSubsystem.options_scope


@dataclass(frozen=True)
class SetupRequest:
request: GofmtRequest
check_only: bool


@dataclass(frozen=True)
class Setup:
process: Process
original_snapshot: Snapshot


@rule(level=LogLevel.DEBUG)
async def setup_gofmt(setup_request: SetupRequest, goroot: GoRoot) -> Setup:
async def setup_gofmt(request: GofmtRequest, goroot: GoRoot) -> Setup:
source_files = await Get(
SourceFiles,
SourceFilesRequest(field_set.sources for field_set in setup_request.request.field_sets),
SourceFilesRequest(field_set.sources for field_set in request.field_sets),
)
source_files_snapshot = (
source_files.snapshot
if setup_request.request.prior_formatter_result is None
else setup_request.request.prior_formatter_result
if request.prior_formatter_result is None
else request.prior_formatter_result
Comment on lines +57 to +58
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future change I think I'll refactor so this isn't needed. It's so confusing 😵

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please!! I can't remember why this was necessary. But it's so copy pasta now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR Because these rules run for fmt (in which case the formatter should run on the prior result) and lint (in which case there is no "prior formatter result")

)

argv = (
os.path.join(goroot.path, "bin/gofmt"),
"-l" if setup_request.check_only else "-w",
"-w",
*source_files_snapshot.files,
)
process = Process(
Expand All @@ -85,7 +77,7 @@ async def setup_gofmt(setup_request: SetupRequest, goroot: GoRoot) -> Setup:
async def gofmt_fmt(request: GofmtRequest, gofmt: GofmtSubsystem) -> FmtResult:
if gofmt.skip:
return FmtResult.skip(formatter_name=request.name)
setup = await Get(Setup, SetupRequest(request, check_only=False))
setup = await Get(Setup, GofmtRequest, request)
result = await Get(ProcessResult, Process, setup.process)
output_snapshot = await Get(Snapshot, Digest, result.output_digest)
return FmtResult(
Expand All @@ -97,28 +89,9 @@ async def gofmt_fmt(request: GofmtRequest, gofmt: GofmtSubsystem) -> FmtResult:
)


@rule(desc="Lint with gofmt", level=LogLevel.DEBUG)
async def gofmt_lint(request: GofmtRequest, gofmt: GofmtSubsystem) -> LintResults:
if gofmt.skip:
return LintResults([], linter_name=request.name)
setup = await Get(Setup, SetupRequest(request, check_only=True))
result = await Get(FallibleProcessResult, Process, setup.process)
lint_result = LintResult.from_fallible_process_result(result)
if lint_result.exit_code == 0 and lint_result.stdout.strip() != "":
# Note: gofmt returns success even if it would have reformatted the files.
# When this occurs, convert the LintResult into a failure.
lint_result = dataclasses.replace(
lint_result,
exit_code=1,
stdout=f"The following Go files require formatting:\n{lint_result.stdout}\n",
)
return LintResults([lint_result], linter_name=request.name)


def rules():
return [
*collect_rules(),
*golang.rules(),
UnionRule(FmtRequest, GofmtRequest),
UnionRule(LintTargetsRequest, GofmtRequest),
]
Loading