Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Generic for TestClient.app #8977

Merged
merged 8 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/8977.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Made ``TestClient.app`` a ``Generic`` so type checkers will know the correct type (avoiding unneeded ``client.app is not None`` checks) -- by :user:`Dreamsorcerer`.
14 changes: 7 additions & 7 deletions aiohttp/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ async def __call__(
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> TestClient[Request]: ...
) -> TestClient[Request, Application]: ...
Dismissed Show dismissed Hide dismissed
@overload
async def __call__(
self,
__param: BaseTestServer[_Request],
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> TestClient[_Request]: ...
) -> TestClient[_Request, None]: ...
Dismissed Show dismissed Hide dismissed


class AiohttpServer(Protocol):
Expand Down Expand Up @@ -349,7 +349,7 @@ async def finalize() -> None:


@pytest.fixture
def aiohttp_client_cls() -> Type[TestClient[Any]]:
def aiohttp_client_cls() -> Type[TestClient[Any, Any]]:
"""
Client class to use in ``aiohttp_client`` factory.

Expand Down Expand Up @@ -377,7 +377,7 @@ def test_login(aiohttp_client):

@pytest.fixture
def aiohttp_client(
loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient[Any]]
loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient[Any, Any]]
) -> Iterator[AiohttpClient]:
"""Factory to create a TestClient instance.

Expand All @@ -393,20 +393,20 @@ async def go(
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> TestClient[Request]: ...
) -> TestClient[Request, Application]: ...
Dismissed Show dismissed Hide dismissed
@overload
async def go(
__param: BaseTestServer[_Request],
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> TestClient[_Request]: ...
) -> TestClient[_Request, None]: ...
Dismissed Show dismissed Hide dismissed
async def go(
__param: Union[Application, BaseTestServer[Any]],
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> TestClient[Any]:
) -> TestClient[Any, Any]:
if isinstance(__param, Application):
server_kwargs = server_kwargs or {}
server = TestServer(__param, **server_kwargs)
Expand Down
26 changes: 22 additions & 4 deletions aiohttp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TypeVar,
Union,
cast,
overload,
)
from unittest import IsolatedAsyncioTestCase, mock

Expand Down Expand Up @@ -72,6 +73,7 @@
else:
Self = Any

_ApplicationNone = TypeVar("_ApplicationNone", Application, None)
_Request = TypeVar("_Request", bound=BaseRequest)

REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
Expand Down Expand Up @@ -251,7 +253,7 @@ async def _make_runner(self, **kwargs: Any) -> ServerRunner:
return ServerRunner(srv, **kwargs)


class TestClient(Generic[_Request]):
class TestClient(Generic[_Request, _ApplicationNone]):
"""
A test client implementation.

Expand All @@ -261,7 +263,23 @@ class TestClient(Generic[_Request]):

__test__ = False

@overload
def __init__(
self: "TestClient[Request, Application]",
server: TestServer,
*,
cookie_jar: Optional[AbstractCookieJar] = None,
**kwargs: Any,
) -> None: ...
Dismissed Show dismissed Hide dismissed
@overload
def __init__(
self: "TestClient[_Request, None]",
server: BaseTestServer[_Request],
*,
cookie_jar: Optional[AbstractCookieJar] = None,
**kwargs: Any,
) -> None: ...
Dismissed Show dismissed Hide dismissed
def __init__( # type: ignore[misc]
self,
server: BaseTestServer[_Request],
*,
Expand Down Expand Up @@ -300,8 +318,8 @@ def server(self) -> BaseTestServer[_Request]:
return self._server

@property
def app(self) -> Optional[Application]:
return cast(Optional[Application], getattr(self._server, "app", None))
def app(self) -> _ApplicationNone:
return getattr(self._server, "app", None) # type: ignore[return-value]

@property
def session(self) -> ClientSession:
Expand Down Expand Up @@ -505,7 +523,7 @@ async def get_server(self, app: Application) -> TestServer:
"""Return a TestServer instance."""
return TestServer(app)

async def get_client(self, server: TestServer) -> TestClient[Request]:
async def get_client(self, server: TestServer) -> TestClient[Request, Application]:
"""Return a TestClient instance."""
return TestClient(server)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3795,7 +3795,7 @@ async def handler(request: web.Request) -> web.Response:
await resp.write(b"1" * 1000)
await asyncio.sleep(0.01)

async def request(client: TestClient[web.Request]) -> None:
async def request(client: TestClient[web.Request, web.Application]) -> None:
timeout = aiohttp.ClientTimeout(total=0.5)
async with await client.get("/", timeout=timeout) as resp:
with pytest.raises(asyncio.TimeoutError):
Expand Down
14 changes: 11 additions & 3 deletions tests/test_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import gzip
import socket
import sys
from typing import Callable, Iterator, Mapping, NoReturn
from unittest import mock

Expand All @@ -20,7 +21,10 @@
make_mocked_request,
)

_TestClient = TestClient[web.Request]
if sys.version_info >= (3, 11):
from typing import assert_type

_TestClient = TestClient[web.Request, web.Application]

_hello_world_str = "Hello, world"
_hello_world_bytes = _hello_world_str.encode("utf-8")
Expand Down Expand Up @@ -71,7 +75,7 @@ def app() -> web.Application:
def test_client(
loop: asyncio.AbstractEventLoop, app: web.Application
) -> Iterator[_TestClient]:
async def make_client() -> TestClient[web.Request]:
async def make_client() -> TestClient[web.Request, web.Application]:
return TestClient(TestServer(app))

client = loop.run_until_complete(make_client())
Expand Down Expand Up @@ -239,6 +243,8 @@ async def test_test_client_props() -> None:
async with client:
assert isinstance(client.port, int)
assert client.server is not None
if sys.version_info >= (3, 11):
assert_type(client.app, web.Application)
assert client.app is not None
assert client.port == 0

Expand All @@ -255,6 +261,8 @@ async def hello(request: web.BaseRequest) -> NoReturn:
async with client:
assert isinstance(client.port, int)
assert client.server is not None
if sys.version_info >= (3, 11):
assert_type(client.app, None)
assert client.app is None
assert client.port == 0

Expand All @@ -271,7 +279,7 @@ async def test_test_server_context_manager(loop: asyncio.AbstractEventLoop) -> N

def test_client_unsupported_arg() -> None:
with pytest.raises(TypeError) as e:
TestClient("string") # type: ignore[arg-type]
TestClient("string") # type: ignore[call-overload]

assert (
str(e.value) == "server must be TestServer instance, found type: <class 'str'>"
Expand Down
6 changes: 4 additions & 2 deletions tests/test_web_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from aiohttp.test_utils import TestClient
from aiohttp.typedefs import Handler, Middleware

CLI = Callable[[Iterable[Middleware]], Awaitable[TestClient[web.Request]]]
CLI = Callable[
[Iterable[Middleware]], Awaitable[TestClient[web.Request, web.Application]]
]


async def test_middleware_modifies_response(
Expand Down Expand Up @@ -169,7 +171,7 @@ async def handler(request: web.Request) -> web.Response:

def wrapper(
extra_middlewares: Iterable[Middleware],
) -> Awaitable[TestClient[web.Request]]:
) -> Awaitable[TestClient[web.Request, web.Application]]:
app = web.Application()
app.router.add_route("GET", "/resource1", handler)
app.router.add_route("GET", "/resource2/", handler)
Expand Down
Loading