Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Replace MypyVisitor by ClassDefVisitor #99

Merged
merged 2 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions bump_pydantic/codemods/add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName

from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor


class AddDefaultNoneCommand(VisitorBasedCodemodCommand):
Expand Down Expand Up @@ -49,7 +49,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore
if self.context.scratch[CONTEXT_KEY].get(fqn.name, False):
if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]:
self.inside_base_model = True

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
Expand Down Expand Up @@ -138,9 +138,6 @@ class Bar(Foo):
wrapper = mrg.get_metadata_wrapper_for_path(module)
context = CodemodContext(wrapper=wrapper)

# classes = run_mypy_visitor(context=context)
# mod = wrapper.visit(command)

command = AddDefaultNoneCommand(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
print(mod.code)
147 changes: 147 additions & 0 deletions bump_pydantic/codemods/class_def_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
There are two objects in the visitor:
1. `base_model_cls` (Set[str]): Set of classes that are BaseModel based.
2. `cls` (Dict[str, Set[str]]): A dictionary mapping each class definition to a set of base classes.

`base_model_cls` accumulates on each iteration.
`cls` also accumulates on each iteration, but it's also partially solved:
1. Check if the module visited is a prefix of any `cls.keys()`.
1.1. If it is, and if any `base_model_cls` is found, remove from `cls`, and add to `base_model_cls`.
1.2. If it's not, it continues on the `cls`
"""
from __future__ import annotations

from collections import defaultdict
from typing import Set, cast

import libcst as cst
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName


class ClassDefVisitor(VisitorBasedCodemodCommand):
METADATA_DEPENDENCIES = {FullyQualifiedNameProvider}

BASE_MODEL_CONTEXT_KEY = "base_model_cls"
NO_BASE_MODEL_CONTEXT_KEY = "no_base_model_cls"
CLS_CONTEXT_KEY = "cls"

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
self.module_fqn: None | QualifiedName = None

self.context.scratch.setdefault(
self.BASE_MODEL_CONTEXT_KEY,
{"pydantic.BaseModel", "pydantic.main.BaseModel"},
)
self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set())
self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set))

def visit_ClassDef(self, node: cst.ClassDef) -> None:
fqn_set = self.get_metadata(FullyQualifiedNameProvider, node)

if not fqn_set:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore

if not node.bases:
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name)

for arg in node.bases:
base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value)
base_fqn_set = base_fqn_set or set()

for base_fqn in cast(Set[QualifiedName], iter(base_fqn_set)): # type: ignore
if base_fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]:
self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(fqn.name)
elif base_fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]:
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name)

# In case we have the following scenario:
# class A(B): ...
# class B(BaseModel): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `A` as soon as we see `B` is a `BaseModel`.
if (
fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class)

# In case we have the following scenario:
# class A(B): ...
# class B(BaseModel): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`.
if (
fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class)

# In case we have the following scenario:
# class A(B): ...
# ...And B is not known.
# We want to make sure that B -> A is added to the `cls` context, so if we find B later,
# we can disambiguate.
if fqn.name not in (
*self.context.scratch[self.BASE_MODEL_CONTEXT_KEY],
*self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY],
):
for base_fqn in cast(Set[QualifiedName], base_fqn_set):
self.context.scratch[self.CLS_CONTEXT_KEY][base_fqn.name].add(fqn.name)

# TODO: Implement this if needed...
def next_file(self, visited: set[str]) -> str | None:
return None


if __name__ == "__main__":
import os
import textwrap
from pathlib import Path
from tempfile import TemporaryDirectory

from libcst.metadata import FullRepoManager
from rich.pretty import pprint

with TemporaryDirectory(dir=os.getcwd()) as tmpdir:
package_dir = f"{tmpdir}/package"
os.mkdir(package_dir)
module_path = f"{package_dir}/a.py"
with open(module_path, "w") as f:
content = textwrap.dedent(
"""
from pydantic import BaseModel

class Foo(BaseModel):
a: str

class Bar(Foo):
b: str

class Potato:
...

class Spam(Potato):
...

foo = Foo(a="text")
foo.dict()
"""
)
f.write(content)
module = str(Path(module_path).relative_to(tmpdir))
mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider})
wrapper = mrg.get_metadata_wrapper_for_path(module)
context = CodemodContext(wrapper=wrapper)
command = ClassDefVisitor(context=context)
mod = wrapper.visit(command)
pprint(context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY])
pprint(context.scratch[ClassDefVisitor.NO_BASE_MODEL_CONTEXT_KEY])
pprint(context.scratch[ClassDefVisitor.CLS_CONTEXT_KEY])
52 changes: 0 additions & 52 deletions bump_pydantic/codemods/mypy_visitor.py

This file was deleted.

43 changes: 37 additions & 6 deletions bump_pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Set, Tuple, Type, TypeVar, Union

import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer
Expand All @@ -18,7 +18,7 @@

from bump_pydantic import __version__
from bump_pydantic.codemods import Rule, gather_codemods
from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor

app = Typer(
help="Convert Pydantic from V1 to V2 ♻️",
Expand Down Expand Up @@ -69,10 +69,41 @@ def main(
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()

console.log("Running mypy to get type information. This may take a while...")
classes = run_mypy_visitor(files)
scratch: dict[str, Any] = {CONTEXT_KEY: classes}
console.log("Finished mypy.")
scratch: dict[str, Any] = {}
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Looking for Pydantic Models...", total=len(files))
queue: List[str] = [files[0]]
visited: Set[str] = set()

while queue:
# Queue logic
filename = queue.pop()
visited.add(filename)
progress.advance(task)

# Visitor logic
code = Path(filename).read_text()
module = cst.parse_module(code)
module_and_package = calculate_module_and_package(str(package), filename)

context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
scratch=scratch,
)
visitor = ClassDefVisitor(context=context)
visitor.transform_module(module)

# Queue logic
next_file = visitor.next_file(visited)
if next_file is not None:
queue.append(next_file)

missing_files = set(files) - visited
if not queue and missing_files:
queue.append(next(iter(missing_files)))

start_time = time.time()

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
"Framework :: Pydantic",
]
dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions", "mypy"]
dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions"]

[project.urls]
Documentation = "https://github.com/pydantic/bump-pydantic#readme"
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from libcst.testing.utils import UnitTest

from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor


@pytest.mark.skip(reason="The file needs to exists for the test to pass.")
class TestClassDefVisitor(UnitTest):
def add_default_none(self, file_path: str, code: str) -> cst.Module:
mod = MetadataWrapper(
Expand All @@ -25,8 +24,8 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module:
)
mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES)
context = CodemodContext(wrapper=mod)
classes = run_mypy_visitor(arg_files=[file_path])
context.scratch.update({CONTEXT_KEY: classes})
instance = ClassDefVisitor(context=context)
mod.visit(instance)

instance = AddDefaultNoneCommand(context=context) # type: ignore[assignment]
return mod.visit(instance)
Expand Down