diff --git a/CHANGES/4912.bugfix b/CHANGES/4912.bugfix new file mode 100644 index 00000000000..6f8adea2309 --- /dev/null +++ b/CHANGES/4912.bugfix @@ -0,0 +1 @@ +Fixed the type annotations in the ``tracing`` module. diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index 631e7d0004d..0c07c642eda 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from typing import TYPE_CHECKING, Awaitable, Callable, Type, Union +from typing import TYPE_CHECKING, Awaitable, Type, TypeVar import attr from multidict import CIMultiDict # noqa @@ -9,29 +9,17 @@ from .signals import Signal if TYPE_CHECKING: # pragma: no cover + from typing_extensions import Protocol + from .client import ClientSession # noqa - _SignalArgs = Union[ - 'TraceRequestStartParams', - 'TraceRequestEndParams', - 'TraceRequestExceptionParams', - 'TraceConnectionQueuedStartParams', - 'TraceConnectionQueuedEndParams', - 'TraceConnectionCreateStartParams', - 'TraceConnectionCreateEndParams', - 'TraceConnectionReuseconnParams', - 'TraceDnsResolveHostStartParams', - 'TraceDnsResolveHostEndParams', - 'TraceDnsCacheHitParams', - 'TraceDnsCacheMissParams', - 'TraceRequestRedirectParams', - 'TraceRequestChunkSentParams', - 'TraceResponseChunkReceivedParams', - ] - _Signal = Signal[Callable[[ClientSession, SimpleNamespace, _SignalArgs], - Awaitable[None]]] -else: - _Signal = Signal + _ParamT_contra = TypeVar('_ParamT_contra', contravariant=True) + + class _SignalCallback(Protocol[_ParamT_contra]): + def __call__(self, + __client_session: ClientSession, + __trace_config_ctx: SimpleNamespace, + __params: _ParamT_contra) -> Awaitable[None]: ... __all__ = ( @@ -54,23 +42,53 @@ def __init__( self, trace_config_ctx_factory: Type[SimpleNamespace]=SimpleNamespace ) -> None: - self._on_request_start = Signal(self) # type: _Signal - self._on_request_chunk_sent = Signal(self) # type: _Signal - self._on_response_chunk_received = Signal(self) # type: _Signal - self._on_request_end = Signal(self) # type: _Signal - self._on_request_exception = Signal(self) # type: _Signal - self._on_request_redirect = Signal(self) # type: _Signal - self._on_connection_queued_start = Signal(self) # type: _Signal - self._on_connection_queued_end = Signal(self) # type: _Signal - self._on_connection_create_start = Signal(self) # type: _Signal - self._on_connection_create_end = Signal(self) # type: _Signal - self._on_connection_reuseconn = Signal(self) # type: _Signal - self._on_dns_resolvehost_start = Signal(self) # type: _Signal - self._on_dns_resolvehost_end = Signal(self) # type: _Signal - self._on_dns_cache_hit = Signal(self) # type: _Signal - self._on_dns_cache_miss = Signal(self) # type: _Signal - - self._trace_config_ctx_factory = trace_config_ctx_factory # type: Type[SimpleNamespace] # noqa + self._on_request_start = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestStartParams]] + self._on_request_chunk_sent = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestChunkSentParams]] + self._on_response_chunk_received = Signal( + self + ) # type: Signal[_SignalCallback[TraceResponseChunkReceivedParams]] + self._on_request_end = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestEndParams]] + self._on_request_exception = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestExceptionParams]] + self._on_request_redirect = Signal( + self + ) # type: Signal[_SignalCallback[TraceRequestRedirectParams]] + self._on_connection_queued_start = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionQueuedStartParams]] + self._on_connection_queued_end = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionQueuedEndParams]] + self._on_connection_create_start = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionCreateStartParams]] + self._on_connection_create_end = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionCreateEndParams]] + self._on_connection_reuseconn = Signal( + self + ) # type: Signal[_SignalCallback[TraceConnectionReuseconnParams]] + self._on_dns_resolvehost_start = Signal( + self + ) # type: Signal[_SignalCallback[TraceDnsResolveHostStartParams]] + self._on_dns_resolvehost_end = Signal( + self + ) # type: Signal[_SignalCallback[TraceDnsResolveHostEndParams]] + self._on_dns_cache_hit = Signal( + self + ) # type: Signal[_SignalCallback[TraceDnsCacheHitParams]] + self._on_dns_cache_miss = Signal( + self + ) # type: Signal[_SignalCallback[TraceDnsCacheMissParams]] + + self._trace_config_ctx_factory = trace_config_ctx_factory def trace_config_ctx( self, @@ -98,63 +116,93 @@ def freeze(self) -> None: self._on_dns_cache_miss.freeze() @property - def on_request_start(self) -> _Signal: + def on_request_start( + self + ) -> 'Signal[_SignalCallback[TraceRequestStartParams]]': return self._on_request_start @property - def on_request_chunk_sent(self) -> _Signal: + def on_request_chunk_sent( + self + ) -> 'Signal[_SignalCallback[TraceRequestChunkSentParams]]': return self._on_request_chunk_sent @property - def on_response_chunk_received(self) -> _Signal: + def on_response_chunk_received( + self + ) -> 'Signal[_SignalCallback[TraceResponseChunkReceivedParams]]': return self._on_response_chunk_received @property - def on_request_end(self) -> _Signal: + def on_request_end( + self + ) -> 'Signal[_SignalCallback[TraceRequestEndParams]]': return self._on_request_end @property - def on_request_exception(self) -> _Signal: + def on_request_exception( + self + ) -> 'Signal[_SignalCallback[TraceRequestExceptionParams]]': return self._on_request_exception @property - def on_request_redirect(self) -> _Signal: + def on_request_redirect( + self + ) -> 'Signal[_SignalCallback[TraceRequestRedirectParams]]': return self._on_request_redirect @property - def on_connection_queued_start(self) -> _Signal: + def on_connection_queued_start( + self + ) -> 'Signal[_SignalCallback[TraceConnectionQueuedStartParams]]': return self._on_connection_queued_start @property - def on_connection_queued_end(self) -> _Signal: + def on_connection_queued_end( + self + ) -> 'Signal[_SignalCallback[TraceConnectionQueuedEndParams]]': return self._on_connection_queued_end @property - def on_connection_create_start(self) -> _Signal: + def on_connection_create_start( + self + ) -> 'Signal[_SignalCallback[TraceConnectionCreateStartParams]]': return self._on_connection_create_start @property - def on_connection_create_end(self) -> _Signal: + def on_connection_create_end( + self + ) -> 'Signal[_SignalCallback[TraceConnectionCreateEndParams]]': return self._on_connection_create_end @property - def on_connection_reuseconn(self) -> _Signal: + def on_connection_reuseconn( + self + ) -> 'Signal[_SignalCallback[TraceConnectionReuseconnParams]]': return self._on_connection_reuseconn @property - def on_dns_resolvehost_start(self) -> _Signal: + def on_dns_resolvehost_start( + self + ) -> 'Signal[_SignalCallback[TraceDnsResolveHostStartParams]]': return self._on_dns_resolvehost_start @property - def on_dns_resolvehost_end(self) -> _Signal: + def on_dns_resolvehost_end( + self + ) -> 'Signal[_SignalCallback[TraceDnsResolveHostEndParams]]': return self._on_dns_resolvehost_end @property - def on_dns_cache_hit(self) -> _Signal: + def on_dns_cache_hit( + self + ) -> 'Signal[_SignalCallback[TraceDnsCacheHitParams]]': return self._on_dns_cache_hit @property - def on_dns_cache_miss(self) -> _Signal: + def on_dns_cache_miss( + self + ) -> 'Signal[_SignalCallback[TraceDnsCacheMissParams]]': return self._on_dns_cache_miss