Skip to content

Commit

Permalink
Add custom error class representing failed packet deserialization
Browse files Browse the repository at this point in the history
Instead of relying purely on IOError, this adds a new error raised only
when packet deserialization fails for a specific, already identified
packet class.

This custom error type will contain the information on the identified
packet class, including the received buffer for deserialization which
failed to parse. This makes it easier to identify potential problems.
This custom error still inherits from `IOError`, hence avoiding breaking
the compatibility.
  • Loading branch information
ItsDrike committed Sep 20, 2023
1 parent 59b4fe7 commit 1b8491b
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 12 deletions.
1 change: 1 addition & 0 deletions changes/209.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Added `InvalidPacketContentError` exception, raised when deserializing of a specific packet fails. This error inherits from `IOError`, making it backwards compatible with the original implementation.
2 changes: 1 addition & 1 deletion mcproto/packets/handshaking/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
return cls(
protocol_version=buf.read_varint(),
Expand Down
16 changes: 8 additions & 8 deletions mcproto/packets/login/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
username = buf.read_utf()
uuid = buf.read_optional(lambda: UUID.deserialize(buf))
Expand Down Expand Up @@ -91,7 +91,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
server_id = buf.read_utf()
public_key_raw = buf.read_bytearray()
Expand Down Expand Up @@ -130,7 +130,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
shared_secret = buf.read_bytearray()
verify_token = buf.read_bytearray()
Expand Down Expand Up @@ -163,7 +163,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
uuid = UUID.deserialize(buf)
username = buf.read_utf()
Expand Down Expand Up @@ -191,7 +191,7 @@ def serialize(self) -> Buffer:
return self.reason.serialize()

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
reason = ChatMessage.deserialize(buf)
return cls(reason)
Expand Down Expand Up @@ -226,7 +226,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
message_id = buf.read_varint()
channel = buf.read_utf()
Expand Down Expand Up @@ -260,7 +260,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
message_id = buf.read_varint()
data = buf.read_optional(lambda: buf.read(buf.remaining))
Expand Down Expand Up @@ -295,7 +295,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
threshold = buf.read_varint()
return cls(threshold)
102 changes: 102 additions & 0 deletions mcproto/packets/packet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

from abc import abstractmethod
from collections.abc import Sequence
from enum import IntEnum
from typing import ClassVar

from typing_extensions import Self

from mcproto.buffer import Buffer
from mcproto.utils.abc import RequiredParamsABCMixin, Serializable

__all__ = [
Expand Down Expand Up @@ -42,6 +46,20 @@ class Packet(Serializable, RequiredParamsABCMixin):
PACKET_ID: ClassVar[int]
GAME_STATE: ClassVar[GameState]

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
try:
return cls._deserialize(buf)
except IOError as exc:
raise InvalidPacketContentError.from_packet_class(cls, buf, repr(exc)) from exc

@classmethod
@abstractmethod
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
raise NotImplementedError


class ServerBoundPacket(Packet):
"""Packet bound to a server (Client -> Server)."""
Expand All @@ -53,3 +71,87 @@ class ClientBoundPacket(Packet):
"""Packet bound to a client (Server -> Client)."""

__slots__ = ()


class InvalidPacketContentError(IOError):
"""Unable to deserialize given packet, as it didn't match the expected content.
This error can occur during deserialization of a specific packet (after the
packet class was already identified), but the deserialization process for this
packet type failed.
This can happen if the server sent a known packet, but it's content didn't match
the expected content for this packet kind.
"""

def __init__(
self,
packet_id: int,
game_state: GameState,
direction: PacketDirection,
buffer: Buffer,
message: str,
) -> None:
"""Initialize the error class.
:param packet_id: Identified packet ID.
:param game_state: Game state of the identified packet.
:param direction: Packet direction of the identified packet.
:param buffer: Buffer received for deserialization, that failed to parse.
:param message: Reason for the failure.
"""
self.packet_id = packet_id
self.game_state = game_state
self.direction = direction
self.buffer = buffer
self.message = message
super().__init__(self.msg)

@classmethod
def from_packet_class(cls, packet_class: type[Packet], buffer: Buffer, message: str) -> Self:
"""Construct the error from packet class.
This is a convenience constructor, picking up the necessary parameters about the identified packet
from the packet class automatically (packet id, game state, ...).
"""
if isinstance(packet_class, ServerBoundPacket):
direction = PacketDirection.SERVERBOUND
elif isinstance(packet_class, ClientBoundPacket):
direction = PacketDirection.CLIENTBOUND
else:
raise ValueError(
"Unable to determine the packet direction. Got a packet class which doesn't "
"inherit from ServerBoundPacket nor ClientBoundPacket class."
)

return cls(packet_class.PACKET_ID, packet_class.GAME_STATE, direction, buffer, message)

@property
def msg(self) -> str:
"""Produce a message for this error."""
msg_parts = []

if self.direction is PacketDirection.CLIENTBOUND:
msg_parts.append("Clientbound")
else:
msg_parts.append("Serverbound")

msg_parts.append("packet in")

if self.game_state is GameState.HANDSHAKING:
msg_parts.append("handshaking")
elif self.game_state is GameState.STATUS:
msg_parts.append("status")
elif self.game_state is GameState.LOGIN:
msg_parts.append("login")
else:
msg_parts.append("play")

msg_parts.append("game state")
msg_parts.append(f"with ID: 0x{self.packet_id:02x}")
msg_parts.append(f"failed to deserialize: {self.message}")

return " ".join(msg_parts)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.msg})"
2 changes: 1 addition & 1 deletion mcproto/packets/status/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
payload = buf.read_value(StructFormat.LONGLONG)
return cls(payload)
4 changes: 2 additions & 2 deletions mcproto/packets/status/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def serialize(self) -> Buffer: # pragma: no cover, nothing to test here.
return Buffer()

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to test here.
def _deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to test here.
"""Deserialize the packet."""
return cls()

Expand Down Expand Up @@ -54,7 +54,7 @@ def serialize(self) -> Buffer:
return buf

@classmethod
def deserialize(cls, buf: Buffer, /) -> Self:
def _deserialize(cls, buf: Buffer, /) -> Self:
"""Deserialize the packet."""
s = buf.read_utf()
data_ = json.loads(s)
Expand Down

0 comments on commit 1b8491b

Please sign in to comment.