Skip to content

Commit

Permalink
bpo-34776: Fix dataclasses to support __future__ "annotations" mode (p…
Browse files Browse the repository at this point in the history
…ythonGH-9518)

(cherry picked from commit d219cc4)

Co-authored-by: Yury Selivanov <[email protected]>
  • Loading branch information
1st1 authored and miss-islington committed Dec 9, 2019
1 parent a0078d9 commit c346808
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 34 deletions.
87 changes: 53 additions & 34 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,23 +368,24 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
# worries about external callers.
if locals is None:
locals = {}
# __builtins__ may be the "builtins" module or
# the value of its "__dict__",
# so make sure "__builtins__" is the module.
if globals is not None and '__builtins__' not in globals:
globals['__builtins__'] = builtins
if 'BUILTINS' not in locals:
locals['BUILTINS'] = builtins
return_annotation = ''
if return_type is not MISSING:
locals['_return_type'] = return_type
return_annotation = '->_return_type'
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
body = '\n'.join(f' {b}' for b in body)

# Compute the text of the entire function.
txt = f'def {name}({args}){return_annotation}:\n{body}'
txt = f' def {name}({args}){return_annotation}:\n{body}'

exec(txt, globals, locals)
return locals[name]
local_vars = ', '.join(locals.keys())
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"

ns = {}
exec(txt, globals, ns)
return ns['__create_fn__'](**locals)


def _field_assign(frozen, name, value, self_name):
Expand All @@ -395,7 +396,7 @@ def _field_assign(frozen, name, value, self_name):
# self_name is what "self" is called in this function: don't
# hard-code "self", since that might be a field name.
if frozen:
return f'__builtins__.object.__setattr__({self_name},{name!r},{value})'
return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
return f'{self_name}.{name}={value}'


Expand Down Expand Up @@ -472,7 +473,7 @@ def _init_param(f):
return f'{f.name}:_type_{f.name}{default}'


def _init_fn(fields, frozen, has_post_init, self_name):
def _init_fn(fields, frozen, has_post_init, self_name, globals):
# fields contains both real fields and InitVar pseudo-fields.

# Make sure we don't have fields without defaults following fields
Expand All @@ -490,12 +491,15 @@ def _init_fn(fields, frozen, has_post_init, self_name):
raise TypeError(f'non-default argument {f.name!r} '
'follows default argument')

globals = {'MISSING': MISSING,
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
locals = {f'_type_{f.name}': f.type for f in fields}
locals.update({
'MISSING': MISSING,
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
})

body_lines = []
for f in fields:
line = _field_init(f, frozen, globals, self_name)
line = _field_init(f, frozen, locals, self_name)
# line is None means that this field doesn't require
# initialization (it's a pseudo-field). Just skip it.
if line:
Expand All @@ -511,7 +515,6 @@ def _init_fn(fields, frozen, has_post_init, self_name):
if not body_lines:
body_lines = ['pass']

locals = {f'_type_{f.name}': f.type for f in fields}
return _create_fn('__init__',
[self_name] + [_init_param(f) for f in fields if f.init],
body_lines,
Expand All @@ -520,20 +523,19 @@ def _init_fn(fields, frozen, has_post_init, self_name):
return_type=None)


def _repr_fn(fields):
def _repr_fn(fields, globals):
fn = _create_fn('__repr__',
('self',),
['return self.__class__.__qualname__ + f"(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
')"'])
')"'],
globals=globals)
return _recursive_repr(fn)


def _frozen_get_del_attr(cls, fields):
# XXX: globals is modified on the first call to _create_fn, then
# the modified version is used in the second call. Is this okay?
globals = {'cls': cls,
def _frozen_get_del_attr(cls, fields, globals):
locals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError}
if fields:
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
Expand All @@ -545,17 +547,19 @@ def _frozen_get_del_attr(cls, fields):
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
f'super(cls, self).__setattr__(name, value)'),
locals=locals,
globals=globals),
_create_fn('__delattr__',
('self', 'name'),
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
f'super(cls, self).__delattr__(name)'),
locals=locals,
globals=globals),
)


def _cmp_fn(name, op, self_tuple, other_tuple):
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
# Create a comparison function. If the fields in the object are
# named 'x' and 'y', then self_tuple is the string
# '(self.x,self.y)' and other_tuple is the string
Expand All @@ -565,14 +569,16 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
('self', 'other'),
[ 'if other.__class__ is self.__class__:',
f' return {self_tuple}{op}{other_tuple}',
'return NotImplemented'])
'return NotImplemented'],
globals=globals)


def _hash_fn(fields):
def _hash_fn(fields, globals):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
('self',),
[f'return hash({self_tuple})'])
[f'return hash({self_tuple})'],
globals=globals)


def _is_classvar(a_type, typing):
Expand Down Expand Up @@ -744,14 +750,14 @@ def _set_new_attribute(cls, name, value):
# take. The common case is to do nothing, so instead of providing a
# function that is a no-op, use None to signify that.

def _hash_set_none(cls, fields):
def _hash_set_none(cls, fields, globals):
return None

def _hash_add(cls, fields):
def _hash_add(cls, fields, globals):
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
return _hash_fn(flds)
return _hash_fn(flds, globals)

def _hash_exception(cls, fields):
def _hash_exception(cls, fields, globals):
# Raise an exception.
raise TypeError(f'Cannot overwrite attribute __hash__ '
f'in class {cls.__name__}')
Expand Down Expand Up @@ -793,6 +799,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# is defined by the base class, which is found first.
fields = {}

if cls.__module__ in sys.modules:
globals = sys.modules[cls.__module__].__dict__
else:
# Theoretically this can happen if someone writes
# a custom string to cls.__module__. In which case
# such dataclass won't be fully introspectable
# (w.r.t. typing.get_type_hints) but will still function
# correctly.
globals = {}

setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
unsafe_hash, frozen))

Expand Down Expand Up @@ -902,6 +918,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# if possible.
'__dataclass_self__' if 'self' in fields
else 'self',
globals,
))

# Get the fields as a list, and include only real fields. This is
Expand All @@ -910,7 +927,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):

if repr:
flds = [f for f in field_list if f.repr]
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))

if eq:
# Create _eq__ method. There's no need for a __ne__ method,
Expand All @@ -920,7 +937,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
other_tuple = _tuple_str('other', flds)
_set_new_attribute(cls, '__eq__',
_cmp_fn('__eq__', '==',
self_tuple, other_tuple))
self_tuple, other_tuple,
globals=globals))

if order:
# Create and set the ordering methods.
Expand All @@ -933,13 +951,14 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
('__ge__', '>='),
]:
if _set_new_attribute(cls, name,
_cmp_fn(name, op, self_tuple, other_tuple)):
_cmp_fn(name, op, self_tuple, other_tuple,
globals=globals)):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in class {cls.__name__}. Consider using '
'functools.total_ordering')

if frozen:
for fn in _frozen_get_del_attr(cls, field_list):
for fn in _frozen_get_del_attr(cls, field_list, globals):
if _set_new_attribute(cls, fn.__name__, fn):
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
Expand All @@ -952,7 +971,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
if hash_action:
# No need to call _set_new_attribute here, since by the time
# we're here the overwriting is unconditional.
cls.__hash__ = hash_action(cls, field_list)
cls.__hash__ = hash_action(cls, field_list, globals)

if not getattr(cls, '__doc__'):
# Create a class doc-string.
Expand Down
12 changes: 12 additions & 0 deletions Lib/test/dataclass_textanno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

import dataclasses


class Foo:
pass


@dataclasses.dataclass
class Bar:
foo: Foo
12 changes: 12 additions & 0 deletions Lib/test/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import unittest
from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
from typing import get_type_hints
from collections import deque, OrderedDict, namedtuple
from functools import total_ordering

Expand Down Expand Up @@ -2918,6 +2919,17 @@ def test_classvar_module_level_import(self):
# won't exist on the instance.
self.assertNotIn('not_iv4', c.__dict__)

def test_text_annotations(self):
from test import dataclass_textanno

self.assertEqual(
get_type_hints(dataclass_textanno.Bar),
{'foo': dataclass_textanno.Foo})
self.assertEqual(
get_type_hints(dataclass_textanno.Bar.__init__),
{'foo': dataclass_textanno.Foo,
'return': type(None)})


class TestMakeDataclass(unittest.TestCase):
def test_simple(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix dataclasses to support forward references in type annotations

0 comments on commit c346808

Please sign in to comment.