Skip to content

Commit

Permalink
improve type hints, apply type related fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
csernazs committed Jan 13, 2025
1 parent 75d69e5 commit c62f4e0
Showing 1 changed file with 53 additions and 34 deletions.
87 changes: 53 additions & 34 deletions pytest_httpserver/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable
from typing import ClassVar
from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import Optional
Expand All @@ -34,6 +35,9 @@

if TYPE_CHECKING:
from ssl import SSLContext
from types import TracebackType

from werkzeug.serving import BaseWSGIServer

URI_DEFAULT = ""
METHOD_ALL = "__ALL"
Expand Down Expand Up @@ -113,11 +117,13 @@ def complete(self, result: bool): # noqa: FBT001

@property
def result(self) -> bool:
return self._result
return bool(self._result)

@property
def elapsed_time(self) -> float:
"""Elapsed time in seconds"""
if self._stop is None:
raise TypeError("unsupported operand type(s) for -: 'NoneType' and 'float'")
return self._stop - self._start


Expand All @@ -139,7 +145,7 @@ def authorization_header_value_matcher(actual: str | None, expected: str) -> boo
func = getattr(Authorization, "from_header", None)
if func is None: # Werkzeug < 2.3.0
func = werkzeug.http.parse_authorization_header # type: ignore[attr-defined]
return func(actual) == func(expected)
return func(actual) == func(expected) # type: ignore

@staticmethod
def default_header_value_matcher(actual: str | None, expected: str) -> bool:
Expand Down Expand Up @@ -174,7 +180,7 @@ def match(self, request_query_string: bytes) -> bool:
return values[0] == values[1]

@abc.abstractmethod
def get_comparing_values(self, request_query_string: bytes) -> tuple:
def get_comparing_values(self, request_query_string: bytes) -> tuple[Any, Any]:
pass


Expand All @@ -195,10 +201,10 @@ def __init__(self, query_string: bytes | str):

self.query_string = query_string

def get_comparing_values(self, request_query_string: bytes) -> tuple:
def get_comparing_values(self, request_query_string: bytes) -> tuple[bytes, bytes]:
if isinstance(self.query_string, str):
query_string = self.query_string.encode()
elif isinstance(self.query_string, bytes):
elif isinstance(self.query_string, bytes): # type: ignore
query_string = self.query_string
else:
raise TypeError("query_string must be a string, or a bytes-like object")
Expand All @@ -211,7 +217,7 @@ class MappingQueryMatcher(QueryMatcher):
Matches a query string to a dictionary or MultiDict specified
"""

def __init__(self, query_dict: Mapping | MultiDict):
def __init__(self, query_dict: Mapping[str, str] | MultiDict[str, str]):
"""
:param query_dict: if dictionary (Mapping) is specified, it will be used as a
key-value mapping where both key and value should be string. If there are multiple
Expand All @@ -221,7 +227,7 @@ def __init__(self, query_dict: Mapping | MultiDict):
"""
self.query_dict = query_dict

def get_comparing_values(self, request_query_string: bytes) -> tuple:
def get_comparing_values(self, request_query_string: bytes) -> tuple[Mapping[str, str], Mapping[str, str]]:
query = MultiDict(urllib.parse.parse_qsl(request_query_string.decode("utf-8")))
if isinstance(self.query_dict, MultiDict):
return (query, self.query_dict)
Expand All @@ -241,14 +247,14 @@ def __init__(self, result: bool): # noqa: FBT001
"""
self.result = result

def get_comparing_values(self, request_query_string): # noqa: ARG002
def get_comparing_values(self, request_query_string: bytes): # noqa: ARG002
if self.result:
return (True, True)
else:
return (True, False)


def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping) -> QueryMatcher:
def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping[str, str]) -> QueryMatcher:
if isinstance(query_string, QueryMatcher):
return query_string

Expand Down Expand Up @@ -312,7 +318,7 @@ def __init__(
data: str | bytes | None = None,
data_encoding: str = "utf-8",
headers: Mapping[str, str] | None = None,
query_string: None | QueryMatcher | str | bytes | Mapping = None,
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
header_value_matcher: HVMATCHER_T | None = None,
json: Any = UNDEFINED,
):
Expand Down Expand Up @@ -410,7 +416,7 @@ def match_json(self, request: Request) -> bool:

return json_received == self.json

def difference(self, request: Request) -> list[tuple]:
def difference(self, request: Request) -> list[tuple[str, str, str | URIPattern]]:
"""
Calculates the difference between the matcher and the request.
Expand All @@ -422,7 +428,7 @@ def difference(self, request: Request) -> list[tuple]:
matches the fields set in the matcher object.
"""

retval: list[tuple] = []
retval: list[tuple[str, Any, Any]] = []

if not self.match_uri(request):
retval.append(("uri", request.path, self.uri))
Expand All @@ -433,8 +439,8 @@ def difference(self, request: Request) -> list[tuple]:
if not self.query_matcher.match(request.query_string):
retval.append(("query_string", request.query_string, self.query_string))

request_headers = {}
expected_headers = {}
request_headers: dict[str, str | None] = {}
expected_headers: dict[str, str] = {}
for key, value in self.headers.items():
if not self.header_value_matcher(key, request.headers.get(key), value):
request_headers[key] = request.headers.get(key)
Expand Down Expand Up @@ -467,7 +473,7 @@ class RequestHandlerBase(abc.ABC):

def respond_with_json(
self,
response_json,
response_json: Any,
status: int = 200,
headers: Mapping[str, str] | None = None,
content_type: str = "application/json",
Expand Down Expand Up @@ -578,7 +584,7 @@ def __repr__(self) -> str:
return retval


class RequestHandlerList(list):
class RequestHandlerList(List[RequestHandler]):
"""
Represents a list of :py:class:`RequestHandler` objects.
Expand Down Expand Up @@ -638,9 +644,9 @@ def __init__(
"""
self.host = host
self.port = port
self.server = None
self.server_thread = None
self.assertions: list[str] = []
self.server: BaseWSGIServer | None = None
self.server_thread: threading.Thread | None = None
self.assertions: list[str | AssertionError] = []
self.handler_errors: list[Exception] = []
self.log: list[tuple[Request, Response]] = []
self.ssl_context = ssl_context
Expand Down Expand Up @@ -727,7 +733,7 @@ def thread_target(self):
This should not be called directly, but can be overridden to tailor it to your needs.
"""

assert self.server is not None
self.server.serve_forever()

def is_running(self) -> bool:
Expand All @@ -736,7 +742,7 @@ def is_running(self) -> bool:
"""
return bool(self.server)

def start(self):
def start(self) -> None:
"""
Start the server in a thread.
Expand All @@ -755,9 +761,16 @@ def start(self):
if self.is_running():
raise HTTPServerError("Server is already running")

app = Request.application(self.application)

self.server = make_server(
self.host, self.port, self.application, ssl_context=self.ssl_context, threaded=self.threaded
self.host,
self.port,
app,
ssl_context=self.ssl_context,
threaded=self.threaded,
)

self.port = self.server.port # Update port (needed if `port` was set to 0)
self.server_thread = threading.Thread(target=self.thread_target)
self.server_thread.start()
Expand All @@ -772,14 +785,16 @@ def stop(self):
Only a running server can be stopped. If the sever is not running, :py:class`HTTPServerError`
will be raised.
"""
assert self.server is not None
assert self.server_thread is not None
if not self.is_running():
raise HTTPServerError("Server is not running")
self.server.shutdown()
self.server_thread.join()
self.server = None
self.server_thread = None

def add_assertion(self, obj):
def add_assertion(self, obj: str | AssertionError):
"""
Add a new assertion
Expand Down Expand Up @@ -848,8 +863,7 @@ def dispatch(self, request: Request) -> Response:
:return: the response object what the handler responded, or a response which contains the error
"""

@Request.application # type: ignore
def application(self, request: Request):
def application(self, request: Request) -> Response:
"""
Entry point of werkzeug.
Expand All @@ -875,7 +889,12 @@ def __enter__(self):
self.start()
return self

def __exit__(self, *args, **kwargs):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
):
"""
Provide the context API
Expand All @@ -886,7 +905,7 @@ def __exit__(self, *args, **kwargs):
self.stop()

@staticmethod
def format_host(host):
def format_host(host: str):
"""
Formats a hostname so it can be used in a URL.
Notably, this adds brackets around IPV6 addresses when
Expand Down Expand Up @@ -929,8 +948,8 @@ class HTTPServer(HTTPServerBase): # pylint: disable=too-many-instance-attribute

def __init__(
self,
host=DEFAULT_LISTEN_HOST,
port=DEFAULT_LISTEN_PORT,
host: str = DEFAULT_LISTEN_HOST,
port: int = DEFAULT_LISTEN_PORT,
ssl_context: SSLContext | None = None,
default_waiting_settings: WaitingSettings | None = None,
*,
Expand Down Expand Up @@ -979,7 +998,7 @@ def expect_request(
data: str | bytes | None = None,
data_encoding: str = "utf-8",
headers: Mapping[str, str] | None = None,
query_string: None | QueryMatcher | str | bytes | Mapping = None,
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
header_value_matcher: HVMATCHER_T | None = None,
handler_type: HandlerType = HandlerType.PERMANENT,
json: Any = UNDEFINED,
Expand Down Expand Up @@ -1062,7 +1081,7 @@ def expect_oneshot_request(
data: str | bytes | None = None,
data_encoding: str = "utf-8",
headers: Mapping[str, str] | None = None,
query_string: None | QueryMatcher | str | bytes | Mapping = None,
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
header_value_matcher: HVMATCHER_T | None = None,
json: Any = UNDEFINED,
) -> RequestHandler:
Expand Down Expand Up @@ -1117,7 +1136,7 @@ def expect_ordered_request(
data: str | bytes | None = None,
data_encoding: str = "utf-8",
headers: Mapping[str, str] | None = None,
query_string: None | QueryMatcher | str | bytes | Mapping = None,
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
header_value_matcher: HVMATCHER_T | None = None,
json: Any = UNDEFINED,
) -> RequestHandler:
Expand Down Expand Up @@ -1175,13 +1194,13 @@ def format_matchers(self) -> str:
This method is primarily used when reporting errors.
"""

def format_handlers(handlers):
def format_handlers(handlers: list[RequestHandler]):
if handlers:
return [" {!r}".format(handler.matcher) for handler in handlers]
else:
return [" none"]

lines = []
lines: list[str] = []
lines.append("Ordered matchers:")
lines.extend(format_handlers(self.ordered_handlers))
lines.append("")
Expand Down

0 comments on commit c62f4e0

Please sign in to comment.