Skip to content

Commit

Permalink
feat #231: add progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer committed Aug 10, 2022
1 parent 7208b96 commit 89ef86e
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 203 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ All notable changes to this project will be documented in this file.
## Breaking API Changes
- The `files` argument of `api.run` is now a `Collection[pathlib.Path]` that represents an exact collection of files to be formatted, instead of a list of paths to search for files. Use `api.get_matching_paths(paths, mode)` to return the set of exact paths expected by `api.run`

## Features
- sqlfmt will now display a progress bar for long runs ([#231](https://github.com/tconbeer/sqlfmt/pull/231)). You can disable this with the `--no-progressbar` option
- `api.run` now accepts an optional `callback` argument, which must be a `Callable[[Awaitable[SqlFormatResult]], None]`. Unless the `--single-process` option is used, the callback is executed after each file is formatted.

## Formatting Changes + Bug Fixes
- fixed a bug that could cause lines with long jinja tags to be one character over the line length limit, and could result in unstable formatting ([#237](https://github.com/tconbeer/sqlfmt/pull/237) - thank you [@nfcampos](https://github.com/nfcampos)!)
- fixed a bug that formatted array literals like they were indexing operations ([#235](https://github.com/tconbeer/sqlfmt/pull/235) - thank you [@nfcampos](https://github.com/nfcampos)!)
Expand Down
222 changes: 47 additions & 175 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jinjafmt = ["black"]
[tool.poetry.dependencies]
python = "^3.7"
click = "^8.0"
tqdm = "^4.0"
platformdirs = "^2.4.0"
importlib_metadata = { version = "*", python = "<3.8" }
gitpython = { version = "^3.1.24", optional = true }
Expand Down
67 changes: 60 additions & 7 deletions src/sqlfmt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
from functools import partial
from glob import glob
from pathlib import Path
from typing import Callable, Collection, Iterable, List, Set, TypeVar
from typing import (
Awaitable,
Callable,
Collection,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
)

from tqdm import tqdm

from sqlfmt.cache import Cache, check_cache, clear_cache, load_cache, write_cache
from sqlfmt.exception import SqlfmtError
Expand All @@ -28,7 +40,11 @@ def format_string(source_string: str, mode: Mode) -> str:
return str(formatted_query)


def run(files: Collection[Path], mode: Mode) -> Report:
def run(
files: Collection[Path],
mode: Mode,
callback: Optional[Callable[[Awaitable[SqlFormatResult]], None]] = None,
) -> Report:
"""
Runs sqlfmt on all files in list of given paths (files), using the specified mode.
Expand All @@ -44,7 +60,7 @@ def run(files: Collection[Path], mode: Mode) -> Report:
else:
cache = load_cache()

results = _format_many(files, cache, mode)
results = _format_many(files, cache, mode, callback=callback)

report = Report(results, mode)

Expand Down Expand Up @@ -74,6 +90,33 @@ def get_matching_paths(paths: Iterable[Path], mode: Mode) -> Set[Path]:
return include_set - exclude_set


def initialize_progress_bar(
total: int, mode: Mode, force_progress_bar: bool = False
) -> Tuple[tqdm, Callable[[Awaitable[SqlFormatResult]], None]]:
"""
Return a callable that can be used with api.run to display a progress bar
that updates after each file is formatted.
Pass force_progress_bar to enable the progress bar, even on non-TTY
terminals (this is handy for testing the progress bar).
"""
if mode.no_progressbar:
disable = True
elif force_progress_bar:
disable = False
else:
# will be disabled on non-TTY envs, enabled otherwise
disable = None
progress_bar = tqdm(
iterable=None, total=total, leave=False, disable=disable, delay=0.5, unit="file"
)

def progress_callback(_: Awaitable[SqlFormatResult]) -> None:
progress_bar.update()

return progress_bar, progress_callback


def _get_included_paths(paths: Iterable[Path], mode: Mode) -> Set[Path]:
"""
Takes a list of paths (files or directories) and a mode as an input, and
Expand All @@ -93,7 +136,10 @@ def _get_included_paths(paths: Iterable[Path], mode: Mode) -> Set[Path]:


def _format_many(
paths: Collection[Path], cache: Cache, mode: Mode
paths: Collection[Path],
cache: Cache,
mode: Mode,
callback: Optional[Callable[[Awaitable[SqlFormatResult]], None]] = None,
) -> List[SqlFormatResult]:
"""
Runs sqlfmt on all files in a collection of paths, using the specified mode.
Expand All @@ -106,15 +152,19 @@ def _format_many(
format_func = partial(_format_one, cache=cache, mode=mode)
if len(paths) > 1 and not mode.single_process:
results: List[SqlFormatResult] = asyncio.get_event_loop().run_until_complete(
_multiprocess_map(format_func, paths)
_multiprocess_map(format_func, paths, callback=callback)
)
else:
results = list(map(format_func, paths))

return results


async def _multiprocess_map(func: Callable[[T], R], seq: Iterable[T]) -> List[R]:
async def _multiprocess_map(
func: Callable[[T], R],
seq: Iterable[T],
callback: Optional[Callable[[Awaitable[R]], None]] = None,
) -> List[R]:
"""
Using multiple processes, creates a Future for each application of func to
an item in seq, then gathers all Futures and returns the result.
Expand All @@ -126,7 +176,10 @@ async def _multiprocess_map(func: Callable[[T], R], seq: Iterable[T]) -> List[R]
with concurrent.futures.ProcessPoolExecutor() as pool:
tasks = []
for item in seq:
tasks.append(loop.run_in_executor(pool, func, item))
future = loop.run_in_executor(pool, func, item)
if callback:
future.add_done_callback(callback)
tasks.append(future)
results: List[R] = await asyncio.gather(*tasks)
return results

Expand Down
13 changes: 12 additions & 1 deletion src/sqlfmt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@
is_flag=True,
help=("Prints much less information to stderr"),
)
@click.option(
"--no-progressbar",
is_flag=True,
help=("Never prints a progressbar to stderr"),
)
@click.option(
"--no-color",
is_flag=True,
Expand Down Expand Up @@ -140,7 +145,13 @@ def sqlfmt(
mode = Mode(**config) # type: ignore

matched_files = api.get_matching_paths(files, mode=mode)
report = api.run(files=matched_files, mode=mode)
progress_bar, progress_callback = api.initialize_progress_bar(
total=len(matched_files), mode=mode
)

report = api.run(files=matched_files, mode=mode, callback=progress_callback)

progress_bar.close()
report.display_report()

if report.number_errored > 0:
Expand Down
1 change: 1 addition & 0 deletions src/sqlfmt/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Mode:
reset_cache: bool = False
verbose: bool = False
quiet: bool = False
no_progressbar: bool = False
no_color: bool = False
force_color: bool = False

Expand Down
13 changes: 11 additions & 2 deletions src/sqlfmt_primer/primer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from git import Repo
from platformdirs import user_cache_dir

from sqlfmt.api import run
from sqlfmt.api import get_matching_paths, initialize_progress_bar, run
from sqlfmt.cache import get_cache_file
from sqlfmt.mode import Mode

Expand Down Expand Up @@ -148,9 +148,18 @@ def sqlfmt_primer(
target_dir = get_project_source_tree(project, reset_cache, working_dir)

click.echo(f"Running sqlfmt on {project.name}", err=True)

files = get_matching_paths([target_dir], mode=mode)
progress_bar, progress_callback = initialize_progress_bar(
total=len(files), mode=mode
)

start_time = timeit.default_timer()
report = run(files=[target_dir], mode=mode)
report = run(files=files, mode=mode, callback=progress_callback)
end_time = timeit.default_timer()

progress_bar.close()

number_formatted = (
report.number_changed + report.number_unchanged + report.number_errored
)
Expand Down
24 changes: 11 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
from itertools import product
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -76,26 +77,23 @@ def reset_cache_mode(unset_no_color_env: None) -> Mode:
return Mode(reset_cache=True)


@pytest.fixture
def no_progressbar_mode(unset_no_color_env: None) -> Mode:
return Mode(no_progressbar=True)


@pytest.fixture
def single_process_mode(unset_no_color_env: None) -> Mode:
return Mode(single_process=True)


@pytest.fixture(
params=[
# (check, diff, single_process)
(False, False, False),
(False, True, False),
(True, False, False),
(True, True, False),
(False, False, True),
(True, False, True),
(False, True, True),
]
)
@pytest.fixture(params=product([True, False], repeat=4))
def all_output_modes(request: Any, unset_no_color_env: None) -> Mode:
return Mode(
check=request.param[0], diff=request.param[1], single_process=request.param[2]
check=request.param[0],
diff=request.param[1],
single_process=request.param[2],
no_progressbar=request.param[3],
)


Expand Down
61 changes: 56 additions & 5 deletions tests/unit_tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import io
import os
from pathlib import Path
from typing import List, Type
from typing import Any, List, Type

import pytest
from tqdm import tqdm

from sqlfmt.api import (
_format_many,
_update_source_files,
format_string,
get_matching_paths,
initialize_progress_bar,
run,
)
from sqlfmt.exception import SqlfmtBracketError, SqlfmtError
Expand All @@ -26,9 +28,9 @@ def unformatted_files(unformatted_dir: Path) -> List[Path]:
return list(unformatted_dir.iterdir())


def test_file_discovery(all_output_modes: Mode) -> None:
def test_file_discovery(default_mode: Mode) -> None:
p = Path("tests/data/unit_tests/test_api/test_file_discovery")
res = list(get_matching_paths(p.iterdir(), all_output_modes))
res = list(get_matching_paths(p.iterdir(), default_mode))

expected = (
p / "top_level_file.sql",
Expand Down Expand Up @@ -100,9 +102,9 @@ def test_format_many_preformatted(


def test_format_many_unformatted(
unformatted_files: List[Path], all_output_modes: Mode
unformatted_files: List[Path], default_mode: Mode
) -> None:
results = list(_format_many(unformatted_files, {}, all_output_modes))
results = list(_format_many(unformatted_files, {}, default_mode))

assert len(results) == len(
unformatted_files
Expand Down Expand Up @@ -269,3 +271,52 @@ def test_run_single_process_does_not_use_multiprocessing(
files = get_matching_paths([unformatted_dir], single_process_mode)
monkeypatch.delattr("sqlfmt.api._multiprocess_map")
_ = run(files=files, mode=single_process_mode)


def test_run_with_callback(
capsys: Any, unformatted_dir: Path, default_mode: Mode
) -> None:
def print_dot(_: Any) -> None:
print(".", end="", flush=True)

files = get_matching_paths([unformatted_dir], default_mode)
expected_dots = len(files)

_ = run(files=files, mode=default_mode, callback=print_dot)
captured = capsys.readouterr()

assert "." * expected_dots in captured.out


def test_initialize_progress_bar(default_mode: Mode) -> None:
total = 100
progress_bar, progress_callback = initialize_progress_bar(
total=total, mode=default_mode, force_progress_bar=True
)
assert progress_bar
assert isinstance(progress_bar, tqdm)
assert progress_bar.format_dict.get("n") == 0
assert progress_bar.format_dict.get("total") == total
assert progress_bar.format_dict.get("elapsed") > 0

assert progress_callback
progress_callback("foo") # type: ignore
assert progress_bar.format_dict.get("n") == 1


def test_initialize_disabled_progress_bar(no_progressbar_mode: Mode) -> None:
total = 100
progress_bar, progress_callback = initialize_progress_bar(
total=total, mode=no_progressbar_mode, force_progress_bar=True
)
# a disabled progress bar's elapsed timer will not count up,
# and calling update() will not increment n
assert progress_bar
assert isinstance(progress_bar, tqdm)
assert progress_bar.format_dict.get("n") == 0
assert progress_bar.format_dict.get("total") == total
assert progress_bar.format_dict.get("elapsed") == 0

assert progress_callback
progress_callback("foo") # type: ignore
assert progress_bar.format_dict.get("n") == 0
8 changes: 8 additions & 0 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,11 @@ def test_preformatted_clickhouse(
args = f"{preformatted_dir.as_posix()} --check --dialect clickhouse"
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 0


def test_preformatted_no_progressbar(
sqlfmt_runner: CliRunner, preformatted_dir: Path
) -> None:
args = f"{preformatted_dir.as_posix()} --check --no-progressbar"
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 0

0 comments on commit 89ef86e

Please sign in to comment.