Skip to content

Commit

Permalink
Types in poller
Browse files Browse the repository at this point in the history
  • Loading branch information
lmazuel committed Mar 15, 2023
1 parent 46a38e7 commit cfafd30
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
4 changes: 2 additions & 2 deletions sdk/core/azure-core/azure/core/polling/_async_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def run(self): # pylint:disable=invalid-overridden-method
"""


async def async_poller(client, initial_response, deserialization_callback, polling_method):
async def async_poller(client: Any, initial_response: Any, deserialization_callback: Callable[[Any], PollingReturnType], polling_method: AsyncPollingMethod[PollingReturnType]):
"""Async Poller for long running operations.
.. deprecated:: 1.5.0
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
self,
client: Any,
initial_response: Any,
deserialization_callback: Callable,
deserialization_callback: Callable[[Any], PollingReturnType],
polling_method: AsyncPollingMethod[PollingReturnType],
):
self._polling_method = polling_method
Expand Down
21 changes: 12 additions & 9 deletions sdk/core/azure-core/azure/core/polling/_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@ def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any
raise TypeError("Polling method '{}' doesn't support from_continuation_token".format(cls.__name__))


class NoPolling(PollingMethod):
class NoPolling(PollingMethod, Generic[PollingReturnType]):
"""An empty poller that returns the deserialized initial response."""

_deserialization_callback: Callable[[Any], PollingReturnType]
"""Deserialization callback passed during initialization"""

def __init__(self):
self._initial_response = None
self._deserialization_callback = None

def initialize(self, _: Any, initial_response: Any, deserialization_callback: Callable) -> None:
def initialize(self, _: Any, initial_response: Any, deserialization_callback: Callable[[Any], PollingReturnType]) -> None:
self._initial_response = initial_response
self._deserialization_callback = deserialization_callback

Expand All @@ -92,7 +94,7 @@ def finished(self) -> bool:
"""
return True

def resource(self) -> Any:
def resource(self) -> PollingReturnType:
return self._deserialization_callback(self._initial_response)

def get_continuation_token(self) -> str:
Expand Down Expand Up @@ -130,7 +132,7 @@ def __init__(
self,
client: Any,
initial_response: Any,
deserialization_callback: Callable,
deserialization_callback: Callable[[Any], PollingReturnType],
polling_method: PollingMethod[PollingReturnType],
) -> None:
self._callbacks: List[Callable] = []
Expand All @@ -147,10 +149,11 @@ def __init__(

# Prepare thread execution
self._thread = None
self._done = None
self._done = threading.Event()
self._exception = None
if not self._polling_method.finished():
self._done = threading.Event()
if self._polling_method.finished():
self._done.set()
else:
self._thread = threading.Thread(
target=with_current_context(self._start),
name="LROPoller({})".format(uuid.uuid4()),
Expand Down Expand Up @@ -266,7 +269,7 @@ def add_done_callback(self, func: Callable) -> None:
argument, a completed LongRunningOperation.
"""
# Still use "_done" and not "done", since CBs are executed inside the thread.
if self._done is None or self._done.is_set():
if self._done.is_set():
func(self._polling_method)
# Let's add them still, for consistency (if you wish to access to it for some reasons)
self._callbacks.append(func)
Expand Down

0 comments on commit cfafd30

Please sign in to comment.