Skip to content

Commit

Permalink
[Event Hubs] combine conn str parsing logic (#18059)
Browse files Browse the repository at this point in the history
* user core parser + remove redundancy

* move sas expiry logic + types

* fix error message

* mypy error

* error message for cs parser only
  • Loading branch information
swathipil authored Apr 22, 2021
1 parent 73d0b36 commit 8ac54ee
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 73 deletions.
99 changes: 66 additions & 33 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils
import six
from azure.core.utils import parse_connection_string as core_parse_connection_string
from azure.core.credentials import AccessToken, AzureSasCredential

from .exceptions import _handle_exception, ClientClosedError, ConnectError
Expand All @@ -43,47 +44,79 @@
_AccessToken = collections.namedtuple("AccessToken", "token expires_on")


def _parse_conn_str(conn_str, kwargs):
# type: (str, Dict[str, Any]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
def _parse_conn_str(conn_str, **kwargs):
# type: (str, Any) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str]
for element in conn_str.split(";"):
key, _, value = element.partition("=")
if key.lower() == "endpoint":
endpoint = value.rstrip("/")
elif key.lower() == "hostname":
endpoint = value.rstrip("/")
elif key.lower() == "sharedaccesskeyname":
shared_access_key_name = value
elif key.lower() == "sharedaccesskey":
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))):
shared_access_signature_expiry = None
eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str]
check_case = kwargs.pop("check_case", False) # type: bool
conn_settings = core_parse_connection_string(conn_str, case_sensitive_keys=check_case)
if check_case:
shared_access_key = conn_settings.get("SharedAccessKey")
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
endpoint = conn_settings.get("Endpoint")
entity_path = conn_settings.get("EntityPath")
# non case sensitive check when parsing connection string for internal use
for key, value in conn_settings.items():
# only sas check is non case sensitive for both conn str properties and internal use
if key.lower() == "sharedaccesssignature":
shared_access_signature = value

if not check_case:
endpoint = conn_settings.get("endpoint") or conn_settings.get("hostname")
if endpoint:
endpoint = endpoint.rstrip("/")
shared_access_key_name = conn_settings.get("sharedaccesskeyname")
shared_access_key = conn_settings.get("sharedaccesskey")
entity_path = conn_settings.get("entitypath")
shared_access_signature = conn_settings.get("sharedaccesssignature")

if shared_access_signature:
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
shared_access_signature_expiry = int(
shared_access_signature.split("se=")[1].split("&")[0] # type: ignore
)
except (
IndexError,
TypeError,
ValueError,
): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)

entity = cast(str, eventhub_name or entity_path)

# check that endpoint is valid
if not endpoint:
raise ValueError("Connection string is either blank or malformed.")
parsed = urlparse(endpoint)
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
host = cast(str, parsed.netloc.strip())

if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
):
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
)
entity = cast(str, eventhub_name or entity_path)
left_slash_pos = cast(str, endpoint).find("//")
if left_slash_pos != -1:
host = cast(str, endpoint)[left_slash_pos + 2 :]
else:
host = str(endpoint)
# Only connection string parser should check that only one of sas and shared access
# key exists. For backwards compatibility, client construction should not have this check.
if check_case and shared_access_signature and shared_access_key:
raise ValueError(
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
)
if not shared_access_signature and not shared_access_key:
raise ValueError(
"At least one of the SharedAccessKey or SharedAccessSignature must be present."
)

return (host,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
Expand Down Expand Up @@ -218,7 +251,7 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg
@staticmethod
def _from_connection_string(conn_str, **kwargs):
# type: (str, Any) -> Dict[str, Any]
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, **kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
if token and token_expiry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse # type: ignore

from ._common import DictMixin
from ._client_base import _parse_conn_str


class EventHubConnectionStringProperties(DictMixin):
Expand Down Expand Up @@ -70,39 +66,14 @@ def parse_connection_string(conn_str):
:type conn_str: str
:rtype: ~azure.eventhub.EventHubConnectionStringProperties
"""
conn_settings = [s.split("=", 1) for s in conn_str.split(";")]
if any(len(tup) != 2 for tup in conn_settings):
raise ValueError("Connection string is either blank or malformed.")
conn_settings = dict(conn_settings)
shared_access_signature = None
for key, value in conn_settings.items():
if key.lower() == "sharedaccesssignature":
shared_access_signature = value
shared_access_key = conn_settings.get("SharedAccessKey")
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
):
raise ValueError(
"Connection string must have both SharedAccessKeyName and SharedAccessKey."
)
if shared_access_signature is not None and shared_access_key is not None:
raise ValueError(
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
)
endpoint = conn_settings.get("Endpoint")
if not endpoint:
raise ValueError("Connection string is either blank or malformed.")
parsed = urlparse(endpoint.rstrip("/"))
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
namespace = parsed.netloc.strip()
fully_qualified_namespace, policy, key, entity, signature = _parse_conn_str(conn_str, check_case=True)[:-1]
endpoint = "sb://" + fully_qualified_namespace + "/"
props = {
"fully_qualified_namespace": namespace,
"fully_qualified_namespace": fully_qualified_namespace,
"endpoint": endpoint,
"eventhub_name": conn_settings.get("EntityPath"),
"shared_access_signature": shared_access_signature,
"shared_access_key_name": shared_access_key_name,
"shared_access_key": shared_access_key,
"eventhub_name": entity,
"shared_access_signature": signature,
"shared_access_key_name": policy,
"shared_access_key": key,
}
return EventHubConnectionStringProperties(**props)
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __enter__(self):

@staticmethod
def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]:
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, **kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
if token and token_expiry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def test_eh_parse_malformed_conn_str_no_endpoint(self, **kwargs):
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_eh_parse_malformed_conn_str_no_endpoint_value(self, **kwargs):
conn_str = 'Endpoint=;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_eh_parse_malformed_conn_str_no_netloc(self, **kwargs):
conn_str = 'Endpoint=MALFORMED;SharedAccessKeyName=test-policy;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
Expand All @@ -57,15 +63,55 @@ def test_eh_parse_conn_str_sas(self, **kwargs):
assert parse_result.fully_qualified_namespace == 'eh-namespace.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_eh_parse_conn_str_whitespace_trailing_semicolon(self, **kwargs):
conn_str = ' Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=; '
parse_result = parse_connection_string(conn_str)
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_eh_parse_conn_str_sas_trailing_semicolon(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;'
parse_result = parse_connection_string(conn_str)
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_eh_parse_conn_str_no_keyname(self, **kwargs):
conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'
assert "Invalid connection string" in str(e.value)

def test_eh_parse_conn_str_no_key(self, **kwargs):
conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKeyName=test-policy'
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'
assert "Invalid connection string" in str(e.value)

def test_eh_parse_conn_str_no_key_or_sas(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/'
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'At least one of the SharedAccessKey or SharedAccessSignature must be present.'

def test_eh_parse_malformed_conn_str_lowercase_endpoint(self, **kwargs):
conn_str = 'endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_eh_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;sharedaccesskeyname=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert "Invalid connection string" in str(e.value)

def test_eh_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;sharedaccesskey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert "Invalid connection string" in str(e.value)

0 comments on commit 8ac54ee

Please sign in to comment.