Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add --diff argument #96

Merged
merged 1 commit into from
Jul 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions bump_pydantic/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import difflib
import functools
import logging
import multiprocessing
import os
import time
import traceback
from pathlib import Path
from typing import Any, Dict, List, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union

import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.helpers import calculate_module_and_package
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
from rich.logging import RichHandler
from rich.console import Console
from rich.progress import Progress
from typer import Argument, Exit, Option, Typer, echo
from typing_extensions import ParamSpec
Expand All @@ -30,10 +30,6 @@
T = TypeVar("T")


logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
logger = logging.getLogger("bump_pydantic")


def version_callback(value: bool):
if value:
echo(f"bump-pydantic version: {__version__}")
Expand All @@ -44,6 +40,7 @@ def version_callback(value: bool):
def main(
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
disable: List[Rule] = Option(default=[], help="Disable a rule."),
diff: bool = Option(False, help="Show diff instead of applying changes."),
log_file: Path = Option("log.txt", help="Log errors to this file."),
version: bool = Option(
None,
Expand All @@ -53,7 +50,8 @@ def main(
help="Show the version and exit.",
),
):
logger.info("Start bump-pydantic.")
console = Console(log_time=True)
console.log("Start bump-pydantic.")
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
os.environ["LIBCST_PARSER_TYPE"] = "native"

Expand All @@ -65,51 +63,63 @@ def main(
files_str = list(package.glob("**/*.py"))
files = [str(file.relative_to(".")) for file in files_str]

logger.info(f"Found {len(files)} files to process.")
console.log(f"Found {len(files)} files to process")

providers = {FullyQualifiedNameProvider, ScopeProvider}
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()

logger.info("Running mypy to get type information. This may take a while...")
console.log("Running mypy to get type information. This may take a while...")
classes = run_mypy_visitor(files)
scratch: dict[str, Any] = {CONTEXT_KEY: classes}
logger.info("Finished mypy.")
console.log("Finished mypy.")

start_time = time.time()

codemods = gather_codemods(disabled=disable)

log_fp = log_file.open("a+")
partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package)
partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package, diff)
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
count_errors = 0
difflines: List[List[str]] = []
with multiprocessing.Pool() as pool:
for error in pool.imap_unordered(partial_run_codemods, files):
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
progress.advance(task)

if _difflines is not None:
difflines.append(_difflines)

if error is not None:
count_errors += 1
log_fp.writelines(error)

modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]

if modified:
logger.info(f"Refactored {len(modified)} files.")
if modified and not diff:
console.log(f"Refactored {len(modified)} files.")

for _difflines in difflines:
color_diff(console, _difflines)

if count_errors > 0:
logger.info(f"Found {count_errors} errors. Please check the {log_file} file.")
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
else:
logger.info("Run successfully!")
console.log("Run successfully!")

if difflines:
raise Exit(1)


def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Union[str, None]:
) -> Tuple[Union[str, None], Union[List[str], None]]:
try:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
Expand All @@ -134,8 +144,30 @@ def run_codemods(

output_code = input_tree.code
if code != output_code:
fp.write(output_code)
fp.truncate()
return None
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
except Exception:
return f"An error happened on {filename}.\n{traceback.format_exc()}"
return f"An error happened on {filename}.\n{traceback.format_exc()}", None


def color_diff(console: Console, lines: Iterable[str]) -> None:
for line in lines:
line = line.rstrip("\n")
if line.startswith("+"):
console.print(line, style="green")
elif line.startswith("-"):
console.print(line, style="red")
elif line.startswith("^"):
console.print(line, style="blue")
else:
console.print(line, style="white")