Skip to content

Commit

Permalink
Make dataclasses/attrs comparison recursive, fixes pytest-dev#4675
Browse files Browse the repository at this point in the history
  • Loading branch information
ibriquem authored and nicoddemus committed Jun 9, 2020
1 parent 3de85a9 commit e2e7f15
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 21 deletions.
1 change: 1 addition & 0 deletions changelog/4675.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make dataclasses/attrs comparison recursive.
48 changes: 27 additions & 21 deletions src/_pytest/assertion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,26 +148,7 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
explanation = None
try:
if op == "==":
if istext(left) and istext(right):
explanation = _diff_text(left, right, verbose)
else:
if issequence(left) and issequence(right):
explanation = _compare_eq_sequence(left, right, verbose)
elif isset(left) and isset(right):
explanation = _compare_eq_set(left, right, verbose)
elif isdict(left) and isdict(right):
explanation = _compare_eq_dict(left, right, verbose)
elif type(left) == type(right) and (isdatacls(left) or isattrs(left)):
type_fn = (isdatacls, isattrs)
explanation = _compare_eq_cls(left, right, verbose, type_fn)
elif verbose > 0:
explanation = _compare_eq_verbose(left, right)
if isiterable(left) and isiterable(right):
expl = _compare_eq_iterable(left, right, verbose)
if explanation is not None:
explanation.extend(expl)
else:
explanation = expl
explanation = _compare_eq_any(left, right, verbose)
elif op == "not in":
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
Expand All @@ -187,6 +168,28 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
return [summary] + explanation


def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
explanation = [] # type: List[str]
if istext(left) and istext(right):
explanation = _diff_text(left, right, verbose)
else:
if issequence(left) and issequence(right):
explanation = _compare_eq_sequence(left, right, verbose)
elif isset(left) and isset(right):
explanation = _compare_eq_set(left, right, verbose)
elif isdict(left) and isdict(right):
explanation = _compare_eq_dict(left, right, verbose)
elif type(left) == type(right) and (isdatacls(left) or isattrs(left)):
type_fn = (isdatacls, isattrs)
explanation = _compare_eq_cls(left, right, verbose, type_fn)
elif verbose > 0:
explanation = _compare_eq_verbose(left, right)
if isiterable(left) and isiterable(right):
expl = _compare_eq_iterable(left, right, verbose)
explanation.extend(expl)
return explanation


def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
"""Return the explanation for the diff between text.
Expand Down Expand Up @@ -439,7 +442,10 @@ def _compare_eq_cls(
explanation += ["Differing attributes:"]
for field in diff:
explanation += [
("%s: %r != %r") % (field, getattr(left, field), getattr(right, field))
("%s: %r != %r") % (field, getattr(left, field), getattr(right, field)),
"",
"Drill down into differing attribute %s:" % field,
*_compare_eq_any(getattr(left, field), getattr(right, field), verbose),
]
return explanation

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from dataclasses import field


@dataclass
class SimpleDataObject:
field_a: int = field()
field_b: int = field()


@dataclass
class ComplexDataObject2:
field_a: SimpleDataObject = field()
field_b: SimpleDataObject = field()


@dataclass
class ComplexDataObject:
field_a: SimpleDataObject = field()
field_b: ComplexDataObject2 = field()


def test_recursive_dataclasses():

left = ComplexDataObject(
SimpleDataObject(1, "b"),
ComplexDataObject2(SimpleDataObject(1, "b"), SimpleDataObject(2, "c"),),
)
right = ComplexDataObject(
SimpleDataObject(1, "b"),
ComplexDataObject2(SimpleDataObject(1, "b"), SimpleDataObject(3, "c"),),
)

assert left == right
80 changes: 80 additions & 0 deletions testing/test_assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,48 @@ def test_dataclasses(self, testdir):
"*Omitting 1 identical items, use -vv to show*",
"*Differing attributes:*",
"*field_b: 'b' != 'c'*",
"*- c*",
"*+ b*",
]
)

@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
def test_recursive_dataclasses(self, testdir):
p = testdir.copy_example("dataclasses/test_compare_recursive_dataclasses.py")
result = testdir.runpytest(p)
result.assert_outcomes(failed=1, passed=0)
result.stdout.fnmatch_lines(
[
"*Omitting 1 identical items, use -vv to show*",
"*Differing attributes:*",
"*field_b: ComplexDataObject2(*SimpleDataObject(field_a=2, field_b='c')) != ComplexDataObject2(*SimpleDataObject(field_a=3, field_b='c'))*", # noqa
"*Drill down into differing attribute field_b:*",
"*Omitting 1 identical items, use -vv to show*",
"*Differing attributes:*",
"*Full output truncated*",
]
)

@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
def test_recursive_dataclasses_verbose(self, testdir):
p = testdir.copy_example("dataclasses/test_compare_recursive_dataclasses.py")
result = testdir.runpytest(p, "-vv")
result.assert_outcomes(failed=1, passed=0)
result.stdout.fnmatch_lines(
[
"*Matching attributes:*",
"*['field_a']*",
"*Differing attributes:*",
"*field_b: ComplexDataObject2(*SimpleDataObject(field_a=2, field_b='c')) != ComplexDataObject2(*SimpleDataObject(field_a=3, field_b='c'))*", # noqa
"*Matching attributes:*",
"*['field_a']*",
"*Differing attributes:*",
"*field_b: SimpleDataObject(field_a=2, field_b='c') "
"!= SimpleDataObject(field_a=3, field_b='c')*",
"*Matching attributes:*",
"*['field_b']*",
"*Differing attributes:*",
"*field_a: 2 != 3",
]
)

Expand Down Expand Up @@ -832,6 +874,44 @@ class SimpleDataObject:
for line in lines[1:]:
assert "field_a" not in line

def test_attrs_recursive(self) -> None:
@attr.s
class OtherDataObject:
field_c = attr.ib()
field_d = attr.ib()

@attr.s
class SimpleDataObject:
field_a = attr.ib()
field_b = attr.ib()

left = SimpleDataObject(OtherDataObject(1, "a"), "b")
right = SimpleDataObject(OtherDataObject(1, "b"), "b")

lines = callequal(left, right)
assert "Matching attributes" not in lines
for line in lines[1:]:
assert "field_b:" not in line
assert "field_c:" not in line

def test_attrs_recursive_verbose(self) -> None:
@attr.s
class OtherDataObject:
field_c = attr.ib()
field_d = attr.ib()

@attr.s
class SimpleDataObject:
field_a = attr.ib()
field_b = attr.ib()

left = SimpleDataObject(OtherDataObject(1, "a"), "b")
right = SimpleDataObject(OtherDataObject(1, "b"), "b")

lines = callequal(left, right)
assert "field_d: 'a' != 'b'" in lines
print("\n".join(lines))

def test_attrs_verbose(self) -> None:
@attr.s
class SimpleDataObject:
Expand Down

0 comments on commit e2e7f15

Please sign in to comment.