Skip to content

Commit

Permalink
refactor: pull get_matching_paths out of api.run
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer committed Aug 10, 2022
1 parent f3a7b22 commit 7208b96
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

## Breaking API Changes
- The `files` argument of `api.run` is now a `Collection[pathlib.Path]` that represents an exact collection of files to be formatted, instead of a list of paths to search for files. Use `api.get_matching_paths(paths, mode)` to return the set of exact paths expected by `api.run`

## Formatting Changes + Bug Fixes
- fixed a bug that could cause lines with long jinja tags to be one character over the line length limit, and could result in unstable formatting ([#237](https://github.com/tconbeer/sqlfmt/pull/237) - thank you [@nfcampos](https://github.com/nfcampos)!)
- fixed a bug that formatted array literals like they were indexing operations ([#235](https://github.com/tconbeer/sqlfmt/pull/235) - thank you [@nfcampos](https://github.com/nfcampos)!)
Expand Down
7 changes: 3 additions & 4 deletions src/sqlfmt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def format_string(source_string: str, mode: Mode) -> str:
return str(formatted_query)


def run(files: List[Path], mode: Mode) -> Report:
def run(files: Collection[Path], mode: Mode) -> Report:
"""
Runs sqlfmt on all files in list of given paths (files), using the specified mode.
Expand All @@ -44,8 +44,7 @@ def run(files: List[Path], mode: Mode) -> Report:
else:
cache = load_cache()

matched_paths = _get_matching_paths(files, mode)
results = _format_many(matched_paths, cache, mode)
results = _format_many(files, cache, mode)

report = Report(results, mode)

Expand All @@ -56,7 +55,7 @@ def run(files: List[Path], mode: Mode) -> Report:
return report


def _get_matching_paths(paths: Iterable[Path], mode: Mode) -> Set[Path]:
def get_matching_paths(paths: Iterable[Path], mode: Mode) -> Set[Path]:
"""
Takes a list of paths (files or directories) and a mode as an input, and
yields paths to individual files that match the input paths (or are contained in
Expand Down
4 changes: 3 additions & 1 deletion src/sqlfmt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def sqlfmt(
}
config.update(non_default_options)
mode = Mode(**config) # type: ignore
report = api.run(files=files, mode=mode)

matched_files = api.get_matching_paths(files, mode=mode)
report = api.run(files=matched_files, mode=mode)
report.display_report()

if report.number_errored > 0:
Expand Down
16 changes: 9 additions & 7 deletions tests/unit_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from sqlfmt.api import (
_format_many,
_get_matching_paths,
_update_source_files,
format_string,
get_matching_paths,
run,
)
from sqlfmt.exception import SqlfmtBracketError, SqlfmtError
Expand All @@ -28,7 +28,7 @@ def unformatted_files(unformatted_dir: Path) -> List[Path]:

def test_file_discovery(all_output_modes: Mode) -> None:
p = Path("tests/data/unit_tests/test_api/test_file_discovery")
res = list(_get_matching_paths(p.iterdir(), all_output_modes))
res = list(get_matching_paths(p.iterdir(), all_output_modes))

expected = (
p / "top_level_file.sql",
Expand All @@ -52,7 +52,7 @@ def test_file_discovery(all_output_modes: Mode) -> None:
def test_file_discovery_with_excludes(exclude: List[str]) -> None:
mode = Mode(exclude=exclude)
p = Path("tests/data/unit_tests/test_api/test_file_discovery")
res = _get_matching_paths(p.iterdir(), mode)
res = get_matching_paths(p.iterdir(), mode)

expected = {
# p / "top_level_file.sql",
Expand Down Expand Up @@ -179,11 +179,11 @@ def test_update_source_files_unformatted(
def test_run_unformatted_update(
unformatted_dir: Path, default_mode: Mode, monkeypatch: pytest.MonkeyPatch
) -> None:

files = get_matching_paths([unformatted_dir], default_mode)
# confirm that we call the _update_source function
monkeypatch.delattr("sqlfmt.api._update_source_files")
with pytest.raises(NameError):
_ = run(files=[unformatted_dir], mode=default_mode)
_ = run(files=files, mode=default_mode)


def test_run_preformatted(
Expand Down Expand Up @@ -230,7 +230,8 @@ def test_run_unformatted(unformatted_files: List[Path], all_output_modes: Mode)


def test_run_error(error_dir: Path, all_output_modes: Mode) -> None:
files = [error_dir]
p = [error_dir]
files = get_matching_paths(p, all_output_modes)
report = run(files=files, mode=all_output_modes)
assert report.number_changed == 0
assert report.number_unchanged == 0
Expand Down Expand Up @@ -265,5 +266,6 @@ def test_run_single_process_does_not_use_multiprocessing(

# confirm that we do not call _multiprocess_map; if we do,
# this will raise
files = get_matching_paths([unformatted_dir], single_process_mode)
monkeypatch.delattr("sqlfmt.api._multiprocess_map")
_ = run(files=[unformatted_dir], mode=single_process_mode)
_ = run(files=files, mode=single_process_mode)

0 comments on commit 7208b96

Please sign in to comment.