From cfafd3033cd1218f3ba2212a3e7d829ec2b63110 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 8 Mar 2023 15:55:53 -0800 Subject: [PATCH] Types in poller --- .../azure/core/polling/_async_poller.py | 4 ++-- .../azure-core/azure/core/polling/_poller.py | 21 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sdk/core/azure-core/azure/core/polling/_async_poller.py b/sdk/core/azure-core/azure/core/polling/_async_poller.py index c50222363531..c4aca6f40e2f 100644 --- a/sdk/core/azure-core/azure/core/polling/_async_poller.py +++ b/sdk/core/azure-core/azure/core/polling/_async_poller.py @@ -71,7 +71,7 @@ async def run(self): # pylint:disable=invalid-overridden-method """ -async def async_poller(client, initial_response, deserialization_callback, polling_method): +async def async_poller(client: Any, initial_response: Any, deserialization_callback: Callable[[Any], PollingReturnType], polling_method: AsyncPollingMethod[PollingReturnType]): """Async Poller for long running operations. .. deprecated:: 1.5.0 @@ -109,7 +109,7 @@ def __init__( self, client: Any, initial_response: Any, - deserialization_callback: Callable, + deserialization_callback: Callable[[Any], PollingReturnType], polling_method: AsyncPollingMethod[PollingReturnType], ): self._polling_method = polling_method diff --git a/sdk/core/azure-core/azure/core/polling/_poller.py b/sdk/core/azure-core/azure/core/polling/_poller.py index 84067f301392..02e082059073 100644 --- a/sdk/core/azure-core/azure/core/polling/_poller.py +++ b/sdk/core/azure-core/azure/core/polling/_poller.py @@ -64,14 +64,16 @@ def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any raise TypeError("Polling method '{}' doesn't support from_continuation_token".format(cls.__name__)) -class NoPolling(PollingMethod): +class NoPolling(PollingMethod, Generic[PollingReturnType]): """An empty poller that returns the deserialized initial response.""" + _deserialization_callback: Callable[[Any], PollingReturnType] + """Deserialization callback passed during initialization""" + def __init__(self): self._initial_response = None - self._deserialization_callback = None - def initialize(self, _: Any, initial_response: Any, deserialization_callback: Callable) -> None: + def initialize(self, _: Any, initial_response: Any, deserialization_callback: Callable[[Any], PollingReturnType]) -> None: self._initial_response = initial_response self._deserialization_callback = deserialization_callback @@ -92,7 +94,7 @@ def finished(self) -> bool: """ return True - def resource(self) -> Any: + def resource(self) -> PollingReturnType: return self._deserialization_callback(self._initial_response) def get_continuation_token(self) -> str: @@ -130,7 +132,7 @@ def __init__( self, client: Any, initial_response: Any, - deserialization_callback: Callable, + deserialization_callback: Callable[[Any], PollingReturnType], polling_method: PollingMethod[PollingReturnType], ) -> None: self._callbacks: List[Callable] = [] @@ -147,10 +149,11 @@ def __init__( # Prepare thread execution self._thread = None - self._done = None + self._done = threading.Event() self._exception = None - if not self._polling_method.finished(): - self._done = threading.Event() + if self._polling_method.finished(): + self._done.set() + else: self._thread = threading.Thread( target=with_current_context(self._start), name="LROPoller({})".format(uuid.uuid4()), @@ -266,7 +269,7 @@ def add_done_callback(self, func: Callable) -> None: argument, a completed LongRunningOperation. """ # Still use "_done" and not "done", since CBs are executed inside the thread. - if self._done is None or self._done.is_set(): + if self._done.is_set(): func(self._polling_method) # Let's add them still, for consistency (if you wish to access to it for some reasons) self._callbacks.append(func)