From b4aae412ce759a6b2309c221a32ff588c364a3f7 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 17 Jul 2023 18:29:58 +0200 Subject: [PATCH 1/2] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Replace=20`MypyVisitor?= =?UTF-8?q?`=20by=20`ClassDefVisitor`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/codemods/add_default_none.py | 7 +- bump_pydantic/codemods/class_def_visitor.py | 147 ++++++++++++++++++++ bump_pydantic/codemods/mypy_visitor.py | 52 ------- bump_pydantic/main.py | 59 +++++++- pyproject.toml | 2 +- tests/unit/test_add_default_none.py | 7 +- 6 files changed, 206 insertions(+), 68 deletions(-) create mode 100644 bump_pydantic/codemods/class_def_visitor.py delete mode 100644 bump_pydantic/codemods/mypy_visitor.py diff --git a/bump_pydantic/codemods/add_default_none.py b/bump_pydantic/codemods/add_default_none.py index be0bf1f..20f511f 100644 --- a/bump_pydantic/codemods/add_default_none.py +++ b/bump_pydantic/codemods/add_default_none.py @@ -5,7 +5,7 @@ from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand from libcst.metadata import FullyQualifiedNameProvider, QualifiedName -from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor class AddDefaultNoneCommand(VisitorBasedCodemodCommand): @@ -49,7 +49,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: return None fqn: QualifiedName = next(iter(fqn_set)) # type: ignore - if self.context.scratch[CONTEXT_KEY].get(fqn.name, False): + if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]: self.inside_base_model = True def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: @@ -138,9 +138,6 @@ class Bar(Foo): wrapper = mrg.get_metadata_wrapper_for_path(module) context = CodemodContext(wrapper=wrapper) - # classes = run_mypy_visitor(context=context) - # mod = wrapper.visit(command) - command = AddDefaultNoneCommand(context=context) # type: ignore[assignment] mod = wrapper.visit(command) print(mod.code) diff --git a/bump_pydantic/codemods/class_def_visitor.py b/bump_pydantic/codemods/class_def_visitor.py new file mode 100644 index 0000000..f911fc3 --- /dev/null +++ b/bump_pydantic/codemods/class_def_visitor.py @@ -0,0 +1,147 @@ +""" +There are two objects in the visitor: +1. `base_model_cls` (Set[str]): Set of classes that are BaseModel based. +2. `cls` (Dict[str, Set[str]]): A dictionary mapping each class definition to a set of base classes. + +`base_model_cls` accumulates on each iteration. +`cls` also accumulates on each iteration, but it's also partially solved: +1. Check if the module visited is a prefix of any `cls.keys()`. +1.1. If it is, and if any `base_model_cls` is found, remove from `cls`, and add to `base_model_cls`. +1.2. If it's not, it continues on the `cls` +""" +from __future__ import annotations + +from collections import defaultdict +from typing import Set, cast + +import libcst as cst +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.metadata import FullyQualifiedNameProvider, QualifiedName + + +class ClassDefVisitor(VisitorBasedCodemodCommand): + METADATA_DEPENDENCIES = {FullyQualifiedNameProvider} + + BASE_MODEL_CONTEXT_KEY = "base_model_cls" + NO_BASE_MODEL_CONTEXT_KEY = "no_base_model_cls" + CLS_CONTEXT_KEY = "cls" + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + self.module_fqn: None | QualifiedName = None + + self.context.scratch.setdefault( + self.BASE_MODEL_CONTEXT_KEY, + {"pydantic.BaseModel", "pydantic.main.BaseModel"}, + ) + self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set()) + self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set)) + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) + + if not fqn_set: + return None + + fqn: QualifiedName = next(iter(fqn_set)) # type: ignore + + if not node.bases: + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) + + for arg in node.bases: + base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value) + base_fqn_set = base_fqn_set or set() + + for base_fqn in cast(Set[QualifiedName], iter(base_fqn_set)): # type: ignore + if base_fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]: + self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(fqn.name) + elif base_fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]: + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) + + # In case we have the following scenario: + # class A(B): ... + # class B(BaseModel): ... + # class D(C): ... + # class C: ... + # We want to disambiguate `A` as soon as we see `B` is a `BaseModel`. + if ( + fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY] + and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] + ): + for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): + self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class) + + # In case we have the following scenario: + # class A(B): ... + # class B(BaseModel): ... + # class D(C): ... + # class C: ... + # We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`. + if ( + fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY] + and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] + ): + for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class) + + # In case we have the following scenario: + # class A(B): ... + # ...And B is not known. + # We want to make sure that B -> A is added to the `cls` context, so if we find B later, + # we can disambiguate. + if fqn.name not in ( + *self.context.scratch[self.BASE_MODEL_CONTEXT_KEY], + *self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY], + ): + for base_fqn in cast(Set[QualifiedName], base_fqn_set): + self.context.scratch[self.CLS_CONTEXT_KEY][base_fqn.name].add(fqn.name) + + # TODO: Implement this if needed... + def next_file(self, visited: set[str]) -> str | None: + return None + + +if __name__ == "__main__": + import os + import textwrap + from pathlib import Path + from tempfile import TemporaryDirectory + + from libcst.metadata import FullRepoManager + from rich.pretty import pprint + + with TemporaryDirectory(dir=os.getcwd()) as tmpdir: + package_dir = f"{tmpdir}/package" + os.mkdir(package_dir) + module_path = f"{package_dir}/a.py" + with open(module_path, "w") as f: + content = textwrap.dedent( + """ + from pydantic import BaseModel + + class Foo(BaseModel): + a: str + + class Bar(Foo): + b: str + + class Potato: + ... + + class Spam(Potato): + ... + + foo = Foo(a="text") + foo.dict() + """ + ) + f.write(content) + module = str(Path(module_path).relative_to(tmpdir)) + mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider}) + wrapper = mrg.get_metadata_wrapper_for_path(module) + context = CodemodContext(wrapper=wrapper) + command = ClassDefVisitor(context=context) + mod = wrapper.visit(command) + pprint(context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]) + pprint(context.scratch[ClassDefVisitor.NO_BASE_MODEL_CONTEXT_KEY]) + pprint(context.scratch[ClassDefVisitor.CLS_CONTEXT_KEY]) diff --git a/bump_pydantic/codemods/mypy_visitor.py b/bump_pydantic/codemods/mypy_visitor.py deleted file mode 100644 index 053fc91..0000000 --- a/bump_pydantic/codemods/mypy_visitor.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -import sys -from argparse import ArgumentParser - -from mypy.build import build -from mypy.main import process_options -from mypy.nodes import ClassDef -from mypy.traverser import TraverserVisitor - -CONTEXT_KEY = "mypy_visitor" - - -class MyPyVisitor(TraverserVisitor): - def __init__(self) -> None: - super().__init__() - self.classes: dict[str, bool] = {} - - def visit_class_def(self, o: ClassDef) -> None: - super().visit_class_def(o) - self.classes[o.fullname] = o.info.has_base("pydantic.main.BaseModel") - - -def run_mypy_visitor(arg_files: list[str]) -> dict[str, bool]: - files, opt = process_options(arg_files, stdout=sys.stdout, stderr=sys.stderr) - - opt.export_types = True - opt.incremental = True - opt.fine_grained_incremental = True - opt.cache_fine_grained = True - opt.allow_redefinition = True - opt.local_partial_types = True - - result = build(files, opt, stdout=sys.stdout, stderr=sys.stderr) - - visitor = MyPyVisitor() - classes: dict[str, bool] = {} - - for file in files: - tree = result.graph[file.module].tree - if tree: - tree.accept(visitor=visitor) - classes.update(visitor.classes) - return classes - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("files", nargs="+") - args = parser.parse_args() - - run_mypy_visitor(args.files) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index 3bdfe30..bc85cf1 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -5,7 +5,7 @@ import time import traceback from pathlib import Path -from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Iterable, List, Set, Tuple, Type, TypeVar, Union import libcst as cst from libcst.codemod import CodemodContext, ContextAwareTransformer @@ -18,7 +18,7 @@ from bump_pydantic import __version__ from bump_pydantic.codemods import Rule, gather_codemods -from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor app = Typer( help="Convert Pydantic from V1 to V2 ♻️", @@ -69,10 +69,41 @@ def main( metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type] metadata_manager.resolve_cache() - console.log("Running mypy to get type information. This may take a while...") - classes = run_mypy_visitor(files) - scratch: dict[str, Any] = {CONTEXT_KEY: classes} - console.log("Finished mypy.") + scratch: dict[str, Any] = {} + with Progress(*Progress.get_default_columns(), transient=True) as progress: + task = progress.add_task(description="Looking for Pydantic Models...", total=len(files)) + queue: List[str] = [files[0]] + visited: Set[str] = set() + + while queue: + # Queue logic + filename = queue.pop() + visited.add(filename) + progress.advance(task) + + # Visitor logic + code = Path(filename).read_text() + module = cst.parse_module(code) + module_and_package = calculate_module_and_package(str(package), filename) + + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + scratch=scratch, + ) + visitor = ClassDefVisitor(context=context) + visitor.transform_module(module) + + # Queue logic + next_file = visitor.next_file(visited) + if next_file is not None: + queue.append(next_file) + + missing_files = set(files) - visited + if not queue and missing_files: + queue.append(next(iter(missing_files))) start_time = time.time() @@ -160,6 +191,22 @@ def run_codemods( return f"An error happened on {filename}.\n{traceback.format_exc()}", None +def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: str) -> Dict[str, Any]: + code = Path(filename).read_text() + module = cst.parse_module(code) + module_and_package = calculate_module_and_package(str(package), filename) + + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + ) + visitor = ClassDefVisitor(context=context) + visitor.transform_module(module) + return context.scratch + + def color_diff(console: Console, lines: Iterable[str]) -> None: for line in lines: line = line.rstrip("\n") diff --git a/pyproject.toml b/pyproject.toml index 9d69136..9836a1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", "Framework :: Pydantic", ] -dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions", "mypy"] +dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions"] [project.urls] Documentation = "https://github.com/pydantic/bump-pydantic#readme" diff --git a/tests/unit/test_add_default_none.py b/tests/unit/test_add_default_none.py index 7301ada..a0f1202 100644 --- a/tests/unit/test_add_default_none.py +++ b/tests/unit/test_add_default_none.py @@ -9,10 +9,9 @@ from libcst.testing.utils import UnitTest from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand -from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor -@pytest.mark.skip(reason="The file needs to exists for the test to pass.") class TestClassDefVisitor(UnitTest): def add_default_none(self, file_path: str, code: str) -> cst.Module: mod = MetadataWrapper( @@ -25,8 +24,8 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module: ) mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES) context = CodemodContext(wrapper=mod) - classes = run_mypy_visitor(arg_files=[file_path]) - context.scratch.update({CONTEXT_KEY: classes}) + instance = ClassDefVisitor(context=context) + mod.visit(instance) instance = AddDefaultNoneCommand(context=context) # type: ignore[assignment] return mod.visit(instance) From 152900dbbc4f1cda1596194cb7f321380a31ad3b Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 17 Jul 2023 18:30:50 +0200 Subject: [PATCH 2/2] remove visit_class_def --- bump_pydantic/main.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index bc85cf1..8cc83ae 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -191,22 +191,6 @@ def run_codemods( return f"An error happened on {filename}.\n{traceback.format_exc()}", None -def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: str) -> Dict[str, Any]: - code = Path(filename).read_text() - module = cst.parse_module(code) - module_and_package = calculate_module_and_package(str(package), filename) - - context = CodemodContext( - metadata_manager=metadata_manager, - filename=filename, - full_module_name=module_and_package.name, - full_package_name=module_and_package.package, - ) - visitor = ClassDefVisitor(context=context) - visitor.transform_module(module) - return context.scratch - - def color_diff(console: Console, lines: Iterable[str]) -> None: for line in lines: line = line.rstrip("\n")