diff --git a/python/ruff-ecosystem/pyproject.toml b/python/ruff-ecosystem/pyproject.toml index 68d8e2a953a7c..14fd402d1a977 100644 --- a/python/ruff-ecosystem/pyproject.toml +++ b/python/ruff-ecosystem/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "ruff-ecosystem" version = "0.0.0" -dependencies = ["unidiff==0.7.5"] +dependencies = ["unidiff==0.7.5", "tomli_w==1.0.0", "tomli==2.0.1"] [project.scripts] ruff-ecosystem = "ruff_ecosystem.cli:entrypoint" diff --git a/python/ruff-ecosystem/ruff_ecosystem/check.py b/python/ruff-ecosystem/ruff_ecosystem/check.py index 89d4320570858..c888a58f00684 100644 --- a/python/ruff-ecosystem/ruff_ecosystem/check.py +++ b/python/ruff-ecosystem/ruff_ecosystem/check.py @@ -28,7 +28,12 @@ ) if TYPE_CHECKING: - from ruff_ecosystem.projects import CheckOptions, ClonedRepository, Project + from ruff_ecosystem.projects import ( + CheckOptions, + ClonedRepository, + ConfigOverrides, + Project, + ) # Matches lines that are summaries rather than diagnostics @@ -477,25 +482,27 @@ async def compare_check( ruff_baseline_executable: Path, ruff_comparison_executable: Path, options: CheckOptions, + config_overrides: ConfigOverrides, cloned_repo: ClonedRepository, ) -> Comparison: - async with asyncio.TaskGroup() as tg: - baseline_task = tg.create_task( - ruff_check( - executable=ruff_baseline_executable.resolve(), - path=cloned_repo.path, - name=cloned_repo.fullname, - options=options, - ), - ) - comparison_task = tg.create_task( - ruff_check( - executable=ruff_comparison_executable.resolve(), - path=cloned_repo.path, - name=cloned_repo.fullname, - options=options, - ), - ) + with config_overrides.patch_config(cloned_repo.path, options.preview): + async with asyncio.TaskGroup() as tg: + baseline_task = tg.create_task( + ruff_check( + executable=ruff_baseline_executable.resolve(), + path=cloned_repo.path, + name=cloned_repo.fullname, + options=options, + ), + ) + comparison_task = tg.create_task( + ruff_check( + executable=ruff_comparison_executable.resolve(), + path=cloned_repo.path, + name=cloned_repo.fullname, + options=options, + ), + ) baseline_output, comparison_output = ( baseline_task.result(), diff --git a/python/ruff-ecosystem/ruff_ecosystem/defaults.py b/python/ruff-ecosystem/ruff_ecosystem/defaults.py index ad75701730d03..66d56c162964e 100644 --- a/python/ruff-ecosystem/ruff_ecosystem/defaults.py +++ b/python/ruff-ecosystem/ruff_ecosystem/defaults.py @@ -1,7 +1,13 @@ """ Default projects for ecosystem checks """ -from ruff_ecosystem.projects import CheckOptions, FormatOptions, Project, Repository +from ruff_ecosystem.projects import ( + CheckOptions, + ConfigOverrides, + FormatOptions, + Project, + Repository, +) # TODO(zanieb): Consider exporting this as JSON and loading from there instead DEFAULT_TARGETS = [ @@ -45,7 +51,14 @@ Project(repo=Repository(owner="pypa", name="build", ref="main")), Project(repo=Repository(owner="pypa", name="cibuildwheel", ref="main")), Project(repo=Repository(owner="pypa", name="pip", ref="main")), - Project(repo=Repository(owner="pypa", name="setuptools", ref="main")), + Project( + repo=Repository(owner="pypa", name="setuptools", ref="main"), + # Since `setuptools` opts into the "preserve" quote style which + # require preview mode, we must disable it during the `--no-preview` run + config_overrides=ConfigOverrides( + when_no_preview={"format.quote-style": "double"} + ), + ), Project(repo=Repository(owner="python", name="mypy", ref="master")), Project( repo=Repository( diff --git a/python/ruff-ecosystem/ruff_ecosystem/format.py b/python/ruff-ecosystem/ruff_ecosystem/format.py index b448c208da81d..fbb561d248139 100644 --- a/python/ruff-ecosystem/ruff_ecosystem/format.py +++ b/python/ruff-ecosystem/ruff_ecosystem/format.py @@ -18,7 +18,7 @@ from ruff_ecosystem.types import Comparison, Diff, Result, ToolError if TYPE_CHECKING: - from ruff_ecosystem.projects import ClonedRepository, FormatOptions + from ruff_ecosystem.projects import ClonedRepository, ConfigOverrides, FormatOptions def markdown_format_result(result: Result) -> str: @@ -137,10 +137,17 @@ async def compare_format( ruff_baseline_executable: Path, ruff_comparison_executable: Path, options: FormatOptions, + config_overrides: ConfigOverrides, cloned_repo: ClonedRepository, format_comparison: FormatComparison, ): - args = (ruff_baseline_executable, ruff_comparison_executable, options, cloned_repo) + args = ( + ruff_baseline_executable, + ruff_comparison_executable, + options, + config_overrides, + cloned_repo, + ) match format_comparison: case FormatComparison.ruff_then_ruff: coro = format_then_format(Formatter.ruff, *args) @@ -162,25 +169,27 @@ async def format_then_format( ruff_baseline_executable: Path, ruff_comparison_executable: Path, options: FormatOptions, + config_overrides: ConfigOverrides, cloned_repo: ClonedRepository, ) -> Sequence[str]: - # Run format to get the baseline - await format( - formatter=baseline_formatter, - executable=ruff_baseline_executable.resolve(), - path=cloned_repo.path, - name=cloned_repo.fullname, - options=options, - ) - # Then get the diff from stdout - diff = await format( - formatter=Formatter.ruff, - executable=ruff_comparison_executable.resolve(), - path=cloned_repo.path, - name=cloned_repo.fullname, - options=options, - diff=True, - ) + with config_overrides.patch_config(cloned_repo.path, options.preview): + # Run format to get the baseline + await format( + formatter=baseline_formatter, + executable=ruff_baseline_executable.resolve(), + path=cloned_repo.path, + name=cloned_repo.fullname, + options=options, + ) + # Then get the diff from stdout + diff = await format( + formatter=Formatter.ruff, + executable=ruff_comparison_executable.resolve(), + path=cloned_repo.path, + name=cloned_repo.fullname, + options=options, + diff=True, + ) return diff @@ -189,32 +198,39 @@ async def format_and_format( ruff_baseline_executable: Path, ruff_comparison_executable: Path, options: FormatOptions, + config_overrides: ConfigOverrides, cloned_repo: ClonedRepository, ) -> Sequence[str]: - # Run format without diff to get the baseline - await format( - formatter=baseline_formatter, - executable=ruff_baseline_executable.resolve(), - path=cloned_repo.path, - name=cloned_repo.fullname, - options=options, - ) + with config_overrides.patch_config(cloned_repo.path, options.preview): + # Run format without diff to get the baseline + await format( + formatter=baseline_formatter, + executable=ruff_baseline_executable.resolve(), + path=cloned_repo.path, + name=cloned_repo.fullname, + options=options, + ) + # Commit the changes commit = await cloned_repo.commit( message=f"Formatted with baseline {ruff_baseline_executable}" ) # Then reset await cloned_repo.reset() - # Then run format again - await format( - formatter=Formatter.ruff, - executable=ruff_comparison_executable.resolve(), - path=cloned_repo.path, - name=cloned_repo.fullname, - options=options, - ) + + with config_overrides.patch_config(cloned_repo.path, options.preview): + # Then run format again + await format( + formatter=Formatter.ruff, + executable=ruff_comparison_executable.resolve(), + path=cloned_repo.path, + name=cloned_repo.fullname, + options=options, + ) + # Then get the diff from the commit diff = await cloned_repo.diff(commit) + return diff diff --git a/python/ruff-ecosystem/ruff_ecosystem/main.py b/python/ruff-ecosystem/ruff_ecosystem/main.py index ca0e3037c78b5..4a3fc95e1482a 100644 --- a/python/ruff-ecosystem/ruff_ecosystem/main.py +++ b/python/ruff-ecosystem/ruff_ecosystem/main.py @@ -113,11 +113,17 @@ async def clone_and_compare( match command: case RuffCommand.check: - compare, options, kwargs = (compare_check, target.check_options, {}) + compare, options, overrides, kwargs = ( + compare_check, + target.check_options, + target.config_overrides, + {}, + ) case RuffCommand.format: - compare, options, kwargs = ( + compare, options, overrides, kwargs = ( compare_format, target.format_options, + target.config_overrides, {"format_comparison": format_comparison}, ) case _: @@ -131,6 +137,7 @@ async def clone_and_compare( baseline_executable, comparison_executable, options, + overrides, cloned_repo, **kwargs, ) diff --git a/python/ruff-ecosystem/ruff_ecosystem/projects.py b/python/ruff-ecosystem/ruff_ecosystem/projects.py index fa5f8224a8958..78b836297f005 100644 --- a/python/ruff-ecosystem/ruff_ecosystem/projects.py +++ b/python/ruff-ecosystem/ruff_ecosystem/projects.py @@ -5,13 +5,18 @@ from __future__ import annotations import abc +import contextlib import dataclasses from asyncio import create_subprocess_exec from dataclasses import dataclass, field from enum import Enum +from functools import cache from pathlib import Path from subprocess import DEVNULL, PIPE -from typing import Self +from typing import Any, Self + +import tomli +import tomli_w from ruff_ecosystem import logger from ruff_ecosystem.types import Serializable @@ -26,14 +31,115 @@ class Project(Serializable): repo: Repository check_options: CheckOptions = field(default_factory=lambda: CheckOptions()) format_options: FormatOptions = field(default_factory=lambda: FormatOptions()) + config_overrides: ConfigOverrides = field(default_factory=lambda: ConfigOverrides()) def with_preview_enabled(self: Self) -> Self: return type(self)( repo=self.repo, check_options=self.check_options.with_options(preview=True), format_options=self.format_options.with_options(preview=True), + config_overrides=self.config_overrides, ) + def __post_init__(self): + # Convert bare dictionaries for `config_overrides` into the correct type + if isinstance(self.config_overrides, dict): + # Bypass the frozen attribute + object.__setattr__( + self, "config_overrides", ConfigOverrides(always=self.config_overrides) + ) + + +@dataclass(frozen=True) +class ConfigOverrides(Serializable): + """ + A collection of key, value pairs to override in the Ruff configuration file. + + The key describes a member to override in the toml file; '.' may be used to indicate a + nested value e.g. `format.quote-style`. + + If a Ruff configuration file does not exist and overrides are provided, it will be createad. + """ + + always: dict[str, Any] = field(default_factory=dict) + when_preview: dict[str, Any] = field(default_factory=dict) + when_no_preview: dict[str, Any] = field(default_factory=dict) + + def __hash__(self) -> int: + # Avoid computing this hash repeatedly since this object is intended + # to be immutable and serializing to toml is not necessarily cheap + @cache + def as_string(): + return tomli_w.dumps( + { + "always": self.always, + "when_preview": self.when_preview, + "when_no_preview": self.when_no_preview, + } + ) + + return hash(as_string()) + + @contextlib.contextmanager + def patch_config( + self, + dirpath: Path, + preview: bool, + ) -> None: + """ + Temporarily patch the Ruff configuration file in the given directory. + """ + ruff_toml = dirpath / "ruff.toml" + pyproject_toml = dirpath / "pyproject.toml" + + # Prefer `ruff.toml` over `pyproject.toml` + if ruff_toml.exists(): + path = ruff_toml + base = [] + else: + path = pyproject_toml + base = ["tool", "ruff"] + + overrides = { + **self.always, + **(self.when_preview if preview else self.when_no_preview), + } + + if not overrides: + yield + return + + # Read the existing content if the file is present + if path.exists(): + contents = path.read_text() + toml = tomli.loads(contents) + else: + contents = None + toml = {} + + # Update the TOML, using `.` to descend into nested keys + for key, value in overrides.items(): + logger.debug(f"Setting {key}={value!r} in {path}") + + target = toml + names = base + key.split(".") + for name in names[:-1]: + if name not in target: + target[name] = {} + target = target[name] + target[names[-1]] = value + + tomli_w.dump(toml, path.open("wb")) + + try: + yield + finally: + # Restore the contents or delete the file + if contents is None: + path.unlink() + else: + path.write_text(contents) + class RuffCommand(Enum): check = "check" @@ -42,6 +148,8 @@ class RuffCommand(Enum): @dataclass(frozen=True) class CommandOptions(Serializable, abc.ABC): + preview: bool = False + def with_options(self: Self, **kwargs) -> Self: """ Return a copy of self with the given options set. @@ -62,7 +170,6 @@ class CheckOptions(CommandOptions): select: str = "" ignore: str = "" exclude: str = "" - preview: bool = False # Generating fixes is slow and verbose show_fixes: bool = False