diff --git a/src/pipautocompile/io.py b/src/pipautocompile/io.py index 21aaac7..bec2223 100644 --- a/src/pipautocompile/io.py +++ b/src/pipautocompile/io.py @@ -6,7 +6,6 @@ if TYPE_CHECKING: from re import RegexFlag - from typing import Iterable from typing import Iterator from _typeshed import StrOrBytesPath @@ -32,11 +31,7 @@ def file_contains_pattern( return False -def find_spec_files( - path: StrPath = ".", - patterns: Iterable[str] = ("**/requirements.in", "**/requirements/*.in"), -) -> Iterator[Path]: - for pattern in patterns: - for spec in Path(path).glob(pattern): - if spec.is_file(): - yield spec +def find_spec_files(pattern: str, path: StrPath = ".") -> Iterator[Path]: + for spec in Path(path).glob(pattern): + if spec.is_file(): + yield spec diff --git a/src/pipautocompile/main.py b/src/pipautocompile/main.py index 8aefb6f..8c1cefe 100644 --- a/src/pipautocompile/main.py +++ b/src/pipautocompile/main.py @@ -48,11 +48,19 @@ default="--allow-unsafe --generate-hashes --no-reuse-hashes --upgrade", show_default=True, ) +@click.option( + "--spec-pattern", + help="Glob pattern to match spec files; may be used more than once.", + multiple=True, + default=("**/requirements.in", "**/requirements/*.in"), + show_default=True, +) def cli( docker_build_stage: str, docker_ssh_agent_passthrough: bool, git_recurse_submodules: bool, pip_compile_args_str: str, + spec_pattern: tuple[str], ) -> None: """Automate pip-compile for multiple environments.""" @@ -61,7 +69,8 @@ def cli( "CUSTOM_COMPILE_COMMAND": quote_args("pip-autocompile", *sys.argv[1:]) } - for spec_dir, specs in groupby(sorted(find_spec_files()), key=lambda s: s.parent): + spec_files = sorted({f for p in spec_pattern for f in find_spec_files(p)}) + for spec_dir, specs in groupby(spec_files, key=lambda s: s.parent): if not git_recurse_submodules and inside_submodule(spec_dir): continue diff --git a/tests/test_io.py b/tests/test_io.py index 54f25a8..2c8d292 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,15 +1,11 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING import pytest from pipautocompile.io import find_spec_files -if TYPE_CHECKING: - from typing import Any - @pytest.fixture def spec_tree(tmp_path: Path) -> Path: @@ -33,64 +29,43 @@ def spec_tree(tmp_path: Path) -> Path: return tmp_path -@pytest.fixture(params=(True, False), ids=("relative", "absolute")) -def spec_root(request: pytest.FixtureRequest, spec_tree: Path) -> Path: - if request.param: # type: ignore[attr-defined] - request.getfixturevalue("monkeypatch").chdir(spec_tree) - return Path() - else: - return spec_tree - - +@pytest.mark.parametrize("chdir", (True, False), ids=("relative", "absolute")) @pytest.mark.parametrize( - argnames=("kwargs", "expected_specs"), + argnames=("pattern", "expected_specs"), argvalues=( - pytest.param( - {}, - { - "nested/requirements/foo.in", - "nested/requirements/bar.in", - "nested/requirements.in", - "requirements/50_base.in", - "requirements/60_dev.in", - "requirements/61_prod.in", - "requirements.in", - }, - id="default patterns", + ( + "*.spec", + set(), ), - pytest.param( - {"patterns": {"requirements.in"}}, + ( + "requirements.in", { "requirements.in", }, - id="requirements.in", ), - pytest.param( - {"patterns": {"requirements/*.in"}}, + ( + "requirements/*.in", { "requirements/50_base.in", "requirements/60_dev.in", "requirements/61_prod.in", }, - id="requirements/*.in", ), - pytest.param( - {"patterns": {"*/requirements.in"}}, + ( + "*/requirements.in", { "nested/requirements.in", }, - id="*/requirements.in", ), - pytest.param( - {"patterns": {"**/requirements.in"}}, + ( + "**/requirements.in", { "requirements.in", "nested/requirements.in", }, - id="**/requirements.in", ), - pytest.param( - {"patterns": {"**/requirements/*.in"}}, + ( + "**/requirements/*.in", { "nested/requirements/foo.in", "nested/requirements/bar.in", @@ -98,35 +73,26 @@ def spec_root(request: pytest.FixtureRequest, spec_tree: Path) -> Path: "requirements/60_dev.in", "requirements/61_prod.in", }, - id="**/requirements/*.in", ), - pytest.param( - {"patterns": {"**/requirements/*.spec"}}, + ( + "**/requirements/*.spec", { "requirements/50_base.spec", "requirements/60_dev.spec", "requirements/61_prod.spec", }, - id="**/requirements/*.spec", - ), - pytest.param( - {"patterns": {"**/requirements/*.in", "**/requirements/*.spec"}}, - { - "nested/requirements/foo.in", - "nested/requirements/bar.in", - "requirements/50_base.in", - "requirements/60_dev.in", - "requirements/61_prod.in", - "requirements/50_base.spec", - "requirements/60_dev.spec", - "requirements/61_prod.spec", - }, - id="**/requirements/*.(in|spec)", ), ), ) def test_find_spec_files( - spec_root: Path, kwargs: dict[str, Any], expected_specs: set[str] + request: pytest.FixtureRequest, + spec_tree: Path, + chdir: bool, + pattern: str, + expected_specs: set[str], ) -> None: - rooted_specs = {spec_root / spec for spec in expected_specs} - assert set(find_spec_files(spec_root, **kwargs)) == rooted_specs + if chdir: + request.getfixturevalue("monkeypatch").chdir(spec_tree) + spec_tree = Path() + rooted_specs = {spec_tree / spec for spec in expected_specs} + assert set(find_spec_files(pattern, path=spec_tree)) == rooted_specs