Skip to content

Commit

Permalink
refactor function
Browse files Browse the repository at this point in the history
  • Loading branch information
aasiffaizal committed Dec 2, 2024
1 parent 3adb4bd commit 01de55d
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions bump_pydantic/codemods/class_def_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ def __init__(self, context: CodemodContext) -> None:
self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set())
self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set))

def _recursively_disambiguate(
self, classname: str, context_set: set[str], ambiguous_classes: dict[str, set[str]]
) -> None:
if classname in context_set and classname in ambiguous_classes:
for child_classname in ambiguous_classes.pop(classname):
def _recursively_disambiguate(self, classname: str, context_set: set[str]) -> None:
if classname in context_set and classname in self.context.scratch[self.CLS_CONTEXT_KEY]:
for child_classname in self.context.scratch[self.CLS_CONTEXT_KEY].pop(classname):
context_set.add(child_classname)
self._recursively_disambiguate(child_classname, context_set, ambiguous_classes)
self._recursively_disambiguate(child_classname, context_set)

def visit_ClassDef(self, node: cst.ClassDef) -> None:
fqn_set = self.get_metadata(FullyQualifiedNameProvider, node)
Expand Down Expand Up @@ -75,9 +73,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
# class C: ...
# We want to disambiguate `A` and then `ChildA` as soon as we see `B` is a `BaseModel`.
# We recursively add child classes to self.BASE_MODEL_CONTEXT_KEY.
self._recursively_disambiguate(
fqn.name, self.context.scratch[self.BASE_MODEL_CONTEXT_KEY], self.context.scratch[self.CLS_CONTEXT_KEY]
)
self._recursively_disambiguate(fqn.name, self.context.scratch[self.BASE_MODEL_CONTEXT_KEY])

# In case we have the following scenario:
# class A(B): ...
Expand All @@ -87,11 +83,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
# class C: ...
# We want to disambiguate `D` and then `E` as soon as we see `C` is NOT a `BaseModel`.
# We recursively add child classes to self.NO_BASE_MODEL_CONTEXT_KEY.
self._recursively_disambiguate(
fqn.name,
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY],
self.context.scratch[self.CLS_CONTEXT_KEY],
)
self._recursively_disambiguate(fqn.name, self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY])

# In case we have the following scenario:
# class A(B): ...
Expand Down

0 comments on commit 01de55d

Please sign in to comment.