Skip to content

Commit

Permalink
[Event Hubs] Update URI used for consumer auth to include consumer gr…
Browse files Browse the repository at this point in the history
…oup (#35626)

* add consumer group to uri used for consumer auth

* changelog

* add consumergroup to auth uri for consumer ops only, not mgmt ops

* lint

* fix sample
  • Loading branch information
swathipil authored Jun 6, 2024
1 parent 61138a7 commit e6f98bc
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 19 deletions.
2 changes: 2 additions & 0 deletions sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bugs Fixed

- Fixed a bug where the correct URI was not being used for consumer authentication, causing issues when assigning roles at the consumer group level. ([#35337](https://github.com/Azure/azure-sdk-for-python/issues/35337))

### Other Changes

## 5.12.0 (2024-05-16)
Expand Down
16 changes: 10 additions & 6 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ def __init__(
else:
self._credential = credential # type: ignore
self._auto_reconnect = kwargs.get("auto_reconnect", True)
self._auth_uri = f"sb://{self._address.hostname}{self._address.path}"
self._auth_uri: str
self._eventhub_auth_uri = f"sb://{self._address.hostname}{self._address.path}"
self._config = Configuration(
amqp_transport=self._amqp_transport,
hostname=self._address.hostname,
Expand Down Expand Up @@ -348,29 +349,32 @@ def _from_connection_string(conn_str: str, **kwargs: Any) -> Dict[str, Any]:
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs

def _create_auth(self) -> Union["uamqp_JWTTokenAuth", JWTTokenAuth]:
def _create_auth(self, *, auth_uri: Optional[str] = None) -> Union["uamqp_JWTTokenAuth", JWTTokenAuth]:
"""
Create an ~uamqp.authentication.SASTokenAuth instance
to authenticate the session.
:return: The auth for the session.
:rtype: JWTTokenAuth or uamqp_JWTTokenAuth
"""
# if auth_uri is not provided, use the default hub one
entity_auth_uri = auth_uri if auth_uri else self._eventhub_auth_uri

try:
# ignore mypy's warning because token_type is Optional
token_type = self._credential.token_type # type: ignore
except AttributeError:
token_type = b"jwt"
if token_type == b"servicebus.windows.net:sastoken":
return self._amqp_transport.create_token_auth(
self._auth_uri,
functools.partial(self._credential.get_token, self._auth_uri),
entity_auth_uri,
functools.partial(self._credential.get_token, entity_auth_uri),
token_type=token_type,
config=self._config,
update_token=True,
)
return self._amqp_transport.create_token_auth(
self._auth_uri,
entity_auth_uri,
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
config=self._config,
Expand Down Expand Up @@ -563,7 +567,7 @@ def _open(self) -> bool:
if not self.running:
if self._handler:
self._handler.close()
auth = self._client._create_auth()
auth = self._client._create_auth(auth_uri=self._client._auth_uri)
self._create_handler(auth)
conn = self._client._conn_manager.get_connection( # pylint: disable=protected-access
endpoint=self._client._address.hostname, auth=auth
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _open(self) -> bool:
if not self.running:
if self._handler:
self._handler.close()
auth = self._client._create_auth()
auth = self._client._create_auth(auth_uri=self._client._auth_uri)
self._create_handler(auth)
conn = self._client._conn_manager.get_connection( # pylint: disable=protected-access
endpoint=self._client._address.hostname, auth=auth
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def __init__(
network_tracing=network_tracing,
**kwargs
)
# consumer auth URI additionally includes consumer group
self._auth_uri = f"sb://{self._address.hostname}{self._address.path}/consumergroups/{self._consumer_group}"
self._lock = threading.Lock()
self._event_processors: Dict[Tuple[str, str], EventProcessor] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(
network_tracing=kwargs.get("logging_enable"),
**kwargs
)

self._auth_uri = f"sb://{self._address.hostname}{self._address.path}"
self._keep_alive = kwargs.get("keep_alive", None)

self._producers: Dict[str, Optional[EventHubProducer]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,14 @@
from .._pyamqp.aio._authentication_async import JWTTokenAuthAsync
try:
from uamqp import (
authentication as uamqp_authentication,
Message as uamqp_Message,
AMQPClientAsync as uamqp_AMQPClientAsync,
)
from uamqp.authentication import JWTTokenAsync
from uamqp.authentication import JWTTokenAsync as uamqp_JWTTokenAsync
except ImportError:
uamqp_authentication = None
uamqp_Message = None
uamqp_AMQPClientAsync = None
JWTTokenAsync = None
uamqp_JWTTokenAsync = None
from azure.core.credentials_async import AsyncTokenCredential

try:
Expand Down Expand Up @@ -109,7 +107,7 @@ def running(self) -> bool:
def running(self, value: bool) -> None:
pass

def _create_handler(self, auth: Union["JWTTokenAsync", JWTTokenAuthAsync]) -> None:
def _create_handler(self, auth: Union["uamqp_JWTTokenAsync", JWTTokenAuthAsync]) -> None:
pass

_MIXIN_BASE = AbstractConsumerProducer
Expand Down Expand Up @@ -268,30 +266,34 @@ def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]:
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs

async def _create_auth_async(self) -> Union[uamqp_authentication.JWTTokenAsync, JWTTokenAuthAsync]:
async def _create_auth_async(
self, *, auth_uri: Optional[str] = None
) -> Union["uamqp_JWTTokenAsync", JWTTokenAuthAsync]:
"""
Create an ~uamqp.authentication.SASTokenAuthAsync instance to authenticate
the session.
:return: A JWTTokenAuthAsync instance to authenticate the session.
:rtype: ~uamqp.authentication.JWTTokenAsync or JWTTokenAuthAsync
"""
# if auth_uri is not provided, use the default hub one
entity_auth_uri = auth_uri if auth_uri else self._eventhub_auth_uri

try:
# ignore mypy's warning because token_type is Optional
token_type = self._credential.token_type # type: ignore
except AttributeError:
token_type = b"jwt"
if token_type == b"servicebus.windows.net:sastoken":
return await self._amqp_transport.create_token_auth_async(
self._auth_uri,
functools.partial(self._credential.get_token, self._auth_uri),
entity_auth_uri,
functools.partial(self._credential.get_token, entity_auth_uri),
token_type=token_type,
config=self._config,
update_token=True,
)
return await self._amqp_transport.create_token_auth_async(
self._auth_uri,
entity_auth_uri,
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
config=self._config,
Expand Down Expand Up @@ -475,7 +477,7 @@ async def _open(self) -> None:
if not self.running:
if self._handler:
await self._handler.close_async()
auth = await self._client._create_auth_async()
auth = await self._client._create_auth_async(auth_uri=self._client._auth_uri)
self._create_handler(auth)
conn = await self._client._conn_manager_async.get_connection(
endpoint=self._client._address.hostname, auth=auth
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def __init__(
network_tracing=network_tracing,
**kwargs,
)
# consumer auth URI additionally includes consumer group
self._auth_uri = f"sb://{self._address.hostname}{self._address.path}/consumergroups/{self._consumer_group}"
self._lock = asyncio.Lock(**self._internal_kwargs)
self._event_processors: Dict[Tuple[str, str], EventProcessor] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(
network_tracing=kwargs.pop("logging_enable", False),
**kwargs
)
self._auth_uri = f"sb://{self._address.hostname}{self._address.path}"
self._keep_alive = kwargs.get("keep_alive", None)
self._producers: Dict[str, Optional[EventHubProducer]] = {
ALL_PARTITIONS: self._create_producer()
Expand Down

0 comments on commit e6f98bc

Please sign in to comment.