Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-34776: Fix dataclasses to support __future__ "annotations" mode #9518

Merged
merged 8 commits into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 53 additions & 34 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,23 +344,24 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
# worries about external callers.
if locals is None:
locals = {}
# __builtins__ may be the "builtins" module or
# the value of its "__dict__",
# so make sure "__builtins__" is the module.
if globals is not None and '__builtins__' not in globals:
globals['__builtins__'] = builtins
if 'BUILTINS' not in locals:
locals['BUILTINS'] = builtins
return_annotation = ''
if return_type is not MISSING:
locals['_return_type'] = return_type
return_annotation = '->_return_type'
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
body = '\n'.join(f' {b}' for b in body)

# Compute the text of the entire function.
txt = f'def {name}({args}){return_annotation}:\n{body}'
txt = f' def {name}({args}){return_annotation}:\n{body}'

exec(txt, globals, locals)
return locals[name]
local_vars = ', '.join(locals.keys())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: add a comment explaining this.

txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"

ns = {}
exec(txt, globals, ns)
return ns['__create_fn__'](**locals)


def _field_assign(frozen, name, value, self_name):
Expand All @@ -371,7 +372,7 @@ def _field_assign(frozen, name, value, self_name):
# self_name is what "self" is called in this function: don't
# hard-code "self", since that might be a field name.
if frozen:
return f'__builtins__.object.__setattr__({self_name},{name!r},{value})'
return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
return f'{self_name}.{name}={value}'


Expand Down Expand Up @@ -448,7 +449,7 @@ def _init_param(f):
return f'{f.name}:_type_{f.name}{default}'


def _init_fn(fields, frozen, has_post_init, self_name):
def _init_fn(fields, frozen, has_post_init, self_name, globals):
# fields contains both real fields and InitVar pseudo-fields.

# Make sure we don't have fields without defaults following fields
Expand All @@ -466,12 +467,15 @@ def _init_fn(fields, frozen, has_post_init, self_name):
raise TypeError(f'non-default argument {f.name!r} '
'follows default argument')

globals = {'MISSING': MISSING,
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
locals = {f'_type_{f.name}': f.type for f in fields}
locals.update({
'MISSING': MISSING,
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
})

body_lines = []
for f in fields:
line = _field_init(f, frozen, globals, self_name)
line = _field_init(f, frozen, locals, self_name)
# line is None means that this field doesn't require
# initialization (it's a pseudo-field). Just skip it.
if line:
Expand All @@ -487,7 +491,6 @@ def _init_fn(fields, frozen, has_post_init, self_name):
if not body_lines:
body_lines = ['pass']

locals = {f'_type_{f.name}': f.type for f in fields}
return _create_fn('__init__',
[self_name] + [_init_param(f) for f in fields if f.init],
body_lines,
Expand All @@ -496,19 +499,18 @@ def _init_fn(fields, frozen, has_post_init, self_name):
return_type=None)


def _repr_fn(fields):
def _repr_fn(fields, globals):
return _create_fn('__repr__',
('self',),
['return self.__class__.__qualname__ + f"(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
')"'])
')"'],
globals=globals)


def _frozen_get_del_attr(cls, fields):
# XXX: globals is modified on the first call to _create_fn, then
# the modified version is used in the second call. Is this okay?
globals = {'cls': cls,
def _frozen_get_del_attr(cls, fields, globals):
locals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError}
if fields:
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
Expand All @@ -520,17 +522,19 @@ def _frozen_get_del_attr(cls, fields):
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
f'super(cls, self).__setattr__(name, value)'),
locals=locals,
globals=globals),
_create_fn('__delattr__',
('self', 'name'),
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
f'super(cls, self).__delattr__(name)'),
locals=locals,
globals=globals),
)


def _cmp_fn(name, op, self_tuple, other_tuple):
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
# Create a comparison function. If the fields in the object are
# named 'x' and 'y', then self_tuple is the string
# '(self.x,self.y)' and other_tuple is the string
Expand All @@ -540,14 +544,16 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
('self', 'other'),
[ 'if other.__class__ is self.__class__:',
f' return {self_tuple}{op}{other_tuple}',
'return NotImplemented'])
'return NotImplemented'],
globals=globals)


def _hash_fn(fields):
def _hash_fn(fields, globals):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
('self',),
[f'return hash({self_tuple})'])
[f'return hash({self_tuple})'],
globals=globals)


def _is_classvar(a_type, typing):
Expand Down Expand Up @@ -719,14 +725,14 @@ def _set_new_attribute(cls, name, value):
# take. The common case is to do nothing, so instead of providing a
# function that is a no-op, use None to signify that.

def _hash_set_none(cls, fields):
def _hash_set_none(cls, fields, globals):
return None

def _hash_add(cls, fields):
def _hash_add(cls, fields, globals):
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
return _hash_fn(flds)
return _hash_fn(flds, globals)

def _hash_exception(cls, fields):
def _hash_exception(cls, fields, globals):
# Raise an exception.
raise TypeError(f'Cannot overwrite attribute __hash__ '
f'in class {cls.__name__}')
Expand Down Expand Up @@ -768,6 +774,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# is defined by the base class, which is found first.
fields = {}

if cls.__module__ in sys.modules:
globals = sys.modules[cls.__module__].__dict__
else:
# Theoretically this can happen if someone writes
# a custom string to cls.__module__. In which case
# such dataclass won't be fully introspectable
# (w.r.t. typing.get_type_hints) but will still function
# correctly.
globals = {}

setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
unsafe_hash, frozen))

Expand Down Expand Up @@ -877,6 +893,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# if possible.
'__dataclass_self__' if 'self' in fields
else 'self',
globals
1st1 marked this conversation as resolved.
Show resolved Hide resolved
))

# Get the fields as a list, and include only real fields. This is
Expand All @@ -885,7 +902,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):

if repr:
flds = [f for f in field_list if f.repr]
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))

if eq:
# Create _eq__ method. There's no need for a __ne__ method,
Expand All @@ -895,7 +912,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
other_tuple = _tuple_str('other', flds)
_set_new_attribute(cls, '__eq__',
_cmp_fn('__eq__', '==',
self_tuple, other_tuple))
self_tuple, other_tuple,
globals=globals))

if order:
# Create and set the ordering methods.
Expand All @@ -908,13 +926,14 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
('__ge__', '>='),
]:
if _set_new_attribute(cls, name,
_cmp_fn(name, op, self_tuple, other_tuple)):
_cmp_fn(name, op, self_tuple, other_tuple,
globals=globals)):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in class {cls.__name__}. Consider using '
'functools.total_ordering')

if frozen:
for fn in _frozen_get_del_attr(cls, field_list):
for fn in _frozen_get_del_attr(cls, field_list, globals):
if _set_new_attribute(cls, fn.__name__, fn):
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
Expand All @@ -927,7 +946,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
if hash_action:
# No need to call _set_new_attribute here, since by the time
# we're here the overwriting is unconditional.
cls.__hash__ = hash_action(cls, field_list)
cls.__hash__ = hash_action(cls, field_list, globals)

if not getattr(cls, '__doc__'):
# Create a class doc-string.
Expand Down
12 changes: 12 additions & 0 deletions Lib/test/dataclass_textanno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

import dataclasses


class Foo:
pass


@dataclasses.dataclass
class Bar:
foo: Foo
12 changes: 12 additions & 0 deletions Lib/test/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import unittest
from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
from typing import get_type_hints
from collections import deque, OrderedDict, namedtuple
from functools import total_ordering

Expand Down Expand Up @@ -2882,6 +2883,17 @@ def test_classvar_module_level_import(self):
# won't exist on the instance.
self.assertNotIn('not_iv4', c.__dict__)

def test_text_annotations(self):
from test import dataclass_textanno

self.assertEqual(
get_type_hints(dataclass_textanno.Bar),
{'foo': dataclass_textanno.Foo})
self.assertEqual(
get_type_hints(dataclass_textanno.Bar.__init__),
{'foo': dataclass_textanno.Foo,
'return': type(None)})


class TestMakeDataclass(unittest.TestCase):
def test_simple(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix dataclasses to support __future__ "annotations" mode
Copy link

@Vlad-Shcherbina Vlad-Shcherbina Jan 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, this problem is not directly related to __future__ annotations.
It's about postponed evaluation of type annotations in general, including forward references in pre-PEP-563 annotations.
Consider rewording along the lines of "Fix dataclasses to support forward references in type annotations".

Here is an example that does not use PEP 563:

from typing import get_type_hints
from dataclasses import dataclass

class T:
    pass

@dataclass()
class C2:
    x: 'T'

print(get_type_hints(C2.__init__))

Before your change:

Traceback (most recent call last):
  File ".\zzz.py", line 11, in <module>
    print(get_type_hints(C2.__init__))
  File "C:\Python37\lib\typing.py", line 1001, in get_type_hints
    value = _eval_type(value, globalns, localns)
  File "C:\Python37\lib\typing.py", line 260, in _eval_type
    return t._evaluate(globalns, localns)
  File "C:\Python37\lib\typing.py", line 464, in _evaluate
    eval(self.__forward_code__, globalns, localns),
  File "<string>", line 1, in <module>
NameError: name 'T' is not defined

After your change it works as expected:

{'x': <class '__main__.T'>, 'return': <class 'NoneType'>}