Skip to content

Commit

Permalink
Merge pull request HypothesisWorks#3813 from tybug/json-defaultdict
Browse files Browse the repository at this point in the history
Vendor `dataclasses.asdict`
  • Loading branch information
Zac-HD authored Dec 16, 2023
2 parents fdfa7c3 + 3ee5f1a commit 867e56a
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 4 deletions.
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

This patch fixes a bug introduced in :ref:`version 6.92.0 <v6.92.0>`, where using :func:`~python:dataclasses.dataclass` with a :class:`~python:collections.defaultdict` field as a strategy argument would error.
45 changes: 45 additions & 0 deletions hypothesis-python/src/hypothesis/internal/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import codecs
import copy
import dataclasses
import inspect
import platform
import sys
Expand Down Expand Up @@ -188,3 +190,46 @@ def bad_django_TestCase(runner):
from hypothesis.extra.django._impl import HypothesisTestCase

return not isinstance(runner, HypothesisTestCase)


# see issue #3812
if sys.version_info[:2] < (3, 12):

def dataclass_asdict(obj, *, dict_factory=dict):
"""
A vendored variant of dataclasses.asdict. Includes the bugfix for
defaultdicts (cpython/32056) for all versions. See also issues/3812.
This should be removed whenever we drop support for 3.11. We can use the
standard dataclasses.asdict after that point.
"""
if not dataclasses._is_dataclass_instance(obj): # pragma: no cover
raise TypeError("asdict() should be called on dataclass instances")
return _asdict_inner(obj, dict_factory)

else: # pragma: no cover
dataclass_asdict = dataclasses.asdict


def _asdict_inner(obj, dict_factory):
if dataclasses._is_dataclass_instance(obj):
return dict_factory(
(f.name, _asdict_inner(getattr(obj, f.name), dict_factory))
for f in dataclasses.fields(obj)
)
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
return type(obj)(*[_asdict_inner(v, dict_factory) for v in obj])
elif isinstance(obj, (list, tuple)):
return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
elif isinstance(obj, dict):
if hasattr(type(obj), "default_factory"):
result = type(obj)(obj.default_factory)
for k, v in obj.items():
result[_asdict_inner(k, dict_factory)] = _asdict_inner(v, dict_factory)
return result
return type(obj)(
(_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
for k, v in obj.items()
)
else:
return copy.deepcopy(obj)
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from hypothesis.internal.conjecture.utils import calc_label_from_cls, check_sample
from hypothesis.internal.entropy import get_seeder_and_restorer
from hypothesis.internal.floats import float_of
from hypothesis.internal.observability import TESTCASE_CALLBACKS
from hypothesis.internal.reflection import (
define_function_signature,
get_pretty_function_description,
Expand Down Expand Up @@ -2103,7 +2104,9 @@ def draw(self, strategy: SearchStrategy[Ex], label: Any = None) -> Ex:
self.count += 1
printer = RepresentationPrinter(context=current_build_context())
desc = f"Draw {self.count}{'' if label is None else f' ({label})'}: "
self.conjecture_data._observability_args[desc] = to_jsonable(result)
if TESTCASE_CALLBACKS:
self.conjecture_data._observability_args[desc] = to_jsonable(result)

printer.text(desc)
printer.pretty(result)
note(printer.getvalue())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import attr

from hypothesis.internal.cache import LRUReusedCache
from hypothesis.internal.compat import dataclass_asdict
from hypothesis.internal.floats import float_to_int
from hypothesis.internal.reflection import proxies
from hypothesis.vendor.pretty import pretty
Expand Down Expand Up @@ -177,7 +178,7 @@ def to_jsonable(obj: object) -> object:
and dcs.is_dataclass(obj)
and not isinstance(obj, type)
):
return to_jsonable(dcs.asdict(obj))
return to_jsonable(dataclass_asdict(obj))
if attr.has(type(obj)):
return to_jsonable(attr.asdict(obj, recurse=False)) # type: ignore
if (pyd := sys.modules.get("pydantic")) and isinstance(obj, pyd.BaseModel):
Expand Down
24 changes: 23 additions & 1 deletion hypothesis-python/tests/cover/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import math
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from functools import partial
from inspect import Parameter, Signature, signature
from typing import ForwardRef, Optional, Union

import pytest

from hypothesis.internal.compat import ceil, floor, get_type_hints
from hypothesis.internal.compat import ceil, dataclass_asdict, floor, get_type_hints

floor_ceil_values = [
-10.7,
Expand Down Expand Up @@ -106,3 +107,24 @@ def func(a, b: int, *c: str, d: Optional[int] = None):
)
def test_get_hints_through_partial(pf, names):
assert set(get_type_hints(pf)) == set(names.split())


@dataclass
class FilledWithStuff:
a: list
b: tuple
c: namedtuple
d: dict
e: defaultdict


def test_dataclass_asdict():
ANamedTuple = namedtuple("ANamedTuple", ("with_some_field"))
obj = FilledWithStuff(a=[1], b=(2), c=ANamedTuple(3), d={4: 5}, e=defaultdict(list))
assert dataclass_asdict(obj) == {
"a": [1],
"b": (2),
"c": ANamedTuple(3),
"d": {4: 5},
"e": {},
}
31 changes: 30 additions & 1 deletion hypothesis-python/tests/cover/test_searchstrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

import dataclasses
import functools
from collections import namedtuple
from collections import defaultdict, namedtuple

import attr
import pytest

from hypothesis.errors import InvalidArgument
Expand Down Expand Up @@ -90,3 +92,30 @@ def test_flatmap_with_invalid_expand():

def test_jsonable():
assert isinstance(to_jsonable(object()), str)


@dataclasses.dataclass()
class HasDefaultDict:
x: defaultdict


@attr.s
class AttrsClass:
n = attr.ib()


def test_jsonable_defaultdict():
obj = HasDefaultDict(defaultdict(list))
obj.x["a"] = [42]
assert to_jsonable(obj) == {"x": {"a": [42]}}


def test_jsonable_attrs():
obj = AttrsClass(n=10)
assert to_jsonable(obj) == {"n": 10}


def test_jsonable_namedtuple():
Obj = namedtuple("Obj", ("x"))
obj = Obj(10)
assert to_jsonable(obj) == {"x": 10}

0 comments on commit 867e56a

Please sign in to comment.