diff --git a/changes/131.internal.md b/changes/131.internal.md new file mode 100644 index 00000000..52d011e1 --- /dev/null +++ b/changes/131.internal.md @@ -0,0 +1 @@ +Any overridden methods in any classes now have to explicitly use the `typing.override` decorator (see [PEP 698](https://peps.python.org/pep-0698/)) diff --git a/docs/conf.py b/docs/conf.py index 8ca26190..d3641c6a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,6 +14,7 @@ from pathlib import Path from packaging.version import parse as parse_version +from typing_extensions import override if sys.version_info >= (3, 11): from tomllib import load as toml_parse @@ -172,6 +173,7 @@ def mock_autodoc() -> None: from sphinx.ext import autodoc class MockedClassDocumenter(autodoc.ClassDocumenter): + @override def add_line(self, line: str, source: str, *lineno: int) -> None: if line == " Bases: :py:class:`object`": return diff --git a/docs/extensions/attributetable.py b/docs/extensions/attributetable.py index 729cd0b7..99740fdf 100644 --- a/docs/extensions/attributetable.py +++ b/docs/extensions/attributetable.py @@ -14,6 +14,7 @@ from sphinx.util.docutils import SphinxDirective from sphinx.util.typing import OptionSpec from sphinx.writers.html5 import HTML5Translator +from typing_extensions import override class AttributeTable(nodes.General, nodes.Element): @@ -111,6 +112,7 @@ def parse_name(self, content: str) -> tuple[str, str]: return modulename, name + @override def run(self) -> list[AttributeTablePlaceholder]: """If you're curious on the HTML this is meant to generate: diff --git a/mcproto/auth/account.py b/mcproto/auth/account.py index 8cfdee01..353d73d1 100644 --- a/mcproto/auth/account.py +++ b/mcproto/auth/account.py @@ -1,6 +1,7 @@ from __future__ import annotations import httpx +from typing_extensions import override from mcproto.types.uuid import UUID as McUUID # noqa: N811 @@ -22,6 +23,7 @@ def __init__(self, mismatched_variable: str, current: object, expected: object) self.expected = expected super().__init__(repr(self)) + @override def __repr__(self) -> str: msg = f"Account has mismatched {self.missmatched_variable}: " msg += f"current={self.current!r}, expected={self.expected!r}." diff --git a/mcproto/auth/microsoft/oauth.py b/mcproto/auth/microsoft/oauth.py index 25b6363c..c660404c 100644 --- a/mcproto/auth/microsoft/oauth.py +++ b/mcproto/auth/microsoft/oauth.py @@ -5,6 +5,7 @@ from typing import TypedDict import httpx +from typing_extensions import override __all__ = [ "MicrosoftOauthResponseErrorType", @@ -67,6 +68,7 @@ def msg(self) -> str: return f"Unknown error: {self.error!r}" return f"Error {self.err_type.name}: {self.err_type.value!r}" + @override def __repr__(self) -> str: return f"{self.__class__.__name__}({self.msg!r})" diff --git a/mcproto/auth/microsoft/xbox.py b/mcproto/auth/microsoft/xbox.py index 8d64a1e7..d454eb0a 100644 --- a/mcproto/auth/microsoft/xbox.py +++ b/mcproto/auth/microsoft/xbox.py @@ -4,6 +4,7 @@ from typing import NamedTuple import httpx +from typing_extensions import override __all__ = [ "XSTSErrorType", @@ -76,6 +77,7 @@ def msg(self) -> str: return " ".join(msg_parts) + @override def __repr__(self) -> str: return f"{self.__class__.__name__}({self.msg})" diff --git a/mcproto/auth/msa.py b/mcproto/auth/msa.py index 83291529..48f541bd 100644 --- a/mcproto/auth/msa.py +++ b/mcproto/auth/msa.py @@ -3,7 +3,7 @@ from enum import Enum import httpx -from typing_extensions import Self +from typing_extensions import Self, override from mcproto.auth.account import Account from mcproto.types.uuid import UUID as McUUID # noqa: N811 @@ -59,6 +59,7 @@ def msg(self) -> str: return " ".join(msg_parts) + @override def __repr__(self) -> str: return f"{self.__class__.__name__}({self.msg})" diff --git a/mcproto/auth/yggdrasil.py b/mcproto/auth/yggdrasil.py index f0baa802..3bd41291 100644 --- a/mcproto/auth/yggdrasil.py +++ b/mcproto/auth/yggdrasil.py @@ -5,7 +5,7 @@ from uuid import uuid4 import httpx -from typing_extensions import Self +from typing_extensions import Self, override from mcproto.auth.account import Account from mcproto.types.uuid import UUID as McUUID # noqa: N811 @@ -102,6 +102,7 @@ def msg(self) -> str: return " ".join(msg_parts) + @override def __repr__(self) -> str: return f"{self.__class__.__name__}({self.msg})" diff --git a/mcproto/buffer.py b/mcproto/buffer.py index 8f2878ae..35a13fb8 100644 --- a/mcproto/buffer.py +++ b/mcproto/buffer.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing_extensions import override + from mcproto.protocol.base_io import BaseSyncReader, BaseSyncWriter __all__ = ["Buffer"] @@ -14,10 +16,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pos = 0 + @override def write(self, data: bytes) -> None: """Write/Store given ``data`` into the buffer.""" self.extend(data) + @override def read(self, length: int) -> bytearray: """Read data stored in the buffer. @@ -52,6 +56,7 @@ def read(self, length: int) -> bytearray: finally: self.pos = end + @override def clear(self, only_already_read: bool = False) -> None: """Clear out the stored data and reset position. diff --git a/mcproto/connection.py b/mcproto/connection.py index 11f2b076..376c3910 100644 --- a/mcproto/connection.py +++ b/mcproto/connection.py @@ -9,7 +9,7 @@ import asyncio_dgram from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from typing_extensions import ParamSpec, Self +from typing_extensions import ParamSpec, Self, override from mcproto.protocol.base_io import BaseAsyncReader, BaseAsyncWriter, BaseSyncReader, BaseSyncWriter @@ -94,6 +94,7 @@ def _write(self, data: bytes, /) -> None: """Send raw ``data`` through this specific connection.""" raise NotImplementedError + @override def write(self, data: bytes, /) -> None: """Send given ``data`` over the connection. @@ -116,6 +117,7 @@ def _read(self, length: int, /) -> bytearray: """ raise NotImplementedError + @override def read(self, length: int, /) -> bytearray: """Receive data sent through the connection. @@ -206,6 +208,7 @@ async def _write(self, data: bytes, /) -> None: """Send raw ``data`` through this specific connection.""" raise NotImplementedError + @override async def write(self, data: bytes, /) -> None: """Send given ``data`` over the connection. @@ -228,6 +231,7 @@ async def _read(self, length: int, /) -> bytearray: """ raise NotImplementedError + @override async def read(self, length: int, /) -> bytearray: """Receive data sent through the connection. @@ -263,28 +267,15 @@ def __init__(self, socket: T_SOCK): super().__init__() self.socket = socket + @override @classmethod def make_client(cls, address: tuple[str, int], timeout: float) -> Self: - """Construct a client connection (Client -> Server) to given server ``address``. - - :param address: Address of the server to connection to. - :param timeout: - Amount of seconds to wait for the connection to be established. - If connection can't be established within this time, :exc:`TimeoutError` will be raised. - This timeout is then also used for any further data receiving. - """ sock = socket.create_connection(address, timeout=timeout) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) return cls(sock) + @override def _read(self, length: int) -> bytearray: - """Receive raw data from this specific connection. - - :param length: - Amount of bytes to be received. If the requested amount can't be received - (server didn't send that much data/server didn't send any data), an :exc:`IOError` - will be raised. - """ result = bytearray() while len(result) < length: new = self.socket.recv(length - len(result)) @@ -301,12 +292,12 @@ def _read(self, length: int) -> bytearray: return result + @override def _write(self, data: bytes) -> None: - """Send raw ``data`` through this specific connection.""" self.socket.send(data) + @override def _close(self) -> None: - """Close the underlying connection.""" # Gracefully end the connection first (shutdown), informing the other side # we're disconnecting, and waiting for them to disconnect cleanly (TCP FIN) try: @@ -329,28 +320,15 @@ def __init__(self, reader: T_STREAMREADER, writer: T_STREAMWRITER, timeout: floa self.writer = writer self.timeout = timeout + @override @classmethod async def make_client(cls, address: tuple[str, int], timeout: float) -> Self: - """Construct a client connection (Client -> Server) to given server ``address``. - - :param address: Address of the server to connection to. - :param timeout: - Amount of seconds to wait for the connection to be established. - If connection can't be established within this time, :exc:`TimeoutError` will be raised. - This timeout is then also used for any further data receiving. - """ conn = asyncio.open_connection(address[0], address[1]) reader, writer = await asyncio.wait_for(conn, timeout=timeout) return cls(reader, writer, timeout) + @override async def _read(self, length: int) -> bytearray: - """Receive data sent through the connection. - - :param length: - Amount of bytes to be received. If the requested amount can't be received - (server didn't send that much data/server didn't send any data), an :exc:`IOError` - will be raised. - """ result = bytearray() while len(result) < length: new = await asyncio.wait_for(self.reader.read(length - len(result)), timeout=self.timeout) @@ -367,12 +345,12 @@ async def _read(self, length: int) -> bytearray: return result + @override async def _write(self, data: bytes) -> None: - """Send raw ``data`` through this specific connection.""" self.writer.write(data) + @override async def _close(self) -> None: - """Close the underlying connection.""" # Close automatically performs a graceful TCP connection shutdown too self.writer.close() @@ -394,42 +372,27 @@ def __init__(self, socket: T_SOCK, address: tuple[str, int]): self.socket = socket self.address = address + @override @classmethod def make_client(cls, address: tuple[str, int], timeout: float) -> Self: - """Construct a client connection (Client -> Server) to given server ``address``. - - :param address: Address of the server to connection to. - :param timeout: - Amount of seconds to wait for the connection to be established. - If connection can't be established within this time, :exc:`TimeoutError` will be raised. - This timeout is then also used for any further data receiving. - """ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) return cls(sock, address) + @override def _read(self, length: int | None = None) -> bytearray: - """Receive data sent through the connection. - - :param length: - For UDP connections, ``length`` parameter is ignored and not required. - Instead, UDP connections always read exactly :attr:`.BUFFER_SIZE` bytes. - - If the requested amount can't be received (server didn't send that much - data/server didn't send any data), an :exc:`IOError` will be raised. - """ result = bytearray() while len(result) == 0: received_data, server_addr = self.socket.recvfrom(self.BUFFER_SIZE) result.extend(received_data) return result + @override def _write(self, data: bytes) -> None: - """Send raw ``data`` through this specific connection.""" self.socket.sendto(data, self.address) + @override def _close(self) -> None: - """Close the underlying connection.""" self.socket.close() @@ -443,39 +406,25 @@ def __init__(self, stream: T_DATAGRAM_CLIENT, timeout: float): self.stream = stream self.timeout = timeout + @override @classmethod async def make_client(cls, address: tuple[str, int], timeout: float) -> Self: - """Construct a client connection (Client -> Server) to given server ``address``. - - :param address: Address of the server to connection to. - :param timeout: - Amount of seconds to wait for the connection to be established. - If connection can't be established within this time, :exc:`TimeoutError` will be raised. - This timeout is then also used for any further data receiving. - """ conn = asyncio_dgram.connect(address) stream = await asyncio.wait_for(conn, timeout=timeout) return cls(stream, timeout) + @override async def _read(self, length: int | None = None) -> bytearray: - """Receive data sent through the connection. - - :param length: - For UDP connections, ``length`` parameter is ignored and not required. - - If the requested amount can't be received (server didn't send that much - data/server didn't send any data), an :exc:`IOError` will be raised. - """ result = bytearray() while len(result) == 0: received_data, server_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout) result.extend(received_data) return result + @override async def _write(self, data: bytes) -> None: - """Send raw ``data`` through this specific connection.""" await self.stream.send(data) + @override async def _close(self) -> None: - """Close the underlying connection.""" self.stream.close() diff --git a/mcproto/multiplayer.py b/mcproto/multiplayer.py index c63114c7..778a01ae 100644 --- a/mcproto/multiplayer.py +++ b/mcproto/multiplayer.py @@ -7,6 +7,7 @@ import httpx from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat +from typing_extensions import override from mcproto.auth.account import Account @@ -75,6 +76,7 @@ def msg(self) -> str: return " ".join(msg_parts) + @override def __repr__(self) -> str: return f"{self.__class__.__name__}({self.msg})" @@ -97,6 +99,7 @@ def __init__(self, response: httpx.Response, client_username: str, server_hash: self.client_ip = client_ip super().__init__(repr(self)) + @override def __repr__(self) -> str: msg = "Unable to verify user join for " msg += f"username={self.client_username!r}, server_hash={self.server_hash!r}, client_ip={self.client_ip!r}" diff --git a/mcproto/packets/handshaking/handshake.py b/mcproto/packets/handshaking/handshake.py index 80ddcb15..46dbc219 100644 --- a/mcproto/packets/handshaking/handshake.py +++ b/mcproto/packets/handshaking/handshake.py @@ -3,7 +3,7 @@ from enum import IntEnum from typing import ClassVar, final -from typing_extensions import Self +from typing_extensions import Self, override from mcproto.buffer import Buffer from mcproto.packets.packet import GameState, ServerBoundPacket @@ -58,8 +58,8 @@ def __init__( self.server_port = server_port self.next_state = next_state + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() buf.write_varint(self.protocol_version) buf.write_utf(self.server_address) @@ -67,9 +67,9 @@ def serialize(self) -> Buffer: buf.write_varint(self.next_state.value) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" return cls( protocol_version=buf.read_varint(), server_address=buf.read_utf(), diff --git a/mcproto/packets/login/login.py b/mcproto/packets/login/login.py index d5e0dcf7..782aa505 100644 --- a/mcproto/packets/login/login.py +++ b/mcproto/packets/login/login.py @@ -5,7 +5,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, load_der_public_key -from typing_extensions import Self +from typing_extensions import Self, override from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket @@ -42,16 +42,16 @@ def __init__(self, *, username: str, uuid: UUID): self.username = username self.uuid = uuid + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() buf.write_utf(self.username) buf.extend(self.uuid.serialize()) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" username = buf.read_utf() uuid = UUID.deserialize(buf) return cls(username=username, uuid=uuid) @@ -80,8 +80,8 @@ def __init__(self, *, server_id: str | None = None, public_key: RSAPublicKey, ve self.public_key = public_key self.verify_token = verify_token + @override def serialize(self) -> Buffer: - """Serialize the packet.""" public_key_raw = self.public_key.public_bytes(encoding=Encoding.DER, format=PublicFormat.SubjectPublicKeyInfo) buf = Buffer() @@ -90,9 +90,9 @@ def serialize(self) -> Buffer: buf.write_bytearray(self.verify_token) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" server_id = buf.read_utf() public_key_raw = buf.read_bytearray() verify_token = buf.read_bytearray() @@ -122,16 +122,16 @@ def __init__(self, *, shared_secret: bytes, verify_token: bytes): self.shared_secret = shared_secret self.verify_token = verify_token + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() buf.write_bytearray(self.shared_secret) buf.write_bytearray(self.verify_token) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" shared_secret = buf.read_bytearray() verify_token = buf.read_bytearray() return cls(shared_secret=shared_secret, verify_token=verify_token) @@ -155,16 +155,16 @@ def __init__(self, uuid: UUID, username: str): self.uuid = uuid self.username = username + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() buf.extend(self.uuid.serialize()) buf.write_utf(self.username) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" uuid = UUID.deserialize(buf) username = buf.read_utf() return cls(uuid, username) @@ -186,13 +186,13 @@ def __init__(self, reason: ChatMessage): """ self.reason = reason + @override def serialize(self) -> Buffer: - """Serialize the packet.""" return self.reason.serialize() + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" reason = ChatMessage.deserialize(buf) return cls(reason) @@ -217,17 +217,17 @@ def __init__(self, message_id: int, channel: str, data: bytes): self.channel = channel self.data = data + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() buf.write_varint(self.message_id) buf.write_utf(self.channel) buf.write(self.data) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" message_id = buf.read_varint() channel = buf.read_utf() data = buf.read(buf.remaining) # All of the remaining data in the buffer @@ -252,16 +252,16 @@ def __init__(self, message_id: int, data: bytes | None): self.message_id = message_id self.data = data + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() buf.write_varint(self.message_id) buf.write_optional(self.data, buf.write) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" message_id = buf.read_varint() data = buf.read_optional(lambda: buf.read(buf.remaining)) return cls(message_id, data) @@ -288,14 +288,14 @@ def __init__(self, threshold: int): """ self.threshold = threshold + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() buf.write_varint(self.threshold) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" threshold = buf.read_varint() return cls(threshold) diff --git a/mcproto/packets/packet.py b/mcproto/packets/packet.py index a9cd2186..37737e70 100644 --- a/mcproto/packets/packet.py +++ b/mcproto/packets/packet.py @@ -5,7 +5,7 @@ from enum import IntEnum from typing import ClassVar -from typing_extensions import Self +from typing_extensions import Self, override from mcproto.buffer import Buffer from mcproto.utils.abc import RequiredParamsABCMixin, Serializable @@ -46,9 +46,9 @@ class Packet(Serializable, RequiredParamsABCMixin): PACKET_ID: ClassVar[int] GAME_STATE: ClassVar[GameState] + @override @classmethod def deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" try: return cls._deserialize(buf) except IOError as exc: @@ -57,7 +57,6 @@ def deserialize(cls, buf: Buffer, /) -> Self: @classmethod @abstractmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" raise NotImplementedError @@ -153,5 +152,6 @@ def msg(self) -> str: return " ".join(msg_parts) + @override def __repr__(self) -> str: return f"{self.__class__.__name__}({self.msg})" diff --git a/mcproto/packets/status/ping.py b/mcproto/packets/status/ping.py index 6aebcc6a..6ab4e355 100644 --- a/mcproto/packets/status/ping.py +++ b/mcproto/packets/status/ping.py @@ -2,7 +2,7 @@ from typing import ClassVar, final -from typing_extensions import Self +from typing_extensions import Self, override from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket @@ -29,14 +29,14 @@ def __init__(self, payload: int): """ self.payload = payload + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() buf.write_value(StructFormat.LONGLONG, self.payload) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" payload = buf.read_value(StructFormat.LONGLONG) return cls(payload) diff --git a/mcproto/packets/status/status.py b/mcproto/packets/status/status.py index 81abd826..4d1ad173 100644 --- a/mcproto/packets/status/status.py +++ b/mcproto/packets/status/status.py @@ -3,7 +3,7 @@ import json from typing import Any, ClassVar, final -from typing_extensions import Self +from typing_extensions import Self, override from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket @@ -20,13 +20,13 @@ class StatusRequest(ServerBoundPacket): PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.STATUS + @override def serialize(self) -> Buffer: # pragma: no cover, nothing to test here. - """Serialize the packet.""" return Buffer() + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to test here. - """Deserialize the packet.""" return cls() @@ -46,16 +46,16 @@ def __init__(self, data: dict[str, Any]): """ self.data = data + @override def serialize(self) -> Buffer: - """Serialize the packet.""" buf = Buffer() s = json.dumps(self.data) buf.write_utf(s) return buf + @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize the packet.""" s = buf.read_utf() data_ = json.loads(s) return cls(data_) diff --git a/mcproto/types/chat.py b/mcproto/types/chat.py index 669407f6..8b915aa2 100644 --- a/mcproto/types/chat.py +++ b/mcproto/types/chat.py @@ -3,7 +3,7 @@ import json from typing import TypedDict, Union, final -from typing_extensions import Self, TypeAlias +from typing_extensions import Self, TypeAlias, override from mcproto.buffer import Buffer from mcproto.types.abc import MCType @@ -53,6 +53,7 @@ def as_dict(self) -> RawChatMessageDict: # pragma: no cover raise TypeError(f"Found unexpected type ({self.raw.__class__!r}) ({self.raw!r}) in `raw` attribute") + @override def __eq__(self, other: object) -> bool: """Check equality between two chat messages. @@ -65,16 +66,16 @@ def __eq__(self, other: object) -> bool: return self.raw == other.raw + @override def serialize(self) -> Buffer: - """Serialize the chat message.""" txt = json.dumps(self.raw) buf = Buffer() buf.write_utf(txt) return buf + @override @classmethod def deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize a chat message.""" txt = buf.read_utf() dct = json.loads(txt) return cls(dct) diff --git a/mcproto/types/uuid.py b/mcproto/types/uuid.py index 18bfa28b..97fddde8 100644 --- a/mcproto/types/uuid.py +++ b/mcproto/types/uuid.py @@ -3,7 +3,7 @@ import uuid from typing import final -from typing_extensions import Self +from typing_extensions import Self, override from mcproto.buffer import Buffer from mcproto.types.abc import MCType @@ -21,13 +21,13 @@ class UUID(MCType, uuid.UUID): __slots__ = () + @override def serialize(self) -> Buffer: - """Serialize the UUID.""" buf = Buffer() buf.write(self.bytes) return buf + @override @classmethod def deserialize(cls, buf: Buffer, /) -> Self: - """Deserialize a UUID.""" return cls(bytes=bytes(buf.read(16))) diff --git a/pyproject.toml b/pyproject.toml index 33172e7b..79c3cf3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ reportUnnecessaryCast = "error" reportUnnecessaryComparison = "error" reportUnnecessaryContains = "error" reportUnnecessaryTypeIgnoreComment = "error" +reportImplicitOverride = "error" reportShadowedImports = "error" [tool.ruff] diff --git a/tests/helpers.py b/tests/helpers.py index 4923ec54..44de2032 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,7 +6,7 @@ from collections.abc import Callable, Coroutine from typing import Any, Generic, TypeVar -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, override T = TypeVar("T") P = ParamSpec("P") @@ -51,6 +51,7 @@ class SynchronizedMixin: _WRAPPED_ATTRIBUTE: str + @override def __getattribute__(self, __name: str) -> Any: """Return attributes of the wrapped object, if the attribute is a coroutine function, synchronize it. @@ -72,6 +73,7 @@ def __getattribute__(self, __name: str) -> Any: return super().__getattribute__(__name) + @override def __setattr__(self, __name: str, __value: object) -> None: """Allow for changing attributes of the wrapped object. diff --git a/tests/mcproto/protocol/helpers.py b/tests/mcproto/protocol/helpers.py index 5cf689a9..21679ed7 100644 --- a/tests/mcproto/protocol/helpers.py +++ b/tests/mcproto/protocol/helpers.py @@ -2,6 +2,8 @@ from unittest.mock import AsyncMock, Mock +from typing_extensions import override + class WriteFunctionMock(Mock): """Mock write function, storing the written data.""" @@ -10,6 +12,7 @@ def __init__(self, *a, **kw): super().__init__(*a, **kw) self.combined_data = bytearray() + @override def __call__(self, data: bytes) -> None: # pyright: ignore[reportIncompatibleMethodOverride] """Override mock's ``__call__`` to extend our :attr:`.combined_data` bytearray. @@ -20,6 +23,7 @@ def __call__(self, data: bytes) -> None: # pyright: ignore[reportIncompatibleMe self.combined_data.extend(data) return super().__call__(data) + @override def assert_has_data(self, data: bytearray, ensure_called: bool = True) -> None: """Ensure that the combined write data by the mocked function matches expected ``data``.""" if ensure_called: @@ -42,6 +46,7 @@ def __init__(self, *a, combined_data: bytearray | None = None, **kw): combined_data = bytearray() self.combined_data = combined_data + @override def __call__(self, length: int) -> bytearray: # pyright: ignore[reportIncompatibleMethodOverride] """Override mock's __call__ to make it return part of our :attr:`.combined_data` bytearray. @@ -54,6 +59,7 @@ def __call__(self, length: int) -> bytearray: # pyright: ignore[reportIncompati del self.combined_data[:length] return super().__call__(length) + @override def assert_read_everything(self, ensure_called: bool = True) -> None: """Ensure that the passed :attr:`.combined_data` was fully read and depleted.""" if ensure_called: diff --git a/tests/mcproto/protocol/test_base_io.py b/tests/mcproto/protocol/test_base_io.py index d3262f0b..96ad4221 100644 --- a/tests/mcproto/protocol/test_base_io.py +++ b/tests/mcproto/protocol/test_base_io.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, Mock import pytest +from typing_extensions import override from mcproto.protocol.base_io import ( BaseAsyncReader, @@ -31,6 +32,7 @@ class SyncWriter(BaseSyncWriter): """Initializable concrete implementation of :class:`~mcproto.protocol.base_io.BaseSyncWriter` ABC.""" + @override def write(self, data: bytes) -> None: """Concrete implementation of abstract write method. @@ -55,6 +57,7 @@ def write(self, data: bytes) -> None: class SyncReader(BaseSyncReader): """Testable concrete implementation of :class:`~mcproto.protocol.base_io.BaseSyncReader` ABC.""" + @override def read(self, length: int) -> bytearray: """Concrete implementation of abstract read method. @@ -79,6 +82,7 @@ def read(self, length: int) -> bytearray: class AsyncWriter(BaseAsyncWriter): """Initializable concrete implementation of :class:`~mcproto.protocol.base_io.BaseAsyncWriter` ABC.""" + @override async def write(self, data: bytes) -> None: """Concrete implementation of abstract write method. @@ -103,6 +107,7 @@ async def write(self, data: bytes) -> None: class AsyncReader(BaseAsyncReader): """Testable concrete implementation of BaseAsyncReader ABC.""" + @override async def read(self, length: int) -> bytearray: """Concrete implementation of abstract read method. @@ -576,36 +581,36 @@ def test_read_optional_false(self, method_mock: Mock | AsyncMock, read_mock: Rea class TestBaseSyncWriter(WriterTests[SyncWriter]): """Tests for individual write methods implemented in :class:`~mcproto.protocol.base_io.BaseSyncWriter`.""" + @override @classmethod def setup_class(cls): - """Initialize writer instance to be tested.""" cls.writer = SyncWriter() class TestBaseSyncReader(ReaderTests[SyncReader]): """Tests for individual write methods implemented in :class:`~mcproto.protocol.base_io.BaseSyncReader`.""" + @override @classmethod def setup_class(cls): - """Initialize reader instance to be tested.""" cls.reader = SyncReader() class TestBaseAsyncWriter(WriterTests[AsyncWriter]): """Tests for individual write methods implemented in :class:`~mcproto.protocol.base_io.BaseSyncReader`.""" + @override @classmethod def setup_class(cls): - """Initialize writer instance to be tested.""" cls.writer = WrappedAsyncWriter() # type: ignore class TestBaseAsyncReader(ReaderTests[AsyncReader]): """Tests for individual write methods implemented in :class:`~mcproto.protocol.base_io.BaseSyncReader`.""" + @override @classmethod def setup_class(cls): - """Initialize writer instance to be tested.""" cls.reader = WrappedAsyncReader() # type: ignore diff --git a/tests/mcproto/test_connection.py b/tests/mcproto/test_connection.py index 49f92469..7a8b4013 100644 --- a/tests/mcproto/test_connection.py +++ b/tests/mcproto/test_connection.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock import pytest +from typing_extensions import override from mcproto.connection import TCPAsyncConnection, TCPSyncConnection from tests.helpers import CustomMockMixin @@ -28,22 +29,26 @@ def __init__(self, *args, read_data: bytearray | None = None, **kwargs) -> None: self._send = WriteFunctionMock() self._closed = False + @override def send(self, data: bytearray) -> None: """Mock version of send method, raising :exc:`OSError` if the socket was closed.""" if self._closed: raise OSError(errno.EBADF, "Bad file descriptor") return self._send(data) + @override def recv(self, length: int) -> bytearray: """Mock version of recv method, raising :exc:`OSError` if the socket was closed.""" if self._closed: raise OSError(errno.EBADF, "Bad file descriptor") return self._recv(length) + @override def close(self) -> None: """Mock version of close method, setting :attr:`_closed` bool flag.""" self._closed = True + @override def shutdown(self, __how: int, /) -> None: """Mock version of shutdown, without any real implementation.""" pass @@ -60,12 +65,14 @@ def __init__(self, *args, **kwargs): self._write = WriteFunctionMock() self._closed = False + @override def write(self, data: bytearray) -> None: """Mock version of write method, raising :exc:`OSError` if the writer was closed.""" if self._closed: raise OSError(errno.EBADF, "Bad file descriptor") return self._write(data) + @override def close(self) -> None: """Mock version of close method, setting :attr:`_closed` bool flag.""" self._closed = True @@ -81,6 +88,7 @@ def __init__(self, *args, read_data: bytearray | None = None, **kwargs) -> None: self.mock_add_spec(["_read"]) self._read = ReadFunctionAsyncMock(combined_data=read_data) + @override def read(self, length: int) -> bytearray: """Mock version of read, using the mocked read method.""" return self._read(length)