Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-90562: Improve zero argument support for super() in dataclasses when slots=True #124692

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Prevent user-defined code execution during attribute scanning
Bobronium committed Sep 28, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit d0173d8e39bf6d2b91995aba774bf744bb7ce193
24 changes: 20 additions & 4 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1237,6 +1237,25 @@ def _update_func_cell_for__class__(f, oldcls, newcls):
return False


def _safe_get_attributes(obj):
# we should avoid triggering any user-defined code
# when inspecting attributes if possible

# look for __slots__ descriptors
type_dict = object.__getattribute__(type(obj), "__dict__")
for value in type_dict.values():
if isinstance(value, types.MemberDescriptorType):
yield value.__get__(obj)

instance_dict_descriptor = type_dict.get("__dict__", None)
if not isinstance(instance_dict_descriptor, types.GetSetDescriptorType):
# __dict__ is either not present, or redefined by user
# as custom descriptor, either way, we're done here
return

yield from instance_dict_descriptor.__get__(obj).values()


def _find_inner_functions(obj, seen=None, depth=0):
if seen is None:
seen = set()
@@ -1252,10 +1271,7 @@ def _find_inner_functions(obj, seen=None, depth=0):
if depth > 2:
return None

for attr in dir(obj):
value = getattr(obj, attr, None)
if value is None:
continue
for value in _safe_get_attributes(obj):
if isinstance(value, types.FunctionType):
yield inspect.unwrap(value)
return
46 changes: 0 additions & 46 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -5140,52 +5140,6 @@ def foo(self, value):

self.assertEqual(A().foo, "bar")

def test_pure_functions_preferred_to_custom_descriptors(self):
class CustomDescriptor:
def __init__(self, f):
self._wrapper = partial(f, value="bar")

def __get__(self, instance, owner):
return self._wrapper(instance)

def __dir__(self):
raise RuntimeError("Never should be accessed")

class B:
def foo(self, value):
return value

with self.assertRaises(RuntimeError) as context:
@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(self, value): ...

self.assertEqual(context.exception.args, ("Never should be accessed",))

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(self, value):
return super().foo(value)

@property
def bar(self):
return super()

self.assertEqual(A().foo, "bar")

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(self, value):
return super().foo(value)

def bar(self):
return super()

self.assertEqual(A().foo, "bar")

def test_custom_too_nested_descriptor(self):
class UnnecessaryNestedWrapper:
def __init__(self, wrapper):