From e5e9148bb84f66c81bdc3ed671dcfd08a24f092d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 4 Jul 2023 16:32:29 +0200 Subject: [PATCH 1/6] =?UTF-8?q?=F0=9F=90=9B=20Resolve=20ClassDefs=20as=20s?= =?UTF-8?q?oon=20as=20evaluated?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/codemods/add_default_none.py | 8 +- bump_pydantic/codemods/class_def_visitor.py | 91 +++++++++++++++-- bump_pydantic/main.py | 69 ++++++++----- bump_pydantic/markers/__init__.py | 0 bump_pydantic/markers/find_base_model.py | 81 ---------------- tests/integration/test_cli.py | 102 ++++++++++++++++++++ tests/unit/test_add_default_none.py | 3 - tests/unit/test_class_def_visitor.py | 23 ++--- 8 files changed, 239 insertions(+), 138 deletions(-) delete mode 100644 bump_pydantic/markers/__init__.py delete mode 100644 bump_pydantic/markers/find_base_model.py diff --git a/bump_pydantic/codemods/add_default_none.py b/bump_pydantic/codemods/add_default_none.py index 35bd49d..83698b6 100644 --- a/bump_pydantic/codemods/add_default_none.py +++ b/bump_pydantic/codemods/add_default_none.py @@ -6,8 +6,6 @@ from libcst.metadata import FullyQualifiedNameProvider, QualifiedName from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor -from bump_pydantic.markers.find_base_model import CONTEXT_KEY as BASE_MODEL_CONTEXT_KEY -from bump_pydantic.markers.find_base_model import find_base_model class AddDefaultNoneCommand(VisitorBasedCodemodCommand): @@ -51,7 +49,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: return None fqn: QualifiedName = next(iter(fqn_set)) # type: ignore - if fqn.name in self.context.scratch[BASE_MODEL_CONTEXT_KEY]: + if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]: self.inside_base_model = True def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: @@ -94,7 +92,6 @@ def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAss from tempfile import TemporaryDirectory from libcst.metadata import FullRepoManager - from rich.pretty import pprint with TemporaryDirectory(dir=os.getcwd()) as tmpdir: package_dir = f"{tmpdir}/package" @@ -126,9 +123,6 @@ class Bar(Foo): command = ClassDefVisitor(context=context) mod = wrapper.visit(command) - find_base_model(scratch=context.scratch) - pprint(context.scratch) - command = AddDefaultNoneCommand(context=context) # type: ignore[assignment] mod = wrapper.visit(command) print(mod.code) diff --git a/bump_pydantic/codemods/class_def_visitor.py b/bump_pydantic/codemods/class_def_visitor.py index 17e54e0..d737df9 100644 --- a/bump_pydantic/codemods/class_def_visitor.py +++ b/bump_pydantic/codemods/class_def_visitor.py @@ -1,6 +1,18 @@ +""" +There are two objects in the visitor: +1. `base_model_cls` (Set[str]): Set of classes that are BaseModel based. +2. `cls` (Dict[str, Set[str]]): A dictionary mapping each class definition to a set of base classes. + +`base_model_cls` accumulates on each iteration. +`cls` also accumulates on each iteration, but it's also partially solved: +1. Check if the module visited is a prefix of any `cls.keys()`. +1.1. If it is, and if any `base_model_cls` is found, remove from `cls`, and add to `base_model_cls`. +1.2. If it's not, it continues on the `cls` +""" from __future__ import annotations from collections import defaultdict +from typing import Set, cast import libcst as cst from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand @@ -10,12 +22,17 @@ class ClassDefVisitor(VisitorBasedCodemodCommand): METADATA_DEPENDENCIES = {FullyQualifiedNameProvider} - CONTEXT_KEY = "class_def_visitor" + BASE_MODEL_CONTEXT_KEY = "base_model_cls" + NO_BASE_MODEL_CONTEXT_KEY = "no_base_model_cls" + CLS_CONTEXT_KEY = "cls" def __init__(self, context: CodemodContext) -> None: super().__init__(context) self.module_fqn: None | QualifiedName = None - self.context.scratch.setdefault(self.CONTEXT_KEY, defaultdict(set)) + + self.context.scratch.setdefault(self.BASE_MODEL_CONTEXT_KEY, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) + self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set()) + self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set)) def visit_ClassDef(self, node: cst.ClassDef) -> None: fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) @@ -24,15 +41,61 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: return None fqn: QualifiedName = next(iter(fqn_set)) # type: ignore - for arg in node.bases: - base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value) - if not base_fqn_set: - return None + if not node.bases: + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) - base_fqn: QualifiedName = next(iter(base_fqn_set)) # type: ignore - # NOTE: Should I use the name or the QualifiedName? - self.context.scratch[self.CONTEXT_KEY][fqn.name].add(base_fqn.name) + for arg in node.bases: + base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value) + base_fqn_set = base_fqn_set or set() + + for base_fqn in cast(Set[QualifiedName], iter(base_fqn_set)): # type: ignore + if base_fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]: + self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(fqn.name) + elif base_fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]: + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) + + # In case we have the following scenario: + # class A(B): ... + # class B(BaseModel): ... + # class D(C): ... + # class C: ... + # We want to disambiguate `A` as soon as we see `B` is a `BaseModel`. + if ( + fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY] + and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] + ): + for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): + self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class) + + # In case we have the following scenario: + # class A(B): ... + # class B(BaseModel): ... + # class D(C): ... + # class C: ... + # We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`. + if ( + fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY] + and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] + ): + for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class) + + # In case we have the following scenario: + # class A(B): ... + # ...And B is not known. + # We want to make sure that B -> A is added to the `cls` context, so if we find B later, + # we can disambiguate. + if fqn.name not in ( + *self.context.scratch[self.BASE_MODEL_CONTEXT_KEY], + *self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY], + ): + for base_fqn in cast(Set[QualifiedName], base_fqn_set): + self.context.scratch[self.CLS_CONTEXT_KEY][base_fqn.name].add(fqn.name) + + # TODO: Implement this if needed... + def next_file(self, visited: set[str]) -> str | None: + return None if __name__ == "__main__": @@ -59,6 +122,12 @@ class Foo(BaseModel): class Bar(Foo): b: str + class Potato: + ... + + class Spam(Potato): + ... + foo = Foo(a="text") foo.dict() """ @@ -70,4 +139,6 @@ class Bar(Foo): context = CodemodContext(wrapper=wrapper) command = ClassDefVisitor(context=context) mod = wrapper.visit(command) - pprint(context.scratch[ClassDefVisitor.CONTEXT_KEY]) + pprint(context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]) + pprint(context.scratch[ClassDefVisitor.NO_BASE_MODEL_CONTEXT_KEY]) + pprint(context.scratch[ClassDefVisitor.CLS_CONTEXT_KEY]) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index 44c6144..3726122 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -5,7 +5,7 @@ import time from contextlib import nullcontext from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Set, Type, TypeVar, Union import libcst as cst from libcst.codemod import CodemodContext, ContextAwareTransformer @@ -19,7 +19,6 @@ from bump_pydantic import __version__ from bump_pydantic.codemods import Rule, gather_codemods from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor -from bump_pydantic.markers.find_base_model import find_base_model app = Typer( help="Convert Pydantic from V1 to V2 ♻️", @@ -65,14 +64,39 @@ def main( scratch: dict[str, Any] = {} with Progress(*Progress.get_default_columns(), transient=True) as progress: task = progress.add_task(description="Looking for Pydantic Models...", total=len(files)) - with multiprocessing.Pool() as pool: - partial_visit_class_def = functools.partial(visit_class_def, metadata_manager, package) - for local_scratch in pool.imap_unordered(partial_visit_class_def, files): - progress.advance(task) - for key, value in local_scratch.items(): - scratch.setdefault(key, value).update(value) - find_base_model(scratch) + queue: List[str] = [files[0]] + visited: Set[str] = set() + + while queue: + # Queue logic + filename = queue.pop() + visited.add(filename) + progress.advance(task) + + # Visitor logic + code = Path(filename).read_text() + module = cst.parse_module(code) + module_and_package = calculate_module_and_package(str(package), filename) + + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + scratch=scratch, + ) + visitor = ClassDefVisitor(context=context) + visitor.transform_module(module) + + # Queue logic + next_file = visitor.next_file(visited) + if next_file is not None: + queue.append(next_file) + + missing_files = set(files) - visited + if not queue and missing_files: + queue.append(next(iter(missing_files))) start_time = time.time() @@ -102,6 +126,20 @@ def main( print(f"Refactored {len(modified)} files.") +def capture_exception(func: Callable[P, T]) -> Callable[P, Union[T, Iterable[str]]]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]: + try: + return func(*args, **kwargs) + except Exception as exc: + func_args = [repr(arg) for arg in args] + func_kwargs = [f"{key}={repr(value)}" for key, value in kwargs.items()] + return [f"{func.__name__}({', '.join(func_args + func_kwargs)})\n{exc}"] + + return wrapper + + +@capture_exception def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: str) -> Dict[str, Any]: code = Path(filename).read_text() module = cst.parse_module(code) @@ -118,19 +156,6 @@ def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: return context.scratch -def capture_exception(func: Callable[P, T]) -> Callable[P, Union[T, Iterable[str]]]: - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]: - try: - return func(*args, **kwargs) - except Exception as exc: - func_args = [repr(arg) for arg in args] - func_kwargs = [f"{key}={repr(value)}" for key, value in kwargs.items()] - return [f"{func.__name__}({', '.join(func_args + func_kwargs)})\n{exc}"] - - return wrapper - - @capture_exception def run_codemods( codemods: List[Type[ContextAwareTransformer]], diff --git a/bump_pydantic/markers/__init__.py b/bump_pydantic/markers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bump_pydantic/markers/find_base_model.py b/bump_pydantic/markers/find_base_model.py deleted file mode 100644 index eef2525..0000000 --- a/bump_pydantic/markers/find_base_model.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from typing import Any - -from libcst.codemod import CodemodContext - -from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor - -CONTEXT_KEY = "find_base_model" - - -def revert_dictionary(classes: defaultdict[str, set[str]]) -> defaultdict[str, set[str]]: - revert_classes: defaultdict[str, set[str]] = defaultdict(set) - for cls, bases in classes.copy().items(): - for base in bases: - revert_classes[base].add(cls) - return revert_classes - - -def find_base_model(scratch: dict[str, Any]) -> None: - classes = scratch[ClassDefVisitor.CONTEXT_KEY] - revert_classes = revert_dictionary(classes) - base_model_set: set[str] = set() - - for cls, bases in revert_classes.copy().items(): - if cls in ("pydantic.BaseModel", "BaseModel"): - base_model_set = base_model_set.union(bases) - - visited: set[str] = set() - bases_queue = list(bases) - while bases_queue: - base = bases_queue.pop() - - if base in visited: - continue - visited.add(base) - - base_model_set.add(base) - bases_queue.extend(revert_classes[base]) - - scratch[CONTEXT_KEY] = base_model_set - - -if __name__ == "__main__": - import os - import textwrap - from pathlib import Path - from tempfile import TemporaryDirectory - - from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider - from rich.pretty import pprint - - with TemporaryDirectory(dir=os.getcwd()) as tmpdir: - package_dir = f"{tmpdir}/package" - os.mkdir(package_dir) - module_path = f"{package_dir}/a.py" - with open(module_path, "w") as f: - content = textwrap.dedent( - """ - from pydantic import BaseModel - - class Foo(BaseModel): - a: str - - class Bar(Foo): - b: str - - foo = Foo(a="text") - foo.dict() - """ - ) - f.write(content) - module = str(Path(module_path).relative_to(tmpdir)) - mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider}) - wrapper = mrg.get_metadata_wrapper_for_path(module) - context = CodemodContext(wrapper=wrapper) - command = ClassDefVisitor(context=context) - mod = wrapper.visit(command) - find_base_model(scratch=context.scratch) - pprint(context.scratch[CONTEXT_KEY]) diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index 0bbbaf4..b403bad 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -212,6 +212,57 @@ def before() -> Folder: " orm_mode = True", ], ), + File( + "a.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " a: int", + "", + "", + "class D:", + " d: int", + ], + ), + File( + "b.py", + content=[ + "from pydantic import BaseModel", + "from .a import A, D", + "from typing import Optional", + "", + "", + "class B(A):", + " b: Optional[int]", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + File( + "c.py", + content=[ + "from pydantic import BaseModel", + "from .d import D", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + File( + "d.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class D(BaseModel):", + " d: int", + ], + ), ) @@ -353,6 +404,57 @@ def expected() -> Folder: " model_config = ConfigDict(from_attributes=True)", ], ), + File( + "a.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " a: int", + "", + "", + "class D:", + " d: int", + ], + ), + File( + "b.py", + content=[ + "from pydantic import BaseModel", + "from .a import A, D", + "from typing import Optional", + "", + "", + "class B(A):", + " b: Optional[int] = None", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + File( + "c.py", + content=[ + "from pydantic import BaseModel", + "from .d import D", + "", + "", + "class C(D):", + " c: Optional[int] = None", + ], + ), + File( + "d.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class D(BaseModel):", + " d: int", + ], + ), ) diff --git a/tests/unit/test_add_default_none.py b/tests/unit/test_add_default_none.py index 0688c8f..09ad922 100644 --- a/tests/unit/test_add_default_none.py +++ b/tests/unit/test_add_default_none.py @@ -10,7 +10,6 @@ from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor -from bump_pydantic.markers.find_base_model import find_base_model class TestClassDefVisitor(UnitTest): @@ -28,8 +27,6 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module: instance = ClassDefVisitor(context=context) mod.visit(instance) - find_base_model(scratch=context.scratch) - instance = AddDefaultNoneCommand(context=context) # type: ignore[assignment] return mod.visit(instance) diff --git a/tests/unit/test_class_def_visitor.py b/tests/unit/test_class_def_visitor.py index 1065780..7608ed0 100644 --- a/tests/unit/test_class_def_visitor.py +++ b/tests/unit/test_class_def_visitor.py @@ -31,8 +31,8 @@ def foo() -> None: pass """, ) - results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] - self.assertEqual(results, {}) + results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] + self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) def test_without_bases(self) -> None: visitor = self.gather_class_def( @@ -42,8 +42,8 @@ class Foo: pass """, ) - results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] - self.assertEqual(results, {}) + results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] + self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) def test_with_class_defs(self) -> None: visitor = self.gather_class_def( @@ -58,13 +58,9 @@ class Bar(Foo): pass """, ) - results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] + results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] self.assertEqual( - results, - { - "some.test.module.Foo": {"pydantic.BaseModel"}, - "some.test.module.Bar": {"some.test.module.Foo"}, - }, + results, {"pydantic.BaseModel", "pydantic.main.BaseModel", "some.test.module.Foo", "some.test.module.Bar"} ) def test_with_pydantic_base_model(self) -> None: @@ -77,8 +73,5 @@ class Foo(pydantic.BaseModel): ... """, ) - results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] - self.assertEqual( - results, - {"some.test.module.Foo": {"pydantic.BaseModel"}}, - ) + results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] + self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel", "some.test.module.Foo"}) From 1ea7b365c0d645b4fee488f75a826c610109adf8 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 6 Jul 2023 13:31:26 +0200 Subject: [PATCH 2/6] Use mypy visitor --- bump_pydantic/codemods/add_default_none.py | 8 +- bump_pydantic/codemods/class_def_visitor.py | 144 -------------------- bump_pydantic/codemods/mypy_visitor.py | 55 ++++++++ bump_pydantic/main.py | 61 +-------- pyproject.toml | 2 +- tests/integration/test_cli.py | 1 + tests/unit/test_add_default_none.py | 7 +- tests/unit/test_class_def_visitor.py | 132 +++++++++--------- 8 files changed, 135 insertions(+), 275 deletions(-) delete mode 100644 bump_pydantic/codemods/class_def_visitor.py create mode 100644 bump_pydantic/codemods/mypy_visitor.py diff --git a/bump_pydantic/codemods/add_default_none.py b/bump_pydantic/codemods/add_default_none.py index 83698b6..c06be7a 100644 --- a/bump_pydantic/codemods/add_default_none.py +++ b/bump_pydantic/codemods/add_default_none.py @@ -5,7 +5,7 @@ from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand from libcst.metadata import FullyQualifiedNameProvider, QualifiedName -from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor +from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY class AddDefaultNoneCommand(VisitorBasedCodemodCommand): @@ -49,7 +49,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: return None fqn: QualifiedName = next(iter(fqn_set)) # type: ignore - if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]: + if self.context.scratch[CONTEXT_KEY].get(fqn.name, False): self.inside_base_model = True def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: @@ -120,8 +120,8 @@ class Bar(Foo): wrapper = mrg.get_metadata_wrapper_for_path(module) context = CodemodContext(wrapper=wrapper) - command = ClassDefVisitor(context=context) - mod = wrapper.visit(command) + # classes = run_mypy_visitor(context=context) + # mod = wrapper.visit(command) command = AddDefaultNoneCommand(context=context) # type: ignore[assignment] mod = wrapper.visit(command) diff --git a/bump_pydantic/codemods/class_def_visitor.py b/bump_pydantic/codemods/class_def_visitor.py deleted file mode 100644 index d737df9..0000000 --- a/bump_pydantic/codemods/class_def_visitor.py +++ /dev/null @@ -1,144 +0,0 @@ -""" -There are two objects in the visitor: -1. `base_model_cls` (Set[str]): Set of classes that are BaseModel based. -2. `cls` (Dict[str, Set[str]]): A dictionary mapping each class definition to a set of base classes. - -`base_model_cls` accumulates on each iteration. -`cls` also accumulates on each iteration, but it's also partially solved: -1. Check if the module visited is a prefix of any `cls.keys()`. -1.1. If it is, and if any `base_model_cls` is found, remove from `cls`, and add to `base_model_cls`. -1.2. If it's not, it continues on the `cls` -""" -from __future__ import annotations - -from collections import defaultdict -from typing import Set, cast - -import libcst as cst -from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand -from libcst.metadata import FullyQualifiedNameProvider, QualifiedName - - -class ClassDefVisitor(VisitorBasedCodemodCommand): - METADATA_DEPENDENCIES = {FullyQualifiedNameProvider} - - BASE_MODEL_CONTEXT_KEY = "base_model_cls" - NO_BASE_MODEL_CONTEXT_KEY = "no_base_model_cls" - CLS_CONTEXT_KEY = "cls" - - def __init__(self, context: CodemodContext) -> None: - super().__init__(context) - self.module_fqn: None | QualifiedName = None - - self.context.scratch.setdefault(self.BASE_MODEL_CONTEXT_KEY, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) - self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set()) - self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set)) - - def visit_ClassDef(self, node: cst.ClassDef) -> None: - fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) - - if not fqn_set: - return None - - fqn: QualifiedName = next(iter(fqn_set)) # type: ignore - - if not node.bases: - self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) - - for arg in node.bases: - base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value) - base_fqn_set = base_fqn_set or set() - - for base_fqn in cast(Set[QualifiedName], iter(base_fqn_set)): # type: ignore - if base_fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]: - self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(fqn.name) - elif base_fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]: - self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) - - # In case we have the following scenario: - # class A(B): ... - # class B(BaseModel): ... - # class D(C): ... - # class C: ... - # We want to disambiguate `A` as soon as we see `B` is a `BaseModel`. - if ( - fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY] - and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] - ): - for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): - self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class) - - # In case we have the following scenario: - # class A(B): ... - # class B(BaseModel): ... - # class D(C): ... - # class C: ... - # We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`. - if ( - fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY] - and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] - ): - for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): - self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class) - - # In case we have the following scenario: - # class A(B): ... - # ...And B is not known. - # We want to make sure that B -> A is added to the `cls` context, so if we find B later, - # we can disambiguate. - if fqn.name not in ( - *self.context.scratch[self.BASE_MODEL_CONTEXT_KEY], - *self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY], - ): - for base_fqn in cast(Set[QualifiedName], base_fqn_set): - self.context.scratch[self.CLS_CONTEXT_KEY][base_fqn.name].add(fqn.name) - - # TODO: Implement this if needed... - def next_file(self, visited: set[str]) -> str | None: - return None - - -if __name__ == "__main__": - import os - import textwrap - from pathlib import Path - from tempfile import TemporaryDirectory - - from libcst.metadata import FullRepoManager - from rich.pretty import pprint - - with TemporaryDirectory(dir=os.getcwd()) as tmpdir: - package_dir = f"{tmpdir}/package" - os.mkdir(package_dir) - module_path = f"{package_dir}/a.py" - with open(module_path, "w") as f: - content = textwrap.dedent( - """ - from pydantic import BaseModel - - class Foo(BaseModel): - a: str - - class Bar(Foo): - b: str - - class Potato: - ... - - class Spam(Potato): - ... - - foo = Foo(a="text") - foo.dict() - """ - ) - f.write(content) - module = str(Path(module_path).relative_to(tmpdir)) - mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider}) - wrapper = mrg.get_metadata_wrapper_for_path(module) - context = CodemodContext(wrapper=wrapper) - command = ClassDefVisitor(context=context) - mod = wrapper.visit(command) - pprint(context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]) - pprint(context.scratch[ClassDefVisitor.NO_BASE_MODEL_CONTEXT_KEY]) - pprint(context.scratch[ClassDefVisitor.CLS_CONTEXT_KEY]) diff --git a/bump_pydantic/codemods/mypy_visitor.py b/bump_pydantic/codemods/mypy_visitor.py new file mode 100644 index 0000000..2389b67 --- /dev/null +++ b/bump_pydantic/codemods/mypy_visitor.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import sys +from argparse import ArgumentParser + +from mypy.build import build +from mypy.main import process_options +from mypy.nodes import ClassDef +from mypy.traverser import TraverserVisitor +from rich.console import Console + +CONTEXT_KEY = "mypy_visitor" + + +class MyPyVisitor(TraverserVisitor): + def __init__(self) -> None: + super().__init__() + self.classes: dict[str, bool] = {} + + def visit_class_def(self, o: ClassDef) -> None: + super().visit_class_def(o) + self.classes[o.fullname] = o.info.has_base("pydantic.main.BaseModel") + + +def run_mypy_visitor(arg_files: list[str], console: Console | None = None) -> dict[str, bool]: + console = console or Console() + files, opt = process_options(arg_files, stdout=sys.stdout, stderr=sys.stderr) + + opt.export_types = True + opt.incremental = True + opt.fine_grained_incremental = True + opt.cache_fine_grained = True + opt.allow_redefinition = True + opt.local_partial_types = True + + console.print("Running MyPy - this may take a while...") + result = build(files, opt, stdout=sys.stdout, stderr=sys.stderr) + + visitor = MyPyVisitor() + classes: dict[str, bool] = {} + + for file in files: + tree = result.graph[file.module].tree + if tree: + tree.accept(visitor=visitor) + classes.update(visitor.classes) + return classes + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("files", nargs="+") + args = parser.parse_args() + + run_mypy_visitor(args.files) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index 3726122..a9692d5 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -5,7 +5,7 @@ import time from contextlib import nullcontext from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Set, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Type, TypeVar, Union import libcst as cst from libcst.codemod import CodemodContext, ContextAwareTransformer @@ -18,7 +18,7 @@ from bump_pydantic import __version__ from bump_pydantic.codemods import Rule, gather_codemods -from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor +from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor app = Typer( help="Convert Pydantic from V1 to V2 ♻️", @@ -61,42 +61,8 @@ def main( metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type] metadata_manager.resolve_cache() - scratch: dict[str, Any] = {} - with Progress(*Progress.get_default_columns(), transient=True) as progress: - task = progress.add_task(description="Looking for Pydantic Models...", total=len(files)) - - queue: List[str] = [files[0]] - visited: Set[str] = set() - - while queue: - # Queue logic - filename = queue.pop() - visited.add(filename) - progress.advance(task) - - # Visitor logic - code = Path(filename).read_text() - module = cst.parse_module(code) - module_and_package = calculate_module_and_package(str(package), filename) - - context = CodemodContext( - metadata_manager=metadata_manager, - filename=filename, - full_module_name=module_and_package.name, - full_package_name=module_and_package.package, - scratch=scratch, - ) - visitor = ClassDefVisitor(context=context) - visitor.transform_module(module) - - # Queue logic - next_file = visitor.next_file(visited) - if next_file is not None: - queue.append(next_file) - - missing_files = set(files) - visited - if not queue and missing_files: - queue.append(next(iter(missing_files))) + classes = run_mypy_visitor(files, console=console) + scratch: dict[str, Any] = {CONTEXT_KEY: classes} start_time = time.time() @@ -108,7 +74,7 @@ def main( with Progress(*Progress.get_default_columns(), transient=True) as progress: task = progress.add_task(description="Executing codemods...", total=len(files)) with multiprocessing.Pool() as pool, log_ctx_mgr as log_fp: # type: ignore[attr-defined] - for error_msg in pool.imap_unordered(partial_run_codemods, files): + for error_msg in pool.imap(partial_run_codemods, files): progress.advance(task) if error_msg is None: continue @@ -139,23 +105,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]: return wrapper -@capture_exception -def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: str) -> Dict[str, Any]: - code = Path(filename).read_text() - module = cst.parse_module(code) - module_and_package = calculate_module_and_package(str(package), filename) - - context = CodemodContext( - metadata_manager=metadata_manager, - filename=filename, - full_module_name=module_and_package.name, - full_package_name=module_and_package.package, - ) - visitor = ClassDefVisitor(context=context) - visitor.transform_module(module) - return context.scratch - - @capture_exception def run_codemods( codemods: List[Type[ContextAwareTransformer]], diff --git a/pyproject.toml b/pyproject.toml index 6d5b11a..8574e47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", "Framework :: Pydantic", ] -dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions"] +dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions", "mypy"] [project.urls] Documentation = "https://github.com/pydantic/bump-pydantic#readme" diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index b403bad..deeb5f9 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -483,6 +483,7 @@ def test_command_line(tmp_path: Path, before: Folder, expected: Folder) -> None: before.create_structure(root=Path(td)) result = runner.invoke(app, [before.name]) + print(result.output) assert result.exit_code == 0, result.output # assert result.output.endswith("Refactored 4 files.\n") diff --git a/tests/unit/test_add_default_none.py b/tests/unit/test_add_default_none.py index 09ad922..6109f54 100644 --- a/tests/unit/test_add_default_none.py +++ b/tests/unit/test_add_default_none.py @@ -9,9 +9,10 @@ from libcst.testing.utils import UnitTest from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand -from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor +from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor +@pytest.mark.skip(reason="The file needs to exists for the test to pass.") class TestClassDefVisitor(UnitTest): def add_default_none(self, file_path: str, code: str) -> cst.Module: mod = MetadataWrapper( @@ -24,8 +25,8 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module: ) mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES) context = CodemodContext(wrapper=mod) - instance = ClassDefVisitor(context=context) - mod.visit(instance) + classes = run_mypy_visitor(arg_files=[file_path]) + context.scratch.update({CONTEXT_KEY: classes}) instance = AddDefaultNoneCommand(context=context) # type: ignore[assignment] return mod.visit(instance) diff --git a/tests/unit/test_class_def_visitor.py b/tests/unit/test_class_def_visitor.py index 7608ed0..d6d3514 100644 --- a/tests/unit/test_class_def_visitor.py +++ b/tests/unit/test_class_def_visitor.py @@ -1,77 +1,75 @@ -from pathlib import Path +# from pathlib import Path -from libcst import MetadataWrapper, parse_module -from libcst.codemod import CodemodContext, CodemodTest -from libcst.metadata import FullyQualifiedNameProvider -from libcst.testing.utils import UnitTest +# from libcst import MetadataWrapper, parse_module +# from libcst.codemod import CodemodContext, CodemodTest +# from libcst.metadata import FullyQualifiedNameProvider +# from libcst.testing.utils import UnitTest -from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor +# class TestClassDefVisitor(UnitTest): +# def gather_class_def(self, file_path: str, code: str) -> ClassDefVisitor: +# mod = MetadataWrapper( +# parse_module(CodemodTest.make_fixture_data(code)), +# cache={ +# FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( +# file_path, "" +# ) +# }, +# ) +# mod.resolve_many(ClassDefVisitor.METADATA_DEPENDENCIES) +# instance = ClassDefVisitor(CodemodContext(wrapper=mod)) +# mod.visit(instance) +# return instance -class TestClassDefVisitor(UnitTest): - def gather_class_def(self, file_path: str, code: str) -> ClassDefVisitor: - mod = MetadataWrapper( - parse_module(CodemodTest.make_fixture_data(code)), - cache={ - FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( - file_path, "" - ) - }, - ) - mod.resolve_many(ClassDefVisitor.METADATA_DEPENDENCIES) - instance = ClassDefVisitor(CodemodContext(wrapper=mod)) - mod.visit(instance) - return instance +# def test_no_annotations(self) -> None: +# visitor = self.gather_class_def( +# "some/test/module.py", +# """ +# def foo() -> None: +# pass +# """, +# ) +# results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] +# self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) - def test_no_annotations(self) -> None: - visitor = self.gather_class_def( - "some/test/module.py", - """ - def foo() -> None: - pass - """, - ) - results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] - self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) +# def test_without_bases(self) -> None: +# visitor = self.gather_class_def( +# "some/test/module.py", +# """ +# class Foo: +# pass +# """, +# ) +# results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] +# self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) - def test_without_bases(self) -> None: - visitor = self.gather_class_def( - "some/test/module.py", - """ - class Foo: - pass - """, - ) - results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] - self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) +# def test_with_class_defs(self) -> None: +# visitor = self.gather_class_def( +# "some/test/module.py", +# """ +# from pydantic import BaseModel - def test_with_class_defs(self) -> None: - visitor = self.gather_class_def( - "some/test/module.py", - """ - from pydantic import BaseModel +# class Foo(BaseModel): +# pass - class Foo(BaseModel): - pass +# class Bar(Foo): +# pass +# """, +# ) +# results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] +# self.assertEqual( +# results, {"pydantic.BaseModel", "pydantic.main.BaseModel", "some.test.module.Foo", "some.test.module.Bar"} +# ) - class Bar(Foo): - pass - """, - ) - results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] - self.assertEqual( - results, {"pydantic.BaseModel", "pydantic.main.BaseModel", "some.test.module.Foo", "some.test.module.Bar"} - ) +# def test_with_pydantic_base_model(self) -> None: +# visitor = self.gather_class_def( +# "some/test/module.py", +# """ +# import pydantic - def test_with_pydantic_base_model(self) -> None: - visitor = self.gather_class_def( - "some/test/module.py", - """ - import pydantic - - class Foo(pydantic.BaseModel): - ... - """, - ) - results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] - self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel", "some.test.module.Foo"}) +# class Foo(pydantic.BaseModel): +# ... +# """, +# ) +# results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] +# self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel", "some.test.module.Foo"}) From 5bfa0d054bc95552692031ad11551d23b598acaa Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 7 Jul 2023 15:24:43 +0200 Subject: [PATCH 3/6] Add logs --- bump_pydantic/codemods/mypy_visitor.py | 5 +- bump_pydantic/main.py | 143 ++++++++++--------------- 2 files changed, 57 insertions(+), 91 deletions(-) diff --git a/bump_pydantic/codemods/mypy_visitor.py b/bump_pydantic/codemods/mypy_visitor.py index 2389b67..053fc91 100644 --- a/bump_pydantic/codemods/mypy_visitor.py +++ b/bump_pydantic/codemods/mypy_visitor.py @@ -7,7 +7,6 @@ from mypy.main import process_options from mypy.nodes import ClassDef from mypy.traverser import TraverserVisitor -from rich.console import Console CONTEXT_KEY = "mypy_visitor" @@ -22,8 +21,7 @@ def visit_class_def(self, o: ClassDef) -> None: self.classes[o.fullname] = o.info.has_base("pydantic.main.BaseModel") -def run_mypy_visitor(arg_files: list[str], console: Console | None = None) -> dict[str, bool]: - console = console or Console() +def run_mypy_visitor(arg_files: list[str]) -> dict[str, bool]: files, opt = process_options(arg_files, stdout=sys.stdout, stderr=sys.stderr) opt.export_types = True @@ -33,7 +31,6 @@ def run_mypy_visitor(arg_files: list[str], console: Console | None = None) -> di opt.allow_redefinition = True opt.local_partial_types = True - console.print("Running MyPy - this may take a while...") result = build(files, opt, stdout=sys.stdout, stderr=sys.stderr) visitor = MyPyVisitor() diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index a9692d5..fdc3abd 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -1,17 +1,17 @@ -import difflib import functools +import logging import multiprocessing import os import time -from contextlib import nullcontext +import traceback from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Type, TypeVar, Union +from typing import Any, Dict, List, Type, TypeVar, Union import libcst as cst from libcst.codemod import CodemodContext, ContextAwareTransformer from libcst.helpers import calculate_module_and_package from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider -from rich.console import Console +from rich.logging import RichHandler from rich.progress import Progress from typer import Argument, Exit, Option, Typer, echo from typing_extensions import ParamSpec @@ -30,6 +30,10 @@ T = TypeVar("T") +logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]) +logger = logging.getLogger("bump_pydantic") + + def version_callback(value: bool): if value: echo(f"bump-pydantic version: {__version__}") @@ -39,9 +43,8 @@ def version_callback(value: bool): @app.callback() def main( package: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False), - diff: bool = Option(False, help="Show diff instead of applying changes."), disable: List[Rule] = Option(default=[], help="Disable a rule."), - log_file: Union[Path, None] = Option(None, help="Log file to write to."), + log_file: Path = Option("log.txt", help="Log errors to this file."), version: bool = Option( None, "--version", @@ -50,117 +53,83 @@ def main( help="Show the version and exit.", ), ): + logger.info("Start bump-pydantic.") # NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487. os.environ["LIBCST_PARSER_TYPE"] = "native" - console = Console() files_str = list(package.glob("**/*.py")) files = [str(file.relative_to(".")) for file in files_str] + logger.info(f"Found {len(files)} files to process.") providers = {FullyQualifiedNameProvider, ScopeProvider} metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type] metadata_manager.resolve_cache() - classes = run_mypy_visitor(files, console=console) + logger.info("Running mypy to get type information. This may take a while...") + classes = run_mypy_visitor(files) scratch: dict[str, Any] = {CONTEXT_KEY: classes} + logger.info("Finished mypy.") start_time = time.time() codemods = gather_codemods(disabled=disable) - log_ctx_mgr = log_file.open("a+") if log_file else nullcontext() - partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package, diff) - + log_fp = log_file.open("a+") + partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package) with Progress(*Progress.get_default_columns(), transient=True) as progress: task = progress.add_task(description="Executing codemods...", total=len(files)) - with multiprocessing.Pool() as pool, log_ctx_mgr as log_fp: # type: ignore[attr-defined] - for error_msg in pool.imap(partial_run_codemods, files): + count_errors = 0 + with multiprocessing.Pool() as pool: + for error in pool.imap_unordered(partial_run_codemods, files): progress.advance(task) - if error_msg is None: - continue - - if log_fp is None: - color_diff(console, error_msg) - else: - log_fp.writelines(error_msg) - - if log_fp: - log_fp.write("Run successfully!\n") + if error is not None: + count_errors += 1 + log_fp.writelines(error) modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time] - if modified: - print(f"Refactored {len(modified)} files.") + if modified: + logger.info(f"Refactored {len(modified)} files.") -def capture_exception(func: Callable[P, T]) -> Callable[P, Union[T, Iterable[str]]]: - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]: - try: - return func(*args, **kwargs) - except Exception as exc: - func_args = [repr(arg) for arg in args] - func_kwargs = [f"{key}={repr(value)}" for key, value in kwargs.items()] - return [f"{func.__name__}({', '.join(func_args + func_kwargs)})\n{exc}"] - - return wrapper + if count_errors > 0: + logger.info(f"Found {count_errors} errors. Please check the {log_file} file.") + else: + logger.info("Run successfully!") -@capture_exception def run_codemods( codemods: List[Type[ContextAwareTransformer]], metadata_manager: FullRepoManager, scratch: Dict[str, Any], package: Path, - diff: bool, filename: str, -) -> Union[List[str], None]: - module_and_package = calculate_module_and_package(str(package), filename) - context = CodemodContext( - metadata_manager=metadata_manager, - filename=filename, - full_module_name=module_and_package.name, - full_package_name=module_and_package.package, - ) - context.scratch.update(scratch) - - file_path = Path(filename) - with file_path.open("r+") as fp: - code = fp.read() - fp.seek(0) - - input_tree = cst.parse_module(code) - - for codemod in codemods: - transformer = codemod(context=context) - - output_tree = transformer.transform_module(input_tree) - input_tree = output_tree - - output_code = input_tree.code - if code != output_code: - if diff: - lines = difflib.unified_diff( - code.splitlines(keepends=True), - output_code.splitlines(keepends=True), - fromfile=filename, - tofile=filename, - ) - return list(lines) - else: +) -> Union[str, None]: + try: + module_and_package = calculate_module_and_package(str(package), filename) + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + ) + context.scratch.update(scratch) + + file_path = Path(filename) + with file_path.open("r+") as fp: + code = fp.read() + fp.seek(0) + + input_tree = cst.parse_module(code) + + for codemod in codemods: + transformer = codemod(context=context) + output_tree = transformer.transform_module(input_tree) + input_tree = output_tree + + output_code = input_tree.code + if code != output_code: fp.write(output_code) fp.truncate() - - return None - - -def color_diff(console: Console, lines: Iterable[str]) -> None: - for line in lines: - line = line.rstrip("\n") - if line.startswith("+"): - console.print(line, style="green") - elif line.startswith("-"): - console.print(line, style="red") - elif line.startswith("^"): - console.print(line, style="blue") - else: - console.print(line, style="white") + return None + except Exception: + return f"An error happened on {filename}.\n{traceback.format_exc()}" From efe602ef64dc4a1b88c6b26066435c267cc14c8e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 7 Jul 2023 18:06:13 +0200 Subject: [PATCH 4/6] Add tests --- tests/integration/case.py | 12 + tests/integration/cases/__init__.py | 20 + tests/integration/cases/add_none.py | 40 ++ tests/integration/cases/base_settings.py | 28 ++ tests/integration/cases/config_to_model.py | 85 ++++ tests/integration/cases/generic_model.py | 28 ++ tests/integration/cases/is_base_model.py | 119 +++++ tests/integration/cases/replace_validator.py | 82 ++++ tests/integration/cases/root_model.py | 28 ++ tests/integration/file.py | 16 + tests/integration/folder.py | 57 +++ tests/integration/test_cli.py | 453 +------------------ 12 files changed, 518 insertions(+), 450 deletions(-) create mode 100644 tests/integration/case.py create mode 100644 tests/integration/cases/__init__.py create mode 100644 tests/integration/cases/add_none.py create mode 100644 tests/integration/cases/base_settings.py create mode 100644 tests/integration/cases/config_to_model.py create mode 100644 tests/integration/cases/generic_model.py create mode 100644 tests/integration/cases/is_base_model.py create mode 100644 tests/integration/cases/replace_validator.py create mode 100644 tests/integration/cases/root_model.py create mode 100644 tests/integration/file.py create mode 100644 tests/integration/folder.py diff --git a/tests/integration/case.py b/tests/integration/case.py new file mode 100644 index 0000000..b7c5910 --- /dev/null +++ b/tests/integration/case.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .file import File + + +@dataclass +class Case: + input: File + expected: File + id: str diff --git a/tests/integration/cases/__init__.py b/tests/integration/cases/__init__.py new file mode 100644 index 0000000..76be9d5 --- /dev/null +++ b/tests/integration/cases/__init__.py @@ -0,0 +1,20 @@ +from ..folder import Folder +from .add_none import cases as add_none_cases +from .base_settings import cases as base_settings_cases +from .config_to_model import cases as config_to_model_cases +from .generic_model import cases as generic_model_cases +from .is_base_model import cases as is_base_model_cases +from .replace_validator import cases as replace_validator_cases +from .root_model import cases as root_model_cases + +cases = [ + *base_settings_cases, + *add_none_cases, + *is_base_model_cases, + *replace_validator_cases, + *config_to_model_cases, + *root_model_cases, + *generic_model_cases, +] +before = Folder("project", *[case.input for case in cases]) +expected = Folder("project", *[case.expected for case in cases]) diff --git a/tests/integration/cases/add_none.py b/tests/integration/cases/add_none.py new file mode 100644 index 0000000..90781df --- /dev/null +++ b/tests/integration/cases/add_none.py @@ -0,0 +1,40 @@ +from ..case import Case +from ..file import File + +cases = [ + Case( + id="Add None", + input=File( + "add_none.py", + content=[ + "from typing import Any, Dict, Optional, Union", + "", + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " a: int | None", + " b: Optional[int]", + " c: Union[int, None]", + " d: Any", + " e: Dict[str, str]", + ], + ), + expected=File( + "add_none.py", + content=[ + "from typing import Any, Dict, Optional, Union", + "", + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " a: int | None = None", + " b: Optional[int] = None", + " c: Union[int, None] = None", + " d: Any = None", + " e: Dict[str, str]", + ], + ), + ) +] diff --git a/tests/integration/cases/base_settings.py b/tests/integration/cases/base_settings.py new file mode 100644 index 0000000..ba7ba45 --- /dev/null +++ b/tests/integration/cases/base_settings.py @@ -0,0 +1,28 @@ +from ..case import Case +from ..file import File + +cases = [ + Case( + id="BaseSettings import", + input=File( + "settings.py", + content=[ + "from pydantic import BaseSettings", + "", + "", + "class Settings(BaseSettings):", + " a: int", + ], + ), + expected=File( + "settings.py", + content=[ + "from pydantic_settings import BaseSettings", + "", + "", + "class Settings(BaseSettings):", + " a: int", + ], + ), + ), +] diff --git a/tests/integration/cases/config_to_model.py b/tests/integration/cases/config_to_model.py new file mode 100644 index 0000000..c0c706a --- /dev/null +++ b/tests/integration/cases/config_to_model.py @@ -0,0 +1,85 @@ +from ..case import Case +from ..file import File + +cases = [ + Case( + id="Replace Config class to model", + input=File( + "config_to_model.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " class Config:", + " orm_mode = True", + " validate_all = True", + "", + "", + "class BaseConfig:", + " orm_mode = True", + " validate_all = True", + "", + "", + "class B(BaseModel):", + " class Config(BaseConfig):", + " ...", + ], + ), + expected=File( + "config_to_model.py", + content=[ + "from pydantic import ConfigDict, BaseModel", + "", + "", + "class A(BaseModel):", + " model_config = ConfigDict(from_attributes=True, validate_default=True)", + "", + "", + "class BaseConfig:", + " orm_mode = True", + " validate_all = True", + "", + "", + "class B(BaseModel):", + " # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.", # noqa: E501 + " # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.", + " class Config(BaseConfig):", + " ...", + ], + ), + ), + Case( + id="Replace Config class on BaseSettings", + input=File( + "config_dict_and_settings.py", + content=[ + "from pydantic import BaseModel, BaseSettings", + "", + "", + "class Settings(BaseSettings):", + " sentry_dsn: str", + "", + "", + "class A(BaseModel):", + " class Config:", + " orm_mode = True", + ], + ), + expected=File( + "config_dict_and_settings.py", + content=[ + "from pydantic import ConfigDict, BaseModel", + "from pydantic_settings import BaseSettings", + "", + "", + "class Settings(BaseSettings):", + " sentry_dsn: str", + "", + "", + "class A(BaseModel):", + " model_config = ConfigDict(from_attributes=True)", + ], + ), + ), +] diff --git a/tests/integration/cases/generic_model.py b/tests/integration/cases/generic_model.py new file mode 100644 index 0000000..5985e87 --- /dev/null +++ b/tests/integration/cases/generic_model.py @@ -0,0 +1,28 @@ +from ..case import Case +from ..file import File + +cases = [ + Case( + id="Replace Fields", + input=File( + "field.py", + content=[ + "from pydantic import BaseModel, Field", + "", + "", + "class A(BaseModel):", + " a: List[int] = Field(..., min_items=1, max_items=10)", + ], + ), + expected=File( + "field.py", + content=[ + "from pydantic import BaseModel, Field", + "", + "", + "class A(BaseModel):", + " a: List[int] = Field(..., min_length=1, max_length=10)", + ], + ), + ), +] diff --git a/tests/integration/cases/is_base_model.py b/tests/integration/cases/is_base_model.py new file mode 100644 index 0000000..baa9946 --- /dev/null +++ b/tests/integration/cases/is_base_model.py @@ -0,0 +1,119 @@ +from ..case import Case +from ..file import File + +cases = [ + Case( + id="Make sure is BaseModel", + input=File( + "a.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " a: int", + "", + "", + "class D:", + " d: int", + ], + ), + expected=File( + "a.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " a: int", + "", + "", + "class D:", + " d: int", + ], + ), + ), + Case( + id="Make sure is BaseModel", + input=File( + "b.py", + content=[ + "from pydantic import BaseModel", + "from .a import A, D", + "from typing import Optional", + "", + "", + "class B(A):", + " b: Optional[int]", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + expected=File( + "b.py", + content=[ + "from pydantic import BaseModel", + "from .a import A, D", + "from typing import Optional", + "", + "", + "class B(A):", + " b: Optional[int] = None", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + ), + Case( + id="Make sure is BaseModel", + input=File( + "c.py", + content=[ + "from pydantic import BaseModel", + "from .d import D", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + expected=File( + "c.py", + content=[ + "from pydantic import BaseModel", + "from .d import D", + "", + "", + "class C(D):", + " c: Optional[int] = None", + ], + ), + ), + Case( + id="Make sure is BaseModel", + input=File( + "d.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class D(BaseModel):", + " d: int", + ], + ), + expected=File( + "d.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class D(BaseModel):", + " d: int", + ], + ), + ), +] diff --git a/tests/integration/cases/replace_validator.py b/tests/integration/cases/replace_validator.py new file mode 100644 index 0000000..601f643 --- /dev/null +++ b/tests/integration/cases/replace_validator.py @@ -0,0 +1,82 @@ +from ..case import Case +from ..file import File + +cases = [ + Case( + id="Replace validator", + input=File( + "replace_validator.py", + content=[ + "from pydantic import BaseModel, validator, root_validator", + "", + "", + "class A(BaseModel):", + " a: int", + " b: str", + "", + " @validator('a')", + " def validate_a(cls, v):", + " return v + 1", + "", + " @root_validator()", + " def validate_b(cls, values):", + " return values", + ], + ), + expected=File( + "replace_validator.py", + content=[ + "from pydantic import field_validator, model_validator, BaseModel", + "", + "", + "class A(BaseModel):", + " a: int", + " b: str", + "", + " @field_validator('a')", + " @classmethod", + " def validate_a(cls, v):", + " return v + 1", + "", + " @model_validator()", + " @classmethod", + " def validate_b(cls, values):", + " return values", + ], + ), + ), + Case( + id="Replace validator with pre=True", + input=File( + "const_to_literal.py", + content=[ + "from enum import Enum", + "from pydantic import BaseModel, Field", + "", + "", + "class A(str, Enum):", + " a = 'a'", + " b = 'b'", + "", + "class A(BaseModel):", + " a: A = Field(A.a, const=True)", + ], + ), + expected=File( + "const_to_literal.py", + content=[ + "from enum import Enum", + "from pydantic import BaseModel", + "from typing import Literal", + "", + "", + "class A(str, Enum):", + " a = 'a'", + " b = 'b'", + "", + "class A(BaseModel):", + " a: Literal[A.a] = A.a", + ], + ), + ), +] diff --git a/tests/integration/cases/root_model.py b/tests/integration/cases/root_model.py new file mode 100644 index 0000000..701355f --- /dev/null +++ b/tests/integration/cases/root_model.py @@ -0,0 +1,28 @@ +from ..case import Case +from ..file import File + +cases = [ + Case( + id="Replace __root__ by RootModel", + input=File( + "root_model.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " __root__ = int", + ], + ), + expected=File( + "root_model.py", + content=[ + "from pydantic import RootModel", + "", + "", + "class A(RootModel[int]):", + " pass", + ], + ), + ), +] diff --git a/tests/integration/file.py b/tests/integration/file.py new file mode 100644 index 0000000..5cf3030 --- /dev/null +++ b/tests/integration/file.py @@ -0,0 +1,16 @@ +from __future__ import annotations + + +class File: + def __init__(self, name: str, content: list[str] | None = None) -> None: + self.name = name + self.content = "\n".join(content or []) + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, File): + return NotImplemented + + if self.name != __value.name: + return False + + return self.content == __value.content diff --git a/tests/integration/folder.py b/tests/integration/folder.py new file mode 100644 index 0000000..3445a4d --- /dev/null +++ b/tests/integration/folder.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from pathlib import Path + +from .file import File + + +class Folder: + def __init__(self, name: str, *files: Folder | File) -> None: + self.name = name + self._files = files + + @property + def files(self) -> list[Folder | File]: + return sorted(self._files, key=lambda f: f.name) + + def create_structure(self, root: Path) -> None: + path = root / self.name + path.mkdir() + + for file in self.files: + if isinstance(file, Folder): + file.create_structure(path) + else: + (path / file.name).write_text(file.content) + + @classmethod + def from_structure(cls, root: Path) -> Folder: + name = root.name + files: list[File | Folder] = [] + + for path in root.iterdir(): + if path.is_dir(): + files.append(cls.from_structure(path)) + else: + files.append(File(path.name, path.read_text().splitlines())) + + return Folder(name, *files) + + def __eq__(self, __value: object) -> bool: + if isinstance(__value, File): + return False + + if not isinstance(__value, Folder): + return NotImplemented + + if self.name != __value.name: + return False + + if len(self.files) != len(__value.files): + return False + + for self_file, other_file in zip(self.files, __value.files): + if self_file != other_file: + return False + + return True diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index deeb5f9..604a964 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -3,459 +3,12 @@ import difflib from pathlib import Path -import pytest from typer.testing import CliRunner from bump_pydantic.main import app - -class Folder: - def __init__(self, name: str, *files: Folder | File) -> None: - self.name = name - self._files = files - - @property - def files(self) -> list[Folder | File]: - return sorted(self._files, key=lambda f: f.name) - - def create_structure(self, root: Path) -> None: - path = root / self.name - path.mkdir() - - for file in self.files: - if isinstance(file, Folder): - file.create_structure(path) - else: - (path / file.name).write_text(file.content) - - @classmethod - def from_structure(cls, root: Path) -> Folder: - name = root.name - files: list[File | Folder] = [] - - for path in root.iterdir(): - if path.is_dir(): - files.append(cls.from_structure(path)) - else: - files.append(File(path.name, path.read_text().splitlines())) - - return Folder(name, *files) - - def __eq__(self, __value: object) -> bool: - if isinstance(__value, File): - return False - - if not isinstance(__value, Folder): - return NotImplemented - - if self.name != __value.name: - return False - - if len(self.files) != len(__value.files): - return False - - for self_file, other_file in zip(self.files, __value.files): - if self_file != other_file: - return False - - return True - - -class File: - def __init__(self, name: str, content: list[str] | None = None) -> None: - self.name = name - self.content = "\n".join(content or []) - - def __eq__(self, __value: object) -> bool: - if not isinstance(__value, File): - return NotImplemented - - if self.name != __value.name: - return False - - return self.content == __value.content - - -@pytest.fixture() -def before() -> Folder: - return Folder( - "project", - File("__init__.py"), - File( - "settings.py", - content=[ - "from pydantic import BaseSettings", - "", - "", - "class Settings(BaseSettings):", - " a: int", - ], - ), - File( - "add_none.py", - content=[ - "from typing import Any, Dict, Optional, Union", - "", - "from pydantic import BaseModel", - "", - "", - "class A(BaseModel):", - " a: int | None", - " b: Optional[int]", - " c: Union[int, None]", - " d: Any", - " e: Dict[str, str]", - ], - ), - File( - "config_to_model.py", - content=[ - "from pydantic import BaseModel", - "", - "", - "class A(BaseModel):", - " class Config:", - " orm_mode = True", - " validate_all = True", - "", - "", - "class BaseConfig:", - " orm_mode = True", - " validate_all = True", - "", - "", - "class B(BaseModel):", - " class Config(BaseConfig):", - " ...", - ], - ), - File( - "replace_generic.py", - content=[ - "from typing import Generic, TypeVar", - "", - "from pydantic.generics import GenericModel", - "", - "T = TypeVar('T')", - "", - "", - "class User(GenericModel, Generic[T]):", - " name: str", - ], - ), - File( - "field.py", - content=[ - "from pydantic import BaseModel, Field", - "", - "", - "class A(BaseModel):", - " a: List[int] = Field(..., min_items=1, max_items=10)", - ], - ), - File( - "root_model.py", - content=[ - "from pydantic import BaseModel", - "", - "", - "class A(BaseModel):", - " __root__ = int", - ], - ), - File( - "replace_validator.py", - content=[ - "from pydantic import BaseModel, validator, root_validator", - "", - "", - "class A(BaseModel):", - " a: int", - " b: str", - "", - " @validator('a')", - " def validate_a(cls, v):", - " return v + 1", - "", - " @root_validator()", - " def validate_b(cls, values):", - " return values", - ], - ), - File( - "const_to_literal.py", - content=[ - "from enum import Enum", - "from pydantic import BaseModel, Field", - "", - "", - "class A(str, Enum):", - " a = 'a'", - " b = 'b'", - "", - "class A(BaseModel):", - " a: A = Field(A.a, const=True)", - ], - ), - File( - "config_dict_and_settings.py", - content=[ - "from pydantic import BaseModel, BaseSettings", - "", - "", - "class Settings(BaseSettings):", - " sentry_dsn: str", - "", - "", - "class A(BaseModel):", - " class Config:", - " orm_mode = True", - ], - ), - File( - "a.py", - content=[ - "from pydantic import BaseModel", - "", - "", - "class A(BaseModel):", - " a: int", - "", - "", - "class D:", - " d: int", - ], - ), - File( - "b.py", - content=[ - "from pydantic import BaseModel", - "from .a import A, D", - "from typing import Optional", - "", - "", - "class B(A):", - " b: Optional[int]", - "", - "", - "class C(D):", - " c: Optional[int]", - ], - ), - File( - "c.py", - content=[ - "from pydantic import BaseModel", - "from .d import D", - "", - "", - "class C(D):", - " c: Optional[int]", - ], - ), - File( - "d.py", - content=[ - "from pydantic import BaseModel", - "", - "", - "class D(BaseModel):", - " d: int", - ], - ), - ) - - -@pytest.fixture() -def expected() -> Folder: - return Folder( - "project", - File("__init__.py"), - File( - "settings.py", - content=[ - "from pydantic_settings import BaseSettings", - "", - "", - "class Settings(BaseSettings):", - " a: int", - ], - ), - File( - "add_none.py", - content=[ - "from typing import Any, Dict, Optional, Union", - "", - "from pydantic import BaseModel", - "", - "", - "class A(BaseModel):", - " a: int | None = None", - " b: Optional[int] = None", - " c: Union[int, None] = None", - " d: Any = None", - " e: Dict[str, str]", - ], - ), - File( - "config_to_model.py", - content=[ - "from pydantic import ConfigDict, BaseModel", - "", - "", - "class A(BaseModel):", - " model_config = ConfigDict(from_attributes=True, validate_default=True)", - "", - "", - "class BaseConfig:", - " orm_mode = True", - " validate_all = True", - "", - "", - "class B(BaseModel):", - " # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.", # noqa: E501 - " # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.", - " class Config(BaseConfig):", - " ...", - ], - ), - File( - "replace_generic.py", - content=[ - "from typing import Generic, TypeVar", - "from pydantic import BaseModel", - "", - "T = TypeVar('T')", - "", - "", - "class User(BaseModel, Generic[T]):", - " name: str", - ], - ), - File( - "field.py", - content=[ - "from pydantic import BaseModel, Field", - "", - "", - "class A(BaseModel):", - " a: List[int] = Field(..., min_length=1, max_length=10)", - ], - ), - File( - "root_model.py", - content=[ - "from pydantic import RootModel", - "", - "", - "class A(RootModel[int]):", - " pass", - ], - ), - File( - "replace_validator.py", - content=[ - "from pydantic import field_validator, model_validator, BaseModel", - "", - "", - "class A(BaseModel):", - " a: int", - " b: str", - "", - " @field_validator('a')", - " @classmethod", - " def validate_a(cls, v):", - " return v + 1", - "", - " @model_validator()", - " @classmethod", - " def validate_b(cls, values):", - " return values", - ], - ), - File( - "const_to_literal.py", - content=[ - "from enum import Enum", - "from pydantic import BaseModel", - "from typing import Literal", - "", - "", - "class A(str, Enum):", - " a = 'a'", - " b = 'b'", - "", - "class A(BaseModel):", - " a: Literal[A.a] = A.a", - ], - ), - File( - "config_dict_and_settings.py", - content=[ - "from pydantic import ConfigDict, BaseModel", - "from pydantic_settings import BaseSettings", - "", - "", - "class Settings(BaseSettings):", - " sentry_dsn: str", - "", - "", - "class A(BaseModel):", - " model_config = ConfigDict(from_attributes=True)", - ], - ), - File( - "a.py", - content=[ - "from pydantic import BaseModel", - "", - "", - "class A(BaseModel):", - " a: int", - "", - "", - "class D:", - " d: int", - ], - ), - File( - "b.py", - content=[ - "from pydantic import BaseModel", - "from .a import A, D", - "from typing import Optional", - "", - "", - "class B(A):", - " b: Optional[int] = None", - "", - "", - "class C(D):", - " c: Optional[int]", - ], - ), - File( - "c.py", - content=[ - "from pydantic import BaseModel", - "from .d import D", - "", - "", - "class C(D):", - " c: Optional[int] = None", - ], - ), - File( - "d.py", - content=[ - "from pydantic import BaseModel", - "", - "", - "class D(BaseModel):", - " d: int", - ], - ), - ) +from .cases import before, expected +from .folder import Folder def find_issue(current: Folder, expected: Folder) -> str: @@ -476,7 +29,7 @@ def find_issue(current: Folder, expected: Folder) -> str: return "Unknown" -def test_command_line(tmp_path: Path, before: Folder, expected: Folder) -> None: +def test_command_line(tmp_path: Path) -> None: runner = CliRunner() with runner.isolated_filesystem(temp_dir=tmp_path) as td: From 668051991ed1f2e0adaff7b62d9378cc7c60a512 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 10 Jul 2023 10:27:48 +0200 Subject: [PATCH 5/6] Uff --- bump_pydantic/main.py | 2 +- tests/integration/case.py | 5 +- tests/integration/cases/__init__.py | 9 +++ .../integration/cases/folder_inside_folder.py | 63 +++++++++++++++++++ tests/integration/test_cli.py | 6 +- 5 files changed, 80 insertions(+), 5 deletions(-) create mode 100644 tests/integration/cases/folder_inside_folder.py diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index fdc3abd..d587b0f 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -66,7 +66,7 @@ def main( metadata_manager.resolve_cache() logger.info("Running mypy to get type information. This may take a while...") - classes = run_mypy_visitor(files) + classes = run_mypy_visitor([str(package)]) scratch: dict[str, Any] = {CONTEXT_KEY: classes} logger.info("Finished mypy.") diff --git a/tests/integration/case.py b/tests/integration/case.py index b7c5910..1585b1a 100644 --- a/tests/integration/case.py +++ b/tests/integration/case.py @@ -3,10 +3,11 @@ from dataclasses import dataclass from .file import File +from .folder import Folder @dataclass class Case: - input: File - expected: File + input: Folder | File + expected: Folder | File id: str diff --git a/tests/integration/cases/__init__.py b/tests/integration/cases/__init__.py index 76be9d5..81856d9 100644 --- a/tests/integration/cases/__init__.py +++ b/tests/integration/cases/__init__.py @@ -1,13 +1,21 @@ +from ..case import Case +from ..file import File from ..folder import Folder from .add_none import cases as add_none_cases from .base_settings import cases as base_settings_cases from .config_to_model import cases as config_to_model_cases +from .folder_inside_folder import cases as folder_inside_folder_cases from .generic_model import cases as generic_model_cases from .is_base_model import cases as is_base_model_cases from .replace_validator import cases as replace_validator_cases from .root_model import cases as root_model_cases cases = [ + Case( + id="empty", + input=File("__init__.py", content=[]), + expected=File("__init__.py", content=[]), + ), *base_settings_cases, *add_none_cases, *is_base_model_cases, @@ -15,6 +23,7 @@ *config_to_model_cases, *root_model_cases, *generic_model_cases, + *folder_inside_folder_cases, ] before = Folder("project", *[case.input for case in cases]) expected = Folder("project", *[case.expected for case in cases]) diff --git a/tests/integration/cases/folder_inside_folder.py b/tests/integration/cases/folder_inside_folder.py new file mode 100644 index 0000000..12ee49b --- /dev/null +++ b/tests/integration/cases/folder_inside_folder.py @@ -0,0 +1,63 @@ +from ..case import Case +from ..file import File +from ..folder import Folder + +cases = [ + Case( + id="Add Folder", + input=Folder( + "folder", + File("__init__.py", content=[]), + File( + "file.py", + content=[ + "from typing import Optional, Union", + "", + "from .another_module import C", + "", + "", + "class A(C):", + " b: Union[int, None]", + " c: Optional[int]", + ], + ), + File( + "another_module.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class C(BaseModel):", + " a: int", + ], + ), + ), + expected=Folder( + "folder", + File("__init__.py", content=[]), + File( + "file.py", + content=[ + "from typing import Optional, Union", + "", + "from .another_module import C", + "", + "", + "class A(C):", + " b: Union[int, None] = None", + " c: Optional[int] = None", + ], + ), + File( + "another_module.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class C(BaseModel):", + " a: int", + ], + ), + ), + ) +] diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index 604a964..c761c5f 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -16,7 +16,9 @@ def find_issue(current: Folder, expected: Folder) -> str: if current_file != expected_file: if current_file.name != expected_file.name: return f"Files have different names: {current_file.name} != {expected_file.name}" - if isinstance(current_file, Folder) or isinstance(expected_file, Folder): + if isinstance(current_file, Folder) and isinstance(expected_file, Folder): + return find_issue(current_file, expected_file) + elif isinstance(current_file, Folder) or isinstance(expected_file, Folder): return f"One of the files is a folder: {current_file.name} != {expected_file.name}" return "\n".join( difflib.unified_diff( @@ -29,6 +31,7 @@ def find_issue(current: Folder, expected: Folder) -> str: return "Unknown" +# @pytest.mark.parametrize("before,expected", zip([before, expected])) def test_command_line(tmp_path: Path) -> None: runner = CliRunner() @@ -36,7 +39,6 @@ def test_command_line(tmp_path: Path) -> None: before.create_structure(root=Path(td)) result = runner.invoke(app, [before.name]) - print(result.output) assert result.exit_code == 0, result.output # assert result.output.endswith("Refactored 4 files.\n") From 1ce0561f36fe542787d8e36db0577809af318ead Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 10 Jul 2023 10:28:53 +0200 Subject: [PATCH 6/6] Use files instead of package on the run_mypy_visitor --- bump_pydantic/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index d587b0f..fdc3abd 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -66,7 +66,7 @@ def main( metadata_manager.resolve_cache() logger.info("Running mypy to get type information. This may take a while...") - classes = run_mypy_visitor([str(package)]) + classes = run_mypy_visitor(files) scratch: dict[str, Any] = {CONTEXT_KEY: classes} logger.info("Finished mypy.")