Skip to content

Commit

Permalink
Display class decorators. Fixes #4860
Browse files Browse the repository at this point in the history
  • Loading branch information
tjprescott committed Mar 7, 2023
1 parent a973600 commit 2439409
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 0 deletions.
1 change: 1 addition & 0 deletions packages/python-packages/api-stub-generator/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Version 0.3.7 (Unreleased)
Fix incorrect type annotation.
Update to follow best practices for accessing '__annotations__'.
Fixed issue where class decorators were not displayed.

## Version 0.3.6 (2022-10-27)
Suppressed unwanted base class methods in DPG libraries.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ def _handle_class_variable(self, child_obj, name, *, type_string=None, value=Non
)
)

def _parse_decorators_from_class(self, class_obj):
try:
class_node = astroid.parse(inspect.getsource(class_obj)).body[0]
class_decorators = class_node.decorators.nodes
self.decorators = [f"@{x.as_string(preserve_quotes=True)}" for x in class_decorators]
except:
self.decorators = []

def _parse_functions_from_class(self, class_obj) -> List[astroid.FunctionDef]:
try:
class_node = astroid.parse(inspect.getsource(class_obj)).body[0]
Expand Down Expand Up @@ -179,6 +187,8 @@ def _inspect(self):

is_typeddict = hasattr(self.obj, "__required_keys__") or hasattr(self.obj, "__optional_keys__")

self._parse_decorators_from_class(self.obj)

# find members in node
# enums with duplicate values are screened out by "getmembers" so
# we must rely on __members__ instead.
Expand Down Expand Up @@ -297,6 +307,11 @@ def generate_tokens(self, apiview):
"""
logging.info(f"Processing class {self.namespace_id}")
# Generate class name line
for decorator in self.decorators:
apiview.add_whitespace()
apiview.add_keyword(decorator)
apiview.add_newline()

apiview.add_whitespace()
apiview.add_line_marker(self.namespace_id)
apiview.add_keyword("class", False, True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class DataClassNode(ClassNode):

def __init__(self, *, name, namespace, parent_node, obj, pkg_root_namespace):
super().__init__(name=name, namespace=namespace, parent_node=parent_node, obj=obj, pkg_root_namespace=pkg_root_namespace)
self.decorators = [x for x in self.decorators if not x.startswith("@dataclass")]
# explicitly set synthesized __init__ return type to None to fix test flakiness
for child in self.child_nodes:
if child.display_name == "__init__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from apistubgentest.models import (
AliasNewType,
AliasUnion,
ClassWithDecorators,
FakeTypedDict,
FakeObject,
GenericStack,
Expand Down Expand Up @@ -39,6 +40,23 @@ class TestClassParsing:

pkg_namespace = "apistubgentest.models"

def test_class_with_decorators(self):
obj = ClassWithDecorators
class_node = ClassNode(name=obj.__name__, namespace=obj.__name__, parent_node=None, obj=obj, pkg_root_namespace=self.pkg_namespace)
actuals = _render_lines(_tokenize(class_node))
expected = [
"@add_id",
"class ClassWithDecorators:",
"",
"def __init__(",
"self, ",
"id, ",
"*args, ",
"**kwargs",
")",
]
_check_all(actuals, expected, obj)

def test_typed_dict_class(self):
obj = FakeTypedDict
class_node = ClassNode(name=obj.__name__, namespace=obj.__name__, parent_node=None, obj=obj, pkg_root_namespace=self.pkg_namespace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._models import (
AliasNewType,
AliasUnion,
ClassWithDecorators,
DocstringClass,
FakeError,
FakeObject,
Expand Down Expand Up @@ -39,6 +40,7 @@
__all__ = (
"AliasNewType",
"AliasUnion",
"ClassWithDecorators",
"DataClassSimple",
"DataClassWithFields",
"DataClassDynamic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ def wrapper(*args, **kwargs):
return wrapper
return decorator

def get_id(self):
return self.__id

def add_id(cls):
cls_init = cls.__init__

def __init__(self, id, *args, **kwargs):
self.__id = id
self.get_id = get_id
cls_init(self, *args, **kwargs)

cls.__init__ = __init__
return cls

@add_id
class ClassWithDecorators:
pass

class PublicCaseInsensitiveEnumMeta(EnumMeta):
def __getitem__(self, name: str):
Expand Down

0 comments on commit 2439409

Please sign in to comment.