diff --git a/CHANGELOG.md b/CHANGELOG.md index fa7d2045..627cf336 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +## Breaking API Changes +- The `files` argument of `api.run` is now a `Collection[pathlib.Path]` that represents an exact collection of files to be formatted, instead of a list of paths to search for files. Use `api.get_matching_paths(paths, mode)` to return the set of exact paths expected by `api.run` + ## Formatting Changes + Bug Fixes - fixed a bug that could cause lines with long jinja tags to be one character over the line length limit, and could result in unstable formatting ([#237](https://github.com/tconbeer/sqlfmt/pull/237) - thank you [@nfcampos](https://github.com/nfcampos)!) - fixed a bug that formatted array literals like they were indexing operations ([#235](https://github.com/tconbeer/sqlfmt/pull/235) - thank you [@nfcampos](https://github.com/nfcampos)!) diff --git a/src/sqlfmt/api.py b/src/sqlfmt/api.py index b6d455fb..0434c4dd 100755 --- a/src/sqlfmt/api.py +++ b/src/sqlfmt/api.py @@ -28,7 +28,7 @@ def format_string(source_string: str, mode: Mode) -> str: return str(formatted_query) -def run(files: List[Path], mode: Mode) -> Report: +def run(files: Collection[Path], mode: Mode) -> Report: """ Runs sqlfmt on all files in list of given paths (files), using the specified mode. @@ -44,8 +44,7 @@ def run(files: List[Path], mode: Mode) -> Report: else: cache = load_cache() - matched_paths = _get_matching_paths(files, mode) - results = _format_many(matched_paths, cache, mode) + results = _format_many(files, cache, mode) report = Report(results, mode) @@ -56,7 +55,7 @@ def run(files: List[Path], mode: Mode) -> Report: return report -def _get_matching_paths(paths: Iterable[Path], mode: Mode) -> Set[Path]: +def get_matching_paths(paths: Iterable[Path], mode: Mode) -> Set[Path]: """ Takes a list of paths (files or directories) and a mode as an input, and yields paths to individual files that match the input paths (or are contained in diff --git a/src/sqlfmt/cli.py b/src/sqlfmt/cli.py index beb13fdf..7d08961a 100755 --- a/src/sqlfmt/cli.py +++ b/src/sqlfmt/cli.py @@ -138,7 +138,9 @@ def sqlfmt( } config.update(non_default_options) mode = Mode(**config) # type: ignore - report = api.run(files=files, mode=mode) + + matched_files = api.get_matching_paths(files, mode=mode) + report = api.run(files=matched_files, mode=mode) report.display_report() if report.number_errored > 0: diff --git a/tests/unit_tests/test_api.py b/tests/unit_tests/test_api.py index 3fd15f4e..6f9b5795 100644 --- a/tests/unit_tests/test_api.py +++ b/tests/unit_tests/test_api.py @@ -7,9 +7,9 @@ from sqlfmt.api import ( _format_many, - _get_matching_paths, _update_source_files, format_string, + get_matching_paths, run, ) from sqlfmt.exception import SqlfmtBracketError, SqlfmtError @@ -28,7 +28,7 @@ def unformatted_files(unformatted_dir: Path) -> List[Path]: def test_file_discovery(all_output_modes: Mode) -> None: p = Path("tests/data/unit_tests/test_api/test_file_discovery") - res = list(_get_matching_paths(p.iterdir(), all_output_modes)) + res = list(get_matching_paths(p.iterdir(), all_output_modes)) expected = ( p / "top_level_file.sql", @@ -52,7 +52,7 @@ def test_file_discovery(all_output_modes: Mode) -> None: def test_file_discovery_with_excludes(exclude: List[str]) -> None: mode = Mode(exclude=exclude) p = Path("tests/data/unit_tests/test_api/test_file_discovery") - res = _get_matching_paths(p.iterdir(), mode) + res = get_matching_paths(p.iterdir(), mode) expected = { # p / "top_level_file.sql", @@ -179,11 +179,11 @@ def test_update_source_files_unformatted( def test_run_unformatted_update( unformatted_dir: Path, default_mode: Mode, monkeypatch: pytest.MonkeyPatch ) -> None: - + files = get_matching_paths([unformatted_dir], default_mode) # confirm that we call the _update_source function monkeypatch.delattr("sqlfmt.api._update_source_files") with pytest.raises(NameError): - _ = run(files=[unformatted_dir], mode=default_mode) + _ = run(files=files, mode=default_mode) def test_run_preformatted( @@ -230,7 +230,8 @@ def test_run_unformatted(unformatted_files: List[Path], all_output_modes: Mode) def test_run_error(error_dir: Path, all_output_modes: Mode) -> None: - files = [error_dir] + p = [error_dir] + files = get_matching_paths(p, all_output_modes) report = run(files=files, mode=all_output_modes) assert report.number_changed == 0 assert report.number_unchanged == 0 @@ -265,5 +266,6 @@ def test_run_single_process_does_not_use_multiprocessing( # confirm that we do not call _multiprocess_map; if we do, # this will raise + files = get_matching_paths([unformatted_dir], single_process_mode) monkeypatch.delattr("sqlfmt.api._multiprocess_map") - _ = run(files=[unformatted_dir], mode=single_process_mode) + _ = run(files=files, mode=single_process_mode)