Skip to content

Commit

Permalink
Fix: Class disambiguation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
aasiffaizal committed Dec 2, 2024
1 parent 253e91d commit 3adb4bd
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 20 deletions.
36 changes: 22 additions & 14 deletions bump_pydantic/codemods/class_def_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def __init__(self, context: CodemodContext) -> None:
self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set())
self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set))

def _recursively_disambiguate(
self, classname: str, context_set: set[str], ambiguous_classes: dict[str, set[str]]
) -> None:
if classname in context_set and classname in ambiguous_classes:
for child_classname in ambiguous_classes.pop(classname):
context_set.add(child_classname)
self._recursively_disambiguate(child_classname, context_set, ambiguous_classes)

def visit_ClassDef(self, node: cst.ClassDef) -> None:
fqn_set = self.get_metadata(FullyQualifiedNameProvider, node)

Expand All @@ -60,30 +68,30 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name)

# In case we have the following scenario:
# class ChildA(A):
# class A(B): ...
# class B(BaseModel): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `A` as soon as we see `B` is a `BaseModel`.
if (
fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class)
# We want to disambiguate `A` and then `ChildA` as soon as we see `B` is a `BaseModel`.
# We recursively add child classes to self.BASE_MODEL_CONTEXT_KEY.
self._recursively_disambiguate(
fqn.name, self.context.scratch[self.BASE_MODEL_CONTEXT_KEY], self.context.scratch[self.CLS_CONTEXT_KEY]
)

# In case we have the following scenario:
# class A(B): ...
# class B(BaseModel): ...
# class E(D): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`.
if (
fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class)
# We want to disambiguate `D` and then `E` as soon as we see `C` is NOT a `BaseModel`.
# We recursively add child classes to self.NO_BASE_MODEL_CONTEXT_KEY.
self._recursively_disambiguate(
fqn.name,
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY],
self.context.scratch[self.CLS_CONTEXT_KEY],
)

# In case we have the following scenario:
# class A(B): ...
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .field import cases as generic_model_cases
from .folder_inside_folder import cases as folder_inside_folder_cases
from .is_base_model import cases as is_base_model_cases
from .nested_inheritance import cases as nested_inheritance_cases
from .replace_validator import cases as replace_validator_cases
from .root_model import cases as root_model_cases
from .unicode import cases as unicode_cases
Expand All @@ -22,6 +23,7 @@
*base_settings_cases,
*add_none_cases,
*is_base_model_cases,
*nested_inheritance_cases,
*replace_validator_cases,
*config_to_model_cases,
*root_model_cases,
Expand Down
77 changes: 77 additions & 0 deletions tests/integration/cases/nested_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from ..case import Case
from ..file import File
from ..folder import Folder

cases = [
Case(
name="Nested Inheritance",
source=Folder(
"nested_inheritance",
File("__init__.py", content=[]),
File(
"bar.py",
content=[
"from .foo import Foo",
"",
"",
"class Bar(Foo):",
" b: str | None",
],
),
File(
"baz.py",
content=[
"from .bar import Bar",
"",
"",
"class Baz(Bar):",
" c: str | None",
],
),
File(
"foo.py",
content=[
"from pydantic import BaseModel",
"",
"",
"class Foo(BaseModel):",
" a: str | None",
],
),
),
expected=Folder(
"nested_inheritance",
File("__init__.py", content=[]),
File(
"bar.py",
content=[
"from .foo import Foo",
"",
"",
"class Bar(Foo):",
" b: str | None = None",
],
),
File(
"baz.py",
content=[
"from .bar import Bar",
"",
"",
"class Baz(Bar):",
" c: str | None = None",
],
),
File(
"foo.py",
content=[
"from pydantic import BaseModel",
"",
"",
"class Foo(BaseModel):",
" a: str | None = None",
],
),
),
)
]
6 changes: 3 additions & 3 deletions tests/unit/test_add_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def add_annotations(self, file_path: str, code: str) -> cst.Module:
mod = MetadataWrapper(
parse_module(CodemodTest.make_fixture_data(code)),
cache={
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get(
file_path, ""
)
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(
Path(""), [file_path], timeout=None
).get(file_path, "")
},
)
mod.resolve_many(AddAnnotationsCommand.METADATA_DEPENDENCIES)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module:
mod = MetadataWrapper(
parse_module(CodemodTest.make_fixture_data(code)),
cache={
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get(
file_path, ""
)
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(
Path(""), [file_path], timeout=None
).get(file_path, "")
},
)
mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES)
Expand Down

0 comments on commit 3adb4bd

Please sign in to comment.