diff --git a/CHANGES/8977.bugfix.rst b/CHANGES/8977.bugfix.rst new file mode 100644 index 00000000000..7d21fe0c3fa --- /dev/null +++ b/CHANGES/8977.bugfix.rst @@ -0,0 +1 @@ +Made ``TestClient.app`` a ``Generic`` so type checkers will know the correct type (avoiding unneeded ``client.app is not None`` checks) -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index f5cdc8b1860..503bb6a9021 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -47,7 +47,7 @@ async def __call__( *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[Request]: ... + ) -> TestClient[Request, Application]: ... @overload async def __call__( self, @@ -55,7 +55,7 @@ async def __call__( *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[_Request]: ... + ) -> TestClient[_Request, None]: ... class AiohttpServer(Protocol): @@ -349,7 +349,7 @@ async def finalize() -> None: @pytest.fixture -def aiohttp_client_cls() -> Type[TestClient[Any]]: +def aiohttp_client_cls() -> Type[TestClient[Any, Any]]: """ Client class to use in ``aiohttp_client`` factory. @@ -377,7 +377,7 @@ def test_login(aiohttp_client): @pytest.fixture def aiohttp_client( - loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient[Any]] + loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient[Any, Any]] ) -> Iterator[AiohttpClient]: """Factory to create a TestClient instance. @@ -393,20 +393,20 @@ async def go( *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[Request]: ... + ) -> TestClient[Request, Application]: ... @overload async def go( __param: BaseTestServer[_Request], *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[_Request]: ... + ) -> TestClient[_Request, None]: ... async def go( __param: Union[Application, BaseTestServer[Any]], *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> TestClient[Any]: + ) -> TestClient[Any, Any]: if isinstance(__param, Application): server_kwargs = server_kwargs or {} server = TestServer(__param, **server_kwargs) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index b02b6664531..88dcf8ebf1b 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -23,6 +23,7 @@ TypeVar, Union, cast, + overload, ) from unittest import IsolatedAsyncioTestCase, mock @@ -72,6 +73,7 @@ else: Self = Any +_ApplicationNone = TypeVar("_ApplicationNone", Application, None) _Request = TypeVar("_Request", bound=BaseRequest) REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" @@ -251,7 +253,7 @@ async def _make_runner(self, **kwargs: Any) -> ServerRunner: return ServerRunner(srv, **kwargs) -class TestClient(Generic[_Request]): +class TestClient(Generic[_Request, _ApplicationNone]): """ A test client implementation. @@ -261,7 +263,23 @@ class TestClient(Generic[_Request]): __test__ = False + @overload def __init__( + self: "TestClient[Request, Application]", + server: TestServer, + *, + cookie_jar: Optional[AbstractCookieJar] = None, + **kwargs: Any, + ) -> None: ... + @overload + def __init__( + self: "TestClient[_Request, None]", + server: BaseTestServer[_Request], + *, + cookie_jar: Optional[AbstractCookieJar] = None, + **kwargs: Any, + ) -> None: ... + def __init__( # type: ignore[misc] self, server: BaseTestServer[_Request], *, @@ -300,8 +318,8 @@ def server(self) -> BaseTestServer[_Request]: return self._server @property - def app(self) -> Optional[Application]: - return cast(Optional[Application], getattr(self._server, "app", None)) + def app(self) -> _ApplicationNone: + return getattr(self._server, "app", None) # type: ignore[return-value] @property def session(self) -> ClientSession: @@ -505,7 +523,7 @@ async def get_server(self, app: Application) -> TestServer: """Return a TestServer instance.""" return TestServer(app) - async def get_client(self, server: TestServer) -> TestClient[Request]: + async def get_client(self, server: TestServer) -> TestClient[Request, Application]: """Return a TestClient instance.""" return TestClient(server) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 437936e97b8..cdd573e8f31 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -3795,7 +3795,7 @@ async def handler(request: web.Request) -> web.Response: await resp.write(b"1" * 1000) await asyncio.sleep(0.01) - async def request(client: TestClient[web.Request]) -> None: + async def request(client: TestClient[web.Request, web.Application]) -> None: timeout = aiohttp.ClientTimeout(total=0.5) async with await client.get("/", timeout=timeout) as resp: with pytest.raises(asyncio.TimeoutError): diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 92cc1a61ed6..718b95b6512 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -1,6 +1,7 @@ import asyncio import gzip import socket +import sys from typing import Callable, Iterator, Mapping, NoReturn from unittest import mock @@ -20,7 +21,10 @@ make_mocked_request, ) -_TestClient = TestClient[web.Request] +if sys.version_info >= (3, 11): + from typing import assert_type + +_TestClient = TestClient[web.Request, web.Application] _hello_world_str = "Hello, world" _hello_world_bytes = _hello_world_str.encode("utf-8") @@ -71,7 +75,7 @@ def app() -> web.Application: def test_client( loop: asyncio.AbstractEventLoop, app: web.Application ) -> Iterator[_TestClient]: - async def make_client() -> TestClient[web.Request]: + async def make_client() -> TestClient[web.Request, web.Application]: return TestClient(TestServer(app)) client = loop.run_until_complete(make_client()) @@ -239,6 +243,8 @@ async def test_test_client_props() -> None: async with client: assert isinstance(client.port, int) assert client.server is not None + if sys.version_info >= (3, 11): + assert_type(client.app, web.Application) assert client.app is not None assert client.port == 0 @@ -255,6 +261,8 @@ async def hello(request: web.BaseRequest) -> NoReturn: async with client: assert isinstance(client.port, int) assert client.server is not None + if sys.version_info >= (3, 11): + assert_type(client.app, None) assert client.app is None assert client.port == 0 @@ -271,7 +279,7 @@ async def test_test_server_context_manager(loop: asyncio.AbstractEventLoop) -> N def test_client_unsupported_arg() -> None: with pytest.raises(TypeError) as e: - TestClient("string") # type: ignore[arg-type] + TestClient("string") # type: ignore[call-overload] assert ( str(e.value) == "server must be TestServer instance, found type: " diff --git a/tests/test_web_middleware.py b/tests/test_web_middleware.py index 3159a9d6f36..7a7b2c8e91a 100644 --- a/tests/test_web_middleware.py +++ b/tests/test_web_middleware.py @@ -9,7 +9,9 @@ from aiohttp.test_utils import TestClient from aiohttp.typedefs import Handler, Middleware -CLI = Callable[[Iterable[Middleware]], Awaitable[TestClient[web.Request]]] +CLI = Callable[ + [Iterable[Middleware]], Awaitable[TestClient[web.Request, web.Application]] +] async def test_middleware_modifies_response( @@ -169,7 +171,7 @@ async def handler(request: web.Request) -> web.Response: def wrapper( extra_middlewares: Iterable[Middleware], - ) -> Awaitable[TestClient[web.Request]]: + ) -> Awaitable[TestClient[web.Request, web.Application]]: app = web.Application() app.router.add_route("GET", "/resource1", handler) app.router.add_route("GET", "/resource2/", handler)