diff --git a/changes/209.feature.md b/changes/209.feature.md new file mode 100644 index 00000000..4ea701de --- /dev/null +++ b/changes/209.feature.md @@ -0,0 +1 @@ +- Added `InvalidPacketContentError` exception, raised when deserializing of a specific packet fails. This error inherits from `IOError`, making it backwards compatible with the original implementation. diff --git a/mcproto/packets/handshaking/handshake.py b/mcproto/packets/handshaking/handshake.py index a7aab98e..80ddcb15 100644 --- a/mcproto/packets/handshaking/handshake.py +++ b/mcproto/packets/handshaking/handshake.py @@ -68,7 +68,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" return cls( protocol_version=buf.read_varint(), diff --git a/mcproto/packets/login/login.py b/mcproto/packets/login/login.py index b3f69107..6f499ac1 100644 --- a/mcproto/packets/login/login.py +++ b/mcproto/packets/login/login.py @@ -50,7 +50,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" username = buf.read_utf() uuid = buf.read_optional(lambda: UUID.deserialize(buf)) @@ -91,7 +91,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" server_id = buf.read_utf() public_key_raw = buf.read_bytearray() @@ -130,7 +130,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" shared_secret = buf.read_bytearray() verify_token = buf.read_bytearray() @@ -163,7 +163,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" uuid = UUID.deserialize(buf) username = buf.read_utf() @@ -191,7 +191,7 @@ def serialize(self) -> Buffer: return self.reason.serialize() @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" reason = ChatMessage.deserialize(buf) return cls(reason) @@ -226,7 +226,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" message_id = buf.read_varint() channel = buf.read_utf() @@ -260,7 +260,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" message_id = buf.read_varint() data = buf.read_optional(lambda: buf.read(buf.remaining)) @@ -295,7 +295,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" threshold = buf.read_varint() return cls(threshold) diff --git a/mcproto/packets/packet.py b/mcproto/packets/packet.py index 6e453f54..a9cd2186 100644 --- a/mcproto/packets/packet.py +++ b/mcproto/packets/packet.py @@ -1,9 +1,13 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Sequence from enum import IntEnum from typing import ClassVar +from typing_extensions import Self + +from mcproto.buffer import Buffer from mcproto.utils.abc import RequiredParamsABCMixin, Serializable __all__ = [ @@ -42,6 +46,20 @@ class Packet(Serializable, RequiredParamsABCMixin): PACKET_ID: ClassVar[int] GAME_STATE: ClassVar[GameState] + @classmethod + def deserialize(cls, buf: Buffer, /) -> Self: + """Deserialize the packet.""" + try: + return cls._deserialize(buf) + except IOError as exc: + raise InvalidPacketContentError.from_packet_class(cls, buf, repr(exc)) from exc + + @classmethod + @abstractmethod + def _deserialize(cls, buf: Buffer, /) -> Self: + """Deserialize the packet.""" + raise NotImplementedError + class ServerBoundPacket(Packet): """Packet bound to a server (Client -> Server).""" @@ -53,3 +71,87 @@ class ClientBoundPacket(Packet): """Packet bound to a client (Server -> Client).""" __slots__ = () + + +class InvalidPacketContentError(IOError): + """Unable to deserialize given packet, as it didn't match the expected content. + + This error can occur during deserialization of a specific packet (after the + packet class was already identified), but the deserialization process for this + packet type failed. + + This can happen if the server sent a known packet, but it's content didn't match + the expected content for this packet kind. + """ + + def __init__( + self, + packet_id: int, + game_state: GameState, + direction: PacketDirection, + buffer: Buffer, + message: str, + ) -> None: + """Initialize the error class. + + :param packet_id: Identified packet ID. + :param game_state: Game state of the identified packet. + :param direction: Packet direction of the identified packet. + :param buffer: Buffer received for deserialization, that failed to parse. + :param message: Reason for the failure. + """ + self.packet_id = packet_id + self.game_state = game_state + self.direction = direction + self.buffer = buffer + self.message = message + super().__init__(self.msg) + + @classmethod + def from_packet_class(cls, packet_class: type[Packet], buffer: Buffer, message: str) -> Self: + """Construct the error from packet class. + + This is a convenience constructor, picking up the necessary parameters about the identified packet + from the packet class automatically (packet id, game state, ...). + """ + if isinstance(packet_class, ServerBoundPacket): + direction = PacketDirection.SERVERBOUND + elif isinstance(packet_class, ClientBoundPacket): + direction = PacketDirection.CLIENTBOUND + else: + raise ValueError( + "Unable to determine the packet direction. Got a packet class which doesn't " + "inherit from ServerBoundPacket nor ClientBoundPacket class." + ) + + return cls(packet_class.PACKET_ID, packet_class.GAME_STATE, direction, buffer, message) + + @property + def msg(self) -> str: + """Produce a message for this error.""" + msg_parts = [] + + if self.direction is PacketDirection.CLIENTBOUND: + msg_parts.append("Clientbound") + else: + msg_parts.append("Serverbound") + + msg_parts.append("packet in") + + if self.game_state is GameState.HANDSHAKING: + msg_parts.append("handshaking") + elif self.game_state is GameState.STATUS: + msg_parts.append("status") + elif self.game_state is GameState.LOGIN: + msg_parts.append("login") + else: + msg_parts.append("play") + + msg_parts.append("game state") + msg_parts.append(f"with ID: 0x{self.packet_id:02x}") + msg_parts.append(f"failed to deserialize: {self.message}") + + return " ".join(msg_parts) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.msg})" diff --git a/mcproto/packets/status/ping.py b/mcproto/packets/status/ping.py index 1a139509..6aebcc6a 100644 --- a/mcproto/packets/status/ping.py +++ b/mcproto/packets/status/ping.py @@ -36,7 +36,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" payload = buf.read_value(StructFormat.LONGLONG) return cls(payload) diff --git a/mcproto/packets/status/status.py b/mcproto/packets/status/status.py index 0db23e0a..81abd826 100644 --- a/mcproto/packets/status/status.py +++ b/mcproto/packets/status/status.py @@ -25,7 +25,7 @@ def serialize(self) -> Buffer: # pragma: no cover, nothing to test here. return Buffer() @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to test here. + def _deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to test here. """Deserialize the packet.""" return cls() @@ -54,7 +54,7 @@ def serialize(self) -> Buffer: return buf @classmethod - def deserialize(cls, buf: Buffer, /) -> Self: + def _deserialize(cls, buf: Buffer, /) -> Self: """Deserialize the packet.""" s = buf.read_utf() data_ = json.loads(s)