Skip to content

Commit

Permalink
More test coverage (#603)
Browse files Browse the repository at this point in the history
* test for default disambiguator with None

* Clean up tuple structuring

* Add test for error in default disambiguator

* More BaseValidationErrors tests

* Add exception note grouping test

* Remove some dead code?

* disambiguators: test edge case
  • Loading branch information
Tinche authored Nov 24, 2024
1 parent dbe138b commit c3596e4
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 31 deletions.
29 changes: 10 additions & 19 deletions src/cattrs/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,15 +885,15 @@ def _structure_optional(self, obj, union):
# We can't actually have a Union of a Union, so this is safe.
return self._structure_func.dispatch(other)(obj, other)

def _structure_tuple(self, obj: Any, tup: type[T]) -> T:
def _structure_tuple(self, obj: Iterable, tup: type[T]) -> T:
"""Deal with structuring into a tuple."""
tup_params = None if tup in (Tuple, tuple) else tup.__args__
has_ellipsis = tup_params and tup_params[-1] is Ellipsis
if tup_params is None or (has_ellipsis and tup_params[0] in ANIES):
# Just a Tuple. (No generic information.)
return tuple(obj)
if has_ellipsis:
# We're dealing with a homogenous tuple, Tuple[int, ...]
# We're dealing with a homogenous tuple, tuple[int, ...]
tup_type = tup_params[0]
conv = self._structure_func.dispatch(tup_type)
if self.detailed_validation:
Expand All @@ -920,13 +920,6 @@ def _structure_tuple(self, obj: Any, tup: type[T]) -> T:

# We're dealing with a heterogenous tuple.
exp_len = len(tup_params)
try:
len_obj = len(obj)
except TypeError:
pass # most likely an unsized iterator, eg generator
else:
if len_obj > exp_len:
exp_len = len_obj
if self.detailed_validation:
errors = []
res = []
Expand All @@ -940,8 +933,8 @@ def _structure_tuple(self, obj: Any, tup: type[T]) -> T:
)
exc.__notes__ = [*getattr(exc, "__notes__", []), msg]
errors.append(exc)
if len(res) < exp_len:
problem = "Not enough" if len(res) < len(tup_params) else "Too many"
if len(obj) != exp_len:
problem = "Not enough" if len(res) < exp_len else "Too many"
exc = ValueError(f"{problem} values in {obj!r} to structure as {tup!r}")
msg = f"Structuring {tup}"
exc.__notes__ = [*getattr(exc, "__notes__", []), msg]
Expand All @@ -950,13 +943,12 @@ def _structure_tuple(self, obj: Any, tup: type[T]) -> T:
raise IterableValidationError(f"While structuring {tup!r}", errors, tup)
return tuple(res)

res = tuple(
if len(obj) != exp_len:
problem = "Not enough" if len(obj) < len(tup_params) else "Too many"
raise ValueError(f"{problem} values in {obj!r} to structure as {tup!r}")
return tuple(
[self._structure_func.dispatch(t)(e, t) for t, e in zip(tup_params, obj)]
)
if len(res) < exp_len:
problem = "Not enough" if len(res) < len(tup_params) else "Too many"
raise ValueError(f"{problem} values in {obj!r} to structure as {tup!r}")
return res

def _get_dis_func(
self,
Expand All @@ -971,11 +963,10 @@ def _get_dis_func(
# logic.
union_types = tuple(e for e in union_types if e is not NoneType)

# TODO: technically both disambiguators could support TypedDicts and
# dataclasses...
# TODO: technically both disambiguators could support TypedDicts too
if not all(has(get_origin(e) or e) for e in union_types):
raise StructureHandlerNotFoundError(
"Only unions of attrs classes supported "
"Only unions of attrs classes and dataclasses supported "
"currently. Register a structure hook manually.",
type_=union,
)
Expand Down
9 changes: 6 additions & 3 deletions src/cattrs/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections.abc import Sequence
from typing import Any, Optional, Union

from typing_extensions import Self

from cattrs._compat import ExceptionGroup


Expand All @@ -17,13 +20,13 @@ def __init__(self, message: str, type_: type) -> None:
class BaseValidationError(ExceptionGroup):
cl: type

def __new__(cls, message, excs, cl: type):
def __new__(cls, message: str, excs: Sequence[Exception], cl: type):
obj = super().__new__(cls, message, excs)
obj.cl = cl
return obj

def derive(self, excs):
return ClassValidationError(self.message, excs, self.cl)
def derive(self, excs: Sequence[Exception]) -> Self:
return self.__class__(self.message, excs, self.cl)


class IterableValidationNote(str):
Expand Down
4 changes: 0 additions & 4 deletions src/cattrs/gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,6 @@ def make_dict_unstructure_fn(
if is_generic(cl):
mapping = generate_mapping(cl, mapping)

for base in getattr(origin, "__orig_bases__", ()):
if is_generic(base) and not str(base).startswith("typing.Generic"):
mapping = generate_mapping(base, mapping)
break
if origin is not None:
cl = origin

Expand Down
53 changes: 52 additions & 1 deletion tests/test_disambiguators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from cattrs import Converter
from cattrs.disambiguators import create_default_dis_func, is_supported_union
from cattrs.errors import StructureHandlerNotFoundError
from cattrs.gen import make_dict_structure_fn, override

from .untyped import simple_classes
Expand Down Expand Up @@ -76,7 +77,40 @@ class H:

with pytest.raises(TypeError):
# The discriminator chosen does not actually help
create_default_dis_func(c, C, D)
create_default_dis_func(c, G, H)

# Not an attrs class or dataclass
class J:
i: int

with pytest.raises(StructureHandlerNotFoundError):
c.get_structure_hook(Union[A, J])

@define
class K:
x: Literal[2]

fn = create_default_dis_func(c, G, K)
with pytest.raises(ValueError):
# The input should be a mapping
fn([])

# A normal class with a required attribute
@define
class L:
b: str

# C and L both have a required attribute, so there will be no fallback.
fn = create_default_dis_func(c, C, L)
with pytest.raises(ValueError):
# We can't disambiguate based on this payload, so we error
fn({"c": 1})

# A has no attributes, so it ends up being the fallback
fn = create_default_dis_func(c, A, C)
with pytest.raises(ValueError):
# The input should be a mapping
fn([])


@given(simple_classes(defaults=False))
Expand Down Expand Up @@ -232,6 +266,23 @@ class D:
assert no_lits({"a": "a"}) is D


def test_default_none():
"""The default disambiguator can handle `None`."""
c = Converter()

@define
class A:
a: int

@define
class B:
b: str

hook = c.get_structure_hook(Union[A, B, None])
assert hook({"a": 1}, Union[A, B, None]) == A(1)
assert hook(None, Union[A, B, None]) is None


def test_converter_no_literals(converter: Converter):
"""A converter can be configured to skip literals."""

Expand Down
24 changes: 24 additions & 0 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,30 @@ class OuterStr:
assert genconverter.structure(raw, OuterStr) == OuterStr(Inner("1"))


def test_unstructure_generic_inheritance(genconverter):
"""Classes inheriting from generic classes work."""
genconverter.register_unstructure_hook(int, lambda v: v + 1)
genconverter.register_unstructure_hook(str, lambda v: str(int(v) + 1))

@define
class Parent(Generic[T]):
a: T

@define
class Child(Parent, Generic[T]):
b: str

instance = Child(1, "2")
assert genconverter.unstructure(instance, Child[int]) == {"a": 2, "b": "3"}

@define
class ExplicitChild(Parent[int]):
b: str

instance = ExplicitChild(1, "2")
assert genconverter.unstructure(instance, ExplicitChild) == {"a": 2, "b": "3"}


def test_unstructure_optional(genconverter):
"""Generics with optional fields work."""

Expand Down
8 changes: 4 additions & 4 deletions tests/test_unions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type, Union
from typing import Union

import pytest
from attrs import define
Expand All @@ -9,7 +9,7 @@


@pytest.mark.parametrize("cls", (BaseConverter, Converter))
def test_custom_union_toplevel_roundtrip(cls: Type[BaseConverter]):
def test_custom_union_toplevel_roundtrip(cls: type[BaseConverter]):
"""
Test custom code union handling.
Expand Down Expand Up @@ -42,7 +42,7 @@ class B:

@pytest.mark.skipif(not is_py310_plus, reason="3.10 union syntax")
@pytest.mark.parametrize("cls", (BaseConverter, Converter))
def test_310_custom_union_toplevel_roundtrip(cls: Type[BaseConverter]):
def test_310_custom_union_toplevel_roundtrip(cls: type[BaseConverter]):
"""
Test custom code union handling.
Expand Down Expand Up @@ -74,7 +74,7 @@ class B:


@pytest.mark.parametrize("cls", (BaseConverter, Converter))
def test_custom_union_clsfield_roundtrip(cls: Type[BaseConverter]):
def test_custom_union_clsfield_roundtrip(cls: type[BaseConverter]):
"""
Test custom code union handling.
Expand Down
56 changes: 56 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,59 @@ def test_notes_pickling():
assert note == "foo"
assert note.name == "name"
assert note.type is int


def test_error_derive():
"""Our ExceptionGroups should derive properly."""
c = Converter(detailed_validation=True)

@define
class Test:
a: int
b: str = field(validator=in_(["a", "b"]))
c: str

with pytest.raises(ClassValidationError) as exc:
c.structure({"a": "a", "b": "c"}, Test)

match, rest = exc.value.split(KeyError)

assert len(match.exceptions) == 1
assert len(rest.exceptions) == 1

assert match.cl == exc.value.cl
assert rest.cl == exc.value.cl


def test_iterable_note_grouping():
"""IterableValidationErrors can group their subexceptions by notes."""
exc1 = ValueError()
exc2 = KeyError()
exc3 = TypeError()

exc2.__notes__ = [note := IterableValidationNote("Test Note", 0, int)]
exc3.__notes__ = ["A string note"]

exc = IterableValidationError("Test", [exc1, exc2, exc3], list[int])

with_notes, without_notes = exc.group_exceptions()

assert with_notes == [(exc2, note)]
assert without_notes == [exc1, exc3]


def test_class_note_grouping():
"""ClassValidationErrors can group their subexceptions by notes."""
exc1 = ValueError()
exc2 = KeyError()
exc3 = TypeError()

exc2.__notes__ = [note := AttributeValidationNote("Test Note", "a", int)]
exc3.__notes__ = ["A string note"]

exc = ClassValidationError("Test", [exc1, exc2, exc3], int)

with_notes, without_notes = exc.group_exceptions()

assert with_notes == [(exc2, note)]
assert without_notes == [exc1, exc3]

0 comments on commit c3596e4

Please sign in to comment.