From 8d46cb241d837f457d062e95cd60d5629e7ba030 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bern=C3=A1t=20G=C3=A1bor?= Date: Sat, 19 Oct 2024 22:00:14 -0700 Subject: [PATCH] Fix CLI args type conversion (#3) --- src/toml_fmt_common/__init__.py | 28 ++++++++++++++++------- tests/test_app.py | 40 +++++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/src/toml_fmt_common/__init__.py b/src/toml_fmt_common/__init__.py index 1c6193d..0346aea 100644 --- a/src/toml_fmt_common/__init__.py +++ b/src/toml_fmt_common/__init__.py @@ -6,7 +6,13 @@ import os import sys from abc import ABC, abstractmethod -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError, Namespace +from argparse import ( + ArgumentDefaultsHelpFormatter, + ArgumentParser, + ArgumentTypeError, + Namespace, + _ArgumentGroup, # noqa: PLC2701 +) from collections import deque from copy import deepcopy from dataclasses import dataclass @@ -16,13 +22,15 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Callable, Iterable, Mapping, Sequence if sys.version_info >= (3, 11): # pragma: >=3.11 cover import tomllib else: # pragma: <3.11 cover import tomli as tomllib +ArgumentGroup = _ArgumentGroup + class FmtNamespace(Namespace): """Options for pyproject-fmt tool.""" @@ -63,7 +71,7 @@ def filename(self) -> str: raise NotImplementedError @abstractmethod - def add_format_flags(self, parser: ArgumentParser) -> None: + def add_format_flags(self, parser: ArgumentGroup) -> None: """ Add any additional flags to configure the formatter. @@ -126,7 +134,7 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]: :param args: CLI arguments :return: the parsed options """ - parser = _build_cli(info) + parser, type_conversion = _build_cli(info) parser.parse_args(namespace=info.opt, args=args) res = [] for pyproject_toml in info.opt.inputs: @@ -144,7 +152,9 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]: if isinstance(config, dict): for key in set(vars(override_opt).keys()) - {"inputs", "stdout", "check", "no_print_diff"}: if key in config: - setattr(override_opt, key, config[key]) + raw = config[key] + converted = type_conversion[key](raw) if key in type_conversion else raw + setattr(override_opt, key, converted) res.append( _Config( toml_filename=pyproject_toml, @@ -159,7 +169,7 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]: return res -def _build_cli(of: TOMLFormatter[T]) -> ArgumentParser: +def _build_cli(of: TOMLFormatter[T]) -> tuple[ArgumentParser, Mapping[str, Callable[[Any], Any]]]: parser = ArgumentParser( formatter_class=ArgumentDefaultsHelpFormatter, prog=of.prog, @@ -200,7 +210,8 @@ def _build_cli(of: TOMLFormatter[T]) -> ArgumentParser: help="number of spaces to use for indentation", metavar="count", ) - of.add_format_flags(format_group) # type: ignore[arg-type] + of.add_format_flags(format_group) + type_conversion = {a.dest: a.type for a in format_group._actions if a.type and a.dest} # noqa: SLF001 msg = "pyproject.toml file(s) to format, use '-' to read from stdin" parser.add_argument( "inputs", @@ -208,7 +219,7 @@ def _build_cli(of: TOMLFormatter[T]) -> ArgumentParser: type=partial(_toml_path_creator, of.filename), help=msg, ) - return parser + return parser, type_conversion def _toml_path_creator(filename: str, argument: str) -> Path | None: @@ -289,6 +300,7 @@ def _color_diff(diff: Iterable[str]) -> Iterable[str]: __all__ = [ + "ArgumentGroup", "FmtNamespace", "TOMLFormatter", "run", diff --git a/tests/test_app.py b/tests/test_app.py index 553958d..ba31ccd 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -6,10 +6,9 @@ import pytest -from toml_fmt_common import GREEN, RED, RESET, FmtNamespace, TOMLFormatter, run +from toml_fmt_common import GREEN, RED, RESET, ArgumentGroup, FmtNamespace, TOMLFormatter, run if TYPE_CHECKING: - from argparse import ArgumentParser from pathlib import Path from pytest_mock import MockerFixture @@ -17,6 +16,7 @@ class DumpNamespace(FmtNamespace): extra: str + tuple_magic: tuple[str, ...] class Dumb(TOMLFormatter[DumpNamespace]): @@ -35,11 +35,18 @@ def filename(self) -> str: def override_cli_from_section(self) -> tuple[str, ...]: return "start", "sub" - def add_format_flags(self, parser: ArgumentParser) -> None: # noqa: PLR6301 + def add_format_flags(self, parser: ArgumentGroup) -> None: # noqa: PLR6301 parser.add_argument("extra", help="this is something extra") + parser.add_argument("-t", "--tuple-magic", default=(), type=lambda t: tuple(t.split("."))) def format(self, text: str, opt: DumpNamespace) -> str: # noqa: PLR6301 - return text if os.environ.get("NO_FMT") else f"{text}\nextras = {opt.extra!r}" + if os.environ.get("NO_FMT"): + return text + return "\n".join([ + text, + f"extras = {opt.extra!r}", + *([f"magic = {','.join(opt.tuple_magic)!r}"] if opt.tuple_magic else []), + ]) def test_dumb_help(capsys: pytest.CaptureFixture[str]) -> None: @@ -77,6 +84,31 @@ def test_dumb_format_with_override(capsys: pytest.CaptureFixture[str], tmp_path: ] +def test_dumb_format_with_override_custom_type(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None: + dumb = tmp_path / "dumb.toml" + dumb.write_text("[start.sub]\ntuple_magic = '1.2.3'") + + exit_code = run(Dumb(), ["E", str(dumb)]) + assert exit_code == 1 + + assert dumb.read_text() == "[start.sub]\ntuple_magic = '1.2.3'\nextras = 'E'\nmagic = '1,2,3'" + + out, err = capsys.readouterr() + assert not err + assert out.splitlines() == [ + f"{RED}--- {dumb}", + f"{RESET}", + f"{GREEN}+++ {dumb}", + f"{RESET}", + "@@ -1,2 +1,4 @@", + "", + " [start.sub]", + " tuple_magic = '1.2.3'", + f"{GREEN}+extras = 'E'{RESET}", + f"{GREEN}+magic = '1,2,3'{RESET}", + ] + + def test_dumb_format_no_print_diff(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None: dumb = tmp_path / "dumb.toml" dumb.write_text("[start.sub]\nextra = 'B'")