Skip to content

Commit

Permalink
Enforce PEP 698 (override decorator)
Browse files Browse the repository at this point in the history
This enables the `reportImplicitOverride` pyright toggle, enabling the
enforcement that any overridden methods in classes be marked with the
`typing.override` (or pre 3.12, with `typing_extensions.override`)
  • Loading branch information
ItsDrike committed Apr 29, 2024
1 parent 659c318 commit adb5f8c
Show file tree
Hide file tree
Showing 23 changed files with 109 additions and 116 deletions.
1 change: 1 addition & 0 deletions changes/131.internal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Any overridden methods in any classes now have to explicitly use the `typing.override` decorator (see [PEP 698](https://peps.python.org/pep-0698/))
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pathlib import Path

from packaging.version import parse as parse_version
from typing_extensions import override

if sys.version_info >= (3, 11):
from tomllib import load as toml_parse
Expand Down Expand Up @@ -172,6 +173,7 @@ def mock_autodoc() -> None:
from sphinx.ext import autodoc

class MockedClassDocumenter(autodoc.ClassDocumenter):
@override
def add_line(self, line: str, source: str, *lineno: int) -> None:
if line == " Bases: :py:class:`object`":
return
Expand Down
2 changes: 2 additions & 0 deletions docs/extensions/attributetable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sphinx.util.docutils import SphinxDirective
from sphinx.util.typing import OptionSpec
from sphinx.writers.html5 import HTML5Translator
from typing_extensions import override


class AttributeTable(nodes.General, nodes.Element):
Expand Down Expand Up @@ -111,6 +112,7 @@ def parse_name(self, content: str) -> tuple[str, str]:

return modulename, name

@override
def run(self) -> list[AttributeTablePlaceholder]:
"""If you're curious on the HTML this is meant to generate:
Expand Down
2 changes: 2 additions & 0 deletions mcproto/auth/account.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import httpx
from typing_extensions import override

from mcproto.types.uuid import UUID as McUUID # noqa: N811

Expand All @@ -22,6 +23,7 @@ def __init__(self, mismatched_variable: str, current: object, expected: object)
self.expected = expected
super().__init__(repr(self))

@override
def __repr__(self) -> str:
msg = f"Account has mismatched {self.missmatched_variable}: "
msg += f"current={self.current!r}, expected={self.expected!r}."
Expand Down
2 changes: 2 additions & 0 deletions mcproto/auth/microsoft/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TypedDict

import httpx
from typing_extensions import override

__all__ = [
"MicrosoftOauthResponseErrorType",
Expand Down Expand Up @@ -67,6 +68,7 @@ def msg(self) -> str:
return f"Unknown error: {self.error!r}"
return f"Error {self.err_type.name}: {self.err_type.value!r}"

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.msg!r})"

Expand Down
2 changes: 2 additions & 0 deletions mcproto/auth/microsoft/xbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import NamedTuple

import httpx
from typing_extensions import override

__all__ = [
"XSTSErrorType",
Expand Down Expand Up @@ -76,6 +77,7 @@ def msg(self) -> str:

return " ".join(msg_parts)

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.msg})"

Expand Down
3 changes: 2 additions & 1 deletion mcproto/auth/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum

import httpx
from typing_extensions import Self
from typing_extensions import Self, override

from mcproto.auth.account import Account
from mcproto.types.uuid import UUID as McUUID # noqa: N811
Expand Down Expand Up @@ -59,6 +59,7 @@ def msg(self) -> str:

return " ".join(msg_parts)

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.msg})"

Expand Down
3 changes: 2 additions & 1 deletion mcproto/auth/yggdrasil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from uuid import uuid4

import httpx
from typing_extensions import Self
from typing_extensions import Self, override

from mcproto.auth.account import Account
from mcproto.types.uuid import UUID as McUUID # noqa: N811
Expand Down Expand Up @@ -102,6 +102,7 @@ def msg(self) -> str:

return " ".join(msg_parts)

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.msg})"

Expand Down
5 changes: 5 additions & 0 deletions mcproto/buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing_extensions import override

from mcproto.protocol.base_io import BaseSyncReader, BaseSyncWriter

__all__ = ["Buffer"]
Expand All @@ -14,10 +16,12 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pos = 0

@override
def write(self, data: bytes) -> None:
"""Write/Store given ``data`` into the buffer."""
self.extend(data)

@override
def read(self, length: int) -> bytearray:
"""Read data stored in the buffer.
Expand Down Expand Up @@ -52,6 +56,7 @@ def read(self, length: int) -> bytearray:
finally:
self.pos = end

@override
def clear(self, only_already_read: bool = False) -> None:
"""Clear out the stored data and reset position.
Expand Down
93 changes: 21 additions & 72 deletions mcproto/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import asyncio_dgram
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from typing_extensions import ParamSpec, Self
from typing_extensions import ParamSpec, Self, override

from mcproto.protocol.base_io import BaseAsyncReader, BaseAsyncWriter, BaseSyncReader, BaseSyncWriter

Expand Down Expand Up @@ -94,6 +94,7 @@ def _write(self, data: bytes, /) -> None:
"""Send raw ``data`` through this specific connection."""
raise NotImplementedError

@override
def write(self, data: bytes, /) -> None:
"""Send given ``data`` over the connection.
Expand All @@ -116,6 +117,7 @@ def _read(self, length: int, /) -> bytearray:
"""
raise NotImplementedError

@override
def read(self, length: int, /) -> bytearray:
"""Receive data sent through the connection.
Expand Down Expand Up @@ -206,6 +208,7 @@ async def _write(self, data: bytes, /) -> None:
"""Send raw ``data`` through this specific connection."""
raise NotImplementedError

@override
async def write(self, data: bytes, /) -> None:
"""Send given ``data`` over the connection.
Expand All @@ -228,6 +231,7 @@ async def _read(self, length: int, /) -> bytearray:
"""
raise NotImplementedError

@override
async def read(self, length: int, /) -> bytearray:
"""Receive data sent through the connection.
Expand Down Expand Up @@ -263,28 +267,15 @@ def __init__(self, socket: T_SOCK):
super().__init__()
self.socket = socket

@override
@classmethod
def make_client(cls, address: tuple[str, int], timeout: float) -> Self:
"""Construct a client connection (Client -> Server) to given server ``address``.
:param address: Address of the server to connection to.
:param timeout:
Amount of seconds to wait for the connection to be established.
If connection can't be established within this time, :exc:`TimeoutError` will be raised.
This timeout is then also used for any further data receiving.
"""
sock = socket.create_connection(address, timeout=timeout)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return cls(sock)

@override
def _read(self, length: int) -> bytearray:
"""Receive raw data from this specific connection.
:param length:
Amount of bytes to be received. If the requested amount can't be received
(server didn't send that much data/server didn't send any data), an :exc:`IOError`
will be raised.
"""
result = bytearray()
while len(result) < length:
new = self.socket.recv(length - len(result))
Expand All @@ -301,12 +292,12 @@ def _read(self, length: int) -> bytearray:

return result

@override
def _write(self, data: bytes) -> None:
"""Send raw ``data`` through this specific connection."""
self.socket.send(data)

@override
def _close(self) -> None:
"""Close the underlying connection."""
# Gracefully end the connection first (shutdown), informing the other side
# we're disconnecting, and waiting for them to disconnect cleanly (TCP FIN)
try:
Expand All @@ -329,28 +320,15 @@ def __init__(self, reader: T_STREAMREADER, writer: T_STREAMWRITER, timeout: floa
self.writer = writer
self.timeout = timeout

@override
@classmethod
async def make_client(cls, address: tuple[str, int], timeout: float) -> Self:
"""Construct a client connection (Client -> Server) to given server ``address``.
:param address: Address of the server to connection to.
:param timeout:
Amount of seconds to wait for the connection to be established.
If connection can't be established within this time, :exc:`TimeoutError` will be raised.
This timeout is then also used for any further data receiving.
"""
conn = asyncio.open_connection(address[0], address[1])
reader, writer = await asyncio.wait_for(conn, timeout=timeout)
return cls(reader, writer, timeout)

@override
async def _read(self, length: int) -> bytearray:
"""Receive data sent through the connection.
:param length:
Amount of bytes to be received. If the requested amount can't be received
(server didn't send that much data/server didn't send any data), an :exc:`IOError`
will be raised.
"""
result = bytearray()
while len(result) < length:
new = await asyncio.wait_for(self.reader.read(length - len(result)), timeout=self.timeout)
Expand All @@ -367,12 +345,12 @@ async def _read(self, length: int) -> bytearray:

return result

@override
async def _write(self, data: bytes) -> None:
"""Send raw ``data`` through this specific connection."""
self.writer.write(data)

@override
async def _close(self) -> None:
"""Close the underlying connection."""
# Close automatically performs a graceful TCP connection shutdown too
self.writer.close()

Expand All @@ -394,42 +372,27 @@ def __init__(self, socket: T_SOCK, address: tuple[str, int]):
self.socket = socket
self.address = address

@override
@classmethod
def make_client(cls, address: tuple[str, int], timeout: float) -> Self:
"""Construct a client connection (Client -> Server) to given server ``address``.
:param address: Address of the server to connection to.
:param timeout:
Amount of seconds to wait for the connection to be established.
If connection can't be established within this time, :exc:`TimeoutError` will be raised.
This timeout is then also used for any further data receiving.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(timeout)
return cls(sock, address)

@override
def _read(self, length: int | None = None) -> bytearray:
"""Receive data sent through the connection.
:param length:
For UDP connections, ``length`` parameter is ignored and not required.
Instead, UDP connections always read exactly :attr:`.BUFFER_SIZE` bytes.
If the requested amount can't be received (server didn't send that much
data/server didn't send any data), an :exc:`IOError` will be raised.
"""
result = bytearray()
while len(result) == 0:
received_data, server_addr = self.socket.recvfrom(self.BUFFER_SIZE)
result.extend(received_data)
return result

@override
def _write(self, data: bytes) -> None:
"""Send raw ``data`` through this specific connection."""
self.socket.sendto(data, self.address)

@override
def _close(self) -> None:
"""Close the underlying connection."""
self.socket.close()


Expand All @@ -443,39 +406,25 @@ def __init__(self, stream: T_DATAGRAM_CLIENT, timeout: float):
self.stream = stream
self.timeout = timeout

@override
@classmethod
async def make_client(cls, address: tuple[str, int], timeout: float) -> Self:
"""Construct a client connection (Client -> Server) to given server ``address``.
:param address: Address of the server to connection to.
:param timeout:
Amount of seconds to wait for the connection to be established.
If connection can't be established within this time, :exc:`TimeoutError` will be raised.
This timeout is then also used for any further data receiving.
"""
conn = asyncio_dgram.connect(address)
stream = await asyncio.wait_for(conn, timeout=timeout)
return cls(stream, timeout)

@override
async def _read(self, length: int | None = None) -> bytearray:
"""Receive data sent through the connection.
:param length:
For UDP connections, ``length`` parameter is ignored and not required.
If the requested amount can't be received (server didn't send that much
data/server didn't send any data), an :exc:`IOError` will be raised.
"""
result = bytearray()
while len(result) == 0:
received_data, server_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout)
result.extend(received_data)
return result

@override
async def _write(self, data: bytes) -> None:
"""Send raw ``data`` through this specific connection."""
await self.stream.send(data)

@override
async def _close(self) -> None:
"""Close the underlying connection."""
self.stream.close()
3 changes: 3 additions & 0 deletions mcproto/multiplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import httpx
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from typing_extensions import override

from mcproto.auth.account import Account

Expand Down Expand Up @@ -75,6 +76,7 @@ def msg(self) -> str:

return " ".join(msg_parts)

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.msg})"

Expand All @@ -97,6 +99,7 @@ def __init__(self, response: httpx.Response, client_username: str, server_hash:
self.client_ip = client_ip
super().__init__(repr(self))

@override
def __repr__(self) -> str:
msg = "Unable to verify user join for "
msg += f"username={self.client_username!r}, server_hash={self.server_hash!r}, client_ip={self.client_ip!r}"
Expand Down
Loading

0 comments on commit adb5f8c

Please sign in to comment.