Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Event Hubs] combine conn str parsing logic #18059

Merged
merged 5 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 61 additions & 32 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils
import six
from azure.core.utils import parse_connection_string as core_parse_connection_string
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved
from azure.core.credentials import AccessToken, AzureSasCredential

from .exceptions import _handle_exception, ClientClosedError, ConnectError
Expand All @@ -43,7 +44,7 @@
_AccessToken = collections.namedtuple("AccessToken", "token expires_on")


def _parse_conn_str(conn_str, kwargs):
def _parse_conn_str(conn_str, **kwargs):
# type: (str, Dict[str, Any]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
Expand All @@ -52,38 +53,66 @@ def _parse_conn_str(conn_str, kwargs):
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str]
for element in conn_str.split(";"):
key, _, value = element.partition("=")
if key.lower() == "endpoint":
endpoint = value.rstrip("/")
elif key.lower() == "hostname":
endpoint = value.rstrip("/")
elif key.lower() == "sharedaccesskeyname":
shared_access_key_name = value
elif key.lower() == "sharedaccesskey":
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))):
check_case = kwargs.pop("check_case", False) # type: Optional[bool]
conn_settings = core_parse_connection_string(conn_str, case_sensitive_keys=check_case)
if check_case:
shared_access_key = conn_settings.get("SharedAccessKey")
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
endpoint = conn_settings.get("Endpoint")
entity_path = conn_settings.get("EntityPath")

# non case sensitive check when parsing connection string for internal use
for key, value in conn_settings.items():
# only sas check is non case sensitive for both conn str properties and internal use
if key.lower() == "sharedaccesssignature":
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
shared_access_signature_expiry = int(
shared_access_signature.split("se=")[1].split("&")[0] # type: ignore
)
except (
IndexError,
TypeError,
ValueError,
): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)

if not check_case:
endpoint = conn_settings.get("endpoint") or conn_settings.get("hostname")
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved
if endpoint:
endpoint = endpoint.rstrip("/")
shared_access_key_name = conn_settings.get("sharedaccesskeyname")
shared_access_key = conn_settings.get("sharedaccesskey")
entity_path = conn_settings.get('entitypath')
swathipil marked this conversation as resolved.
Show resolved Hide resolved
shared_access_signature = conn_settings.get("sharedaccesssignature")

entity = cast(str, eventhub_name or entity_path)

# check that endpoint is valid
if not endpoint:
raise ValueError("Connection string is either blank or malformed.")
parsed = urlparse(endpoint)
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
host = cast(str, parsed.netloc.strip())

if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
):
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
"Connection string must have both SharedAccessKeyName and SharedAccessKey."
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved
)
entity = cast(str, eventhub_name or entity_path)
left_slash_pos = cast(str, endpoint).find("//")
if left_slash_pos != -1:
host = cast(str, endpoint)[left_slash_pos + 2 :]
else:
host = str(endpoint)
if shared_access_signature and shared_access_key:
raise ValueError(
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
)
if not shared_access_signature and not shared_access_key:
raise ValueError(
"At least one of the SharedAccessKey or SharedAccessSignature must be present."
)
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved

return (host,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
Expand Down Expand Up @@ -218,7 +247,7 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg
@staticmethod
def _from_connection_string(conn_str, **kwargs):
# type: (str, Any) -> Dict[str, Any]
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, **kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
if token and token_expiry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse # type: ignore

from ._common import DictMixin
from ._client_base import _parse_conn_str


class EventHubConnectionStringProperties(DictMixin):
Expand Down Expand Up @@ -70,39 +66,14 @@ def parse_connection_string(conn_str):
:type conn_str: str
:rtype: ~azure.eventhub.EventHubConnectionStringProperties
"""
conn_settings = [s.split("=", 1) for s in conn_str.split(";")]
if any(len(tup) != 2 for tup in conn_settings):
raise ValueError("Connection string is either blank or malformed.")
conn_settings = dict(conn_settings)
shared_access_signature = None
for key, value in conn_settings.items():
if key.lower() == "sharedaccesssignature":
shared_access_signature = value
shared_access_key = conn_settings.get("SharedAccessKey")
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
):
raise ValueError(
"Connection string must have both SharedAccessKeyName and SharedAccessKey."
)
if shared_access_signature is not None and shared_access_key is not None:
raise ValueError(
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
)
endpoint = conn_settings.get("Endpoint")
if not endpoint:
raise ValueError("Connection string is either blank or malformed.")
parsed = urlparse(endpoint.rstrip("/"))
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
namespace = parsed.netloc.strip()
fully_qualified_namespace, policy, key, entity, signature = _parse_conn_str(conn_str, check_case=True)[:-1]
endpoint = "sb://" + fully_qualified_namespace + "/"
props = {
"fully_qualified_namespace": namespace,
"fully_qualified_namespace": fully_qualified_namespace,
"endpoint": endpoint,
"eventhub_name": conn_settings.get("EntityPath"),
"shared_access_signature": shared_access_signature,
"shared_access_key_name": shared_access_key_name,
"shared_access_key": shared_access_key,
"eventhub_name": entity,
"shared_access_signature": signature,
"shared_access_key_name": policy,
"shared_access_key": key,
}
return EventHubConnectionStringProperties(**props)
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __enter__(self):

@staticmethod
def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]:
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, **kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
if token and token_expiry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def test_eh_parse_malformed_conn_str_no_endpoint(self, **kwargs):
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_eh_parse_malformed_conn_str_no_endpoint_value(self, **kwargs):
conn_str = 'Endpoint=;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_eh_parse_malformed_conn_str_no_netloc(self, **kwargs):
conn_str = 'Endpoint=MALFORMED;SharedAccessKeyName=test-policy;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
Expand All @@ -57,6 +63,22 @@ def test_eh_parse_conn_str_sas(self, **kwargs):
assert parse_result.fully_qualified_namespace == 'eh-namespace.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_eh_parse_conn_str_whitespace_trailing_semicolon(self, **kwargs):
conn_str = ' Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=; '
parse_result = parse_connection_string(conn_str)
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_eh_parse_conn_str_sas_trailing_semicolon(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;'
parse_result = parse_connection_string(conn_str)
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_eh_parse_conn_str_no_keyname(self, **kwargs):
conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
Expand All @@ -69,3 +91,27 @@ def test_eh_parse_conn_str_no_key(self, **kwargs):
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'

def test_eh_parse_conn_str_no_key_or_sas(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/'
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'At least one of the SharedAccessKey or SharedAccessSignature must be present.'

def test_eh_parse_malformed_conn_str_lowercase_endpoint(self, **kwargs):
conn_str = 'endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_eh_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;sharedaccesskeyname=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'

def test_eh_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;sharedaccesskey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'