Skip to content

Commit

Permalink
🐛 Resolve ClassDefs as soon as evaluated
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jul 4, 2023
1 parent b4d30f0 commit e5e9148
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 138 deletions.
8 changes: 1 addition & 7 deletions bump_pydantic/codemods/add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName

from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor
from bump_pydantic.markers.find_base_model import CONTEXT_KEY as BASE_MODEL_CONTEXT_KEY
from bump_pydantic.markers.find_base_model import find_base_model


class AddDefaultNoneCommand(VisitorBasedCodemodCommand):
Expand Down Expand Up @@ -51,7 +49,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore
if fqn.name in self.context.scratch[BASE_MODEL_CONTEXT_KEY]:
if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]:
self.inside_base_model = True

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
Expand Down Expand Up @@ -94,7 +92,6 @@ def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAss
from tempfile import TemporaryDirectory

from libcst.metadata import FullRepoManager
from rich.pretty import pprint

with TemporaryDirectory(dir=os.getcwd()) as tmpdir:
package_dir = f"{tmpdir}/package"
Expand Down Expand Up @@ -126,9 +123,6 @@ class Bar(Foo):
command = ClassDefVisitor(context=context)
mod = wrapper.visit(command)

find_base_model(scratch=context.scratch)
pprint(context.scratch)

command = AddDefaultNoneCommand(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
print(mod.code)
91 changes: 81 additions & 10 deletions bump_pydantic/codemods/class_def_visitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
"""
There are two objects in the visitor:
1. `base_model_cls` (Set[str]): Set of classes that are BaseModel based.
2. `cls` (Dict[str, Set[str]]): A dictionary mapping each class definition to a set of base classes.
`base_model_cls` accumulates on each iteration.
`cls` also accumulates on each iteration, but it's also partially solved:
1. Check if the module visited is a prefix of any `cls.keys()`.
1.1. If it is, and if any `base_model_cls` is found, remove from `cls`, and add to `base_model_cls`.
1.2. If it's not, it continues on the `cls`
"""
from __future__ import annotations

from collections import defaultdict
from typing import Set, cast

import libcst as cst
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
Expand All @@ -10,12 +22,17 @@
class ClassDefVisitor(VisitorBasedCodemodCommand):
METADATA_DEPENDENCIES = {FullyQualifiedNameProvider}

CONTEXT_KEY = "class_def_visitor"
BASE_MODEL_CONTEXT_KEY = "base_model_cls"
NO_BASE_MODEL_CONTEXT_KEY = "no_base_model_cls"
CLS_CONTEXT_KEY = "cls"

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
self.module_fqn: None | QualifiedName = None
self.context.scratch.setdefault(self.CONTEXT_KEY, defaultdict(set))

self.context.scratch.setdefault(self.BASE_MODEL_CONTEXT_KEY, {"pydantic.BaseModel", "pydantic.main.BaseModel"})
self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set())
self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set))

def visit_ClassDef(self, node: cst.ClassDef) -> None:
fqn_set = self.get_metadata(FullyQualifiedNameProvider, node)
Expand All @@ -24,15 +41,61 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore
for arg in node.bases:
base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value)

if not base_fqn_set:
return None
if not node.bases:
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name)

base_fqn: QualifiedName = next(iter(base_fqn_set)) # type: ignore
# NOTE: Should I use the name or the QualifiedName?
self.context.scratch[self.CONTEXT_KEY][fqn.name].add(base_fqn.name)
for arg in node.bases:
base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value)
base_fqn_set = base_fqn_set or set()

for base_fqn in cast(Set[QualifiedName], iter(base_fqn_set)): # type: ignore
if base_fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]:
self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(fqn.name)
elif base_fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]:
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name)

# In case we have the following scenario:
# class A(B): ...
# class B(BaseModel): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `A` as soon as we see `B` is a `BaseModel`.
if (
fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class)

# In case we have the following scenario:
# class A(B): ...
# class B(BaseModel): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`.
if (
fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class)

# In case we have the following scenario:
# class A(B): ...
# ...And B is not known.
# We want to make sure that B -> A is added to the `cls` context, so if we find B later,
# we can disambiguate.
if fqn.name not in (
*self.context.scratch[self.BASE_MODEL_CONTEXT_KEY],
*self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY],
):
for base_fqn in cast(Set[QualifiedName], base_fqn_set):
self.context.scratch[self.CLS_CONTEXT_KEY][base_fqn.name].add(fqn.name)

# TODO: Implement this if needed...
def next_file(self, visited: set[str]) -> str | None:
return None


if __name__ == "__main__":
Expand All @@ -59,6 +122,12 @@ class Foo(BaseModel):
class Bar(Foo):
b: str
class Potato:
...
class Spam(Potato):
...
foo = Foo(a="text")
foo.dict()
"""
Expand All @@ -70,4 +139,6 @@ class Bar(Foo):
context = CodemodContext(wrapper=wrapper)
command = ClassDefVisitor(context=context)
mod = wrapper.visit(command)
pprint(context.scratch[ClassDefVisitor.CONTEXT_KEY])
pprint(context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY])
pprint(context.scratch[ClassDefVisitor.NO_BASE_MODEL_CONTEXT_KEY])
pprint(context.scratch[ClassDefVisitor.CLS_CONTEXT_KEY])
69 changes: 47 additions & 22 deletions bump_pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Type, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Set, Type, TypeVar, Union

import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer
Expand All @@ -19,7 +19,6 @@
from bump_pydantic import __version__
from bump_pydantic.codemods import Rule, gather_codemods
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor
from bump_pydantic.markers.find_base_model import find_base_model

app = Typer(
help="Convert Pydantic from V1 to V2 ♻️",
Expand Down Expand Up @@ -65,14 +64,39 @@ def main(
scratch: dict[str, Any] = {}
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Looking for Pydantic Models...", total=len(files))
with multiprocessing.Pool() as pool:
partial_visit_class_def = functools.partial(visit_class_def, metadata_manager, package)
for local_scratch in pool.imap_unordered(partial_visit_class_def, files):
progress.advance(task)
for key, value in local_scratch.items():
scratch.setdefault(key, value).update(value)

find_base_model(scratch)
queue: List[str] = [files[0]]
visited: Set[str] = set()

while queue:
# Queue logic
filename = queue.pop()
visited.add(filename)
progress.advance(task)

# Visitor logic
code = Path(filename).read_text()
module = cst.parse_module(code)
module_and_package = calculate_module_and_package(str(package), filename)

context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
scratch=scratch,
)
visitor = ClassDefVisitor(context=context)
visitor.transform_module(module)

# Queue logic
next_file = visitor.next_file(visited)
if next_file is not None:
queue.append(next_file)

missing_files = set(files) - visited
if not queue and missing_files:
queue.append(next(iter(missing_files)))

start_time = time.time()

Expand Down Expand Up @@ -102,6 +126,20 @@ def main(
print(f"Refactored {len(modified)} files.")


def capture_exception(func: Callable[P, T]) -> Callable[P, Union[T, Iterable[str]]]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]:
try:
return func(*args, **kwargs)
except Exception as exc:
func_args = [repr(arg) for arg in args]
func_kwargs = [f"{key}={repr(value)}" for key, value in kwargs.items()]
return [f"{func.__name__}({', '.join(func_args + func_kwargs)})\n{exc}"]

return wrapper


@capture_exception
def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: str) -> Dict[str, Any]:
code = Path(filename).read_text()
module = cst.parse_module(code)
Expand All @@ -118,19 +156,6 @@ def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename:
return context.scratch


def capture_exception(func: Callable[P, T]) -> Callable[P, Union[T, Iterable[str]]]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]:
try:
return func(*args, **kwargs)
except Exception as exc:
func_args = [repr(arg) for arg in args]
func_kwargs = [f"{key}={repr(value)}" for key, value in kwargs.items()]
return [f"{func.__name__}({', '.join(func_args + func_kwargs)})\n{exc}"]

return wrapper


@capture_exception
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
Expand Down
Empty file removed bump_pydantic/markers/__init__.py
Empty file.
81 changes: 0 additions & 81 deletions bump_pydantic/markers/find_base_model.py

This file was deleted.

Loading

0 comments on commit e5e9148

Please sign in to comment.