Skip to content

Commit

Permalink
Improve gen_test_serializable for more flexibility on error checking
Browse files Browse the repository at this point in the history
  • Loading branch information
LiteApplication committed May 25, 2024
1 parent d6f2921 commit 98ee687
Show file tree
Hide file tree
Showing 12 changed files with 312 additions and 263 deletions.
91 changes: 34 additions & 57 deletions changes/285.internal.1.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,34 @@
- Changed the way `Serializable` classes are handled:

Here is how a basic `Serializable` class looks like:

```python
@final
@dataclass
class ToyClass(Serializable):
"""Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`."""

a: int
b: str | int

@override
def __attrs_post_init__(self):
"""Initialize the object."""
if isinstance(self.b, int):
self.b = str(self.b)

super().__attrs_post_init__() # This will call validate()

@override
def serialize_to(self, buf: Buffer):
"""Write the object to a buffer."""
self.b = cast(str, self.b) # Handled by the __attrs_post_init__ method
buf.write_varint(self.a)
buf.write_utf(self.b)

@classmethod
@override
def deserialize(cls, buf: Buffer) -> ToyClass:
"""Deserialize the object from a buffer."""
a = buf.read_varint()
if a == 0:
raise ZeroDivisionError("a must be non-zero")
b = buf.read_utf()
return cls(a, b)

@override
def validate(self) -> None:
"""Validate the object's attributes."""
if self.a == 0:
raise ZeroDivisionError("a must be non-zero")
if len(self.b) > 10:
raise ValueError("b must be less than 10 characters")

```

The `Serializable` class implement the following methods:

- `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer.
- `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer.

And the following optional methods:

- `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid.
- `__attrs_post_init__() -> None`: Initializes the object. Call `super().__attrs_post_init__()` to validate the object.
- **Function**: `gen_serializable_test`
- Generates tests for serializable classes, covering serialization, deserialization, validation, and error handling.
- **Parameters**:
- `context` (dict): Context to add the test functions to (usually `globals()`).
- `cls` (type): The serializable class to test.
- `fields` (list): Tuples of field names and types of the serializable class.
- `serialize_deserialize` (list, optional): Tuples for testing successful serialization/deserialization.
- `validation_fail` (list, optional): Tuples for testing validation failures with expected exceptions.
- `deserialization_fail` (list, optional): Tuples for testing deserialization failures with expected exceptions.
- **Note**: Implement `__eq__` in the class for accurate comparison.

- The `gen_serializable_test` function generates a test class with the following tests:

.. literalinclude:: /../tests/mcproto/utils/test_serializable.py
:language: python
:start-after: # region Test ToyClass
:end-before: # endregion Test ToyClass

- The generated test class will have the following tests:

```python
class TestGenToyClass:
def test_serialization(self):
# 3 subtests for the cases 1, 2, 3 (serialize_deserialize)

def test_deserialization(self):
# 3 subtests for the cases 1, 2, 3 (serialize_deserialize)

def test_validation(self):
# 3 subtests for the cases 4, 5, 6 (validation_fail)

def test_exceptions(self):
# 3 subtests for the cases 7, 8, 9 (deserialization_fail)
```
69 changes: 16 additions & 53 deletions changes/285.internal.2.md
Original file line number Diff line number Diff line change
@@ -1,53 +1,16 @@
- Added a test generator for `Serializable` classes:

The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments:

- `context`: The dictionary containing the context in which the generated test class will be placed (e.g. `globals()`).
> Dictionary updates must reflect in the context. This is the case for `globals()` but implementation-specific for `locals()`.
- `cls`: The `Serializable` class to generate tests for.
- `fields`: A list of fields where the test values will be placed.

> In the example above, the `ToyClass` class has two fields: `a` and `b`.
- `test_data`: A list of tuples containing either:
- `((field1_value, field2_value, ...), expected_bytes)`: The values of the fields and the expected serialized bytes. This needs to work both ways, i.e. `cls(field1_value, field2_value, ...) == cls.deserialize(expected_bytes).`
- `((field1_value, field2_value, ...), exception)`: The values of the fields and the expected exception when validating the object.
- `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize.

The `gen_serializable_test` function generates a test class with the following tests:

```python
gen_serializable_test(
context=globals(),
cls=ToyClass,
fields=[("a", int), ("b", str)],
test_data=[
((1, "hello"), b"\x01\x05hello"),
((2, "world"), b"\x02\x05world"),
((3, 1234567890), b"\x03\x0a1234567890"),
((0, "hello"), ZeroDivisionError("a must be non-zero")), # With an error message
((1, "hello world"), ValueError), # No error message
((1, 12345678900), ValueError("b must be less than 10 .*")), # With an error message and regex
(ZeroDivisionError, b"\x00"),
(ZeroDivisionError, b"\x01\x05hello"),
(IOError, b"\x01"),
],
)
```

The generated test class will have the following tests:

```python
class TestGenToyClass:
def test_serialization(self):
# 2 subtests for the cases 1 and 2

def test_deserialization(self):
# 2 subtests for the cases 1 and 2

def test_validation(self):
# 2 subtests for the cases 3 and 4

def test_exceptions(self):
# 2 subtests for the cases 5 and 6
```
- **Class**: `Serializable`
- Base class for types that should be (de)serializable into/from `mcproto.Buffer` data.
- **Methods**:
- `__attrs_post_init__()`: Runs validation after object initialization, override to define custom behavior.
- `serialize() -> Buffer`: Returns the object as a `Buffer`.
- `serialize_to(buf: Buffer)`: Abstract method to write the object to a `Buffer`.
- `validate()`: Validates the object's attributes; can be overridden for custom validation.
- `deserialize(cls, buf: Buffer) -> Self`: Abstract method to construct the object from a `Buffer`.
- **Note**: Use the `dataclass` decorator when adding parameters to subclasses.

- Exemple:

.. literalinclude:: /../tests/mcproto/utils/test_serializable.py
:language: python
:start-after: # region ToyClass
:end-before: # endregion ToyClass
2 changes: 1 addition & 1 deletion mcproto/utils/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __new__(cls: type[Self], *a: Any, **kw: Any) -> Self:
class Serializable(ABC):
"""Base class for any type that should be (de)serializable into/from :class:`~mcproto.Buffer` data.
Any class that inherits from this class and adds parameters should use the :func:`~mcproto.utils.abc.dataclass`
Any class that inherits from this class and adds parameters should use the :func:`~mcproto.utils.abc.define`
decorator.
"""

Expand Down
128 changes: 82 additions & 46 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import asyncio
import inspect
import re
import unittest.mock
from collections.abc import Callable, Coroutine
from typing import Any, Generic, TypeVar
from typing import Any, Generic, NamedTuple, TypeVar
from typing_extensions import TypeGuard

import pytest
Expand All @@ -17,7 +18,14 @@
P = ParamSpec("P")
T_Mock = TypeVar("T_Mock", bound=unittest.mock.Mock)

__all__ = ["synchronize", "SynchronizedMixin", "UnpropagatingMockMixin", "CustomMockMixin", "gen_serializable_test"]
__all__ = [
"synchronize",
"SynchronizedMixin",
"UnpropagatingMockMixin",
"CustomMockMixin",
"gen_serializable_test",
"TestExc",
]


def synchronize(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
Expand Down Expand Up @@ -169,27 +177,42 @@ def __init__(self, **kwargs):
super().__init__(spec_set=self.spec_set, **kwargs) # type: ignore # Mixin class, this __init__ is valid


def isexception(obj: object) -> TypeGuard[type[Exception] | Exception]:
def isexception(obj: object) -> TypeGuard[type[Exception] | TestExc]:
"""Check if the object is an exception."""
return (isinstance(obj, type) and issubclass(obj, Exception)) or isinstance(obj, Exception)
return (isinstance(obj, type) and issubclass(obj, Exception)) or isinstance(obj, TestExc)


def get_exception(exception: type[Exception] | Exception) -> tuple[type[Exception], str | None]:
"""Get the exception type and message."""
if isinstance(exception, type):
return exception, None
return type(exception), str(exception)
class TestExc(NamedTuple):
"""Named tuple to check if an exception is raised with a specific message.
:param exception: The exception type.
:param match: If specified, a string containing a regular expression, or a regular expression object, that is
tested against the string representation of the exception using :func:`re.search`.
:param kwargs: The keyword arguments passed to the exception.
If :attr:`kwargs` is not None, the exception instance will need to have the same attributes with the same values.
"""

exception: type[Exception] | tuple[type[Exception], ...]
match: str | re.Pattern[str] | None = None
kwargs: dict[str, Any] | None = None

@classmethod
def from_exception(cls, exception: type[Exception] | tuple[type[Exception], ...] | TestExc) -> TestExc:
"""Create a :class:`TestExc` from an exception, does nothing if the object is already a :class:`TestExc`."""
if isinstance(exception, TestExc):
return exception
return cls(exception)


def gen_serializable_test(
context: dict[str, Any],
cls: type[Serializable],
fields: list[tuple[str, type | str]],
test_data: list[
tuple[tuple[Any, ...], bytes]
| tuple[tuple[Any, ...], type[Exception] | Exception]
| tuple[type[Exception] | Exception, bytes]
],
serialize_deserialize: list[tuple[tuple[Any, ...], bytes]] | None = None,
validation_fail: list[tuple[tuple[Any, ...], type[Exception] | TestExc]] | None = None,
deserialization_fail: list[tuple[bytes, type[Exception] | TestExc]] | None = None,
):
"""Generate tests for a serializable class.
Expand All @@ -199,15 +222,14 @@ def gen_serializable_test(
:param context: The context to add the test functions to. This is usually `globals()`.
:param cls: The serializable class to test.
:param fields: A list of tuples containing the field names and types of the serializable class.
:param test_data: A list of test data. Each element is a tuple containing either:
- A tuple of parameters to pass to the serializable class constructor and the expected bytes after
serialization
- A tuple of parameters to pass to the serializable class constructor and the expected exception during
validation
- An exception to expect during deserialization and the bytes to deserialize
Exception can be either a type or an instance of an exception, in the latter case the exception message will
be used to match the exception, and can contain regex patterns.
:param serialize_deserialize: A list of tuples containing:
- The tuple representing the arguments to pass to the :class:`mcproto.utils.abc.Serializable` class
- The expected bytes
:param validation_fail: A list of tuples containing the arguments to pass to the
:class:`mcproto.utils.abc.Serializable` class and the expected exception, either as is or wrapped in a
:class:`TestExc` object.
:param deserialization_fail: A list of tuples containing the bytes to pass to the :meth:`deserialize` method of the
class and the expected exception, either as is or wrapped in a :class:`TestExc` object.
Example usage:
Expand All @@ -221,28 +243,30 @@ def gen_serializable_test(
.. note::
The test cases will use :meth:`__eq__` to compare the objects, so make sure to implement it in the class if
you are not using a dataclass.
you are not using the autogenerated method from :func:`attrs.define`.
"""
# Separate the test data into parameters for each test function
# This holds the parameters for the serialization and deserialization tests
parameters: list[tuple[dict[str, Any], bytes]] = []

# This holds the parameters for the validation tests
validation_fail: list[tuple[dict[str, Any], type[Exception] | Exception]] = []
validation_fail_kw: list[tuple[dict[str, Any], TestExc]] = []

for data, exp_bytes in [] if serialize_deserialize is None else serialize_deserialize:
kwargs = dict(zip([f[0] for f in fields], data))
parameters.append((kwargs, exp_bytes))

# This holds the parameters for the deserialization error tests
deserialization_fail: list[tuple[bytes, type[Exception] | Exception]] = []
for data, exc in [] if validation_fail is None else validation_fail:
kwargs = dict(zip([f[0] for f in fields], data))
exc_wrapped = TestExc.from_exception(exc)
validation_fail_kw.append((kwargs, exc_wrapped))

for data_or_exc, expected_bytes_or_exc in test_data:
if isinstance(data_or_exc, tuple) and isinstance(expected_bytes_or_exc, bytes):
kwargs = dict(zip([f[0] for f in fields], data_or_exc))
parameters.append((kwargs, expected_bytes_or_exc))
elif isexception(data_or_exc) and isinstance(expected_bytes_or_exc, bytes):
deserialization_fail.append((expected_bytes_or_exc, data_or_exc))
elif isinstance(data_or_exc, tuple) and isexception(expected_bytes_or_exc):
kwargs = dict(zip([f[0] for f in fields], data_or_exc))
validation_fail.append((kwargs, expected_bytes_or_exc))
# Just make sure that the exceptions are wrapped in TestExc
deserialization_fail = (
[]
if deserialization_fail is None
else [(data, TestExc.from_exception(exc)) for data, exc in deserialization_fail]
)

def generate_name(param: dict[str, Any] | bytes, i: int) -> str:
"""Generate a name for the test case."""
Expand Down Expand Up @@ -301,33 +325,45 @@ def test_deserialization(self, kwargs: dict[str, Any], expected_bytes: bytes):

@pytest.mark.parametrize(
("kwargs", "exc"),
validation_fail,
ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail)),
validation_fail_kw,
ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail_kw)),
)
def test_validation(self, kwargs: dict[str, Any], exc: type[Exception] | Exception):
def test_validation(self, kwargs: dict[str, Any], exc: TestExc):
"""Test validation of the object."""
exc, msg = get_exception(exc)
with pytest.raises(exc, match=msg):
with pytest.raises(exc.exception, match=exc.match) as exc_info:
cls(**kwargs)

# If exc.kwargs is not None, check them against the exception
if exc.kwargs is not None:
for key, value in exc.kwargs.items():
assert value == getattr(
exc_info.value, key
), f"{key}: {value!r} != {getattr(exc_info.value, key)!r}"

@pytest.mark.parametrize(
("content", "exc"),
deserialization_fail,
ids=tuple(generate_name(content, i) for i, (content, _) in enumerate(deserialization_fail)),
)
def test_deserialization_error(self, content: bytes, exc: type[Exception] | Exception):
def test_deserialization_error(self, content: bytes, exc: TestExc):
"""Test deserialization error handling."""
buf = Buffer(content)
exc, msg = get_exception(exc)
with pytest.raises(exc, match=msg):
with pytest.raises(exc.exception, match=exc.match) as exc_info:
cls.deserialize(buf)

# If exc.kwargs is not None, check them against the exception
if exc.kwargs is not None:
for key, value in exc.kwargs.items():
assert value == getattr(
exc_info.value, key
), f"{key}: {value!r} != {getattr(exc_info.value, key)!r}"

if len(parameters) == 0:
# If there are no serialization tests, remove them
del TestClass.test_serialization
del TestClass.test_deserialization

if len(validation_fail) == 0:
if len(validation_fail_kw) == 0:
# If there are no validation tests, remove them
del TestClass.test_validation

Expand Down
4 changes: 3 additions & 1 deletion tests/mcproto/packets/handshaking/test_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
("server_port", int),
("next_state", NextState),
],
test_data=[
serialize_deserialize=[
(
(757, "mc.aircs.racing", 25565, NextState.LOGIN),
bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd02"),
Expand All @@ -29,6 +29,8 @@
(757, "hypixel.net", 25565, NextState.STATUS),
bytes.fromhex("f5050b6879706978656c2e6e657463dd01"),
),
],
validation_fail=[
# Invalid next state
((757, "localhost", 25565, 3), ValueError),
],
Expand Down
Loading

0 comments on commit 98ee687

Please sign in to comment.