From 2e98b961839fc09c8fee2d04de09219c96edc4ec Mon Sep 17 00:00:00 2001 From: Alexis Rossfelder Date: Mon, 6 May 2024 20:27:44 +0200 Subject: [PATCH] Add more types * Angle * Bitset and FixedBitset * Position * Vec3 * Quaternion * Slot * Identifier * TextComponent - Rename ChatMessage to JSONTextComponent --- mcproto/packets/login/login.py | 6 +- mcproto/types/angle.py | 75 ++++++++ mcproto/types/bitset.py | 193 +++++++++++++++++++ mcproto/types/chat.py | 121 ++++++++++-- mcproto/types/identifier.py | 51 ++++++ mcproto/types/quaternion.py | 91 +++++++++ mcproto/types/slot.py | 78 ++++++++ mcproto/types/vec3.py | 164 +++++++++++++++++ tests/mcproto/packets/login/test_login.py | 6 +- tests/mcproto/types/test_angle.py | 68 +++++++ tests/mcproto/types/test_bitset.py | 204 +++++++++++++++++++++ tests/mcproto/types/test_chat.py | 62 ++++++- tests/mcproto/types/test_identifier.py | 19 ++ tests/mcproto/types/test_quaternion.py | 121 ++++++++++++ tests/mcproto/types/test_slot.py | 32 ++++ tests/mcproto/types/test_vec3.py | 214 ++++++++++++++++++++++ 16 files changed, 1476 insertions(+), 29 deletions(-) create mode 100644 mcproto/types/angle.py create mode 100644 mcproto/types/bitset.py create mode 100644 mcproto/types/identifier.py create mode 100644 mcproto/types/quaternion.py create mode 100644 mcproto/types/slot.py create mode 100644 mcproto/types/vec3.py create mode 100644 tests/mcproto/types/test_angle.py create mode 100644 tests/mcproto/types/test_bitset.py create mode 100644 tests/mcproto/types/test_identifier.py create mode 100644 tests/mcproto/types/test_quaternion.py create mode 100644 tests/mcproto/types/test_slot.py create mode 100644 tests/mcproto/types/test_vec3.py diff --git a/mcproto/packets/login/login.py b/mcproto/packets/login/login.py index f5708f03..31f9563d 100644 --- a/mcproto/packets/login/login.py +++ b/mcproto/packets/login/login.py @@ -9,7 +9,7 @@ from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket -from mcproto.types.chat import ChatMessage +from mcproto.types.chat import JSONTextComponent from mcproto.types.uuid import UUID from mcproto.utils.abc import dataclass @@ -176,7 +176,7 @@ class LoginDisconnect(ClientBoundPacket): PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - reason: ChatMessage + reason: JSONTextComponent @override def serialize_to(self, buf: Buffer) -> None: @@ -185,7 +185,7 @@ def serialize_to(self, buf: Buffer) -> None: @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - reason = ChatMessage.deserialize(buf) + reason = JSONTextComponent.deserialize(buf) return cls(reason) diff --git a/mcproto/types/angle.py b/mcproto/types/angle.py new file mode 100644 index 00000000..d1346f76 --- /dev/null +++ b/mcproto/types/angle.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import final +import math + +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.abc import MCType, dataclass +from mcproto.types.vec3 import Vec3 + + +@dataclass +@final +class Angle(MCType): + """Represents a rotation angle for an entity. + + :param value: The angle value in 1/256th of a full rotation. + """ + + angle: int + + @override + def serialize_to(self, buf: Buffer) -> None: + payload = int(self.angle) & 0xFF + # Convert to a signed byte. + if payload & 0x80: + payload -= 1 << 8 + buf.write_value(StructFormat.BYTE, payload) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Angle: + payload = buf.read_value(StructFormat.BYTE) + return cls(angle=int(payload * 360 / 256)) + + @override + def validate(self) -> None: + """Constrain the angle to the range [0, 256).""" + self.angle %= 256 + + def in_direction(self, base: Vec3, distance: float) -> Vec3: + """Calculate the position in the direction of the angle in the xz-plane. + + 0/256: Positive z-axis + 64/-192: Negative x-axis + 128/-128: Negative z-axis + 192/-64: Positive x-axis + + :param base: The base position. + :param distance: The distance to move. + :return: The new position. + """ + x = base.x - distance * math.sin(self.to_radians()) + z = base.z + distance * math.cos(self.to_radians()) + return Vec3(x=x, y=base.y, z=z) + + @classmethod + def from_degrees(cls, degrees: float) -> Angle: + """Create an angle from degrees.""" + return cls(angle=int(degrees * 256 / 360)) + + def to_degrees(self) -> float: + """Return the angle in degrees.""" + return self.angle * 360 / 256 + + @classmethod + def from_radians(cls, radians: float) -> Angle: + """Create an angle from radians.""" + return cls(angle=int(math.degrees(radians) * 256 / 360)) + + def to_radians(self) -> float: + """Return the angle in radians.""" + return math.radians(self.angle * 360 / 256) diff --git a/mcproto/types/bitset.py b/mcproto/types/bitset.py new file mode 100644 index 00000000..2b0927d5 --- /dev/null +++ b/mcproto/types/bitset.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import math + +from typing import ClassVar +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.abc import MCType, dataclass + + +@dataclass +class FixedBitset(MCType): + """Represents a fixed-size bitset.""" + + __n: ClassVar[int] = -1 + + data: bytearray + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write(bytes(self.data)) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> FixedBitset: + data = buf.read(math.ceil(cls.__n / 8)) + return cls(data=data) + + @override + def validate(self) -> None: + """Validate the bitset.""" + if self.__n == -1: + raise ValueError("Bitset size is not defined.") + if len(self.data) != math.ceil(self.__n / 8): + raise ValueError(f"Bitset size is {len(self.data) * 8}, expected {self.__n}.") + + @staticmethod + def of_size(n: int) -> type[FixedBitset]: + """Return a new FixedBitset class with the given size. + + :param n: The size of the bitset. + """ + new_class = type(f"FixedBitset{n}", (FixedBitset,), {}) + new_class.__n = n + return new_class + + @classmethod + def from_int(cls, n: int) -> FixedBitset: + """Return a new FixedBitset with the given integer value. + + :param n: The integer value. + """ + if cls.__n == -1: + raise ValueError("Bitset size is not defined.") + if n < 0: + # Manually compute two's complement + n = -n + data = bytearray(n.to_bytes(math.ceil(cls.__n / 8), "big")) + for i in range(len(data)): + data[i] ^= 0xFF + data[-1] += 1 + else: + data = bytearray(n.to_bytes(math.ceil(cls.__n / 8), "big")) + return cls(data=data) + + def __setitem__(self, index: int, value: bool) -> None: + byte_index = index // 8 + bit_index = index % 8 + if value: + self.data[byte_index] |= 1 << bit_index + else: + self.data[byte_index] &= ~(1 << bit_index) + + def __getitem__(self, index: int) -> bool: + byte_index = index // 8 + bit_index = index % 8 + return bool(self.data[byte_index] & (1 << bit_index)) + + def __len__(self) -> int: + return self.__n + + def __and__(self, other: FixedBitset) -> FixedBitset: + if self.__n != other.__n: + raise ValueError("Bitsets must have the same size.") + return type(self)(data=bytearray(a & b for a, b in zip(self.data, other.data))) + + def __or__(self, other: FixedBitset) -> FixedBitset: + if self.__n != other.__n: + raise ValueError("Bitsets must have the same size.") + return type(self)(data=bytearray(a | b for a, b in zip(self.data, other.data))) + + def __xor__(self, other: FixedBitset) -> FixedBitset: + if self.__n != other.__n: + raise ValueError("Bitsets must have the same size.") + return type(self)(data=bytearray(a ^ b for a, b in zip(self.data, other.data))) + + def __invert__(self) -> FixedBitset: + return type(self)(data=bytearray(~a & 0xFF for a in self.data)) + + def __bytes__(self) -> bytes: + return bytes(self.data) + + @override + def __eq__(self, value: object) -> bool: + if not isinstance(value, FixedBitset): + return NotImplemented + return self.data == value.data and self.__n == value.__n + + +@dataclass +class Bitset(MCType): + """Represents a lenght-prefixed bitset with a variable size. + + :param size: The number of longs in the array representing the bitset. + :param data: The bits of the bitset. + """ + + size: int + data: list[int] + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write_varint(self.size) + for i in range(self.size): + buf.write_value(StructFormat.LONGLONG, self.data[i]) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Bitset: + size = buf.read_varint() + if buf.remaining < size * 8: + raise IOError("Not enough data to read bitset.") + data = [buf.read_value(StructFormat.LONGLONG) for _ in range(size)] + return cls(size=size, data=data) + + @override + def validate(self) -> None: + """Validate the bitset.""" + if self.size != len(self.data): + raise ValueError(f"Bitset size is {self.size}, expected {len(self.data)}.") + + @classmethod + def from_int(cls, n: int, size: int | None = None) -> Bitset: + """Return a new Bitset with the given integer value. + + :param n: The integer value. + :param size: The number of longs in the array representing the bitset. + """ + if size is None: + size = math.ceil(float(n.bit_length()) / 64.0) + data = [n >> (i * 64) & 0xFFFFFFFFFFFFFFFF for i in range(size)] + return cls(size=size, data=data) + + def __getitem__(self, index: int) -> bool: + byte_index = index // 64 + bit_index = index % 64 + + return bool(self.data[byte_index] & (1 << bit_index)) + + def __setitem__(self, index: int, value: bool) -> None: + byte_index = index // 64 + bit_index = index % 64 + + if value: + self.data[byte_index] |= 1 << bit_index + else: + self.data[byte_index] &= ~(1 << bit_index) + + def __len__(self) -> int: + return self.size * 64 + + def __and__(self, other: Bitset) -> Bitset: + if self.size != other.size: + raise ValueError("Bitsets must have the same size.") + return Bitset(size=self.size, data=[a & b for a, b in zip(self.data, other.data)]) + + def __or__(self, other: Bitset) -> Bitset: + if self.size != other.size: + raise ValueError("Bitsets must have the same size.") + return Bitset(size=self.size, data=[a | b for a, b in zip(self.data, other.data)]) + + def __xor__(self, other: Bitset) -> Bitset: + if self.size != other.size: + raise ValueError("Bitsets must have the same size.") + return Bitset(size=self.size, data=[a ^ b for a, b in zip(self.data, other.data)]) + + def __invert__(self) -> Bitset: + return Bitset(size=self.size, data=[~a for a in self.data]) + + def __bytes__(self) -> bytes: + return b"".join(a.to_bytes(8, "big") for a in self.data) diff --git a/mcproto/types/chat.py b/mcproto/types/chat.py index fe978631..e674bb53 100644 --- a/mcproto/types/chat.py +++ b/mcproto/types/chat.py @@ -1,26 +1,27 @@ from __future__ import annotations import json -from typing import TypedDict, Union, final +from typing import Tuple, TypedDict, Union, cast, final from typing_extensions import Self, TypeAlias, override from mcproto.buffer import Buffer from mcproto.types.abc import MCType, dataclass +from mcproto.types.nbt import NBTag, StringNBT, ByteNBT, FromObjectSchema, FromObjectType __all__ = [ - "ChatMessage", - "RawChatMessage", - "RawChatMessageDict", + "TextComponent", + "RawTextComponentDict", + "RawTextComponent", ] -class RawChatMessageDict(TypedDict, total=False): +class RawTextComponentDict(TypedDict, total=False): """Dictionary structure of JSON chat messages when serialized.""" text: str translation: str - extra: list[RawChatMessageDict] + extra: list[RawTextComponentDict] color: str bold: bool @@ -30,22 +31,27 @@ class RawChatMessageDict(TypedDict, total=False): obfuscated: bool -RawChatMessage: TypeAlias = Union[RawChatMessageDict, "list[RawChatMessageDict]", str] +RawTextComponent: TypeAlias = Union[RawTextComponentDict, "list[RawTextComponentDict]", str] + + +def _deep_copy_dict(data: RawTextComponentDict) -> RawTextComponentDict: + """Deep copy a dictionary structure.""" + json_data = json.dumps(data) + return json.loads(json_data) @dataclass -@final -class ChatMessage(MCType): +class JSONTextComponent(MCType): """Minecraft chat message representation.""" - raw: RawChatMessage + raw: RawTextComponent - def as_dict(self) -> RawChatMessageDict: + def as_dict(self) -> RawTextComponentDict: """Convert received ``raw`` into a stadard :class:`dict` form.""" if isinstance(self.raw, list): - return RawChatMessageDict(extra=self.raw) + return RawTextComponentDict(extra=self.raw) if isinstance(self.raw, str): - return RawChatMessageDict(text=self.raw) + return RawTextComponentDict(text=self.raw) if isinstance(self.raw, dict): # pyright: ignore[reportUnnecessaryIsInstance] return self.raw @@ -61,7 +67,7 @@ def __eq__(self, other: object) -> bool: a chat message that appears the same, but was representing in a different way will fail this equality check. """ - if not isinstance(other, ChatMessage): + if not isinstance(other, JSONTextComponent): return NotImplemented return self.raw == other.raw @@ -93,3 +99,90 @@ def validate(self) -> None: raise AttributeError( "Expected each element in `raw` to have either 'text' or 'extra' key, got neither" ) + + +@final +class TextComponent(JSONTextComponent): + """Minecraft chat message representation. + + This class provides the new chat message format using NBT data instead of JSON. + """ + + __slots__ = () + + @override + def serialize_to(self, buf: Buffer) -> None: + payload = self._convert_to_dict(self.raw) + payload = cast(FromObjectType, payload) # We just ensured that the data is converted to the correct format + nbt = NBTag.from_object(data=payload, schema=self._build_schema()) # This will validate the data + nbt.serialize_to(buf) + + @override + @classmethod + def deserialize(cls, buf: Buffer, /) -> Self: + nbt = NBTag.deserialize(buf, with_name=False) + # Ensure the schema is compatible with the one defined in the class + data, schema = cast(Tuple[FromObjectType, FromObjectSchema], nbt.to_object(include_schema=True)) + + def recursive_validate(recieved: FromObjectSchema, expected: FromObjectSchema) -> None: + if isinstance(recieved, dict): + if not isinstance(expected, dict): + raise TypeError(f"Expected {expected!r}, got dict") + for key, value in recieved.items(): + if key not in expected: + raise KeyError(f"Unexpected key {key!r}") + recursive_validate(value, expected[key]) + elif isinstance(recieved, list): + if not isinstance(expected, list): + raise TypeError(f"Expected {expected!r}, got list") + for rec in recieved: + recursive_validate(rec, expected[0]) + elif recieved != expected: + raise TypeError(f"Expected {expected!r}, got {recieved!r}") + + recursive_validate(schema, cls._build_schema()) + data = cast(RawTextComponentDict, data) # We just ensured that the data is compatible with the schema + return cls(data) + + @staticmethod + def _build_schema() -> FromObjectSchema: + """Build the schema for the NBT data representing the chat message.""" + schema: FromObjectSchema = { + "text": StringNBT, + "color": StringNBT, + "bold": ByteNBT, + "italic": ByteNBT, + "underlined": ByteNBT, + "strikethrough": ByteNBT, + "obfuscated": ByteNBT, + } + # Allow the schema to be recursive + schema["extra"] = [schema] # type: ignore + return schema + + @staticmethod + def _convert_to_dict(msg: RawTextComponent) -> RawTextComponentDict: + """Convert a chat message into a dictionary representation.""" + if isinstance(msg, str): + return {"text": msg} + + if isinstance(msg, list): + main = TextComponent._convert_to_dict(msg[0]) + if "extra" not in main: + main["extra"] = [] + for elem in msg[1:]: + main["extra"].append(TextComponent._convert_to_dict(elem)) + return main + + if isinstance(msg, dict): # pyright: ignore[reportUnnecessaryIsInstance] + return _deep_copy_dict(msg) # We don't want to modify self.raw for example + + raise TypeError(f"Unexpected type {msg!r} ({msg.__class__.__name__})") # pragma: no cover + + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, TextComponent): + return NotImplemented + self_dict = self._convert_to_dict(self.raw) + other_dict = self._convert_to_dict(other.raw) + return self_dict == other_dict diff --git a/mcproto/types/identifier.py b/mcproto/types/identifier.py new file mode 100644 index 00000000..cdd8deca --- /dev/null +++ b/mcproto/types/identifier.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import re +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.types.abc import MCType, dataclass + + +@dataclass +class Identifier(MCType): + """A Minecraft identifier. + + :param namespace: The namespace of the identifier. + :param path: The path of the identifier. + """ + + namespace: str + path: str + + @override + def serialize_to(self, buf: Buffer) -> None: + data = f"{self.namespace}:{self.path}" + buf.write_utf(data) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Identifier: + data = buf.read_utf() + namespace, path = data.split(":", 1) + return cls(namespace, path) + + @override + def validate(self) -> None: + if len(self.namespace) == 0: + raise ValueError("Namespace cannot be empty.") + + if len(self.path) == 0: + raise ValueError("Path cannot be empty.") + + if len(self.namespace) + len(self.path) + 1 > 32767: + raise ValueError("Identifier is too long.") + + namespace_regex = r"^[a-z0-9-_]+$" + path_regex = r"^[a-z0-9-_/]+$" + + if not re.match(namespace_regex, self.namespace): + raise ValueError(f"Namespace must match regex {namespace_regex}, got {self.namespace!r}") + + if not re.match(path_regex, self.path): + raise ValueError(f"Path must match regex {path_regex}, got {self.path!r}") diff --git a/mcproto/types/quaternion.py b/mcproto/types/quaternion.py new file mode 100644 index 00000000..c2f85ae7 --- /dev/null +++ b/mcproto/types/quaternion.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import math + +from typing import final +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.abc import MCType, dataclass + + +@dataclass +@final +class Quaternion(MCType): + """Represents a quaternion. + + :param x: The x component. + :param y: The y component. + :param z: The z component. + :param w: The w component. + """ + + x: float | int + y: float | int + z: float | int + w: float | int + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write_value(StructFormat.FLOAT, self.x) + buf.write_value(StructFormat.FLOAT, self.y) + buf.write_value(StructFormat.FLOAT, self.z) + buf.write_value(StructFormat.FLOAT, self.w) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Quaternion: + x = buf.read_value(StructFormat.FLOAT) + y = buf.read_value(StructFormat.FLOAT) + z = buf.read_value(StructFormat.FLOAT) + w = buf.read_value(StructFormat.FLOAT) + return cls(x=x, y=y, z=z, w=w) + + @override + def validate(self) -> None: + """Validate the quaternion's components.""" + # Check that the components are floats or integers. + if not all(isinstance(comp, (float, int)) for comp in (self.x, self.y, self.z, self.w)): # type: ignore + raise TypeError( + f"Quaternion components must be floats or integers, got {self.x!r}, {self.y!r}, {self.z!r}, {self.w!r}" + ) + + # Check that the components are not NaN. + if any(not math.isfinite(comp) for comp in (self.x, self.y, self.z, self.w)): + raise ValueError( + f"Quaternion components must not be NaN, got {self.x!r}, {self.y!r}, {self.z!r}, {self.w!r}." + ) + + def __add__(self, other: Quaternion) -> Quaternion: + # Use the type of self to return a Quaternion or a subclass. + return type(self)(x=self.x + other.x, y=self.y + other.y, z=self.z + other.z, w=self.w + other.w) + + def __sub__(self, other: Quaternion) -> Quaternion: + return type(self)(x=self.x - other.x, y=self.y - other.y, z=self.z - other.z, w=self.w - other.w) + + def __neg__(self) -> Quaternion: + return type(self)(x=-self.x, y=-self.y, z=-self.z, w=-self.w) + + def __mul__(self, other: float) -> Quaternion: + return type(self)(x=self.x * other, y=self.y * other, z=self.z * other, w=self.w * other) + + def __truediv__(self, other: float) -> Quaternion: + return type(self)(x=self.x / other, y=self.y / other, z=self.z / other, w=self.w / other) + + def to_tuple(self) -> tuple[float, float, float, float]: + """Convert the quaternion to a tuple.""" + return (self.x, self.y, self.z, self.w) + + def norm_squared(self) -> float: + """Return the squared norm of the quaternion.""" + return self.x**2 + self.y**2 + self.z**2 + self.w**2 + + def norm(self) -> float: + """Return the norm of the quaternion.""" + return math.sqrt(self.norm_squared()) + + def normalize(self) -> Quaternion: + """Return the normalized quaternion.""" + norm = self.norm() + return Quaternion(x=self.x / norm, y=self.y / norm, z=self.z / norm, w=self.w / norm) diff --git a/mcproto/types/slot.py b/mcproto/types/slot.py new file mode 100644 index 00000000..d8e21d48 --- /dev/null +++ b/mcproto/types/slot.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import cast, final + +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.nbt import CompoundNBT, EndNBT, NBTag +from mcproto.types.abc import MCType, dataclass + +__all__ = ["Slot"] + +""" +https://wiki.vg/Slot_Data +""" + + +@dataclass +@final +class Slot(MCType): + """Represents a slot in an inventory. + + :param present: Whether the slot has an item in it. + :param item_id: (optional) The item ID of the item in the slot. + :param item_count: (optional) The count of items in the slot. + :param nbt: (optional) The NBT data of the item in the slot. + + The optional parameters are present if and only if the slot is present. + """ + + present: bool + item_id: int | None = None + item_count: int | None = None + nbt: NBTag | None = None + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write_value(StructFormat.BOOL, self.present) + if self.present: + self.item_id = cast(int, self.item_id) + self.item_count = cast(int, self.item_count) + self.nbt = cast(NBTag, self.nbt) + buf.write_varint(self.item_id) + buf.write_value(StructFormat.BYTE, self.item_count) + self.nbt.serialize_to(buf, with_name=False) # In 1.20.2 and later, the NBT is not named, there is only the + # type (TAG_End or TAG_Compound) and the payload. + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Slot: + present = buf.read_value(StructFormat.BOOL) + if not present: + return cls(present=False) + item_id = buf.read_varint() + item_count = buf.read_value(StructFormat.BYTE) + nbt = NBTag.deserialize(buf, with_name=False) + return cls(present=True, item_id=item_id, item_count=item_count, nbt=nbt) + + @override + def validate(self) -> None: + # If the slot is present, all the fields must be present. + if self.present: + if self.item_id is None: + raise ValueError("Item ID is missing.") + if self.item_count is None: + raise ValueError("Item count is missing.") + if self.nbt is None: + self.nbt = EndNBT() + elif not isinstance(self.nbt, (CompoundNBT, EndNBT)): + raise TypeError("NBT data associated with a slot must be in a CompoundNBT.") + else: + if self.item_id is not None: + raise ValueError("Item ID must be None if there is no item in the slot.") + if self.item_count is not None: + raise ValueError("Item count must be None if there is no item in the slot.") + if self.nbt is not None and not isinstance(self.nbt, EndNBT): + raise ValueError("NBT data must be None if there is no item in the slot.") diff --git a/mcproto/types/vec3.py b/mcproto/types/vec3.py new file mode 100644 index 00000000..1cc8fcbe --- /dev/null +++ b/mcproto/types/vec3.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import math + +from typing import cast, final +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.abc import MCType, dataclass + + +@dataclass +class Vec3(MCType): + """Represents a 3D vector. + + :param x: The x component. + :param y: The y component. + :param z: The z component. + """ + + x: float | int + y: float | int + z: float | int + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write_value(StructFormat.FLOAT, self.x) + buf.write_value(StructFormat.FLOAT, self.y) + buf.write_value(StructFormat.FLOAT, self.z) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Vec3: + x = buf.read_value(StructFormat.FLOAT) + y = buf.read_value(StructFormat.FLOAT) + z = buf.read_value(StructFormat.FLOAT) + return cls(x=x, y=y, z=z) + + @override + def validate(self) -> None: + """Validate the vector's components.""" + # Check that the components are floats or integers. + if not all(isinstance(comp, (float, int)) for comp in (self.x, self.y, self.z)): # type: ignore + raise TypeError(f"Vector components must be floats or integers, got {self.x!r}, {self.y!r}, {self.z!r}") + + # Check that the components are not NaN. + if any(not math.isfinite(comp) for comp in (self.x, self.y, self.z)): + raise ValueError(f"Vector components must not be NaN, got {self.x!r}, {self.y!r}, {self.z!r}.") + + def __add__(self, other: Vec3) -> Vec3: + # Use the type of self to return a Vec3 or a subclass. + return type(self)(x=self.x + other.x, y=self.y + other.y, z=self.z + other.z) + + def __sub__(self, other: Vec3) -> Vec3: + return type(self)(x=self.x - other.x, y=self.y - other.y, z=self.z - other.z) + + def __neg__(self) -> Vec3: + return type(self)(x=-self.x, y=-self.y, z=-self.z) + + def __mul__(self, other: float) -> Vec3: + return type(self)(x=self.x * other, y=self.y * other, z=self.z * other) + + def __truediv__(self, other: float) -> Vec3: + return type(self)(x=self.x / other, y=self.y / other, z=self.z / other) + + def to_position(self) -> Position: + """Convert the vector to a position.""" + return Position(x=int(self.x), y=int(self.y), z=int(self.z)) + + def to_tuple(self) -> tuple[float, float, float]: + """Convert the vector to a tuple.""" + return (self.x, self.y, self.z) + + def to_vec3(self) -> Vec3: + """Convert the vector to a Vec3. + + This function creates a new Vec3 object with the same components. + """ + return Vec3(x=self.x, y=self.y, z=self.z) + + def norm_squared(self) -> float: + """Return the squared norm of the vector.""" + return self.x**2 + self.y**2 + self.z**2 + + def norm(self) -> float: + """Return the norm of the vector.""" + return math.sqrt(self.norm_squared()) + + def normalize(self) -> Vec3: + """Return the normalized vector.""" + norm = self.norm() + return Vec3(x=self.x / norm, y=self.y / norm, z=self.z / norm) + + +@final +class Position(Vec3): + """Represents a position in the world. + + :param x: The x coordinate (26 bits). + :param y: The y coordinate (12 bits). + :param z: The z coordinate (26 bits). + """ + + __slots__ = () + + @override + def serialize_to(self, buf: Buffer) -> None: + self.x = cast(int, self.x) + self.y = cast(int, self.y) + self.z = cast(int, self.z) + encoded = ((self.x & 0x3FFFFFF) << 38) | ((self.z & 0x3FFFFFF) << 12) | (self.y & 0xFFF) + + # Convert the bit mess to a signed integer for packing. + if encoded & 0x8000000000000000: + encoded -= 1 << 64 + buf.write_value(StructFormat.LONGLONG, encoded) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Position: + encoded = buf.read_value(StructFormat.LONGLONG) + x = (encoded >> 38) & 0x3FFFFFF + z = (encoded >> 12) & 0x3FFFFFF + y = encoded & 0xFFF + + # Convert back to signed integers. + if x >= 1 << 25: + x -= 1 << 26 + if y >= 1 << 11: + y -= 1 << 12 + if z >= 1 << 25: + z -= 1 << 26 + + return cls(x=x, y=y, z=z) + + @override + def validate(self) -> None: + """Validate the position's coordinates. + + They are all signed integers, but the x and z coordinates are 26 bits + and the y coordinate is 12 bits. + """ + super().validate() # Validate the Vec3 components. + + self.x = int(self.x) + self.y = int(self.y) + self.z = int(self.z) + if not (-1 << 25 <= self.x < 1 << 25): + raise OverflowError(f"Invalid x coordinate: {self.x}") + if not (-1 << 11 <= self.y < 1 << 11): + raise OverflowError(f"Invalid y coordinate: {self.y}") + if not (-1 << 25 <= self.z < 1 << 25): + raise OverflowError(f"Invalid z coordinate: {self.z}") + + +POS_UP = Position(0, 1, 0) +POS_DOWN = Position(0, -1, 0) +POS_NORTH = Position(0, 0, -1) +POS_SOUTH = Position(0, 0, 1) +POS_EAST = Position(1, 0, 0) +POS_WEST = Position(-1, 0, 0) + +POS_ZERO = Position(0, 0, 0) diff --git a/tests/mcproto/packets/login/test_login.py b/tests/mcproto/packets/login/test_login.py index 71067022..dfcb72bc 100644 --- a/tests/mcproto/packets/login/test_login.py +++ b/tests/mcproto/packets/login/test_login.py @@ -11,7 +11,7 @@ LoginSuccess, ) from mcproto.packets.packet import InvalidPacketContentError -from mcproto.types.chat import ChatMessage +from mcproto.types.chat import JSONTextComponent from mcproto.types.uuid import UUID from tests.helpers import gen_serializable_test from tests.mcproto.test_encryption import RSA_PUBLIC_KEY @@ -90,10 +90,10 @@ def test_login_encryption_request_noid(): gen_serializable_test( context=globals(), cls=LoginDisconnect, - fields=[("reason", ChatMessage)], + fields=[("reason", JSONTextComponent)], test_data=[ ( - (ChatMessage("You are banned."),), + (JSONTextComponent("You are banned."),), bytes.fromhex("1122596f75206172652062616e6e65642e22"), ), ], diff --git a/tests/mcproto/types/test_angle.py b/tests/mcproto/types/test_angle.py new file mode 100644 index 00000000..2de04fd9 --- /dev/null +++ b/tests/mcproto/types/test_angle.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import pytest + +from mcproto.types.vec3 import Position, POS_NORTH, POS_SOUTH, POS_EAST, POS_WEST, POS_ZERO +from mcproto.types.angle import Angle +from tests.helpers import gen_serializable_test + +PI = 3.14159265358979323846 +EPSILON = 1e-6 + +gen_serializable_test( + context=globals(), + cls=Angle, + fields=[("angle", int)], + test_data=[ + ((0,), b"\x00"), + ((256,), b"\x00"), + ((-1,), b"\xff"), + ((-256,), b"\x00"), + ((2,), b"\x02"), + ((-2,), b"\xfe"), + ], +) + + +@pytest.mark.parametrize( + ("angle", "base", "distance", "expected"), + [ + (Angle(0), POS_ZERO, 1, POS_SOUTH), + (Angle(64), POS_ZERO, 1, POS_WEST), + (Angle(128), POS_ZERO, 1, POS_NORTH), + (Angle(192), POS_ZERO, 1, POS_EAST), + ], +) +def test_in_direction(angle: Angle, base: Position, distance: int, expected: Position): + """Test that the in_direction method moves the base position in the correct direction.""" + assert (angle.in_direction(base, distance) - expected).norm() < EPSILON + + +@pytest.mark.parametrize( + ("base2", "degrees"), + [ + (0, 0), + (64, 90), + (128, 180), + (192, 270), + ], +) +def test_degrees(base2: int, degrees: int): + """Test that the from_degrees and to_degrees methods work correctly.""" + assert Angle.from_degrees(degrees) == Angle(base2) + assert Angle(base2).to_degrees() == degrees + + +@pytest.mark.parametrize( + ("rad", "angle"), + [ + (0, 0), + (PI / 2, 64), + (PI, 128), + (3 * PI / 2, 192), + ], +) +def test_radians(rad: float, angle: int): + """Test that the from_radians and to_radians methods work correctly.""" + assert Angle.from_radians(rad) == Angle(angle) + assert abs(Angle(angle).to_radians() - rad) < EPSILON diff --git a/tests/mcproto/types/test_bitset.py b/tests/mcproto/types/test_bitset.py new file mode 100644 index 00000000..0a349a93 --- /dev/null +++ b/tests/mcproto/types/test_bitset.py @@ -0,0 +1,204 @@ +from typing import List + +from mcproto.types.bitset import FixedBitset, Bitset +from tests.helpers import gen_serializable_test + +import pytest + + +gen_serializable_test( + context=globals(), + cls=FixedBitset.of_size(64), + fields=[("data", bytearray)], + test_data=[ + ((bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00"),), b"\x00\x00\x00\x00\x00\x00\x00\x00"), + ((bytearray(b"\xff\xff\xff\xff\xff\xff\xff\xff"),), b"\xff\xff\xff\xff\xff\xff\xff\xff"), + ((bytearray(b"\x55\x55\x55\x55\x55\x55\x55\x55"),), b"\x55\x55\x55\x55\x55\x55\x55\x55"), + ], +) + +gen_serializable_test( + context=globals(), + cls=FixedBitset.of_size(16), + fields=[("data", List[int])], + test_data=[ + ((bytearray(b"\x00"),), ValueError), + ], +) + +gen_serializable_test( + context=globals(), + cls=Bitset, + fields=[("size", int), ("data", List[int])], + test_data=[ + ((1, [1]), b"\x01\x00\x00\x00\x00\x00\x00\x00\x01"), + ( + (2, [1, -1]), + b"\x02\x00\x00\x00\x00\x00\x00\x00\x01\xff\xff\xff\xff\xff\xff\xff\xff", + ), + (IOError, b"\x01"), + ((3, [1]), ValueError), + ], +) + + +def test_fixed_bitset_indexing(): + """Test indexing and setting values in a FixedBitset.""" + b = FixedBitset.of_size(12).from_int(0) + assert b[0] is False + assert b[12] is False + + b[0] = True + assert b[0] is True + assert b[12] is False + + b[12] = True + assert b[12] is True + assert b[0] is True + + b[0] = False + assert b[0] is False + assert b[12] is True + + +def test_bitset_indexing(): + """Test indexing and setting values in a Bitset.""" + b = Bitset.from_int(0, size=2) + assert b[0] is False + assert b[127] is False + + b[0] = True + assert b[0] is True + + b[127] = True + assert b[127] is True + + b[0] = False + assert b[0] is False + + +def test_fixed_bitset_and(): + """Test bitwise AND operation between FixedBitsets.""" + b1 = FixedBitset.of_size(64).from_int(0xFFFFFFFFFFFFFFFF) + b2 = FixedBitset.of_size(64).from_int(0) + + result = b1 & b2 + assert bytes(result) == b"\x00\x00\x00\x00\x00\x00\x00\x00" + + +def test_bitset_and(): + """Test bitwise AND operation between Bitsets.""" + b1 = Bitset(2, [0x0101010101010101, 0x0101010101010100]) + b2 = Bitset(2, [1, 1]) + + result = b1 & b2 + assert result == Bitset(2, [1, 0]) + + +def test_fixed_bitset_or(): + """Test bitwise OR operation between FixedBitsets.""" + b1 = FixedBitset.of_size(8).from_int(0xFE) + b2 = FixedBitset.of_size(8).from_int(0x01) + + result = b1 | b2 + assert bytes(result) == b"\xff" + + +def test_bitset_or(): + """Test bitwise OR operation between Bitsets.""" + b1 = Bitset(2, [0x0101010101010101, 0x0101010101010100]) + b2 = Bitset(2, [1, 1]) + + result = b1 | b2 + assert bytes(result) == b"\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" + + +def test_fixed_bitset_xor(): + """Test bitwise XOR operation between FixedBitsets.""" + b1 = FixedBitset.of_size(64)(bytearray(b"\xff\xff\xff\xff\xff\xff\xff\xff")) + b2 = FixedBitset.of_size(64)(bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00")) + + result = b1 ^ b2 + assert result == FixedBitset.of_size(64).from_int(-1) + + +def test_bitset_xor(): + """Test bitwise XOR operation between Bitsets.""" + b1 = Bitset(2, [0x0101010101010101, 0x0101010101010101]) + b2 = Bitset(2, [0, 0]) + + result = b1 ^ b2 + assert result == Bitset(2, [0x0101010101010101, 0x0101010101010101]) + + +def test_fixed_bitset_invert(): + """Test bitwise inversion operation on FixedBitsets.""" + b = FixedBitset.of_size(64)(bytearray(b"\xff\xff\xff\xff\xff\xff\xff\xff")) + + inverted = ~b + assert inverted == FixedBitset.of_size(64).from_int(0) + + +def test_bitset_invert(): + """Test bitwise inversion operation on Bitsets.""" + b = Bitset(2, [0, 0]) + + inverted = ~b + assert inverted == Bitset(2, [-1, -1]) + + +def test_fixed_bitset_size_undefined(): + """Test that FixedBitset raises ValueError when size is not defined.""" + with pytest.raises(ValueError): + FixedBitset.from_int(0) + + with pytest.raises(ValueError): + FixedBitset(bytearray(b"\x00\x00\x00\x00")) + + +def test_bitset_len(): + """Test that FixedBitset has the correct length.""" + b = FixedBitset.of_size(64).from_int(0) + assert len(b) == 64 + + b = FixedBitset.of_size(8).from_int(0) + assert len(b) == 8 + + b = Bitset(2, [0, 0]) + assert len(b) == 128 + + +def test_fixed_bitset_operations_length_mismatch(): + """Test that FixedBitset operations raise ValueError when lengths don't match.""" + b1 = FixedBitset.of_size(64).from_int(0) + b2 = FixedBitset.of_size(8).from_int(0) + b3 = "not a bitset" + + with pytest.raises(ValueError): + b1 & b2 # type: ignore + + with pytest.raises(ValueError): + b1 | b2 # type: ignore + + with pytest.raises(ValueError): + b1 ^ b2 # type: ignore + + assert b1 != b3 + + +def test_bitset_operations_length_mismatch(): + """Test that Bitset operations raise ValueError when lengths don't match.""" + b1 = Bitset(2, [0, 0]) + b2 = Bitset.from_int(1) + b3 = "not a bitset" + + with pytest.raises(ValueError): + b1 & b2 # type: ignore + + with pytest.raises(ValueError): + b1 | b2 # type: ignore + + with pytest.raises(ValueError): + b1 ^ b2 # type: ignore + + assert b1 != b3 diff --git a/tests/mcproto/types/test_chat.py b/tests/mcproto/types/test_chat.py index 3c8f05a6..4098dd05 100644 --- a/tests/mcproto/types/test_chat.py +++ b/tests/mcproto/types/test_chat.py @@ -2,8 +2,9 @@ import pytest -from mcproto.types.chat import ChatMessage, RawChatMessage, RawChatMessageDict +from mcproto.types.chat import JSONTextComponent, RawTextComponent, RawTextComponentDict, TextComponent from tests.helpers import gen_serializable_test +from mcproto.types.nbt import CompoundNBT, StringNBT, ByteNBT, ListNBT @pytest.mark.parametrize( @@ -23,9 +24,9 @@ ), ], ) -def test_as_dict(raw: RawChatMessage, expected_dict: RawChatMessageDict): - """Test converting raw ChatMessage input into dict produces expected dict.""" - chat = ChatMessage(raw) +def test_as_dict(raw: RawTextComponent, expected_dict: RawTextComponentDict): + """Test converting raw TextComponent input into dict produces expected dict.""" + chat = JSONTextComponent(raw) assert chat.as_dict() == expected_dict @@ -44,15 +45,15 @@ def test_as_dict(raw: RawChatMessage, expected_dict: RawChatMessageDict): ), ], ) -def test_equality(raw1: RawChatMessage, raw2: RawChatMessage, expected_result: bool): - """Test comparing ChatMessage instances produces expected equality result.""" - assert (ChatMessage(raw1) == ChatMessage(raw2)) is expected_result +def test_equality(raw1: RawTextComponent, raw2: RawTextComponent, expected_result: bool): + """Test comparing TextComponent instances produces expected equality result.""" + assert (JSONTextComponent(raw1) == JSONTextComponent(raw2)) is expected_result gen_serializable_test( context=globals(), - cls=ChatMessage, - fields=[("raw", RawChatMessage)], + cls=JSONTextComponent, + fields=[("raw", RawTextComponent)], test_data=[ ( ("A Minecraft Server",), @@ -74,3 +75,46 @@ def test_equality(raw1: RawChatMessage, raw2: RawChatMessage, expected_result: b (([[]],), TypeError), ], ) + +gen_serializable_test( + context=globals(), + cls=TextComponent, + fields=[("raw", RawTextComponent)], + test_data=[ + (({"text": "abc"},), bytes(CompoundNBT([StringNBT("abc", name="text")]).serialize())), + ( + ([{"text": "abc"}, {"text": "def"}],), + bytes( + CompoundNBT( + [ + StringNBT("abc", name="text"), + ListNBT([CompoundNBT([StringNBT("def", name="text")])], name="extra"), + ] + ).serialize() + ), + ), + (("A Minecraft Server",), bytes(CompoundNBT([StringNBT("A Minecraft Server", name="text")]).serialize())), + ( + ([{"text": "abc", "extra": [{"text": "def"}]}, {"text": "ghi"}],), + bytes( + CompoundNBT( + [ + StringNBT("abc", name="text"), + ListNBT( + [ + CompoundNBT([StringNBT("def", name="text")]), + CompoundNBT([StringNBT("ghi", name="text")]), + ], + name="extra", + ), + ] + ).serialize() + ), + ), + # Type shitfuckery + (TypeError, bytes(CompoundNBT([CompoundNBT([ByteNBT(0, "Something")], "text")]).serialize())), + (KeyError, bytes(CompoundNBT([ByteNBT(0, "unknownkey")]).serialize())), + (TypeError, bytes(CompoundNBT([ListNBT([StringNBT("Expected str")], "text")]).serialize())), + (TypeError, bytes(CompoundNBT([StringNBT("Wrong type", "extra")]).serialize())), + ], +) diff --git a/tests/mcproto/types/test_identifier.py b/tests/mcproto/types/test_identifier.py new file mode 100644 index 00000000..d7cbb731 --- /dev/null +++ b/tests/mcproto/types/test_identifier.py @@ -0,0 +1,19 @@ +from mcproto.types.identifier import Identifier +from tests.helpers import gen_serializable_test + + +gen_serializable_test( + context=globals(), + cls=Identifier, + fields=[("namespace", str), ("path", str)], + test_data=[ + (("minecraft", "stone"), b"\x0fminecraft:stone"), + (("minecraft", "stone_brick"), b"\x15minecraft:stone_brick"), + (("minecraft", "stone_brick_slab"), b"\x1aminecraft:stone_brick_slab"), + (("minecr*ft", "stone_brick_slab_top"), ValueError), # Invalid namespace + (("minecraft", "stone_brick_slab_t@p"), ValueError), # Invalid path + (("", "something"), ValueError), # Empty namespace + (("minecraft", ""), ValueError), # Empty path + (("minecraft", "a" * 32767), ValueError), # Too long + ], +) diff --git a/tests/mcproto/types/test_quaternion.py b/tests/mcproto/types/test_quaternion.py new file mode 100644 index 00000000..096e685b --- /dev/null +++ b/tests/mcproto/types/test_quaternion.py @@ -0,0 +1,121 @@ +from __future__ import annotations +import struct +from typing import cast +import pytest +import math +from mcproto.types.quaternion import Quaternion +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=Quaternion, + fields=[("x", float), ("y", float), ("z", float), ("w", float)], + test_data=[ + ((0.0, 0.0, 0.0, 0.0), struct.pack(">ffff", 0.0, 0.0, 0.0, 0.0)), + ((-1.0, -1.0, -1.0, -1.0), struct.pack(">ffff", -1.0, -1.0, -1.0, -1.0)), + ((1.0, 2.0, 3.0, 4.0), struct.pack(">ffff", 1.0, 2.0, 3.0, 4.0)), + ((1.5, 2.5, 3.5, 4.5), struct.pack(">ffff", 1.5, 2.5, 3.5, 4.5)), + # Invalid values + ((1.0, 2.0, "3.0", 4.0), TypeError), + ((float("nan"), 2.0, 3.0, 4.0), ValueError), + ((1.0, float("inf"), 3.0, 4.0), ValueError), + ((1.0, 2.0, -float("inf"), 4.0), ValueError), + ], +) + + +def test_quaternion_addition(): + """Test that two Quaternion objects can be added together (resulting in a new Quaternion object).""" + v1 = Quaternion(x=1.0, y=2.0, z=3.0, w=4.0) + v2 = Quaternion(x=4.5, y=5.25, z=6.125, w=7.0625) + v3 = v1 + v2 + assert type(v3) == Quaternion + assert v3.x == 5.5 + assert v3.y == 7.25 + assert v3.z == 9.125 + assert v3.w == 11.0625 + + +def test_quaternion_subtraction(): + """Test that two Quaternion objects can be subtracted (resulting in a new Quaternion object).""" + v1 = Quaternion(x=1.0, y=2.0, z=3.0, w=4.0) + v2 = Quaternion(x=4.5, y=5.25, z=6.125, w=7.0625) + v3 = v2 - v1 + assert type(v3) == Quaternion + assert v3.x == 3.5 + assert v3.y == 3.25 + assert v3.z == 3.125 + assert v3.w == 3.0625 + + +def test_quaternion_negative(): + """Test that a Quaternion object can be negated.""" + v1 = Quaternion(x=1.0, y=2.5, z=3.0, w=4.5) + v2 = -v1 + assert type(v2) == Quaternion + assert v2.x == -1.0 + assert v2.y == -2.5 + assert v2.z == -3.0 + assert v2.w == -4.5 + + +def test_quaternion_multiplication_int(): + """Test that a Quaternion object can be multiplied by an integer.""" + v1 = Quaternion(x=1.0, y=2.25, z=3.0, w=4.5) + v2 = v1 * 2 + assert v2.x == 2.0 + assert v2.y == 4.5 + assert v2.z == 6.0 + assert v2.w == 9.0 + + +def test_quaternion_multiplication_float(): + """Test that a Quaternion object can be multiplied by a float.""" + v1 = Quaternion(x=2.0, y=4.5, z=6.0, w=9.0) + v2 = v1 * 1.5 + assert type(v2) == Quaternion + assert v2.x == 3.0 + assert v2.y == 6.75 + assert v2.z == 9.0 + assert v2.w == 13.5 + + +def test_quaternion_norm_squared(): + """Test that the squared norm of a Quaternion object can be calculated.""" + v = Quaternion(x=3.0, y=4.0, z=5.0, w=6.0) + assert v.norm_squared() == 86.0 + + +def test_quaternion_norm(): + """Test that the norm of a Quaternion object can be calculated.""" + v = Quaternion(x=3.0, y=4.0, z=5.0, w=6.0) + assert (v.norm() - 86.0**0.5) < 1e-6 + + +@pytest.mark.parametrize( + ("x", "y", "z", "w", "expected"), + [ + (0, 0, 0, 0, ZeroDivisionError), + (1, 0, 0, 0, Quaternion(x=1, y=0, z=0, w=0)), + (0, 1, 0, 0, Quaternion(x=0, y=1, z=0, w=0)), + (0, 0, 1, 0, Quaternion(x=0, y=0, z=1, w=0)), + (0, 0, 0, 1, Quaternion(x=0, y=0, z=0, w=1)), + (1, 1, 1, 1, Quaternion(x=1, y=1, z=1, w=1) / math.sqrt(4)), + (-1, -1, -1, -1, Quaternion(x=-1, y=-1, z=-1, w=-1) / math.sqrt(4)), + ], +) +def test_quaternion_normalize(x: float, y: float, z: float, w: float, expected: Quaternion | type): + """Test that a Quaternion object can be normalized.""" + v = Quaternion(x=x, y=y, z=z, w=w) + if isinstance(expected, type): + expected = cast(type[Exception], expected) + with pytest.raises(expected): + v.normalize() + else: + assert (v.normalize() - expected).norm() < 1e-6 + + +def test_quaternion_tuple(): + """Test that a Quaternion object can be converted to a tuple.""" + v = Quaternion(x=1.0, y=2.0, z=3.0, w=4.0) + assert v.to_tuple() == (1.0, 2.0, 3.0, 4.0) diff --git a/tests/mcproto/types/test_slot.py b/tests/mcproto/types/test_slot.py new file mode 100644 index 00000000..7db93de6 --- /dev/null +++ b/tests/mcproto/types/test_slot.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from mcproto.types.nbt import ByteNBT, CompoundNBT, EndNBT, IntNBT, NBTag +from mcproto.types.slot import Slot +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=Slot, + fields=[("present", bool), ("item_id", int), ("item_count", int), ("nbt", NBTag)], + test_data=[ + ((False, None, None, None), b"\x00"), + ((True, 1, 1, None), b"\x01\x01\x01\x00"), # EndNBT() is automatically added + ((True, 1, 1, EndNBT()), b"\x01\x01\x01\x00"), + ( + (True, 2, 3, CompoundNBT([IntNBT(4, "int_nbt"), ByteNBT(5, "byte_nbt")])), + b"\x01\x02\x03" + CompoundNBT([IntNBT(4, "int_nbt"), ByteNBT(5, "byte_nbt")]).serialize(), + ), + # Present but no item_id + ((True, None, 1, None), ValueError), + # Present but no item_count + ((True, 1, None, None), ValueError), + # Present but the NBT has the wrong type + ((True, 1, 1, IntNBT(1, "int_nbt")), TypeError), + # Not present but item_id is present + ((False, 1, 1, None), ValueError), + # Not present but item_count is present + ((False, None, 1, None), ValueError), + # Not present but NBT is present + ((False, None, None, CompoundNBT([])), ValueError), + ], +) diff --git a/tests/mcproto/types/test_vec3.py b/tests/mcproto/types/test_vec3.py new file mode 100644 index 00000000..01203d08 --- /dev/null +++ b/tests/mcproto/types/test_vec3.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import struct +from typing import cast +import pytest +import math + +from mcproto.types.vec3 import Position, Vec3 +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=Position, + fields=[("x", int), ("y", int), ("z", int)], + test_data=[ + ((0, 0, 0), b"\x00\x00\x00\x00\x00\x00\x00\x00"), + ((-1, -1, -1), b"\xff\xff\xff\xff\xff\xff\xff\xff"), + # from https://wiki.vg/Protocol#Position + ( + (18357644, 831, -20882616), + bytes([0b01000110, 0b00000111, 0b01100011, 0b00101100, 0b00010101, 0b10110100, 0b10000011, 0b00111111]), + ), + # X out of bounds + ((1 << 25, 0, 0), OverflowError), + ((-(1 << 25) - 1, 0, 0), OverflowError), + # Y out of bounds + ((0, 1 << 11, 0), OverflowError), + ((0, -(1 << 11) - 1, 0), OverflowError), + # Z out of bounds + ((0, 0, 1 << 25), OverflowError), + ((0, 0, -(1 << 25) - 1), OverflowError), + ], +) + +gen_serializable_test( + context=globals(), + cls=Vec3, + fields=[("x", float), ("y", float), ("z", float)], + test_data=[ + ((0.0, 0.0, 0.0), struct.pack(">fff", 0.0, 0.0, 0.0)), + ((-1.0, -1.0, -1.0), struct.pack(">fff", -1.0, -1.0, -1.0)), + ((1.0, 2.0, 3.0), struct.pack(">fff", 1.0, 2.0, 3.0)), + ((1.5, 2.5, 3.5), struct.pack(">fff", 1.5, 2.5, 3.5)), + # Invalid values + ((1.0, 2.0, "3.0"), TypeError), + ((float("nan"), 2.0, 3.0), ValueError), + ((1.0, float("inf"), 3.0), ValueError), + ((1.0, 2.0, -float("inf")), ValueError), + ], +) + + +def test_position_addition(): + """Test that two Position objects can be added together (resuling in a new Position object).""" + p1 = Position(x=1, y=2, z=3) + p2 = Position(x=4, y=5, z=6) + p3 = p1 + p2 + assert type(p3) == Position + assert p3.x == 5 + assert p3.y == 7 + assert p3.z == 9 + + +def test_position_subtraction(): + """Test that two Position objects can be subtracted (resuling in a new Position object).""" + p1 = Position(x=1, y=2, z=3) + p2 = Position(x=2, y=4, z=6) + p3 = p2 - p1 + assert type(p3) == Position + assert p3.x == 1 + assert p3.y == 2 + assert p3.z == 3 + + +def test_position_negative(): + """Test that a Position object can be negated.""" + p1 = Position(x=1, y=2, z=3) + p2 = -p1 + assert type(p2) == Position + assert p2.x == -1 + assert p2.y == -2 + assert p2.z == -3 + + +def test_position_multiplication_int(): + """Test that a Position object can be multiplied by an integer.""" + p1 = Position(x=1, y=2, z=3) + p2 = p1 * 2 + assert p2.x == 2 + assert p2.y == 4 + assert p2.z == 6 + + +def test_position_multiplication_float(): + """Test that a Position object can be multiplied by a float.""" + p1 = Position(x=2, y=4, z=6) + p2 = p1 * 1.5 + assert type(p2) == Position + assert p2.x == 3 + assert p2.y == 6 + assert p2.z == 9 + + +def test_vec3_to_position(): + """Test that a Vec3 object can be converted to a Position object.""" + v = Vec3(x=1.5, y=2.5, z=3.5) + p = v.to_position() + assert type(p) == Position + assert p.x == 1 + assert p.y == 2 + assert p.z == 3 + + +def test_position_to_vec3(): + """Test that a Position object can be converted to a Vec3 object.""" + p = Position(x=1, y=2, z=3) + v = p.to_vec3() + assert type(v) == Vec3 + assert v.x == 1.0 + assert v.y == 2.0 + assert v.z == 3.0 + + +def test_position_to_tuple(): + """Test that a Position object can be converted to a tuple.""" + p = Position(x=1, y=2, z=3) + t = p.to_tuple() + assert type(t) == tuple + assert t == (1, 2, 3) + + +def test_vec3_addition(): + """Test that two Vec3 objects can be added together (resuling in a new Vec3 object).""" + v1 = Vec3(x=1.0, y=2.0, z=3.0) + v2 = Vec3(x=4.5, y=5.25, z=6.125) + v3 = v1 + v2 + assert type(v3) == Vec3 + assert v3.x == 5.5 + assert v3.y == 7.25 + assert v3.z == 9.125 + + +def test_vec3_subtraction(): + """Test that two Vec3 objects can be subtracted (resuling in a new Vec3 object).""" + v1 = Vec3(x=1.0, y=2.0, z=3.0) + v2 = Vec3(x=4.5, y=5.25, z=6.125) + v3 = v2 - v1 + assert type(v3) == Vec3 + assert v3.x == 3.5 + assert v3.y == 3.25 + assert v3.z == 3.125 + + +def test_vec3_negative(): + """Test that a Vec3 object can be negated.""" + v1 = Vec3(x=1.0, y=2.5, z=3.0) + v2 = -v1 + assert type(v2) == Vec3 + assert v2.x == -1.0 + assert v2.y == -2.5 + assert v2.z == -3.0 + + +def test_vec3_multiplication_int(): + """Test that a Vec3 object can be multiplied by an integer.""" + v1 = Vec3(x=1.0, y=2.25, z=3.0) + v2 = v1 * 2 + assert v2.x == 2.0 + assert v2.y == 4.5 + assert v2.z == 6.0 + + +def test_vec3_multiplication_float(): + """Test that a Vec3 object can be multiplied by a float.""" + v1 = Vec3(x=2.0, y=4.5, z=6.0) + v2 = v1 * 1.5 + assert type(v2) == Vec3 + assert v2.x == 3.0 + assert v2.y == 6.75 + assert v2.z == 9.0 + + +def test_vec3_norm_squared(): + """Test that the squared norm of a Vec3 object can be calculated.""" + v = Vec3(x=3.0, y=4.0, z=5.0) + assert v.norm_squared() == 50.0 + + +def test_vec3_norm(): + """Test that the norm of a Vec3 object can be calculated.""" + v = Vec3(x=3.0, y=4.0, z=5.0) + assert (v.norm() - 50.0**0.5) < 1e-6 + + +@pytest.mark.parametrize( + ("x", "y", "z", "expected"), + [ + (0, 0, 0, ZeroDivisionError), + (1, 0, 0, Vec3(x=1, y=0, z=0)), + (0, 1, 0, Vec3(x=0, y=1, z=0)), + (0, 0, 1, Vec3(x=0, y=0, z=1)), + (1, 1, 1, Vec3(x=1, y=1, z=1) / math.sqrt(3)), + (-1, -1, -1, Vec3(x=-1, y=-1, z=-1) / math.sqrt(3)), + ], +) +def test_vec3_normalize(x: float, y: float, z: float, expected: Vec3 | type): + """Test that a Vec3 object can be normalized.""" + v = Vec3(x=x, y=y, z=z) + if isinstance(expected, type): + expected = cast(type[Exception], expected) + with pytest.raises(expected): + v.normalize() + else: + assert (v.normalize() - expected).norm() < 1e-6