Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EventHubs & ServiceBus] raise error for loop param if Python 3.10 #19953

Merged
merged 3 commits into from
Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sdk/eventhub/azure-eventhub/azure/eventhub/aio/_async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------

import sys
import asyncio

def get_dict_with_loop_if_needed(loop):
if sys.version_info >= (3, 10):
if loop:
raise ValueError("Starting Python 3.10, asyncio no longer supports loop as a parameter.")
elif loop:
return {'loop': loop}
return {}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MGMT_STATUS_CODE,
MGMT_STATUS_DESC
)
from ._async_utils import get_dict_with_loop_if_needed
from ._connection_manager_async import get_connection_manager
from ._error_async import _handle_exception

Expand Down Expand Up @@ -129,7 +130,7 @@ def __init__(
credential: Union["AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential],
**kwargs: Any
) -> None:
self._loop = kwargs.pop("loop", None)
self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None))
if isinstance(credential, AzureSasCredential):
self._credential = EventhubAzureSasTokenCredentialAsync(credential) # type: ignore
elif isinstance(credential, AzureNamedKeyCredential):
Expand All @@ -142,7 +143,7 @@ def __init__(
credential=self._credential,
**kwargs
)
self._conn_manager_async = get_connection_manager(loop=self._loop, **kwargs)
self._conn_manager_async = get_connection_manager(**kwargs)

def __enter__(self):
raise TypeError(
Expand Down Expand Up @@ -214,7 +215,7 @@ async def _backoff_async(
if backoff <= self._config.backoff_max and (
timeout_time is None or time.time() + backoff <= timeout_time
): # pylint:disable=no-else-return
await asyncio.sleep(backoff, loop=self._loop)
await asyncio.sleep(backoff, **self._internal_kwargs)
_LOGGER.info(
"%r has an exception (%r). Retrying...",
format(entity_name),
Expand Down Expand Up @@ -379,14 +380,14 @@ def _handler(self):
"""

@property
def _loop(self):
# type: () -> asyncio.AbstractEventLoop
"""The event loop that users pass in to call wrap sync calls to async API.
def _internal_kwargs(self):
# type: () -> dict
"""The dict with an event loop that users may pass in to wrap sync calls to async API.
It's furthur passed to uamqp APIs
"""

@_loop.setter
def _loop(self, value):
@_internal_kwargs.setter
def _internal_kwargs(self, value):
pass

@property
Expand Down Expand Up @@ -439,7 +440,7 @@ async def _open(self) -> None:
)
)
while not await self._handler.client_ready_async():
await asyncio.sleep(0.05, loop=self._loop)
await asyncio.sleep(0.05, **self._internal_kwargs)
self._max_message_size_on_link = (
self._handler.message_handler._link.peer_max_message_size
or constants.MAX_MESSAGE_LENGTH_BYTES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from uamqp import ReceiveClientAsync, Source

from ._client_base_async import ConsumerProducerMixin
from ._async_utils import get_dict_with_loop_if_needed
from .._common import EventData
from ..exceptions import _error_handler
from .._utils import create_properties, event_position_selector
Expand Down Expand Up @@ -61,7 +62,6 @@ class EventHubConsumer(
network bandwidth consumption that is generally a favorable trade-off when considered against periodically
making requests for partition properties using the Event Hub client.
It is set to `False` by default.
:keyword ~asyncio.AbstractEventLoop loop: An event loop.
"""

def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> None:
Expand All @@ -82,7 +82,7 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N
self._on_event_received = kwargs[
"on_event_received"
] # type: Callable[[Union[Optional[EventData], List[EventData]]], Awaitable[None]]
self._loop = kwargs.get("loop", None)
self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None))
self._client = client
self._source = source
self._offset = event_position
Expand Down Expand Up @@ -147,7 +147,7 @@ def _create_handler(self, auth: "JWTTokenAsync") -> None:
auto_complete=False,
properties=properties,
desired_capabilities=desired_capabilities,
loop=self._loop,
**self._internal_kwargs
)

self._handler._streaming_receive = True # pylint:disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(
network_tracing=network_tracing,
**kwargs
)
self._lock = asyncio.Lock(loop=self._loop)
self._lock = asyncio.Lock(**self._internal_kwargs)
self._event_processors = dict() # type: Dict[Tuple[str, str], EventProcessor]

async def __aenter__(self):
Expand Down Expand Up @@ -198,7 +198,7 @@ def _create_consumer(
prefetch=prefetch,
idle_timeout=self._idle_timeout,
track_last_enqueued_event_properties=track_last_enqueued_event_properties,
loop=self._loop,
**self._internal_kwargs
)
return handler

Expand Down Expand Up @@ -378,7 +378,7 @@ async def _receive(
owner_level=owner_level,
prefetch=prefetch,
track_last_enqueued_event_properties=track_last_enqueued_event_properties,
loop=self._loop,
**self._internal_kwargs
)
self._event_processors[
(self._consumer_group, partition_id or ALL_PARTITIONS)
Expand Down Expand Up @@ -687,7 +687,6 @@ async def close(self) -> None:
await asyncio.gather(
*[p.stop() for p in self._event_processors.values()],
return_exceptions=True,
loop=self._loop
)
self._event_processors = {}
await super(EventHubConsumerClient, self)._close_async()
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
from functools import partial

from azure.eventhub import EventData
from ..._common import EventData
from ..._eventprocessor.common import CloseReason, LoadBalancingStrategy
from ..._eventprocessor._eventprocessor_mixin import EventProcessorMixin
from ..._utils import get_event_links
Expand All @@ -29,6 +29,7 @@
from .checkpoint_store import CheckpointStore
from ._ownership_manager import OwnershipManager
from .utils import get_running_loop
from .._async_utils import get_dict_with_loop_if_needed

if TYPE_CHECKING:
from datetime import datetime
Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(
track_last_enqueued_event_properties
)
self._id = str(uuid.uuid4())
self._loop = loop or get_running_loop()
self._internal_kwargs = get_dict_with_loop_if_needed(loop)
self._running = False

self._consumers = {} # type: Dict[str, EventHubConsumer]
Expand Down Expand Up @@ -160,7 +161,7 @@ def _create_tasks_for_claimed_ownership(
if partition_id not in self._tasks or self._tasks[partition_id].done():
checkpoint = checkpoints.get(partition_id) if checkpoints else None
if self._running:
self._tasks[partition_id] = self._loop.create_task(
self._tasks[partition_id] = get_running_loop().create_task(
self._receive(partition_id, checkpoint)
)
_LOGGER.info(
Expand Down Expand Up @@ -382,7 +383,7 @@ async def start(self) -> None:
)
await self._process_error(None, err) # type: ignore

await asyncio.sleep(load_balancing_interval, loop=self._loop)
await asyncio.sleep(load_balancing_interval, **self._internal_kwargs)

async def stop(self) -> None:
"""Stop the EventProcessor.
Expand All @@ -401,5 +402,5 @@ async def stop(self) -> None:
await self._cancel_tasks_for_partitions(pids)
_LOGGER.info("EventProcessor %r tasks have been cancelled.", self._id)
while self._tasks:
await asyncio.sleep(1, loop=self._loop)
await asyncio.sleep(1, **self._internal_kwargs)
_LOGGER.info("EventProcessor %r has been stopped.", self._id)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from .._constants import TIMEOUT_SYMBOL
from ._client_base_async import ConsumerProducerMixin
from ._async_utils import get_dict_with_loop_if_needed

if TYPE_CHECKING:
from uamqp.authentication import JWTTokenAsync # pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -56,7 +57,6 @@ class EventHubProducer(
periods of inactivity. The default value is `None`, i.e. no keep alive pings.
:keyword bool auto_reconnect: Whether to automatically reconnect the producer if a retryable error occurs.
Default value is `True`.
:keyword ~asyncio.AbstractEventLoop loop: An event loop. If not specified the default event loop will be used.
"""

def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> None:
Expand All @@ -70,7 +70,7 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N
self.running = False
self.closed = False

self._loop = kwargs.get("loop", None)
self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None))
self._max_message_size_on_link = None
self._client = client
self._target = target
Expand All @@ -92,7 +92,7 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N
self._handler = None # type: Optional[SendClientAsync]
self._outcome = None # type: Optional[constants.MessageSendResult]
self._condition = None # type: Optional[Exception]
self._lock = asyncio.Lock(loop=self._loop)
self._lock = asyncio.Lock(**self._internal_kwargs)
self._link_properties = {
types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000))
}
Expand All @@ -111,7 +111,7 @@ def _create_handler(self, auth: "JWTTokenAsync") -> None:
properties=create_properties(
self._client._config.user_agent # pylint:disable=protected-access
),
loop=self._loop,
**self._internal_kwargs
)

async def _open_with_retry(self) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
ALL_PARTITIONS: self._create_producer()
} # type: Dict[str, Optional[EventHubProducer]]
self._lock = asyncio.Lock(
loop=self._loop
**self._internal_kwargs
) # sync the creation of self._producers
self._max_message_size_on_link = 0
self._partition_ids = None # Optional[List[str]]
Expand Down Expand Up @@ -165,7 +165,7 @@ def _create_producer(
partition=partition_id,
send_timeout=send_timeout,
idle_timeout=self._idle_timeout,
loop=self._loop,
**self._internal_kwargs
)
return handler

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_renewable_lock_duration,
)
from .._common.auto_lock_renewer import SHORT_RENEW_OFFSET, SHORT_RENEW_SCALING_FACTOR
from ._async_utils import get_running_loop
from ._async_utils import get_dict_with_loop_if_needed
from ..exceptions import AutoLockRenewTimeout, AutoLockRenewFailed, ServiceBusError

Renewable = Union[ServiceBusSession, ServiceBusReceivedMessage]
Expand All @@ -41,8 +41,6 @@ class AutoLockRenewer:
:param on_lock_renew_failure: A callback may be specified to be called when the lock is lost on the renewable
that is being registered. Default value is None (no callback).
:type on_lock_renew_failure: Optional[LockRenewFailureCallback]
:param loop: An async event loop.
:type loop: Optional[~asyncio.AbstractEventLoop]

.. admonition:: Example:

Expand All @@ -68,9 +66,9 @@ def __init__(
on_lock_renew_failure: Optional[AsyncLockRenewFailureCallback] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
self._internal_kwargs = get_dict_with_loop_if_needed(loop)
self._shutdown = asyncio.Event()
self._futures = [] # type: List[asyncio.Future]
self._loop = loop or get_running_loop()
self._sleep_time = 1
self._renew_period = 10
self._on_lock_renew_failure = on_lock_renew_failure
Expand Down Expand Up @@ -226,7 +224,7 @@ def register(
on_lock_renew_failure or self._on_lock_renew_failure,
renew_period_override,
),
loop=self._loop,
**self._internal_kwargs
)
self._futures.append(renew_future)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# license information.
# -------------------------------------------------------------------------

import sys

import asyncio
import logging
import functools
Expand Down Expand Up @@ -65,3 +67,12 @@ async def create_authentication(client):
http_proxy=client._config.http_proxy,
transport_type=client._config.transport_type,
)


def get_dict_with_loop_if_needed(loop):
if sys.version_info >= (3, 10):
if loop:
raise ValueError("Starting Python 3.10, asyncio no longer supports loop as a parameter.")
elif loop:
return {'loop': loop}
return {}