From e5e9148bb84f66c81bdc3ed671dcfd08a24f092d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 4 Jul 2023 16:32:29 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Resolve=20ClassDefs=20as=20soon?= =?UTF-8?q?=20as=20evaluated?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/codemods/add_default_none.py | 8 +- bump_pydantic/codemods/class_def_visitor.py | 91 +++++++++++++++-- bump_pydantic/main.py | 69 ++++++++----- bump_pydantic/markers/__init__.py | 0 bump_pydantic/markers/find_base_model.py | 81 ---------------- tests/integration/test_cli.py | 102 ++++++++++++++++++++ tests/unit/test_add_default_none.py | 3 - tests/unit/test_class_def_visitor.py | 23 ++--- 8 files changed, 239 insertions(+), 138 deletions(-) delete mode 100644 bump_pydantic/markers/__init__.py delete mode 100644 bump_pydantic/markers/find_base_model.py diff --git a/bump_pydantic/codemods/add_default_none.py b/bump_pydantic/codemods/add_default_none.py index 35bd49d..83698b6 100644 --- a/bump_pydantic/codemods/add_default_none.py +++ b/bump_pydantic/codemods/add_default_none.py @@ -6,8 +6,6 @@ from libcst.metadata import FullyQualifiedNameProvider, QualifiedName from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor -from bump_pydantic.markers.find_base_model import CONTEXT_KEY as BASE_MODEL_CONTEXT_KEY -from bump_pydantic.markers.find_base_model import find_base_model class AddDefaultNoneCommand(VisitorBasedCodemodCommand): @@ -51,7 +49,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: return None fqn: QualifiedName = next(iter(fqn_set)) # type: ignore - if fqn.name in self.context.scratch[BASE_MODEL_CONTEXT_KEY]: + if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]: self.inside_base_model = True def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: @@ -94,7 +92,6 @@ def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAss from tempfile import TemporaryDirectory from libcst.metadata import FullRepoManager - from rich.pretty import pprint with TemporaryDirectory(dir=os.getcwd()) as tmpdir: package_dir = f"{tmpdir}/package" @@ -126,9 +123,6 @@ class Bar(Foo): command = ClassDefVisitor(context=context) mod = wrapper.visit(command) - find_base_model(scratch=context.scratch) - pprint(context.scratch) - command = AddDefaultNoneCommand(context=context) # type: ignore[assignment] mod = wrapper.visit(command) print(mod.code) diff --git a/bump_pydantic/codemods/class_def_visitor.py b/bump_pydantic/codemods/class_def_visitor.py index 17e54e0..d737df9 100644 --- a/bump_pydantic/codemods/class_def_visitor.py +++ b/bump_pydantic/codemods/class_def_visitor.py @@ -1,6 +1,18 @@ +""" +There are two objects in the visitor: +1. `base_model_cls` (Set[str]): Set of classes that are BaseModel based. +2. `cls` (Dict[str, Set[str]]): A dictionary mapping each class definition to a set of base classes. + +`base_model_cls` accumulates on each iteration. +`cls` also accumulates on each iteration, but it's also partially solved: +1. Check if the module visited is a prefix of any `cls.keys()`. +1.1. If it is, and if any `base_model_cls` is found, remove from `cls`, and add to `base_model_cls`. +1.2. If it's not, it continues on the `cls` +""" from __future__ import annotations from collections import defaultdict +from typing import Set, cast import libcst as cst from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand @@ -10,12 +22,17 @@ class ClassDefVisitor(VisitorBasedCodemodCommand): METADATA_DEPENDENCIES = {FullyQualifiedNameProvider} - CONTEXT_KEY = "class_def_visitor" + BASE_MODEL_CONTEXT_KEY = "base_model_cls" + NO_BASE_MODEL_CONTEXT_KEY = "no_base_model_cls" + CLS_CONTEXT_KEY = "cls" def __init__(self, context: CodemodContext) -> None: super().__init__(context) self.module_fqn: None | QualifiedName = None - self.context.scratch.setdefault(self.CONTEXT_KEY, defaultdict(set)) + + self.context.scratch.setdefault(self.BASE_MODEL_CONTEXT_KEY, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) + self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set()) + self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set)) def visit_ClassDef(self, node: cst.ClassDef) -> None: fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) @@ -24,15 +41,61 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: return None fqn: QualifiedName = next(iter(fqn_set)) # type: ignore - for arg in node.bases: - base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value) - if not base_fqn_set: - return None + if not node.bases: + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) - base_fqn: QualifiedName = next(iter(base_fqn_set)) # type: ignore - # NOTE: Should I use the name or the QualifiedName? - self.context.scratch[self.CONTEXT_KEY][fqn.name].add(base_fqn.name) + for arg in node.bases: + base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value) + base_fqn_set = base_fqn_set or set() + + for base_fqn in cast(Set[QualifiedName], iter(base_fqn_set)): # type: ignore + if base_fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]: + self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(fqn.name) + elif base_fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]: + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) + + # In case we have the following scenario: + # class A(B): ... + # class B(BaseModel): ... + # class D(C): ... + # class C: ... + # We want to disambiguate `A` as soon as we see `B` is a `BaseModel`. + if ( + fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY] + and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] + ): + for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): + self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class) + + # In case we have the following scenario: + # class A(B): ... + # class B(BaseModel): ... + # class D(C): ... + # class C: ... + # We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`. + if ( + fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY] + and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] + ): + for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class) + + # In case we have the following scenario: + # class A(B): ... + # ...And B is not known. + # We want to make sure that B -> A is added to the `cls` context, so if we find B later, + # we can disambiguate. + if fqn.name not in ( + *self.context.scratch[self.BASE_MODEL_CONTEXT_KEY], + *self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY], + ): + for base_fqn in cast(Set[QualifiedName], base_fqn_set): + self.context.scratch[self.CLS_CONTEXT_KEY][base_fqn.name].add(fqn.name) + + # TODO: Implement this if needed... + def next_file(self, visited: set[str]) -> str | None: + return None if __name__ == "__main__": @@ -59,6 +122,12 @@ class Foo(BaseModel): class Bar(Foo): b: str + class Potato: + ... + + class Spam(Potato): + ... + foo = Foo(a="text") foo.dict() """ @@ -70,4 +139,6 @@ class Bar(Foo): context = CodemodContext(wrapper=wrapper) command = ClassDefVisitor(context=context) mod = wrapper.visit(command) - pprint(context.scratch[ClassDefVisitor.CONTEXT_KEY]) + pprint(context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]) + pprint(context.scratch[ClassDefVisitor.NO_BASE_MODEL_CONTEXT_KEY]) + pprint(context.scratch[ClassDefVisitor.CLS_CONTEXT_KEY]) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index 44c6144..3726122 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -5,7 +5,7 @@ import time from contextlib import nullcontext from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Set, Type, TypeVar, Union import libcst as cst from libcst.codemod import CodemodContext, ContextAwareTransformer @@ -19,7 +19,6 @@ from bump_pydantic import __version__ from bump_pydantic.codemods import Rule, gather_codemods from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor -from bump_pydantic.markers.find_base_model import find_base_model app = Typer( help="Convert Pydantic from V1 to V2 ♻️", @@ -65,14 +64,39 @@ def main( scratch: dict[str, Any] = {} with Progress(*Progress.get_default_columns(), transient=True) as progress: task = progress.add_task(description="Looking for Pydantic Models...", total=len(files)) - with multiprocessing.Pool() as pool: - partial_visit_class_def = functools.partial(visit_class_def, metadata_manager, package) - for local_scratch in pool.imap_unordered(partial_visit_class_def, files): - progress.advance(task) - for key, value in local_scratch.items(): - scratch.setdefault(key, value).update(value) - find_base_model(scratch) + queue: List[str] = [files[0]] + visited: Set[str] = set() + + while queue: + # Queue logic + filename = queue.pop() + visited.add(filename) + progress.advance(task) + + # Visitor logic + code = Path(filename).read_text() + module = cst.parse_module(code) + module_and_package = calculate_module_and_package(str(package), filename) + + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + scratch=scratch, + ) + visitor = ClassDefVisitor(context=context) + visitor.transform_module(module) + + # Queue logic + next_file = visitor.next_file(visited) + if next_file is not None: + queue.append(next_file) + + missing_files = set(files) - visited + if not queue and missing_files: + queue.append(next(iter(missing_files))) start_time = time.time() @@ -102,6 +126,20 @@ def main( print(f"Refactored {len(modified)} files.") +def capture_exception(func: Callable[P, T]) -> Callable[P, Union[T, Iterable[str]]]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]: + try: + return func(*args, **kwargs) + except Exception as exc: + func_args = [repr(arg) for arg in args] + func_kwargs = [f"{key}={repr(value)}" for key, value in kwargs.items()] + return [f"{func.__name__}({', '.join(func_args + func_kwargs)})\n{exc}"] + + return wrapper + + +@capture_exception def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: str) -> Dict[str, Any]: code = Path(filename).read_text() module = cst.parse_module(code) @@ -118,19 +156,6 @@ def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: return context.scratch -def capture_exception(func: Callable[P, T]) -> Callable[P, Union[T, Iterable[str]]]: - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]: - try: - return func(*args, **kwargs) - except Exception as exc: - func_args = [repr(arg) for arg in args] - func_kwargs = [f"{key}={repr(value)}" for key, value in kwargs.items()] - return [f"{func.__name__}({', '.join(func_args + func_kwargs)})\n{exc}"] - - return wrapper - - @capture_exception def run_codemods( codemods: List[Type[ContextAwareTransformer]], diff --git a/bump_pydantic/markers/__init__.py b/bump_pydantic/markers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bump_pydantic/markers/find_base_model.py b/bump_pydantic/markers/find_base_model.py deleted file mode 100644 index eef2525..0000000 --- a/bump_pydantic/markers/find_base_model.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from typing import Any - -from libcst.codemod import CodemodContext - -from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor - -CONTEXT_KEY = "find_base_model" - - -def revert_dictionary(classes: defaultdict[str, set[str]]) -> defaultdict[str, set[str]]: - revert_classes: defaultdict[str, set[str]] = defaultdict(set) - for cls, bases in classes.copy().items(): - for base in bases: - revert_classes[base].add(cls) - return revert_classes - - -def find_base_model(scratch: dict[str, Any]) -> None: - classes = scratch[ClassDefVisitor.CONTEXT_KEY] - revert_classes = revert_dictionary(classes) - base_model_set: set[str] = set() - - for cls, bases in revert_classes.copy().items(): - if cls in ("pydantic.BaseModel", "BaseModel"): - base_model_set = base_model_set.union(bases) - - visited: set[str] = set() - bases_queue = list(bases) - while bases_queue: - base = bases_queue.pop() - - if base in visited: - continue - visited.add(base) - - base_model_set.add(base) - bases_queue.extend(revert_classes[base]) - - scratch[CONTEXT_KEY] = base_model_set - - -if __name__ == "__main__": - import os - import textwrap - from pathlib import Path - from tempfile import TemporaryDirectory - - from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider - from rich.pretty import pprint - - with TemporaryDirectory(dir=os.getcwd()) as tmpdir: - package_dir = f"{tmpdir}/package" - os.mkdir(package_dir) - module_path = f"{package_dir}/a.py" - with open(module_path, "w") as f: - content = textwrap.dedent( - """ - from pydantic import BaseModel - - class Foo(BaseModel): - a: str - - class Bar(Foo): - b: str - - foo = Foo(a="text") - foo.dict() - """ - ) - f.write(content) - module = str(Path(module_path).relative_to(tmpdir)) - mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider}) - wrapper = mrg.get_metadata_wrapper_for_path(module) - context = CodemodContext(wrapper=wrapper) - command = ClassDefVisitor(context=context) - mod = wrapper.visit(command) - find_base_model(scratch=context.scratch) - pprint(context.scratch[CONTEXT_KEY]) diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index 0bbbaf4..b403bad 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -212,6 +212,57 @@ def before() -> Folder: " orm_mode = True", ], ), + File( + "a.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " a: int", + "", + "", + "class D:", + " d: int", + ], + ), + File( + "b.py", + content=[ + "from pydantic import BaseModel", + "from .a import A, D", + "from typing import Optional", + "", + "", + "class B(A):", + " b: Optional[int]", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + File( + "c.py", + content=[ + "from pydantic import BaseModel", + "from .d import D", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + File( + "d.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class D(BaseModel):", + " d: int", + ], + ), ) @@ -353,6 +404,57 @@ def expected() -> Folder: " model_config = ConfigDict(from_attributes=True)", ], ), + File( + "a.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class A(BaseModel):", + " a: int", + "", + "", + "class D:", + " d: int", + ], + ), + File( + "b.py", + content=[ + "from pydantic import BaseModel", + "from .a import A, D", + "from typing import Optional", + "", + "", + "class B(A):", + " b: Optional[int] = None", + "", + "", + "class C(D):", + " c: Optional[int]", + ], + ), + File( + "c.py", + content=[ + "from pydantic import BaseModel", + "from .d import D", + "", + "", + "class C(D):", + " c: Optional[int] = None", + ], + ), + File( + "d.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class D(BaseModel):", + " d: int", + ], + ), ) diff --git a/tests/unit/test_add_default_none.py b/tests/unit/test_add_default_none.py index 0688c8f..09ad922 100644 --- a/tests/unit/test_add_default_none.py +++ b/tests/unit/test_add_default_none.py @@ -10,7 +10,6 @@ from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor -from bump_pydantic.markers.find_base_model import find_base_model class TestClassDefVisitor(UnitTest): @@ -28,8 +27,6 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module: instance = ClassDefVisitor(context=context) mod.visit(instance) - find_base_model(scratch=context.scratch) - instance = AddDefaultNoneCommand(context=context) # type: ignore[assignment] return mod.visit(instance) diff --git a/tests/unit/test_class_def_visitor.py b/tests/unit/test_class_def_visitor.py index 1065780..7608ed0 100644 --- a/tests/unit/test_class_def_visitor.py +++ b/tests/unit/test_class_def_visitor.py @@ -31,8 +31,8 @@ def foo() -> None: pass """, ) - results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] - self.assertEqual(results, {}) + results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] + self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) def test_without_bases(self) -> None: visitor = self.gather_class_def( @@ -42,8 +42,8 @@ class Foo: pass """, ) - results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] - self.assertEqual(results, {}) + results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] + self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel"}) def test_with_class_defs(self) -> None: visitor = self.gather_class_def( @@ -58,13 +58,9 @@ class Bar(Foo): pass """, ) - results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] + results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] self.assertEqual( - results, - { - "some.test.module.Foo": {"pydantic.BaseModel"}, - "some.test.module.Bar": {"some.test.module.Foo"}, - }, + results, {"pydantic.BaseModel", "pydantic.main.BaseModel", "some.test.module.Foo", "some.test.module.Bar"} ) def test_with_pydantic_base_model(self) -> None: @@ -77,8 +73,5 @@ class Foo(pydantic.BaseModel): ... """, ) - results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] - self.assertEqual( - results, - {"some.test.module.Foo": {"pydantic.BaseModel"}}, - ) + results = visitor.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY] + self.assertEqual(results, {"pydantic.BaseModel", "pydantic.main.BaseModel", "some.test.module.Foo"})